diff --git a/data_management/opensearch_indexer/opensearch_indexer/index_consignment/bulk_index_consignment.py b/data_management/opensearch_indexer/opensearch_indexer/index_consignment/bulk_index_consignment.py index 0e5655e9..82d08b9c 100644 --- a/data_management/opensearch_indexer/opensearch_indexer/index_consignment/bulk_index_consignment.py +++ b/data_management/opensearch_indexer/opensearch_indexer/index_consignment/bulk_index_consignment.py @@ -14,12 +14,16 @@ get_s3_file, get_secret_data, ) -from ..text_extraction import add_text_content +from ..text_extraction import TextExtractionStatus, add_text_content logger = logging.getLogger() logger.setLevel(logging.INFO) +class ConsignmentBulkIndexError(Exception): + pass + + def bulk_index_consignment_from_aws( consignment_reference: str, secret_id: str ) -> None: @@ -79,13 +83,51 @@ def bulk_index_consignment( """ files = _fetch_files_in_consignment(consignment_reference, database_url) documents_to_index = _construct_documents(files, bucket_name) - bulk_index_files_in_opensearch( - documents_to_index, - open_search_host_url, - open_search_http_auth, - open_search_bulk_index_timeout, - open_search_ca_certs, - ) + + document_text_extraction_exceptions_message = "" + for doc in documents_to_index: + if doc["document"]["text_extraction_status"] not in [ + TextExtractionStatus.SKIPPED.value, + TextExtractionStatus.SUCCEEDED.value, + ]: + if document_text_extraction_exceptions_message == "": + document_text_extraction_exceptions_message += ( + "Text extraction failed on the following documents:" + ) + document_text_extraction_exceptions_message += f"\n{doc['file_id']}" + + bulk_indexing_exception_message = "" + try: + bulk_index_files_in_opensearch( + documents_to_index, + open_search_host_url, + open_search_http_auth, + open_search_bulk_index_timeout, + open_search_ca_certs, + ) + except Exception as bulk_indexing_exception: + bulk_indexing_exception_message = bulk_indexing_exception.text + logger.error("Bulk indexing of files resulted in some errors") + + # Combine and raise all errors from failed attempts to extract text or index documents + if ( + document_text_extraction_exceptions_message + or bulk_indexing_exception_message + ): + consignment_bulk_index_error_message = ( + "The following errors occurred when attempting to " + f"bulk index consignment reference: {consignment_reference}" + ) + if document_text_extraction_exceptions_message: + consignment_bulk_index_error_message += ( + f"\n{document_text_extraction_exceptions_message}" + ) + if bulk_indexing_exception_message: + consignment_bulk_index_error_message += ( + f"\n{bulk_indexing_exception_message}" + ) + + raise ConsignmentBulkIndexError(consignment_bulk_index_error_message) def _construct_documents(files: List[Dict], bucket_name: str) -> List[Dict]: diff --git a/data_management/opensearch_indexer/opensearch_indexer/text_extraction.py b/data_management/opensearch_indexer/opensearch_indexer/text_extraction.py index b4d0734b..adfd22df 100644 --- a/data_management/opensearch_indexer/opensearch_indexer/text_extraction.py +++ b/data_management/opensearch_indexer/opensearch_indexer/text_extraction.py @@ -1,5 +1,6 @@ import logging import tempfile +from enum import Enum from typing import Dict import textract @@ -7,6 +8,13 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) + +class TextExtractionStatus(Enum): + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + SKIPPED = "SKIPPED" + + SUPPORTED_TEXTRACT_FORMATS = [ "csv", "doc", @@ -45,18 +53,20 @@ def add_text_content(file: Dict, file_stream: bytes) -> Dict: f"Text extraction skipped for unsupported file type: {file_type}" ) file["content"] = "" - file["text_extraction_status"] = "n/a" + file["text_extraction_status"] = TextExtractionStatus.SKIPPED.value else: try: file["content"] = extract_text(file_stream, file_type) logger.info(f"Text extraction succeeded for file {file['file_id']}") - file["text_extraction_status"] = "success" + file["text_extraction_status"] = ( + TextExtractionStatus.SUCCEEDED.value + ) except Exception as e: logger.error( f"Text extraction failed for file {file['file_id']}: {e}" ) file["content"] = "" - file["text_extraction_status"] = "failed" + file["text_extraction_status"] = TextExtractionStatus.FAILED.value return file diff --git a/data_management/opensearch_indexer/requirements.txt b/data_management/opensearch_indexer/requirements.txt index 8adac967..be748965 100644 --- a/data_management/opensearch_indexer/requirements.txt +++ b/data_management/opensearch_indexer/requirements.txt @@ -4,3 +4,5 @@ requests-aws4auth==1.3.1 SQLAlchemy==2.0.32 pg8000==1.31.2 textract==1.6.5 +testing-postgresql==1.3.0 +psycopg2==2.9.10 diff --git a/data_management/opensearch_indexer/tests/conftest.py b/data_management/opensearch_indexer/tests/conftest.py new file mode 100644 index 00000000..e482a2de --- /dev/null +++ b/data_management/opensearch_indexer/tests/conftest.py @@ -0,0 +1,98 @@ +import tempfile + +import pytest +from sqlalchemy import ( + Boolean, + Column, + DateTime, + ForeignKey, + String, + Text, + create_engine, +) +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import declarative_base, relationship +from testing.postgresql import PostgresqlFactory + +Base = declarative_base() + + +class Body(Base): + __tablename__ = "Body" + BodyId = Column(UUID(as_uuid=True), primary_key=True) + Name = Column(Text) + Description = Column(Text) + + +class Series(Base): + __tablename__ = "Series" + SeriesId = Column(UUID(as_uuid=True), primary_key=True) + BodyId = Column(UUID(as_uuid=True), ForeignKey("Body.BodyId")) + Name = Column(Text) + Description = Column(Text) + body = relationship("Body", foreign_keys="Series.BodyId") + + +class Consignment(Base): + __tablename__ = "Consignment" + ConsignmentId = Column(UUID(as_uuid=True), primary_key=True) + SeriesId = Column(UUID(as_uuid=True), ForeignKey("Series.SeriesId")) + BodyId = Column(UUID(as_uuid=True), ForeignKey("Body.BodyId")) + ConsignmentReference = Column(Text) + ConsignmentType = Column(String, nullable=False) + IncludeTopLevelFolder = Column(Boolean) + ContactName = Column(Text) + ContactEmail = Column(Text) + TransferStartDatetime = Column(DateTime) + TransferCompleteDatetime = Column(DateTime) + ExportDatetime = Column(DateTime) + CreatedDatetime = Column(DateTime) + series = relationship("Series", foreign_keys="Consignment.SeriesId") + + +class File(Base): + __tablename__ = "File" + FileId = Column(UUID(as_uuid=True), primary_key=True) + ConsignmentId = Column( + UUID(as_uuid=True), ForeignKey("Consignment.ConsignmentId") + ) + FileReference = Column(Text, nullable=False) + FileType = Column(Text, nullable=False) + FileName = Column(Text, nullable=False) + FilePath = Column(Text, nullable=False) + CiteableReference = Column(Text) + Checksum = Column(Text) + CreatedDatetime = Column(DateTime) + consignment = relationship("Consignment", foreign_keys="File.ConsignmentId") + + +class FileMetadata(Base): + __tablename__ = "FileMetadata" + MetadataId = Column(UUID(as_uuid=True), primary_key=True) + FileId = Column(UUID(as_uuid=True), ForeignKey("File.FileId")) + PropertyName = Column(Text, nullable=False) + Value = Column(Text) + CreatedDatetime = Column(DateTime) + file = relationship("File", foreign_keys="FileMetadata.FileId") + + +@pytest.fixture() +def temp_db(): + temp_db_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + temp_db_file.close() + database_url = f"sqlite:///{temp_db_file.name}" + engine = create_engine(database_url) + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture(scope="session") +def database(request): + # Launch new PostgreSQL server + postgresql = PostgresqlFactory(cache_initialized_db=True)() + yield postgresql + + # PostgreSQL server is terminated here + @request.addfinalizer + def drop_database(): + postgresql.stop() diff --git a/data_management/opensearch_indexer/tests/test_consignment_lambda_handler.py b/data_management/opensearch_indexer/tests/test_consignment_lambda_handler.py new file mode 100644 index 00000000..ad6f1c39 --- /dev/null +++ b/data_management/opensearch_indexer/tests/test_consignment_lambda_handler.py @@ -0,0 +1,318 @@ +import json +from unittest.mock import patch +from uuid import uuid4 + +import boto3 +import botocore +from moto import mock_aws +from opensearch_indexer.index_consignment.lambda_function import lambda_handler +from opensearch_indexer.text_extraction import TextExtractionStatus +from requests_aws4auth import AWS4Auth +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from .conftest import Body, Consignment, File, FileMetadata, Series + +# Original botocore _make_api_call function +orig = botocore.client.BaseClient._make_api_call + + +# Mocked botocore _make_api_call function +def mock_make_api_call(self, operation_name, kwarg): + if operation_name == "AssumeRole": + return { + "Credentials": { + "AccessKeyId": "test_access_key", + "SecretAccessKey": "test_secret_key", # pragma: allowlist secret + "SessionToken": "test_token", + "Expiration": "2024-09-18T12:00:00Z", + } + } + return orig(self, operation_name, kwarg) + + +@mock_aws +def test_lambda_handler_invokes_bulk_index_with_correct_file_data( + monkeypatch, database +): + """ + Test case for the lambda_handler function to ensure correct integration with the OpenSearch indexer. + + Given: + - An S3 bucket containing files. + - A secret stored in AWS Secrets Manager containing configuration details such as database connection, + OpenSearch host URL, and an IAM role for OpenSearch access. + + When: + - The lambda_handler function is invoked via an S3 event notification. + + Then: + - The bulk_index_files_in_opensearch function is called with the correct parameters for each file: + - Correct file metadata and content, including the extracted text, metadata properties, + and associated consignment details. + - The OpenSearch host URL. + - An AWS4Auth object with credentials derived from the assumed IAM role. + - The timeout for the OpenSearch bulk indexing operation. + """ + # Set up the database engine and session using the URL + engine = create_engine(database.url()) + from data_management.opensearch_indexer.tests.conftest import Base + + Base.metadata.create_all(engine) # Create tables for the test + + # Create a session and set up test data + Session = sessionmaker(bind=engine) + session = Session() + + secret_name = "test_vars" # pragma: allowlist secret + + monkeypatch.setenv("SECRET_ID", secret_name) + + bucket_name = "test_bucket" + + opensearch_master_role_arn = ( + "arn:aws:iam::123456789012:role/test-opensearch-role" + ) + secret_string = json.dumps( + { + "DB_USER": "postgres", + "DB_PASSWORD": "", + "DB_HOST": "127.0.0.1", + "DB_PORT": database.settings["port"], + "DB_NAME": "test", + "AWS_REGION": "eu-west-2", + "RECORD_BUCKET_NAME": bucket_name, + "OPEN_SEARCH_HOST": "https://test-opensearch.com", + "OPEN_SEARCH_MASTER_ROLE_ARN": opensearch_master_role_arn, + "OPEN_SEARCH_BULK_INDEX_TIMEOUT": 600, + } + ) + + secretsmanager_client = boto3.client( + "secretsmanager", region_name="eu-west-2" + ) + + secretsmanager_client.create_secret( + Name=secret_name, SecretString=secret_string + ) + + s3_client = boto3.client("s3", region_name="us-east-1") + + body_id = uuid4() + series_id = uuid4() + consignment_id = uuid4() + + consignment_reference = "TDR-2024-ABCD" + + file_1_id = uuid4() + file_2_id = uuid4() + file_3_id = uuid4() + + session.add_all( + [ + File( + FileId=file_1_id, + FileType="File", + FileName="test-document.txt", + FileReference="file-123", + FilePath="/path/to/file", + CiteableReference="cite-ref-123", + ConsignmentId=consignment_id, + ), + File( + FileId=file_2_id, + FileType="File", + FileName="test-document.txt", + FileReference="file-123", + FilePath="/path/to/file", + CiteableReference="cite-ref-123", + ConsignmentId=consignment_id, + ), + File( + FileId=file_3_id, + FileType="File", + FileName="test-document.txt", + FileReference="file-123", + FilePath="/path/to/file", + CiteableReference="cite-ref-123", + ConsignmentId=consignment_id, + ), + Consignment( + ConsignmentId=consignment_id, + ConsignmentType="foo", + ConsignmentReference=consignment_reference, + SeriesId=series_id, + ), + Series(SeriesId=series_id, Name="series-name", BodyId=body_id), + Body( + BodyId=body_id, + Name="body-name", + Description="transferring body description", + ), + FileMetadata( + MetadataId=uuid4(), + FileId=file_1_id, + PropertyName="Key1", + Value="Value1", + ), + FileMetadata( + MetadataId=uuid4(), + FileId=file_1_id, + PropertyName="Key2", + Value="Value2", + ), + FileMetadata( + MetadataId=uuid4(), + FileId=file_2_id, + PropertyName="Key3", + Value="Value3", + ), + FileMetadata( + MetadataId=uuid4(), + FileId=file_2_id, + PropertyName="Key4", + Value="Value4", + ), + FileMetadata( + MetadataId=uuid4(), + FileId=file_3_id, + PropertyName="Key5", + Value="Value5", + ), + FileMetadata( + MetadataId=uuid4(), + FileId=file_3_id, + PropertyName="Key6", + Value="Value6", + ), + ] + ) + session.commit() + + object_key_1 = f"{consignment_reference}/{file_1_id}" + object_key_2 = f"{consignment_reference}/{file_2_id}" + object_key_3 = f"{consignment_reference}/{file_3_id}" + + s3_client.create_bucket(Bucket=bucket_name) + + s3_client.put_object( + Bucket=bucket_name, Key=object_key_1, Body=b"Test file content" + ) + s3_client.put_object(Bucket=bucket_name, Key=object_key_2, Body=b"") + s3_client.put_object( + Bucket=bucket_name, + Key=object_key_3, + Body=b"File content but in file we do not support text extraction for", + ) + + sns_message = { + "properties": { + "messageType": "uk.gov.nationalarchives.da.messages.ayrmetadata.loaded", + "function": "ddt-ayrmetadataload-process", + }, + "parameters": { + "reference": consignment_reference, + "originator": "DDT", + }, + } + + event = { + "Records": [ + { + "Sns": { + "Message": json.dumps(sns_message), + }, + } + ] + } + + with patch( + "botocore.client.BaseClient._make_api_call", new=mock_make_api_call + ): + with patch( + "opensearch_indexer.index_consignment.bulk_index_consignment.bulk_index_files_in_opensearch" + ) as mock_bulk_index_files_in_opensearch: + lambda_handler(event, None) + + args, _ = mock_bulk_index_files_in_opensearch.call_args + + assert args[0] == [ + { + "file_id": str(file_1_id), + "document": { + "file_id": str(file_1_id), + "file_name": "test-document.txt", + "file_reference": "file-123", + "file_path": "/path/to/file", + "citeable_reference": "cite-ref-123", + "series_id": str(series_id), + "series_name": "series-name", + "transferring_body": "body-name", + "transferring_body_id": str(body_id), + "transferring_body_description": "transferring body description", + "consignment_id": str(consignment_id), + "consignment_reference": "TDR-2024-ABCD", + "Key1": "Value1", + "Key2": "Value2", + "content": "Test file content", + "text_extraction_status": TextExtractionStatus.SUCCEEDED.value, + }, + }, + { + "document": { + "file_id": str(file_2_id), + "file_name": "test-document.txt", + "file_reference": "file-123", + "file_path": "/path/to/file", + "citeable_reference": "cite-ref-123", + "series_id": str(series_id), + "series_name": "series-name", + "transferring_body": "body-name", + "transferring_body_id": str(body_id), + "transferring_body_description": "transferring body description", + "consignment_id": str(consignment_id), + "consignment_reference": "TDR-2024-ABCD", + "Key3": "Value3", + "Key4": "Value4", + "content": "", + "text_extraction_status": TextExtractionStatus.SUCCEEDED.value, + }, + "file_id": str(file_2_id), + }, + { + "document": { + "file_id": str(file_3_id), + "file_name": "test-document.txt", + "file_reference": "file-123", + "file_path": "/path/to/file", + "citeable_reference": "cite-ref-123", + "series_id": str(series_id), + "series_name": "series-name", + "consignment_id": str(consignment_id), + "consignment_reference": "TDR-2024-ABCD", + "transferring_body": "body-name", + "transferring_body_id": str(body_id), + "transferring_body_description": "transferring body description", + "Key5": "Value5", + "Key6": "Value6", + "content": "File content but in file we do not support text extraction for", + "text_extraction_status": TextExtractionStatus.SUCCEEDED.value, + }, + "file_id": str(file_3_id), + }, + ] + assert args[1] == "https://test-opensearch.com" + + aws_auth = args[2] + assert isinstance(aws_auth, AWS4Auth) + assert aws_auth.access_id == "test_access_key" + assert ( + aws_auth.signing_key.secret_key + == "test_secret_key" # pragma: allowlist secret + ) + assert aws_auth.region == "eu-west-2" + assert aws_auth.service == "es" + assert aws_auth.session_token == "test_token" + + assert args[3] == 600 + assert args[4] is None diff --git a/data_management/opensearch_indexer/tests/test_process_and_index_file.py b/data_management/opensearch_indexer/tests/test_process_and_index_file.py index f835afbe..b22fd7d8 100644 --- a/data_management/opensearch_indexer/tests/test_process_and_index_file.py +++ b/data_management/opensearch_indexer/tests/test_process_and_index_file.py @@ -1,93 +1,14 @@ -import tempfile from unittest import mock from uuid import uuid4 -import pytest from opensearch_indexer.index_file_content_and_metadata_in_opensearch import ( index_file_content_and_metadata_in_opensearch, ) +from opensearch_indexer.text_extraction import TextExtractionStatus from opensearchpy import RequestsHttpConnection -from sqlalchemy import ( - Boolean, - Column, - DateTime, - ForeignKey, - String, - Text, - create_engine, -) -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import declarative_base, relationship, sessionmaker - -Base = declarative_base() - - -class Body(Base): - __tablename__ = "Body" - BodyId = Column(UUID(as_uuid=True), primary_key=True) - Name = Column(Text) - Description = Column(Text) - - -class Series(Base): - __tablename__ = "Series" - SeriesId = Column(UUID(as_uuid=True), primary_key=True) - BodyId = Column(UUID(as_uuid=True), ForeignKey("Body.BodyId")) - Name = Column(Text) - Description = Column(Text) - body = relationship("Body", foreign_keys="Series.BodyId") - - -class Consignment(Base): - __tablename__ = "Consignment" - ConsignmentId = Column(UUID(as_uuid=True), primary_key=True) - SeriesId = Column(UUID(as_uuid=True), ForeignKey("Series.SeriesId")) - BodyId = Column(UUID(as_uuid=True), ForeignKey("Body.BodyId")) - ConsignmentReference = Column(Text) - ConsignmentType = Column(String, nullable=False) - IncludeTopLevelFolder = Column(Boolean) - ContactName = Column(Text) - ContactEmail = Column(Text) - TransferStartDatetime = Column(DateTime) - TransferCompleteDatetime = Column(DateTime) - ExportDatetime = Column(DateTime) - CreatedDatetime = Column(DateTime) - series = relationship("Series", foreign_keys="Consignment.SeriesId") - - -class File(Base): - __tablename__ = "File" - FileId = Column(UUID(as_uuid=True), primary_key=True) - ConsignmentId = Column( - UUID(as_uuid=True), ForeignKey("Consignment.ConsignmentId") - ) - FileReference = Column(Text, nullable=False) - FileType = Column(Text, nullable=False) - FileName = Column(Text, nullable=False) - FilePath = Column(Text, nullable=False) - CiteableReference = Column(Text) - Checksum = Column(Text) - CreatedDatetime = Column(DateTime) - consignment = relationship("Consignment", foreign_keys="File.ConsignmentId") - - -class FileMetadata(Base): - __tablename__ = "FileMetadata" - MetadataId = Column(UUID(as_uuid=True), primary_key=True) - FileId = Column(UUID(as_uuid=True), ForeignKey("File.FileId")) - PropertyName = Column(Text, nullable=False) - Value = Column(Text) - CreatedDatetime = Column(DateTime) - file = relationship("File", foreign_keys="FileMetadata.FileId") - +from sqlalchemy.orm import sessionmaker -@pytest.fixture -def temp_db(): - temp_db_file = tempfile.NamedTemporaryFile(suffix=".db") - database_url = f"sqlite:///{temp_db_file.name}" - engine = create_engine(database_url) - Base.metadata.create_all(engine) - yield engine +from .conftest import Body, Consignment, File, FileMetadata, Series @mock.patch( @@ -195,6 +116,6 @@ def test_index_file_content_and_metadata_in_opensearch( "Key1": "Value1", "Key2": "Value2", "content": "Text stream", - "text_extraction_status": "success", + "text_extraction_status": TextExtractionStatus.SUCCEEDED.value, }, )