diff --git a/app/routes/__init__.py b/app/routes/__init__.py index 0be2c749b..dca6f634c 100644 --- a/app/routes/__init__.py +++ b/app/routes/__init__.py @@ -1,7 +1,9 @@ -from typing import Tuple, List, Sequence +from asyncio import Task, create_task, gather +from typing import List, Sequence, Tuple from urllib.parse import urlparse from fastapi import Depends, HTTPException, Path +from fastapi.logger import logger from fastapi.security import OAuth2PasswordBearer from ..crud.versions import get_version @@ -66,12 +68,12 @@ async def dataset_version_dependency( async def verify_source_file_access(sources: List[str]) -> None: - + """For each source URI, verify that it points to an existing object + or a bucket and prefix which contain one or more objects. Returns + nothing on success, but raises an HTTPException if one or more + sources are invalid""" # TODO: - # 1. Making the list functions asynchronous and using asyncio.gather - # to check for valid sources in a non-blocking fashion would be good. - # Perhaps use the aioboto3 package for aws, gcloud-aio-storage for gcs. - # 2. It would be nice if the acceptable file extensions were passed + # 1. It would be nice if the acceptable file extensions were passed # into this function so we could say, for example, that there must be # TIFFs found for a new raster tile set, but a CSV is required for a new # vector tile set version. Even better would be to specify whether @@ -79,9 +81,15 @@ async def verify_source_file_access(sources: List[str]) -> None: invalid_sources: List[str] = list() + tasks: List[Task] = list() + for source in sources: url_parts = urlparse(source, allow_fragments=False) - list_func = source_uri_lister_constructor[url_parts.scheme.lower()] + try: + list_func = source_uri_lister_constructor[url_parts.scheme.lower()] + except KeyError: + invalid_sources.append(source) + continue bucket = url_parts.netloc prefix = url_parts.path.lstrip("/") @@ -100,14 +108,25 @@ async def verify_source_file_access(sources: List[str]) -> None: ): new_prefix += "/" - if not await list_func( - bucket, - new_prefix, - limit=10, - exit_after_max=1, - extensions=SUPPORTED_FILE_EXTENSIONS, - ): - invalid_sources.append(source) + tasks.append( + create_task( + list_func( + bucket, + new_prefix, + limit=10, + exit_after_max=1, + extensions=SUPPORTED_FILE_EXTENSIONS, + ) + ) + ) + + results = await gather(*tasks, return_exceptions=True) + for uri, result in zip(sources, results): + if isinstance(result, Exception): + logger.error(f"Encountered exception checking src_uri {uri}: {result}") + invalid_sources.append(uri) + elif not result: + invalid_sources.append(uri) if invalid_sources: raise HTTPException(