Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
evekhm committed Jul 18, 2024
1 parent 00a4485 commit 8a2de1e
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 53 deletions.
6 changes: 4 additions & 2 deletions classify-split-extract-workflow/classify-job/bq_mlops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ def object_table_create(
f_uris (List[str]): List of file URIs.
document_type (str): Type of the document.
table_suffix (str, optional): Suffix for the table name. Defaults to current UTC timestamp.
retention_days (int, optional): Number of days before the table expires. Defaults to BQ_OBJECT_TABLE_RETENTION_DAYS.
retention_days (int, optional): Number of days before the table expires.
Defaults to BQ_OBJECT_TABLE_RETENTION_DAYS.
Returns:
str: The name of the created BigQuery table.
"""

uris = "', '".join(f_uris)
object_table_name = f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_MLOPS}.{document_type.upper()}_DOCUMENTS_{table_suffix}"
object_table_name = f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_MLOPS}." \
f"{document_type.upper()}_DOCUMENTS_{table_suffix}"
query = f"""
CREATE OR REPLACE EXTERNAL TABLE `{object_table_name}`
WITH CONNECTION `{BQ_PROJECT_ID}.{BQ_REGION}.{BQ_GCS_CONNECTION_NAME}`
Expand Down
30 changes: 19 additions & 11 deletions classify-split-extract-workflow/classify-job/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud import run_v2, storage
import json
import os
from typing import Optional, Dict

from google.cloud import run_v2, storage
import google.auth
from logging_handler import Logger

Expand Down Expand Up @@ -78,7 +77,8 @@
f"Settings used: CLASSIFY_INPUT_BUCKET=gs://{CLASSIFY_INPUT_BUCKET}, INPUT_FILE={INPUT_FILE}, "
f"CLASSIFY_OUTPUT_BUCKET=gs://{CLASSIFY_OUTPUT_BUCKET}, OUTPUT_FILE_JSON={OUTPUT_FILE_JSON}, "
f"OUTPUT_FILE_CSV={OUTPUT_FILE_CSV}, CALL_BACK_URL={CALL_BACK_URL}, "
f"BQ_DATASET_ID_PROCESSED_DOCS={BQ_DATASET_ID_PROCESSED_DOCS}, BQ_DATASET_ID_MLOPS={BQ_DATASET_ID_MLOPS}, "
f"BQ_DATASET_ID_PROCESSED_DOCS={BQ_DATASET_ID_PROCESSED_DOCS}, "
f"BQ_DATASET_ID_MLOPS={BQ_DATASET_ID_MLOPS}, "
f"BQ_PROJECT_ID={BQ_PROJECT_ID}, BQ_GCS_CONNECTION_NAME={BQ_GCS_CONNECTION_NAME}, "
f"DOCAI_OUTPUT_BUCKET={DOCAI_OUTPUT_BUCKET}"
)
Expand Down Expand Up @@ -126,11 +126,13 @@ def get_config(config_name: Optional[str] = None, element_path: str = None) -> O
if isinstance(config_data_loaded, dict):
config_data_loaded_new = config_data_loaded.get(key)
if config_data_loaded_new is None:
logger.error(f"Key '{key}' not present in the configuration {json.dumps(config_data_loaded, indent=4)}")
logger.error(f"Key '{key}' not present in the "
f"configuration {json.dumps(config_data_loaded, indent=4)}")
return None
config_data_loaded = config_data_loaded_new
else:
logger.error(f"Expected a dictionary at '{key}' but found a {type(config_data_loaded).__name__}")
logger.error(f"Expected a dictionary at '{key}' but found a "
f"{type(config_data_loaded).__name__}")
return None

return config_data_loaded
Expand Down Expand Up @@ -212,13 +214,15 @@ def get_classification_default_class() -> str:
"""

settings = get_docai_settings()
classification_default_class = settings.get("classification_default_class", CLASSIFICATION_UNDETECTABLE)
classification_default_class = settings.get("classification_default_class",
CLASSIFICATION_UNDETECTABLE)
parser = get_parser_by_doc_type(classification_default_class)
if parser:
return classification_default_class

logger.warning(
f"Classification default label {classification_default_class} is not a valid Label or missing a corresponding "
f"Classification default label {classification_default_class}"
f" is not a valid Label or missing a corresponding "
f"parser in parser_config"
)
return CLASSIFICATION_UNDETECTABLE
Expand All @@ -240,13 +244,17 @@ def get_model_name_table_name(document_type: str) -> tuple[Optional[str], Option
parser = get_parser_by_doc_type(document_type)
if parser:
parser_name = get_parser_name_by_doc_type(document_type)
model_name = f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_MLOPS}.{parser.get('model_name', parser_name.upper() + '_MODEL')}"
out_table_name = f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_PROCESSED_DOCS}." \
f"{parser.get('out_table_name', parser_name.upper() + '_DOCUMENTS')}"
model_name = (
f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_MLOPS}."
f"{parser.get('model_name', parser_name.upper() + '_MODEL')}"
)
out_table_name = (
f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_PROCESSED_DOCS}."
f"{parser.get('out_table_name', parser_name.upper() + '_DOCUMENTS')}"
)
else:
logger.warning(f"No parser found for document type {document_type}")
return None, None

logger.info(f"model_name={model_name}, out_table_name={out_table_name}")
return model_name, out_table_name

45 changes: 15 additions & 30 deletions classify-split-extract-workflow/classify-job/docai_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains DocAI helper functions"""

import re
from typing import Tuple, Optional, List
from google.cloud.documentai_toolbox import document
from typing import Tuple, Optional

from google.api_core.client_options import ClientOptions
from google.cloud import documentai_v1 as documentai

from config import get_parser_by_name
from logging_handler import Logger
from google.api_core.client_options import ClientOptions

logger = Logger.get_logger(__file__)


def get_processor_and_client(processor_name: str) -> \
Tuple[Optional[documentai.types.processor.Processor], Optional[documentai.DocumentProcessorServiceClient]]:
def get_processor_and_client(
processor_name: str,
) -> Tuple[
Optional[documentai.types.processor.Processor],
Optional[documentai.DocumentProcessorServiceClient],
]:
"""
Retrieves the Document AI processor and client based on the processor name.
Args:
processor_name (str): The name of the processor.
Returns:
Tuple[Optional[documentai.types.processor.Processor], Optional[documentai.DocumentProcessorServiceClient]]:
Tuple[Optional[documentai.types.processor.Processor],
Optional[documentai.DocumentProcessorServiceClient]]:
A tuple containing the processor and the Document Processor Service client.
"""

Expand All @@ -57,31 +64,9 @@ def get_processor_and_client(processor_name: str) -> \


def get_processor_location(processor_path: str) -> Optional[str]:
"""
Extracts the location from the processor path.
Args:
processor_path (str): The processor path.
Returns:
Optional[str]: The location extracted from the processor path.
"""
"""Extracts the location from the processor path."""

match = re.match(r'projects/(.+)/locations/(.+)/processors', processor_path)
match = re.match(r"projects/(.+)/locations/(.+)/processors", processor_path)
if match and len(match.groups()) >= 2:
return match.group(2)
return None


def split_pdf_document(document_path: str, pdf_path: str, output_path: str) -> List[str]:
try:
wrapped_document = document.Document.from_document_path(document_path=document_path)
output_files = wrapped_document.split_pdf(
pdf_path=pdf_path, output_path=output_path
)

logger.info(f"Document {pdf_path} successfully split into {', '.join(output_files)}")
return output_files
except Exception as e:
logger.error(f"Error splitting PDF document: {e}")
return []
13 changes: 8 additions & 5 deletions classify-split-extract-workflow/classify-job/gcs_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def download_file(gcs_uri: str, bucket_name: str = None, file_to_download: str =
gcs_uri (str): GCS URI of the object/file to download.
bucket_name (str, optional): Name of the bucket. Defaults to None.
file_to_download (str, optional): Desired filename in GCS. Defaults to None.
output_filename (str, optional): Local filename to save the downloaded file. Defaults to 'gcs.pdf'.
output_filename (str, optional): Local filename to save the downloaded file.
Defaults to 'gcs.pdf'.
Returns:
str: Local path of the downloaded file.
Expand Down Expand Up @@ -114,7 +115,8 @@ def get_list_of_uris(bucket_name: str, file_uri: str) -> List[str]:
logger.info(f"Skipping {file_uri} - not supported mime type: {mime_type}")
else:
# Batch Processing
logger.info(f"Starting pipeline to process documents inside bucket=[{bucket_name}] and sub-folder=[{dirs}]")
logger.info(f"Starting pipeline to process documents inside"
f" bucket=[{bucket_name}] and sub-folder=[{dirs}]")
if dirs is None or dirs == "":
blob_list = storage_client.list_blobs(bucket_name)
else:
Expand All @@ -123,8 +125,8 @@ def get_list_of_uris(bucket_name: str, file_uri: str) -> List[str]:
count = 0
for blob in blob_list:
if blob.name and not blob.name.endswith('/') and \
blob.name != START_PIPELINE_FILENAME and not os.path.dirname(blob.name).endswith(
SPLITTER_OUTPUT_DIR):
blob.name != START_PIPELINE_FILENAME and \
not os.path.dirname(blob.name).endswith(SPLITTER_OUTPUT_DIR):
count += 1
f_uri = f"gs://{bucket_name}/{blob.name}"
logger.info(f"Handling {count}(th) document - {f_uri}")
Expand Down Expand Up @@ -166,7 +168,8 @@ def upload_file(bucket_name, source_file_name, destination_blob_name) -> str:
return output_gcs


def write_data_to_gcs(bucket_name: str, blob_name: str, content: str, mime_type: str = "text/plain") -> Tuple[str, str]:
def write_data_to_gcs(bucket_name: str, blob_name: str, content: str,
mime_type: str = "text/plain") -> Tuple[str, str]:
"""
Writes data to a GCS bucket in the specified format.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ def check_confidence_threshold_passed(predicted_confidence: float) -> bool:


def split_pdf(gcs_uri: str, entities: List[Document.Entity]) -> Dict:
"""Splits local PDF file into multiple PDF files based on output from a Splitter/Classifier processor.
"""Splits local PDF file into multiple PDF files based on output from a
Splitter/Classifier processor.
Args:
gcs_uri (str):
Expand All @@ -279,9 +280,11 @@ def split_pdf(gcs_uri: str, entities: List[Document.Entity]) -> Dict:
})

gcs_helper.add_metadata(gcs_uri=gcs_uri, metadata=metadata)
add_predicted_document_type(metadata=metadata, input_gcs_source=gcs_uri, documents=documents)
add_predicted_document_type(metadata=metadata, input_gcs_source=gcs_uri,
documents=documents)
else:
temp_local_dir = os.path.join(os.path.dirname(__file__), "temp_files", utils.get_utc_timestamp())
temp_local_dir = os.path.join(os.path.dirname(__file__), "temp_files",
utils.get_utc_timestamp())
if not os.path.exists(temp_local_dir):
os.makedirs(temp_local_dir)

Expand Down Expand Up @@ -325,7 +328,8 @@ def split_pdf(gcs_uri: str, entities: List[Document.Entity]) -> Dict:
destination_blob_name=destination_blob_name)
gcs_helper.add_metadata(destination_blob_uri, metadata)

add_predicted_document_type(metadata=metadata, input_gcs_source=destination_blob_uri,
add_predicted_document_type(metadata=metadata,
input_gcs_source=destination_blob_uri,
documents=documents)

utils.delete_directory(temp_local_dir)
Expand Down
3 changes: 2 additions & 1 deletion classify-split-extract-workflow/classify-job/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def split_pages(file_pattern: str, bucket_name: str, output_dir: str) -> None:
pdf_writers[page_index // 15].add_page(page)

for shard_index, pdf_writer in enumerate(pdf_writers):
output_filename = f"{output_dir}/{file_path[3:-4]} - part {shard_index + 1} of {num_shards}.pdf"
output_filename = f"{output_dir}/{file_path[3:-4]} - " \
f"part {shard_index + 1} of {num_shards}.pdf"
blob = bucket.blob(output_filename)
with blob.open("wb", content_type='application/pdf') as output_file:
pdf_writer.write(output_file)
Expand Down

0 comments on commit 8a2de1e

Please sign in to comment.