Skip to content
This repository has been archived by the owner on Dec 10, 2024. It is now read-only.

Commit

Permalink
WIP: Incremental SYnc State Message Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
michelledv01 committed Aug 10, 2023
1 parent 2652870 commit 285b54a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from airbyte_cdk import AirbyteLogger
from airbyte_cdk.sources.streams import IncrementalMixin
from airbyte_protocol.models import ConfiguredAirbyteStream, Type, AirbyteMessage
from airbyte_protocol.models import ConfiguredAirbyteStream, Type, AirbyteMessage, AirbyteStateMessage

from source_google_firestore.AirbyteHelpers import AirbyteHelpers
from source_google_firestore.FirestoreSource import FirestoreSource
Expand All @@ -17,6 +17,7 @@ def __init__(self, firestore: FirestoreSource, logger: AirbyteLogger, config: js
self.query = QueryHelpers(firestore, logger, config, airbyte_stream)
self.logger = logger
self.airbyte = AirbyteHelpers(airbyte_stream, config)
self.collection_name = airbyte_stream.stream.name
self.cursor_field = config.get("cursor_field", "updated_at")
self._cursor_value = None

Expand All @@ -41,6 +42,16 @@ def chunk_time(self, last_updated_at):
last_updated_at += timedelta(minutes=1)
return timeframes

def get_cursor_value(self, documents: list[dict]):
max_value = None

for doc in documents:
field_value = doc.get(self.cursor_field, None)
if field_value is not None or field_value != "":
if max_value is None or field_value > max_value:
max_value = field_value
return max_value

def stream(self, state):
airbyte = self.airbyte
query = self.query
Expand All @@ -53,12 +64,12 @@ def stream(self, state):
for airbyte_message in airbyte.send_airbyte_message(documents):
yield airbyte_message
self._cursor_value = datetime.utcnow()
yield AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=self.state))
else:
for timeframe in timeframes:
self.logger.info(f"Fetching documents from {timeframe['start_at']} to {timeframe['end_at']}")
documents: list[dict] = query.fetch_records(cursor_value=timeframe)
self.logger.info(f"Finished fetching documents. Total documents: {len(documents)}")
for airbyte_message in airbyte.send_airbyte_message(documents):
yield airbyte_message
self._cursor_value = timeframe[self.cursor_field]
yield AirbyteMessage(type=Type.STATE, state=self.state)
yield AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=self.state))
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ class QueryHelpers:
def __init__(self, firestore: FirestoreSource, logger: AirbyteLogger, config: json, airbyte_stream: ConfiguredAirbyteStream):
self.firestore = firestore
self.logger = logger
self.airbyte_stream = airbyte_stream
self.collection_name = airbyte_stream.stream.name
self.primary_key = config.get("primary_key", "id")
self.cursor_field = config.get("cursor_field", "updated_at")
self.append_sub_collections = enable_append_sub_collections(config)
self.documents = []

def get_documents_query(self, collection_name: str, document: dict, cursor_value):
def get_documents_query(self, document: dict, cursor_value):
firestore = self.firestore
cursor_field = self.cursor_field
base_query = firestore.get_documents(collection_name).limit(1000)
base_query = firestore.get_documents(self.collection_name).limit(1000)

if cursor_value:
start_after = FieldFilter(cursor_field, ">=", DatetimeWithNanoseconds.fromtimestamp(cursor_value["start_at"]))
Expand All @@ -41,47 +40,49 @@ def get_documents_query(self, collection_name: str, document: dict, cursor_value
base_query = base_query.order_by(self.primary_key)

if document is not None:
start_after = {self.primary_key: document[self.primary_key]} if document else None
if start_after:
if cursor_value is None:
start_after = {self.primary_key: document[self.primary_key]} if document else None
base_query = base_query.start_after(start_after)
else:
start_after = {self.primary_key: document[self.primary_key], self.cursor_field: document.get(self.cursor_field, None)}
base_query = base_query.start_after(start_after)

return base_query

def get_sub_collection_documents(self, collection_name, parent_id):
def get_sub_collection_documents(self, parent_id):
firestore = self.firestore
sub_collections_documents = {}
# Fetch documents from sub-collections
for sub_collection in firestore.get_sub_collections(collection_name, str(parent_id)):
for sub_collection in firestore.get_sub_collections(self.collection_name, str(parent_id)):
sub_collection_name = sub_collection.id
documents = [child_doc.to_dict() for child_doc in sub_collection.stream()]
sub_collections_documents[sub_collection_name] = documents

return sub_collections_documents

def handle_sub_collections(self, parent_documents: list, collection_name: str):
def handle_sub_collections(self, parent_documents: list):
documents = []
for parent_doc in parent_documents:
# Fetch nested sub-collections for each parent document
sub_collections_documents = self.get_sub_collection_documents(collection_name, parent_doc[self.primary_key])
sub_collections_documents = self.get_sub_collection_documents(parent_doc[self.primary_key])
documents.append(parent_doc | sub_collections_documents)
return documents

def fetch_records(self, start_at=None, cursor_value=None) -> list[dict]:
def fetch_records(self, start_at=None, cursor_value=None, data=[]) -> list[dict]:
logger = self.logger
collection_name = self.airbyte_stream.stream.name
data = self.documents

base_query = self.get_documents_query(collection_name, start_at, cursor_value)
base_query = self.get_documents_query(start_at, cursor_value)
documents = [doc.to_dict() for doc in base_query.stream()]

if self.append_sub_collections:
documents = self.handle_sub_collections(documents, collection_name)
data.extend(self.handle_sub_collections(list(documents), collection_name))
logger.info(f"Fetched {len(documents)} documents. Total documents: {len(data)}")
documents = self.handle_sub_collections(documents)

data.extend(documents)

next_start_at = documents[-1] if documents else None

if next_start_at is not None:
logger.info(f"Fetching next batch of documents. Last document: {next_start_at[self.primary_key]}")
return self.fetch_records(next_start_at, cursor_value)
logger.info(f"Fetching next batch of documents. Last document: {next_start_at[self.primary_key]} Total documents: {len(data)}")
return self.fetch_records(next_start_at, cursor_value, data)
else:
return data

0 comments on commit 285b54a

Please sign in to comment.