diff --git a/app/tasks/batch.py b/app/tasks/batch.py index 9edb9efdd..0a9f09c08 100644 --- a/app/tasks/batch.py +++ b/app/tasks/batch.py @@ -129,7 +129,7 @@ def submit_batch_job( "command": job.command, "vcpus": job.vcpus, "memory": job.memory, - "environment": job.environment, + "environment": "", }, "retryStrategy": { "attempts": job.attempts, @@ -152,6 +152,8 @@ def submit_batch_job( logger.info(f"Submitting batch job with payload: {payload}") + payload["containerOverrides"]["environment"] = job.environment + response = client.submit_job(**payload) return UUID(response["jobId"]) diff --git a/tests_v2/conftest.py b/tests_v2/conftest.py index 4340865d8..33fbb3a53 100755 --- a/tests_v2/conftest.py +++ b/tests_v2/conftest.py @@ -2,6 +2,7 @@ import os from datetime import datetime from typing import Any, AsyncGenerator, Dict, Tuple +from uuid import UUID import pytest import pytest_asyncio @@ -161,7 +162,9 @@ async def generic_vector_source_version( dataset_name, _ = generic_dataset version_name: str = "v1" - await create_vector_source_version(async_client, dataset_name, version_name, monkeypatch) + await create_vector_source_version( + async_client, dataset_name, version_name, monkeypatch + ) # yield version yield dataset_name, version_name, VERSION_METADATA @@ -228,6 +231,7 @@ async def create_vector_source_version( response = await async_client.get(f"/dataset/{dataset_name}/{version_name}") assert response.json()["data"]["status"] == "saved" + @pytest_asyncio.fixture async def generic_raster_version( async_client: AsyncClient, @@ -293,6 +297,7 @@ async def generic_raster_version( # clean up await async_client.delete(f"/dataset/{dataset_name}/{version_name}") + @pytest_asyncio.fixture async def licensed_dataset( async_client: AsyncClient, @@ -312,6 +317,7 @@ async def licensed_dataset( # Clean up await async_client.delete(f"/dataset/{dataset_name}") + @pytest_asyncio.fixture async def licensed_version( async_client: AsyncClient, @@ -323,7 +329,9 @@ async def licensed_version( dataset_name, _ = licensed_dataset version_name: str = "v1" - await create_vector_source_version(async_client, dataset_name, version_name, monkeypatch) + await create_vector_source_version( + async_client, dataset_name, version_name, monkeypatch + ) # yield version yield dataset_name, version_name, VERSION_METADATA @@ -331,6 +339,7 @@ async def licensed_version( # clean up await async_client.delete(f"/dataset/{dataset_name}/{version_name}") + @pytest_asyncio.fixture async def apikey( async_client: AsyncClient, monkeypatch: MonkeyPatch @@ -451,3 +460,10 @@ async def _create_geostore(geojson: Dict[str, Any], async_client: AsyncClient) - assert response.status_code == 201 return response.json()["data"]["gfw_geostore_id"] + + +async def mock_callback(task_id: UUID, change_log: ChangeLog): + async def dummy_function(): + pass + + return dummy_function diff --git a/tests_v2/unit/app/tasks/test_batch.py b/tests_v2/unit/app/tasks/test_batch.py new file mode 100644 index 000000000..c09ad8b32 --- /dev/null +++ b/tests_v2/unit/app/tasks/test_batch.py @@ -0,0 +1,40 @@ +from typing import Dict, List +from unittest.mock import MagicMock, patch + +from fastapi.logger import logger + +from app.tasks.batch import submit_batch_job +from app.tasks.vector_source_assets import _create_add_gfw_fields_job +from tests_v2.conftest import mock_callback + +TEST_JOB_ENV: List[Dict[str, str]] = [{"name": "PASSWORD", "value": "DON'T LOG ME"}] + + +@patch("app.utils.aws.boto3.client") +@patch.object(logger, "info") # Patch the logger.info directly +@patch("app.tasks.batch.UUID") # Patch the UUID class +async def test_submit_batch_job(mock_uuid, mock_logging_info, mock_boto3_client): + mock_client = MagicMock() + mock_boto3_client.return_value = mock_client + + attempt_duration_seconds: int = 100 + + job = await _create_add_gfw_fields_job( + "some_dataset", + "v1", + parents=list(), + job_env=TEST_JOB_ENV, + callback=mock_callback, + attempt_duration_seconds=attempt_duration_seconds, + ) + + # Call the function you want to test + submit_batch_job(job) + + mock_boto3_client.assert_called_once_with( + "batch", region_name="us-east-1", endpoint_url=None + ) + + # Assert that the logger.info was called with the expected log message + assert "add_gfw_fields" in mock_logging_info.call_args.args[0] + assert "DON'T LOG ME" not in mock_logging_info.call_args.args[0] diff --git a/tests_v2/unit/app/tasks/test_vector_source_assets.py b/tests_v2/unit/app/tasks/test_vector_source_assets.py index 0c42ab744..6b6da42a5 100644 --- a/tests_v2/unit/app/tasks/test_vector_source_assets.py +++ b/tests_v2/unit/app/tasks/test_vector_source_assets.py @@ -16,6 +16,7 @@ append_vector_source_asset, vector_source_asset, ) +from tests_v2.conftest import mock_callback MODULE_PATH_UNDER_TEST = "app.tasks.vector_source_assets" @@ -40,14 +41,6 @@ VECTOR_ASSET_UUID = UUID("1b368160-caf8-2bd7-819a-ad4949361f02") -async def dummy_function(): - pass - - -async def mock_callback(task_id: UUID, change_log: ChangeLog): - return dummy_function - - class TestVectorSourceAssetsHelpers: @pytest.mark.asyncio async def test__create_vector_schema_job_no_schema(self):