diff --git a/classify-split-extract-workflow/classify-job/bq_mlops.py b/classify-split-extract-workflow/classify-job/bq_mlops.py index 7e0f29ff1..670434b76 100644 --- a/classify-split-extract-workflow/classify-job/bq_mlops.py +++ b/classify-split-extract-workflow/classify-job/bq_mlops.py @@ -42,14 +42,16 @@ def object_table_create( f_uris (List[str]): List of file URIs. document_type (str): Type of the document. table_suffix (str, optional): Suffix for the table name. Defaults to current UTC timestamp. - retention_days (int, optional): Number of days before the table expires. Defaults to BQ_OBJECT_TABLE_RETENTION_DAYS. + retention_days (int, optional): Number of days before the table expires. + Defaults to BQ_OBJECT_TABLE_RETENTION_DAYS. Returns: str: The name of the created BigQuery table. """ uris = "', '".join(f_uris) - object_table_name = f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_MLOPS}.{document_type.upper()}_DOCUMENTS_{table_suffix}" + object_table_name = f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_MLOPS}." \ + f"{document_type.upper()}_DOCUMENTS_{table_suffix}" query = f""" CREATE OR REPLACE EXTERNAL TABLE `{object_table_name}` WITH CONNECTION `{BQ_PROJECT_ID}.{BQ_REGION}.{BQ_GCS_CONNECTION_NAME}` diff --git a/classify-split-extract-workflow/classify-job/config.py b/classify-split-extract-workflow/classify-job/config.py index 6331d9f01..41bd61512 100644 --- a/classify-split-extract-workflow/classify-job/config.py +++ b/classify-split-extract-workflow/classify-job/config.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.cloud import run_v2, storage import json import os -from typing import Optional, Dict - +from typing import Optional, Dict, Any, Union +from google.cloud import run_v2, storage import google.auth from logging_handler import Logger @@ -52,57 +51,57 @@ PDF_MIME_TYPE = "application/pdf" MIME_TYPES = [ PDF_MIME_TYPE, - # "image/gif", # TODO: Add/Check support for all these types - # "image/tiff", - # "image/jpeg", - # "image/png", - # "image/bmp", - # "image/webp" +] + +OTHER_MIME_TYPES_TO_SUPPORT = [ + "image/gif", + "image/tiff", + "image/jpeg", + "image/png", + "image/bmp", + "image/webp" ] NO_CLASSIFIER_LABEL = "No Classifier" METADATA_CONFIDENCE = "confidence" METADATA_DOCUMENT_TYPE = "type" - CONFIG_JSON_DOCUMENT_TYPES_CONFIG = "document_types_config" FULL_JOB_NAME = run_v2.ExecutionsClient.job_path(PROJECT_ID, REGION, "classify-job") # Global variables -gcs = None -bucket = None -last_modified_time_of_object = None -config_data = None +BUCKET = None +LAST_MODIFIED_TIME_OF_CONFIG = None +CONFIG_DATA = None logger.info( f"Settings used: CLASSIFY_INPUT_BUCKET=gs://{CLASSIFY_INPUT_BUCKET}, INPUT_FILE={INPUT_FILE}, " f"CLASSIFY_OUTPUT_BUCKET=gs://{CLASSIFY_OUTPUT_BUCKET}, OUTPUT_FILE_JSON={OUTPUT_FILE_JSON}, " f"OUTPUT_FILE_CSV={OUTPUT_FILE_CSV}, CALL_BACK_URL={CALL_BACK_URL}, " - f"BQ_DATASET_ID_PROCESSED_DOCS={BQ_DATASET_ID_PROCESSED_DOCS}, BQ_DATASET_ID_MLOPS={BQ_DATASET_ID_MLOPS}, " + f"BQ_DATASET_ID_PROCESSED_DOCS={BQ_DATASET_ID_PROCESSED_DOCS}, " + f"BQ_DATASET_ID_MLOPS={BQ_DATASET_ID_MLOPS}, " f"BQ_PROJECT_ID={BQ_PROJECT_ID}, BQ_GCS_CONNECTION_NAME={BQ_GCS_CONNECTION_NAME}, " f"DOCAI_OUTPUT_BUCKET={DOCAI_OUTPUT_BUCKET}" ) -def init_bucket(bucket_name: str) -> None: +def init_bucket(bucket_name: str) -> Optional[storage.Bucket]: """ Initializes the GCS bucket. Args: bucket_name (str): The name of the bucket. """ - global gcs, bucket - if not gcs: - gcs = storage.Client() - - if not bucket: - bucket = gcs.bucket(bucket_name) - if not bucket.exists(): - logger.error(f"Bucket does not exist: gs://{bucket_name}") - bucket = None + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + if not bucket.exists(): + logger.error(f"Bucket does not exist: gs://{bucket_name}") + return None # Return None to indicate failure + return bucket -def get_config(config_name: Optional[str] = None, element_path: str = None) -> Optional[Dict]: +def get_config(config_name: Optional[str] = None, + element_path: Optional[str] = None) -> Optional[Dict[Any, Any]]: """ Retrieves the configuration data. @@ -113,12 +112,12 @@ def get_config(config_name: Optional[str] = None, element_path: str = None) -> O Returns: Dict: The configuration data. """ - global config_data, last_modified_time_of_object - if not config_data: - config_data = load_config(CONFIG_BUCKET, CONFIG_FILE_NAME) - assert config_data, "Unable to load configuration data" + global CONFIG_DATA, LAST_MODIFIED_TIME_OF_CONFIG + if not CONFIG_DATA: + CONFIG_DATA = load_config(CONFIG_BUCKET, CONFIG_FILE_NAME) + assert CONFIG_DATA, "Unable to load configuration data" - config_data_loaded = config_data.get(config_name, {}) if config_name else config_data + config_data_loaded = CONFIG_DATA.get(config_name, {}) if config_name else CONFIG_DATA if element_path: keys = element_path.split('.') @@ -126,68 +125,109 @@ def get_config(config_name: Optional[str] = None, element_path: str = None) -> O if isinstance(config_data_loaded, dict): config_data_loaded_new = config_data_loaded.get(key) if config_data_loaded_new is None: - logger.error(f"Key '{key}' not present in the configuration {json.dumps(config_data_loaded, indent=4)}") + logger.error(f"Key '{key}' not present in the " + f"configuration {json.dumps(config_data_loaded, indent=4)}") return None config_data_loaded = config_data_loaded_new else: - logger.error(f"Expected a dictionary at '{key}' but found a {type(config_data_loaded).__name__}") + logger.error(f"Expected a dictionary at '{key}' but found a " + f"{type(config_data_loaded).__name__}") return None return config_data_loaded def get_parser_name_by_doc_type(doc_type: str) -> Optional[str]: + """Retrieves the parser name based on the document type. + + Args: + doc_type (str): The document type. + + Returns: + Optional[str]: The parser name, or None if not found. + """ return get_config(CONFIG_JSON_DOCUMENT_TYPES_CONFIG, f"{doc_type}.parser") -def get_document_types_config() -> Dict: +def get_document_types_config() -> Dict[Any, Any]: + """ + Retrieves the document types configuration. + + Returns: + Dict: The document types configuration. + """ return get_config(CONFIG_JSON_DOCUMENT_TYPES_CONFIG) -def get_parser_by_doc_type(doc_type: str) -> Optional[Dict]: +def get_parser_by_doc_type(doc_type: str) -> Optional[Dict[Any, Any]]: + """ + Retrieves the parser by document type. + + Args: + doc_type (str): The document type. + + Returns: + Optional[Dict]: The parser configuration. + """ parser_name = get_parser_name_by_doc_type(doc_type) if parser_name: return get_config("parser_config", parser_name) - return None -def load_config(bucket_name: str, filename: str) -> Optional[Dict]: - global bucket, last_modified_time_of_object, config_data +def load_config(bucket_name: str, filename: str) -> Optional[Dict[Any, Any]]: + """ + Loads the configuration data from a GCS bucket or local file. + + Args: + bucket_name (str): The GCS bucket name. + filename (str): The configuration file name. + + Returns: + Optional[Dict]: The configuration data. + """ + global BUCKET, LAST_MODIFIED_TIME_OF_CONFIG, CONFIG_DATA - if not bucket: - init_bucket(bucket_name) + if not BUCKET: + BUCKET = init_bucket(bucket_name) - if not bucket: + if not BUCKET: return None - blob = bucket.get_blob(filename) + blob = BUCKET.get_blob(filename) if not blob: logger.error(f"Error: file does not exist gs://{bucket_name}/{filename}") return None last_modified_time = blob.updated - if last_modified_time == last_modified_time_of_object: - return config_data + if last_modified_time == LAST_MODIFIED_TIME_OF_CONFIG: + return CONFIG_DATA logger.info(f"Reloading config from: {filename}") try: - config_data = json.loads(blob.download_as_text(encoding="utf-8")) - last_modified_time_of_object = last_modified_time + CONFIG_DATA = json.loads(blob.download_as_text(encoding="utf-8")) + LAST_MODIFIED_TIME_OF_CONFIG = last_modified_time except Exception as e: logger.error(f"Error while obtaining file from GCS gs://{bucket_name}/{filename}: {e}") logger.warning(f"Using local {filename}") try: - with open(os.path.join(os.path.dirname(__file__), "config", filename)) as json_file: - config_data = json.load(json_file) - except Exception as e: - logger.error(f"Error loading local config file {filename}: {e}") + with open(os.path.join(os.path.dirname(__file__), "config", filename), + encoding='utf-8') as json_file: + CONFIG_DATA = json.load(json_file) + except (FileNotFoundError, json.JSONDecodeError) as exc: + logger.error(f"Error loading local config file {filename}: {exc}") return None - return config_data + return CONFIG_DATA -def get_docai_settings() -> Dict: +def get_docai_settings() -> Dict[Any, Any]: + """ + Retrieves the Document AI settings configuration. + + Returns: + Dict: The Document AI settings configuration. + """ return get_config("settings_config") @@ -212,19 +252,30 @@ def get_classification_default_class() -> str: """ settings = get_docai_settings() - classification_default_class = settings.get("classification_default_class", CLASSIFICATION_UNDETECTABLE) + classification_default_class = settings.get("classification_default_class", + CLASSIFICATION_UNDETECTABLE) parser = get_parser_by_doc_type(classification_default_class) if parser: return classification_default_class logger.warning( - f"Classification default label {classification_default_class} is not a valid Label or missing a corresponding " + f"Classification default label {classification_default_class}" + f" is not a valid Label or missing a corresponding " f"parser in parser_config" ) return CLASSIFICATION_UNDETECTABLE def get_document_class_by_classifier_label(label_name: str) -> Optional[str]: + """ + Retrieves the document class by classifier label. + + Args: + label_name (str): The classifier label name. + + Returns: + Optional[str]: The document class. + """ for k, v in get_document_types_config().items(): if v.get("classifier_label") == label_name: return k @@ -232,21 +283,45 @@ def get_document_class_by_classifier_label(label_name: str) -> Optional[str]: return None -def get_parser_by_name(parser_name: str) -> Optional[Dict]: +def get_parser_by_name(parser_name: str) -> Optional[Dict[Any, Any]]: + """ + Retrieves the parser configuration by parser name. + + Args: + parser_name (str): The parser name. + + Returns: + Optional[Dict]: The parser configuration. + """ return get_config("parser_config", parser_name) -def get_model_name_table_name(document_type: str) -> tuple[Optional[str], Optional[str]]: - parser = get_parser_by_doc_type(document_type) - if parser: - parser_name = get_parser_name_by_doc_type(document_type) - model_name = f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_MLOPS}.{parser.get('model_name', parser_name.upper() + '_MODEL')}" - out_table_name = f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_PROCESSED_DOCS}." \ - f"{parser.get('out_table_name', parser_name.upper() + '_DOCUMENTS')}" +def get_model_name_table_name(document_type: str) -> \ + Union[tuple[Optional[str], Optional[str]], tuple[None, None]]: + """ + Retrieves the output table name and model name by document type. + + Args: + document_type (str): The document type. + + Returns: + Union[tuple[Optional[str], Optional[str]], tuple[None, None]]: The output table name and + model name. + """ + parser_name = get_parser_name_by_doc_type(document_type) + if parser_name: + parser = get_parser_by_doc_type(document_type) + model_name = ( + f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_MLOPS}." + f"{parser.get('model_name', parser_name.upper() + '_MODEL')}" + ) + out_table_name = ( + f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_PROCESSED_DOCS}." + f"{parser.get('out_table_name', parser_name.upper() + '_DOCUMENTS')}" + ) else: logger.warning(f"No parser found for document type {document_type}") return None, None logger.info(f"model_name={model_name}, out_table_name={out_table_name}") return model_name, out_table_name - diff --git a/classify-split-extract-workflow/classify-job/docai_helper.py b/classify-split-extract-workflow/classify-job/docai_helper.py index e08e055ec..ac137cfba 100644 --- a/classify-split-extract-workflow/classify-job/docai_helper.py +++ b/classify-split-extract-workflow/classify-job/docai_helper.py @@ -12,19 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. """This module contains DocAI helper functions""" + import re -from typing import Tuple, Optional, List -from google.cloud.documentai_toolbox import document +from typing import Tuple, Optional + +from google.api_core.client_options import ClientOptions from google.cloud import documentai_v1 as documentai + from config import get_parser_by_name from logging_handler import Logger -from google.api_core.client_options import ClientOptions logger = Logger.get_logger(__file__) -def get_processor_and_client(processor_name: str) -> \ - Tuple[Optional[documentai.types.processor.Processor], Optional[documentai.DocumentProcessorServiceClient]]: +def get_processor_and_client( + processor_name: str, +) -> Tuple[ + Optional[documentai.types.processor.Processor], + Optional[documentai.DocumentProcessorServiceClient], +]: """ Retrieves the Document AI processor and client based on the processor name. @@ -32,7 +38,8 @@ def get_processor_and_client(processor_name: str) -> \ processor_name (str): The name of the processor. Returns: - Tuple[Optional[documentai.types.processor.Processor], Optional[documentai.DocumentProcessorServiceClient]]: + Tuple[Optional[documentai.types.processor.Processor], + Optional[documentai.DocumentProcessorServiceClient]]: A tuple containing the processor and the Document Processor Service client. """ @@ -57,31 +64,9 @@ def get_processor_and_client(processor_name: str) -> \ def get_processor_location(processor_path: str) -> Optional[str]: - """ - Extracts the location from the processor path. - - Args: - processor_path (str): The processor path. - - Returns: - Optional[str]: The location extracted from the processor path. - """ + """Extracts the location from the processor path.""" - match = re.match(r'projects/(.+)/locations/(.+)/processors', processor_path) + match = re.match(r"projects/(.+)/locations/(.+)/processors", processor_path) if match and len(match.groups()) >= 2: return match.group(2) return None - - -def split_pdf_document(document_path: str, pdf_path: str, output_path: str) -> List[str]: - try: - wrapped_document = document.Document.from_document_path(document_path=document_path) - output_files = wrapped_document.split_pdf( - pdf_path=pdf_path, output_path=output_path - ) - - logger.info(f"Document {pdf_path} successfully split into {', '.join(output_files)}") - return output_files - except Exception as e: - logger.error(f"Error splitting PDF document: {e}") - return [] diff --git a/classify-split-extract-workflow/classify-job/gcs_helper.py b/classify-split-extract-workflow/classify-job/gcs_helper.py index d69df8d31..6cfcbd981 100644 --- a/classify-split-extract-workflow/classify-job/gcs_helper.py +++ b/classify-split-extract-workflow/classify-job/gcs_helper.py @@ -13,7 +13,7 @@ # limitations under the License. import os -from typing import List, Tuple, Dict +from typing import List, Tuple, Dict, Optional from google.cloud import storage from logging_handler import Logger @@ -24,8 +24,12 @@ storage_client = storage.Client() -def download_file(gcs_uri: str, bucket_name: str = None, file_to_download: str = None, - output_filename: str = None) -> str: +def download_file( + gcs_uri: str, + bucket_name: Optional[str] = None, + file_to_download: Optional[str] = None, + output_filename: Optional[str] = "gcs.pdf", # Provide a default value for output_filename +) -> str: """ Downloads a file from a Google Cloud Storage (GCS) bucket to the local directory. @@ -33,7 +37,8 @@ def download_file(gcs_uri: str, bucket_name: str = None, file_to_download: str = gcs_uri (str): GCS URI of the object/file to download. bucket_name (str, optional): Name of the bucket. Defaults to None. file_to_download (str, optional): Desired filename in GCS. Defaults to None. - output_filename (str, optional): Local filename to save the downloaded file. Defaults to 'gcs.pdf'. + output_filename (str, optional): Local filename to save the downloaded file. + Defaults to 'gcs.pdf'. Returns: str: Local path of the downloaded file. @@ -82,20 +87,13 @@ def add_metadata(gcs_uri: str, metadata: Dict[str, str]): def get_list_of_uris(bucket_name: str, file_uri: str) -> List[str]: - """ - Retrieves a list of URIs from a GCS bucket. - - Args: - bucket_name (str): Name of the GCS bucket. - file_uri (str): URI of the file or directory in the bucket. - - Returns: - List[str]: List of URIs in the specified bucket and directory. - """ - logger.info(f"Getting list of URIs for bucket=[{bucket_name}] and file=[{file_uri}]") - uri_list = [] + """Retrieves a list of URIs from a GCS bucket.""" + logger.info( + f"Getting list of URIs for bucket=[{bucket_name}] and file=[{file_uri}]" + ) + uri_list: List[str] = [] # Type annotation for uri_list if not file_uri: - logger.warning(f"No file URI provided") + logger.warning("No file URI provided") return uri_list dirs, filename = split_uri_2_path_filename(file_uri) @@ -114,7 +112,8 @@ def get_list_of_uris(bucket_name: str, file_uri: str) -> List[str]: logger.info(f"Skipping {file_uri} - not supported mime type: {mime_type}") else: # Batch Processing - logger.info(f"Starting pipeline to process documents inside bucket=[{bucket_name}] and sub-folder=[{dirs}]") + logger.info(f"Starting pipeline to process documents inside" + f" bucket=[{bucket_name}] and sub-folder=[{dirs}]") if dirs is None or dirs == "": blob_list = storage_client.list_blobs(bucket_name) else: @@ -123,8 +122,8 @@ def get_list_of_uris(bucket_name: str, file_uri: str) -> List[str]: count = 0 for blob in blob_list: if blob.name and not blob.name.endswith('/') and \ - blob.name != START_PIPELINE_FILENAME and not os.path.dirname(blob.name).endswith( - SPLITTER_OUTPUT_DIR): + blob.name != START_PIPELINE_FILENAME and \ + not os.path.dirname(blob.name).endswith(SPLITTER_OUTPUT_DIR): count += 1 f_uri = f"gs://{bucket_name}/{blob.name}" logger.info(f"Handling {count}(th) document - {f_uri}") @@ -166,7 +165,8 @@ def upload_file(bucket_name, source_file_name, destination_blob_name) -> str: return output_gcs -def write_data_to_gcs(bucket_name: str, blob_name: str, content: str, mime_type: str = "text/plain") -> Tuple[str, str]: +def write_data_to_gcs(bucket_name: str, blob_name: str, content: str, + mime_type: str = "text/plain") -> Tuple[str, str]: """ Writes data to a GCS bucket in the specified format. diff --git a/classify-split-extract-workflow/classify-job/split_and_classify.py b/classify-split-extract-workflow/classify-job/split_and_classify.py index 247f713cd..8abca5aca 100644 --- a/classify-split-extract-workflow/classify-job/split_and_classify.py +++ b/classify-split-extract-workflow/classify-job/split_and_classify.py @@ -258,7 +258,8 @@ def check_confidence_threshold_passed(predicted_confidence: float) -> bool: def split_pdf(gcs_uri: str, entities: List[Document.Entity]) -> Dict: - """Splits local PDF file into multiple PDF files based on output from a Splitter/Classifier processor. + """Splits local PDF file into multiple PDF files based on output from a + Splitter/Classifier processor. Args: gcs_uri (str): @@ -279,9 +280,11 @@ def split_pdf(gcs_uri: str, entities: List[Document.Entity]) -> Dict: }) gcs_helper.add_metadata(gcs_uri=gcs_uri, metadata=metadata) - add_predicted_document_type(metadata=metadata, input_gcs_source=gcs_uri, documents=documents) + add_predicted_document_type(metadata=metadata, input_gcs_source=gcs_uri, + documents=documents) else: - temp_local_dir = os.path.join(os.path.dirname(__file__), "temp_files", utils.get_utc_timestamp()) + temp_local_dir = os.path.join(os.path.dirname(__file__), "temp_files", + utils.get_utc_timestamp()) if not os.path.exists(temp_local_dir): os.makedirs(temp_local_dir) @@ -325,7 +328,8 @@ def split_pdf(gcs_uri: str, entities: List[Document.Entity]) -> Dict: destination_blob_name=destination_blob_name) gcs_helper.add_metadata(destination_blob_uri, metadata) - add_predicted_document_type(metadata=metadata, input_gcs_source=destination_blob_uri, + add_predicted_document_type(metadata=metadata, + input_gcs_source=destination_blob_uri, documents=documents) utils.delete_directory(temp_local_dir) diff --git a/classify-split-extract-workflow/classify-job/utils.py b/classify-split-extract-workflow/classify-job/utils.py index 41377848f..e4f4853c7 100644 --- a/classify-split-extract-workflow/classify-job/utils.py +++ b/classify-split-extract-workflow/classify-job/utils.py @@ -78,7 +78,8 @@ def split_pages(file_pattern: str, bucket_name: str, output_dir: str) -> None: pdf_writers[page_index // 15].add_page(page) for shard_index, pdf_writer in enumerate(pdf_writers): - output_filename = f"{output_dir}/{file_path[3:-4]} - part {shard_index + 1} of {num_shards}.pdf" + output_filename = f"{output_dir}/{file_path[3:-4]} - " \ + f"part {shard_index + 1} of {num_shards}.pdf" blob = bucket.blob(output_filename) with blob.open("wb", content_type='application/pdf') as output_file: pdf_writer.write(output_file)