diff --git a/app/authentication/token.py b/app/authentication/token.py index a01d0a68f..64f4abf48 100644 --- a/app/authentication/token.py +++ b/app/authentication/token.py @@ -34,14 +34,31 @@ async def is_admin(token: str = Depends(oauth2_scheme)) -> bool: User must be ADMIN for gfw app """ + return await is_app_admin(token, "gfw", "Unauthorized") + +async def is_gfwpro_admin(error_str: str, token: str = Depends(oauth2_scheme)) -> bool: + """Calls GFW API to authorize user. + + User must be ADMIN for gfw pro app + """ + + return await is_app_admin(token, "gfw-pro", error_str) + +async def is_app_admin(token: str, app: str, error_str: str) -> bool: + """Calls GFW API to authorize user. + + User must be an ADMIN for the specified app, else it will throw + an exception with the specified error string. + """ + response: Response = await who_am_i(token) if response.status_code == 401 or not ( response.json()["role"] == "ADMIN" - and "gfw" in response.json()["extraUserData"]["apps"] + and app in response.json()["extraUserData"]["apps"] ): logger.warning(f"ADMIN privileges required. Unauthorized user: {response.text}") - raise HTTPException(status_code=401, detail="Unauthorized") + raise HTTPException(status_code=401, detail=error_str) else: return True diff --git a/app/routes/datasets/queries.py b/app/routes/datasets/queries.py index d29e0628b..8964c8dff 100755 --- a/app/routes/datasets/queries.py +++ b/app/routes/datasets/queries.py @@ -23,6 +23,7 @@ from pydantic.tools import parse_obj_as from sqlalchemy.sql import and_ +from ...authentication.token import is_gfwpro_admin from ...application import db # from ...authentication.api_keys import get_api_key @@ -85,6 +86,10 @@ # Special suffixes to do an extra area density calculation on the raster data set. AREA_DENSITY_RASTER_SUFFIXES = ["_ha-1", "_ha_yr-1"] +# Datasets that require admin privileges to do a query. (Extra protection on +# commercial datasets which shouldn't be downloaded in any way.) +PROTECTED_QUERY_DATASETS = ["wdpa_licensed_protected_areas"] + @router.get( "/{dataset}/{version}/query", response_class=RedirectResponse, @@ -155,6 +160,9 @@ async def query_dataset_json( """ dataset, version = dataset_version + if dataset in PROTECTED_QUERY_DATASETS: + await is_gfwpro_admin(error_str="Unauthorized query on a restricted dataset") + if geostore_id: geostore: Optional[GeostoreCommon] = await get_geostore( geostore_id, geostore_origin diff --git a/tests_v2/conftest.py b/tests_v2/conftest.py index facafd1fb..4340865d8 100755 --- a/tests_v2/conftest.py +++ b/tests_v2/conftest.py @@ -161,6 +161,25 @@ async def generic_vector_source_version( dataset_name, _ = generic_dataset version_name: str = "v1" + await create_vector_source_version(async_client, dataset_name, version_name, monkeypatch) + + # yield version + yield dataset_name, version_name, VERSION_METADATA + + # clean up + await async_client.delete(f"/dataset/{dataset_name}/{version_name}") + + +# Create a vector version, given the name of an existing dataset, plus a new version +# name. +async def create_vector_source_version( + async_client: AsyncClient, + dataset_name: str, + version_name: str, + monkeypatch: MonkeyPatch, +): + """Create generic vector source version.""" + # patch all functions which reach out to external services batch_job_mock = BatchJobMock() monkeypatch.setattr(versions, "_verify_source_file_access", void_coroutine) @@ -209,13 +228,6 @@ async def generic_vector_source_version( response = await async_client.get(f"/dataset/{dataset_name}/{version_name}") assert response.json()["data"]["status"] == "saved" - # yield version - yield dataset_name, version_name, VERSION_METADATA - - # clean up - await async_client.delete(f"/dataset/{dataset_name}/{version_name}") - - @pytest_asyncio.fixture async def generic_raster_version( async_client: AsyncClient, @@ -281,6 +293,43 @@ async def generic_raster_version( # clean up await async_client.delete(f"/dataset/{dataset_name}/{version_name}") +@pytest_asyncio.fixture +async def licensed_dataset( + async_client: AsyncClient, +) -> AsyncGenerator[Tuple[str, Dict[str, Any]], None]: + """Create licensed dataset.""" + + # Create dataset + dataset_name: str = "wdpa_licensed_protected_areas" + + await async_client.put( + f"/dataset/{dataset_name}", json={"metadata": DATASET_METADATA} + ) + + # Yield dataset name and associated metadata + yield dataset_name, DATASET_METADATA + + # Clean up + await async_client.delete(f"/dataset/{dataset_name}") + +@pytest_asyncio.fixture +async def licensed_version( + async_client: AsyncClient, + licensed_dataset: Tuple[str, str], + monkeypatch: MonkeyPatch, +) -> AsyncGenerator[Tuple[str, str, Dict[str, Any]], None]: + """Create licensed version.""" + + dataset_name, _ = licensed_dataset + version_name: str = "v1" + + await create_vector_source_version(async_client, dataset_name, version_name, monkeypatch) + + # yield version + yield dataset_name, version_name, VERSION_METADATA + + # clean up + await async_client.delete(f"/dataset/{dataset_name}/{version_name}") @pytest_asyncio.fixture async def apikey( diff --git a/tests_v2/unit/app/routes/datasets/test_query.py b/tests_v2/unit/app/routes/datasets/test_query.py index 122e67a71..fdb072e5e 100755 --- a/tests_v2/unit/app/routes/datasets/test_query.py +++ b/tests_v2/unit/app/routes/datasets/test_query.py @@ -433,6 +433,20 @@ async def test_query_vector_asset_disallowed_10( "You might need to add explicit type casts." ) +@pytest.mark.asyncio() +async def test_query_licensed_disallowed_11( + licensed_version, async_client: AsyncClient +): + dataset, version, _ = licensed_version + + response = await async_client.get( + f"/dataset/{dataset}/{version}/query?sql=select(*) from mytable;", + follow_redirects=True, + ) + assert response.status_code == 401 + assert response.json()["message"] == ( + "Unauthorized query on a restricted dataset" + ) @pytest.mark.asyncio @pytest.mark.skip("Temporarily skip while _get_data_environment is being cached")