Skip to content

Commit

Permalink
GTC-2576 Restrict queries on licensed WDPA dataset
Browse files Browse the repository at this point in the history
Only allow queries to versions of the licensed WDPA dataset for admin
users (i.e. a bearer token is provided that signifies an admin user).

For now, we just have a simple list of restricted datasets in the code,
rather than adding some new dataset attribute.
  • Loading branch information
danscales committed Oct 9, 2023
1 parent 84691de commit bd497d5
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 7 deletions.
21 changes: 21 additions & 0 deletions app/authentication/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ async def is_service_account(token: str = Depends(oauth2_scheme)) -> bool:
return True


# Check is the authorized user is an admin. Return true if so, throw
# an exception if not.
async def is_admin(token: str = Depends(oauth2_scheme)) -> bool:
"""Calls GFW API to authorize user.
Expand All @@ -45,6 +47,25 @@ async def is_admin(token: str = Depends(oauth2_scheme)) -> bool:
else:
return True

# Check is the authorized user is an admin. Return true if so, false if not (with no
# exception).
async def is_admin_no_exception(token: str = Depends(oauth2_scheme)) -> bool:
"""Calls GFW API to authorize user.
User must be ADMIN for gfw app
"""

response: Response = await who_am_i(token)

if response.status_code == 401 or not (
response.json()["role"] == "ADMIN"
and "gfw" in response.json()["extraUserData"]["apps"]
):
logger.warning(f"ADMIN privileges required. Unauthorized user: {response.text}")
return False
else:
return True


async def get_user(token: str = Depends(oauth2_scheme)) -> Tuple[str, str]:
"""Calls GFW API to authorize user.
Expand Down
10 changes: 10 additions & 0 deletions app/routes/datasets/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pydantic.tools import parse_obj_as
from sqlalchemy.sql import and_

from ...authentication.token import is_admin_no_exception
from ...application import db

# from ...authentication.api_keys import get_api_key
Expand Down Expand Up @@ -85,6 +86,10 @@
# Special suffixes to do an extra area density calculation on the raster data set.
AREA_DENSITY_RASTER_SUFFIXES = ["_ha-1", "_ha_yr-1"]

# Datasets that require admin privileges to do a query. (Extra protection on
# commercial datasets which shouldn't be downloaded in any way.)
PROTECTED_QUERY_DATASETS = ["licensed_wdpa_protected_areas"]

@router.get(
"/{dataset}/{version}/query",
response_class=RedirectResponse,
Expand Down Expand Up @@ -155,6 +160,11 @@ async def query_dataset_json(
"""

dataset, version = dataset_version
if dataset in PROTECTED_QUERY_DATASETS:
is_authorized = await is_admin_no_exception()
if not is_authorized:
raise HTTPException(status_code=401, detail="Unauthorized")

if geostore_id:
geostore: Optional[GeostoreCommon] = await get_geostore(
geostore_id, geostore_origin
Expand Down
63 changes: 56 additions & 7 deletions tests_v2/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,25 @@ async def generic_vector_source_version(
dataset_name, _ = generic_dataset
version_name: str = "v1"

await create_vector_source_version(async_client, dataset_name, version_name, monkeypatch)

# yield version
yield dataset_name, version_name, VERSION_METADATA

# clean up
await async_client.delete(f"/dataset/{dataset_name}/{version_name}")


# Create a vector version, given the name of an existing dataset, plus a new version
# name.
async def create_vector_source_version(
async_client: AsyncClient,
dataset_name: str,
version_name: str,
monkeypatch: MonkeyPatch,
):
"""Create generic vector source version."""

# patch all functions which reach out to external services
batch_job_mock = BatchJobMock()
monkeypatch.setattr(versions, "_verify_source_file_access", void_coroutine)
Expand Down Expand Up @@ -209,13 +228,6 @@ async def generic_vector_source_version(
response = await async_client.get(f"/dataset/{dataset_name}/{version_name}")
assert response.json()["data"]["status"] == "saved"

# yield version
yield dataset_name, version_name, VERSION_METADATA

# clean up
await async_client.delete(f"/dataset/{dataset_name}/{version_name}")


@pytest_asyncio.fixture
async def generic_raster_version(
async_client: AsyncClient,
Expand Down Expand Up @@ -281,6 +293,43 @@ async def generic_raster_version(
# clean up
await async_client.delete(f"/dataset/{dataset_name}/{version_name}")

@pytest_asyncio.fixture
async def licensed_dataset(
async_client: AsyncClient,
) -> AsyncGenerator[Tuple[str, Dict[str, Any]], None]:
"""Create licensed dataset."""

# Create dataset
dataset_name: str = "licensed_wdpa_protected_areas"

await async_client.put(
f"/dataset/{dataset_name}", json={"metadata": DATASET_METADATA}
)

# Yield dataset name and associated metadata
yield dataset_name, DATASET_METADATA

# Clean up
await async_client.delete(f"/dataset/{dataset_name}")

@pytest_asyncio.fixture
async def licensed_version(
async_client: AsyncClient,
licensed_dataset: Tuple[str, str],
monkeypatch: MonkeyPatch,
) -> AsyncGenerator[Tuple[str, str, Dict[str, Any]], None]:
"""Create licensed version."""

dataset_name, _ = licensed_dataset
version_name: str = "v1"

await create_vector_source_version(async_client, dataset_name, version_name, monkeypatch)

# yield version
yield dataset_name, version_name, VERSION_METADATA

# clean up
await async_client.delete(f"/dataset/{dataset_name}/{version_name}")

@pytest_asyncio.fixture
async def apikey(
Expand Down
14 changes: 14 additions & 0 deletions tests_v2/unit/app/routes/datasets/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,20 @@ async def test_query_vector_asset_disallowed_10(
"You might need to add explicit type casts."
)

@pytest.mark.asyncio()
async def test_query_licensed_disallowed_11(
licensed_version, async_client: AsyncClient
):
dataset, version, _ = licensed_version

response = await async_client.get(
f"/dataset/{dataset}/{version}/query?sql=select(*) from mytable;",
follow_redirects=True,
)
assert response.status_code == 401
assert response.json()["message"] == (
"Unauthorized"
)

@pytest.mark.asyncio
@pytest.mark.skip("Temporarily skip while _get_data_environment is being cached")
Expand Down

0 comments on commit bd497d5

Please sign in to comment.