Skip to content

Commit

Permalink
Overhaul verify_source_file_access for finally check src_uris concurr…
Browse files Browse the repository at this point in the history
…ently
  • Loading branch information
dmannarino committed Jan 22, 2024
1 parent eb43778 commit 367b548
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions app/routes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Tuple, List, Sequence
from asyncio import Task, create_task, gather
from typing import List, Sequence, Tuple
from urllib.parse import urlparse

from fastapi import Depends, HTTPException, Path
from fastapi.logger import logger
from fastapi.security import OAuth2PasswordBearer

from ..crud.versions import get_version
Expand Down Expand Up @@ -66,22 +68,28 @@ async def dataset_version_dependency(


async def verify_source_file_access(sources: List[str]) -> None:

"""For each source URI, verify that it points to an existing object
or a bucket and prefix which contain one or more objects. Returns
nothing on success, but raises an HTTPException if one or more
sources are invalid"""
# TODO:
# 1. Making the list functions asynchronous and using asyncio.gather
# to check for valid sources in a non-blocking fashion would be good.
# Perhaps use the aioboto3 package for aws, gcloud-aio-storage for gcs.
# 2. It would be nice if the acceptable file extensions were passed
# 1. It would be nice if the acceptable file extensions were passed
# into this function so we could say, for example, that there must be
# TIFFs found for a new raster tile set, but a CSV is required for a new
# vector tile set version. Even better would be to specify whether
# paths to individual files or "folders" (prefixes) are allowed.

invalid_sources: List[str] = list()

tasks: List[Task] = list()

for source in sources:
url_parts = urlparse(source, allow_fragments=False)
list_func = source_uri_lister_constructor[url_parts.scheme.lower()]
try:
list_func = source_uri_lister_constructor[url_parts.scheme.lower()]
except KeyError:
invalid_sources.append(source)
continue
bucket = url_parts.netloc
prefix = url_parts.path.lstrip("/")

Expand All @@ -100,14 +108,25 @@ async def verify_source_file_access(sources: List[str]) -> None:
):
new_prefix += "/"

if not await list_func(
bucket,
new_prefix,
limit=10,
exit_after_max=1,
extensions=SUPPORTED_FILE_EXTENSIONS,
):
invalid_sources.append(source)
tasks.append(
create_task(
list_func(
bucket,
new_prefix,
limit=10,
exit_after_max=1,
extensions=SUPPORTED_FILE_EXTENSIONS,
)
)
)

results = await gather(*tasks, return_exceptions=True)
for uri, result in zip(sources, results):
if isinstance(result, Exception):
logger.error(f"Encountered exception checking src_uri {uri}: {result}")
invalid_sources.append(uri)
elif not result:
invalid_sources.append(uri)

if invalid_sources:
raise HTTPException(
Expand Down

0 comments on commit 367b548

Please sign in to comment.