Skip to content

Commit

Permalink
Merge pull request #446 from wri/develop
Browse files Browse the repository at this point in the history
Merge-up code to redact environment when logging Batch payloads
  • Loading branch information
dmannarino authored Dec 5, 2023
2 parents dc0447c + 8b869f6 commit f7480cf
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 11 deletions.
4 changes: 3 additions & 1 deletion app/tasks/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def submit_batch_job(
"command": job.command,
"vcpus": job.vcpus,
"memory": job.memory,
"environment": job.environment,
"environment": "<redacted>",
},
"retryStrategy": {
"attempts": job.attempts,
Expand All @@ -152,6 +152,8 @@ def submit_batch_job(

logger.info(f"Submitting batch job with payload: {payload}")

payload["containerOverrides"]["environment"] = job.environment

response = client.submit_job(**payload)

return UUID(response["jobId"])
20 changes: 18 additions & 2 deletions tests_v2/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, Tuple
from uuid import UUID

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -161,7 +162,9 @@ async def generic_vector_source_version(
dataset_name, _ = generic_dataset
version_name: str = "v1"

await create_vector_source_version(async_client, dataset_name, version_name, monkeypatch)
await create_vector_source_version(
async_client, dataset_name, version_name, monkeypatch
)

# yield version
yield dataset_name, version_name, VERSION_METADATA
Expand Down Expand Up @@ -228,6 +231,7 @@ async def create_vector_source_version(
response = await async_client.get(f"/dataset/{dataset_name}/{version_name}")
assert response.json()["data"]["status"] == "saved"


@pytest_asyncio.fixture
async def generic_raster_version(
async_client: AsyncClient,
Expand Down Expand Up @@ -293,6 +297,7 @@ async def generic_raster_version(
# clean up
await async_client.delete(f"/dataset/{dataset_name}/{version_name}")


@pytest_asyncio.fixture
async def licensed_dataset(
async_client: AsyncClient,
Expand All @@ -312,6 +317,7 @@ async def licensed_dataset(
# Clean up
await async_client.delete(f"/dataset/{dataset_name}")


@pytest_asyncio.fixture
async def licensed_version(
async_client: AsyncClient,
Expand All @@ -323,14 +329,17 @@ async def licensed_version(
dataset_name, _ = licensed_dataset
version_name: str = "v1"

await create_vector_source_version(async_client, dataset_name, version_name, monkeypatch)
await create_vector_source_version(
async_client, dataset_name, version_name, monkeypatch
)

# yield version
yield dataset_name, version_name, VERSION_METADATA

# clean up
await async_client.delete(f"/dataset/{dataset_name}/{version_name}")


@pytest_asyncio.fixture
async def apikey(
async_client: AsyncClient, monkeypatch: MonkeyPatch
Expand Down Expand Up @@ -451,3 +460,10 @@ async def _create_geostore(geojson: Dict[str, Any], async_client: AsyncClient) -
assert response.status_code == 201

return response.json()["data"]["gfw_geostore_id"]


async def mock_callback(task_id: UUID, change_log: ChangeLog):
async def dummy_function():
pass

return dummy_function
40 changes: 40 additions & 0 deletions tests_v2/unit/app/tasks/test_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Dict, List
from unittest.mock import MagicMock, patch

from fastapi.logger import logger

from app.tasks.batch import submit_batch_job
from app.tasks.vector_source_assets import _create_add_gfw_fields_job
from tests_v2.conftest import mock_callback

TEST_JOB_ENV: List[Dict[str, str]] = [{"name": "PASSWORD", "value": "DON'T LOG ME"}]


@patch("app.utils.aws.boto3.client")
@patch.object(logger, "info") # Patch the logger.info directly
@patch("app.tasks.batch.UUID") # Patch the UUID class
async def test_submit_batch_job(mock_uuid, mock_logging_info, mock_boto3_client):
mock_client = MagicMock()
mock_boto3_client.return_value = mock_client

attempt_duration_seconds: int = 100

job = await _create_add_gfw_fields_job(
"some_dataset",
"v1",
parents=list(),
job_env=TEST_JOB_ENV,
callback=mock_callback,
attempt_duration_seconds=attempt_duration_seconds,
)

# Call the function you want to test
submit_batch_job(job)

mock_boto3_client.assert_called_once_with(
"batch", region_name="us-east-1", endpoint_url=None
)

# Assert that the logger.info was called with the expected log message
assert "add_gfw_fields" in mock_logging_info.call_args.args[0]
assert "DON'T LOG ME" not in mock_logging_info.call_args.args[0]
9 changes: 1 addition & 8 deletions tests_v2/unit/app/tasks/test_vector_source_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
append_vector_source_asset,
vector_source_asset,
)
from tests_v2.conftest import mock_callback

MODULE_PATH_UNDER_TEST = "app.tasks.vector_source_assets"

Expand All @@ -40,14 +41,6 @@
VECTOR_ASSET_UUID = UUID("1b368160-caf8-2bd7-819a-ad4949361f02")


async def dummy_function():
pass


async def mock_callback(task_id: UUID, change_log: ChangeLog):
return dummy_function


class TestVectorSourceAssetsHelpers:
@pytest.mark.asyncio
async def test__create_vector_schema_job_no_schema(self):
Expand Down

0 comments on commit f7480cf

Please sign in to comment.