From 285b54a8adcb94113bdd79489fbdd4b7400ba850 Mon Sep 17 00:00:00 2001 From: michelle Date: Thu, 10 Aug 2023 15:09:45 -0400 Subject: [PATCH] WIP: Incremental SYnc State Message Fix --- .../FirestoreIncrementalRefresh.py | 17 ++++++-- .../source_google_firestore/QueryHelpers.py | 39 ++++++++++--------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/airbyte-integrations/connectors/source-google-firestore/source_google_firestore/FirestoreIncrementalRefresh.py b/airbyte-integrations/connectors/source-google-firestore/source_google_firestore/FirestoreIncrementalRefresh.py index 685a396a23a9..4a563cece2d0 100644 --- a/airbyte-integrations/connectors/source-google-firestore/source_google_firestore/FirestoreIncrementalRefresh.py +++ b/airbyte-integrations/connectors/source-google-firestore/source_google_firestore/FirestoreIncrementalRefresh.py @@ -5,7 +5,7 @@ from airbyte_cdk import AirbyteLogger from airbyte_cdk.sources.streams import IncrementalMixin -from airbyte_protocol.models import ConfiguredAirbyteStream, Type, AirbyteMessage +from airbyte_protocol.models import ConfiguredAirbyteStream, Type, AirbyteMessage, AirbyteStateMessage from source_google_firestore.AirbyteHelpers import AirbyteHelpers from source_google_firestore.FirestoreSource import FirestoreSource @@ -17,6 +17,7 @@ def __init__(self, firestore: FirestoreSource, logger: AirbyteLogger, config: js self.query = QueryHelpers(firestore, logger, config, airbyte_stream) self.logger = logger self.airbyte = AirbyteHelpers(airbyte_stream, config) + self.collection_name = airbyte_stream.stream.name self.cursor_field = config.get("cursor_field", "updated_at") self._cursor_value = None @@ -41,6 +42,16 @@ def chunk_time(self, last_updated_at): last_updated_at += timedelta(minutes=1) return timeframes + def get_cursor_value(self, documents: list[dict]): + max_value = None + + for doc in documents: + field_value = doc.get(self.cursor_field, None) + if field_value is not None or field_value != "": + if max_value is None or field_value > max_value: + max_value = field_value + return max_value + def stream(self, state): airbyte = self.airbyte query = self.query @@ -53,12 +64,12 @@ def stream(self, state): for airbyte_message in airbyte.send_airbyte_message(documents): yield airbyte_message self._cursor_value = datetime.utcnow() + yield AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=self.state)) else: for timeframe in timeframes: self.logger.info(f"Fetching documents from {timeframe['start_at']} to {timeframe['end_at']}") documents: list[dict] = query.fetch_records(cursor_value=timeframe) - self.logger.info(f"Finished fetching documents. Total documents: {len(documents)}") for airbyte_message in airbyte.send_airbyte_message(documents): yield airbyte_message self._cursor_value = timeframe[self.cursor_field] - yield AirbyteMessage(type=Type.STATE, state=self.state) + yield AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=self.state)) diff --git a/airbyte-integrations/connectors/source-google-firestore/source_google_firestore/QueryHelpers.py b/airbyte-integrations/connectors/source-google-firestore/source_google_firestore/QueryHelpers.py index d0cbff072659..2048bcf25426 100644 --- a/airbyte-integrations/connectors/source-google-firestore/source_google_firestore/QueryHelpers.py +++ b/airbyte-integrations/connectors/source-google-firestore/source_google_firestore/QueryHelpers.py @@ -22,16 +22,15 @@ class QueryHelpers: def __init__(self, firestore: FirestoreSource, logger: AirbyteLogger, config: json, airbyte_stream: ConfiguredAirbyteStream): self.firestore = firestore self.logger = logger - self.airbyte_stream = airbyte_stream + self.collection_name = airbyte_stream.stream.name self.primary_key = config.get("primary_key", "id") self.cursor_field = config.get("cursor_field", "updated_at") self.append_sub_collections = enable_append_sub_collections(config) - self.documents = [] - def get_documents_query(self, collection_name: str, document: dict, cursor_value): + def get_documents_query(self, document: dict, cursor_value): firestore = self.firestore cursor_field = self.cursor_field - base_query = firestore.get_documents(collection_name).limit(1000) + base_query = firestore.get_documents(self.collection_name).limit(1000) if cursor_value: start_after = FieldFilter(cursor_field, ">=", DatetimeWithNanoseconds.fromtimestamp(cursor_value["start_at"])) @@ -41,47 +40,49 @@ def get_documents_query(self, collection_name: str, document: dict, cursor_value base_query = base_query.order_by(self.primary_key) if document is not None: - start_after = {self.primary_key: document[self.primary_key]} if document else None - if start_after: + if cursor_value is None: + start_after = {self.primary_key: document[self.primary_key]} if document else None + base_query = base_query.start_after(start_after) + else: + start_after = {self.primary_key: document[self.primary_key], self.cursor_field: document.get(self.cursor_field, None)} base_query = base_query.start_after(start_after) return base_query - def get_sub_collection_documents(self, collection_name, parent_id): + def get_sub_collection_documents(self, parent_id): firestore = self.firestore sub_collections_documents = {} # Fetch documents from sub-collections - for sub_collection in firestore.get_sub_collections(collection_name, str(parent_id)): + for sub_collection in firestore.get_sub_collections(self.collection_name, str(parent_id)): sub_collection_name = sub_collection.id documents = [child_doc.to_dict() for child_doc in sub_collection.stream()] sub_collections_documents[sub_collection_name] = documents return sub_collections_documents - def handle_sub_collections(self, parent_documents: list, collection_name: str): + def handle_sub_collections(self, parent_documents: list): documents = [] for parent_doc in parent_documents: # Fetch nested sub-collections for each parent document - sub_collections_documents = self.get_sub_collection_documents(collection_name, parent_doc[self.primary_key]) + sub_collections_documents = self.get_sub_collection_documents(parent_doc[self.primary_key]) documents.append(parent_doc | sub_collections_documents) return documents - def fetch_records(self, start_at=None, cursor_value=None) -> list[dict]: + def fetch_records(self, start_at=None, cursor_value=None, data=[]) -> list[dict]: logger = self.logger - collection_name = self.airbyte_stream.stream.name - data = self.documents - base_query = self.get_documents_query(collection_name, start_at, cursor_value) + base_query = self.get_documents_query(start_at, cursor_value) documents = [doc.to_dict() for doc in base_query.stream()] if self.append_sub_collections: - documents = self.handle_sub_collections(documents, collection_name) - data.extend(self.handle_sub_collections(list(documents), collection_name)) - logger.info(f"Fetched {len(documents)} documents. Total documents: {len(data)}") + documents = self.handle_sub_collections(documents) + + data.extend(documents) + next_start_at = documents[-1] if documents else None if next_start_at is not None: - logger.info(f"Fetching next batch of documents. Last document: {next_start_at[self.primary_key]}") - return self.fetch_records(next_start_at, cursor_value) + logger.info(f"Fetching next batch of documents. Last document: {next_start_at[self.primary_key]} Total documents: {len(data)}") + return self.fetch_records(next_start_at, cursor_value, data) else: return data