Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bigquery): add streaming inserts support #1123

Merged
merged 22 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dlt/common/schema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class TTableSchema(TypedDict, total=False):
columns: TTableSchemaColumns
resource: Optional[str]
table_format: Optional[TTableFormat]
insert_api: Optional[Literal["streaming", "default"]]
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved


class TPartialTableSchema(TTableSchema):
Expand Down
116 changes: 88 additions & 28 deletions dlt/destinations/impl/bigquery/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import functools
import os
from pathlib import Path
from typing import ClassVar, Optional, Sequence, Tuple, List, cast, Dict

import google.cloud.bigquery as bigquery # noqa: I250
from google.api_core import exceptions as api_core_exceptions
from google.cloud import exceptions as gcp_exceptions
from google.api_core import retry
from google.cloud.bigquery.retry import _DEFAULT_RETRY_DEADLINE, _RETRYABLE_REASONS

from dlt.common import json, logger
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.destinations.job_impl import DestinationJsonlLoadJob, DestinationParquetLoadJob
from dlt.common.destination.reference import (
FollowupJob,
NewLoadJob,
Expand Down Expand Up @@ -43,6 +47,7 @@
from dlt.destinations.job_impl import NewReferenceJob
from dlt.destinations.sql_jobs import SqlMergeJob
from dlt.destinations.type_mapping import TypeMapper
from dlt.pipeline.current import destination_state


class BigQueryTypeMapper(TypeMapper):
Expand Down Expand Up @@ -217,32 +222,53 @@ def restore_file_load(self, file_path: str) -> LoadJob:
return job

def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
job = super().start_file_load(table, file_path, load_id)
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

if not job:
try:
job = BigQueryLoadJob(
FileStorage.get_file_name_from_file_path(file_path),
self._create_load_job(table, file_path),
self.config.http_timeout,
self.config.retry_deadline,
insert_api = table.get("insert_api", self.config.loading_api)
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
if insert_api == "streaming":
if file_path.endswith(".jsonl"):
job_cls = DestinationJsonlLoadJob
elif file_path.endswith(".parquet"):
job_cls = DestinationParquetLoadJob
else:
raise ValueError(
f"Unsupported file type for BigQuery streaming inserts: {file_path}"
)
except api_core_exceptions.GoogleAPICallError as gace:
reason = BigQuerySqlClient._get_reason_from_errors(gace)
if reason == "notFound":
# google.api_core.exceptions.NotFound: 404 – table not found
raise UnknownTableException(table["name"]) from gace
elif (
reason == "duplicate"
): # google.api_core.exceptions.Conflict: 409 PUT – already exists
return self.restore_file_load(file_path)
elif reason in BQ_TERMINAL_REASONS:
# google.api_core.exceptions.BadRequest - will not be processed ie bad job name
raise LoadJobTerminalException(
file_path, f"The server reason was: {reason}"
) from gace
else:
raise DestinationTransientException(gace) from gace

job = job_cls(
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
table,
file_path,
self.config,
self.schema,
destination_state(),
functools.partial(_streaming_load, self.sql_client),
[],
)
else:
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
job = super().start_file_load(table, file_path, load_id)

if not job:
try:
job = BigQueryLoadJob(
FileStorage.get_file_name_from_file_path(file_path),
self._create_load_job(table, file_path),
self.config.http_timeout,
self.config.retry_deadline,
)
except api_core_exceptions.GoogleAPICallError as gace:
reason = BigQuerySqlClient._get_reason_from_errors(gace)
if reason == "notFound":
# google.api_core.exceptions.NotFound: 404 – table not found
raise UnknownTableException(table["name"]) from gace
elif (
reason == "duplicate"
): # google.api_core.exceptions.Conflict: 409 PUT – already exists
return self.restore_file_load(file_path)
elif reason in BQ_TERMINAL_REASONS:
# google.api_core.exceptions.BadRequest - will not be processed ie bad job name
raise LoadJobTerminalException(
file_path, f"The server reason was: {reason}"
) from gace
else:
raise DestinationTransientException(gace) from gace
return job

def _get_table_update_sql(
Expand Down Expand Up @@ -328,9 +354,7 @@ def prepare_load_table(

def _get_column_def_sql(self, column: TColumnSchema, table_format: TTableFormat = None) -> str:
name = self.capabilities.escape_identifier(column["name"])
column_def_sql = (
f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}"
)
column_def_sql = f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}"
if column.get(ROUND_HALF_EVEN_HINT, False):
column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_EVEN')"
if column.get(ROUND_HALF_AWAY_FROM_ZERO_HINT, False):
Expand Down Expand Up @@ -425,3 +449,39 @@ def _from_db_type(
self, bq_t: str, precision: Optional[int], scale: Optional[int]
) -> TColumnType:
return self.type_mapper.from_db_type(bq_t, precision, scale)


def _streaming_load(sql_client, items, table):
"""
Upload the given items into BigQuery table, using streaming API.
Streaming API is used for small amounts of data, with optimal
batch size equal to 500 rows.

Args:
sql_client (dlt.destinations.impl.bigquery.bigquery.BigQueryClient):
BigQuery client.
items (List[Dict[Any, Any]]): List of rows to upload.
table (Dict[Any, Any]): Table schema.
"""

def _should_retry(exc):
"""Predicate to decide if we need to retry the exception.

Args:
exc (google.api_core.exceptions.GoogleAPIError):
Exception raised by the client.

Returns:
bool: True if the exception is retryable, False otherwise.
"""
reason = exc.errors[0]["reason"]
return reason in (list(_RETRYABLE_REASONS) + ["notFound"])
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

full_name = sql_client.make_qualified_table_name(table["name"], escape=False)

bq_client = sql_client._client
bq_client.insert_rows_json(
full_name,
items,
retry=retry.Retry(predicate=_should_retry, deadline=_DEFAULT_RETRY_DEADLINE),
)
3 changes: 3 additions & 0 deletions dlt/destinations/impl/bigquery/bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def bigquery_adapter(
round_half_even: TColumnNames = None,
table_description: Optional[str] = None,
table_expiration_datetime: Optional[str] = None,
insert_api: Optional[Literal["streaming_insert", "default"]] = "default",
) -> DltResource:
"""
Prepares data for loading into BigQuery.
Expand Down Expand Up @@ -144,6 +145,8 @@ def bigquery_adapter(
except ValueError as e:
raise ValueError(f"{table_expiration_datetime} could not be parsed!") from e

additional_table_hints |= {"insert_api": insert_api} # type: ignore[operator]

if column_hints or additional_table_hints:
resource.apply_hints(columns=column_hints, additional_table_hints=additional_table_hints)
else:
Expand Down
2 changes: 2 additions & 0 deletions dlt/destinations/impl/bigquery/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class BigQueryClientConfiguration(DestinationClientDwhWithStagingConfiguration):
retry_deadline: float = (
60.0 # how long to retry the operation in case of error, the backoff 60 s.
)
loading_api: str = "default"
batch_size: int = 500
rudolfix marked this conversation as resolved.
Show resolved Hide resolved

__config_gen_annotations__: ClassVar[List[str]] = ["location"]

Expand Down
127 changes: 8 additions & 119 deletions dlt/destinations/impl/destination/destination.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,29 @@
from abc import ABC, abstractmethod
from types import TracebackType
from typing import ClassVar, Dict, Optional, Type, Iterable, Iterable, cast, Dict, List
from copy import deepcopy
from types import TracebackType
from typing import ClassVar, Optional, Type, Iterable, cast, List

from dlt.common.destination.reference import LoadJob
from dlt.destinations.job_impl import EmptyLoadJob
from dlt.common.typing import TDataItems, AnyFun
from dlt.common import json
from dlt.pipeline.current import (
destination_state,
commit_load_package_state,
)
from dlt.common.typing import AnyFun
from dlt.pipeline.current import destination_state
from dlt.common.configuration import create_resolved_partial

from dlt.common.schema import Schema, TTableSchema, TSchemaTables
from dlt.common.schema.typing import TTableSchema
from dlt.common.storages import FileStorage
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import (
TLoadJobState,
LoadJob,
DoNothingJob,
JobClientBase,
)

from dlt.destinations.impl.destination import capabilities
from dlt.destinations.impl.destination.configuration import (
CustomDestinationClientConfiguration,
TDestinationCallable,
from dlt.destinations.impl.destination.configuration import CustomDestinationClientConfiguration
from dlt.destinations.job_impl import (
DestinationJsonlLoadJob,
DestinationParquetLoadJob,
)


class DestinationLoadJob(LoadJob, ABC):
def __init__(
self,
table: TTableSchema,
file_path: str,
config: CustomDestinationClientConfiguration,
schema: Schema,
destination_state: Dict[str, int],
destination_callable: TDestinationCallable,
skipped_columns: List[str],
) -> None:
super().__init__(FileStorage.get_file_name_from_file_path(file_path))
self._file_path = file_path
self._config = config
self._table = table
self._schema = schema
# we create pre_resolved callable here
self._callable = destination_callable
self._state: TLoadJobState = "running"
self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}"
self.skipped_columns = skipped_columns
try:
if self._config.batch_size == 0:
# on batch size zero we only call the callable with the filename
self.call_callable_with_items(self._file_path)
else:
current_index = destination_state.get(self._storage_id, 0)
for batch in self.run(current_index):
self.call_callable_with_items(batch)
current_index += len(batch)
destination_state[self._storage_id] = current_index

self._state = "completed"
except Exception as e:
self._state = "retry"
raise e
finally:
# save progress
commit_load_package_state()

@abstractmethod
def run(self, start_index: int) -> Iterable[TDataItems]:
pass

def call_callable_with_items(self, items: TDataItems) -> None:
if not items:
return
# call callable
self._callable(items, self._table)

def state(self) -> TLoadJobState:
return self._state

def exception(self) -> str:
raise NotImplementedError()


class DestinationParquetLoadJob(DestinationLoadJob):
def run(self, start_index: int) -> Iterable[TDataItems]:
# stream items
from dlt.common.libs.pyarrow import pyarrow

# guard against changed batch size after restart of loadjob
assert (
start_index % self._config.batch_size
) == 0, "Batch size was changed during processing of one load package"

# on record batches we cannot drop columns, we need to
# select the ones we want to keep
keep_columns = list(self._table["columns"].keys())
start_batch = start_index / self._config.batch_size
with pyarrow.parquet.ParquetFile(self._file_path) as reader:
for record_batch in reader.iter_batches(
batch_size=self._config.batch_size, columns=keep_columns
):
if start_batch > 0:
start_batch -= 1
continue
yield record_batch


class DestinationJsonlLoadJob(DestinationLoadJob):
def run(self, start_index: int) -> Iterable[TDataItems]:
current_batch: TDataItems = []

# stream items
with FileStorage.open_zipsafe_ro(self._file_path) as f:
encoded_json = json.typed_loads(f.read())

for item in encoded_json:
# find correct start position
if start_index > 0:
start_index -= 1
continue
# skip internal columns
for column in self.skipped_columns:
item.pop(column, None)
current_batch.append(item)
if len(current_batch) == self._config.batch_size:
yield current_batch
current_batch = []
yield current_batch


class DestinationClient(JobClientBase):
"""Sink Client"""

Expand Down
Loading
Loading