Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bigquery): add streaming inserts support #1123

Merged
merged 22 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 82 additions & 10 deletions dlt/destinations/impl/bigquery/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import functools
import os
from pathlib import Path
from typing import ClassVar, Optional, Sequence, Tuple, List, cast, Dict
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, cast

import google.cloud.bigquery as bigquery # noqa: I250
from google.api_core import exceptions as api_core_exceptions
from google.cloud import exceptions as gcp_exceptions
from google.api_core import retry
from google.cloud.bigquery.retry import _DEFAULT_RETRY_DEADLINE, _RETRYABLE_REASONS

from dlt.common import json, logger
from dlt.common.destination import DestinationCapabilitiesContext
Expand All @@ -22,8 +25,11 @@
from dlt.common.schema.utils import table_schema_has_type
from dlt.common.storages.file_storage import FileStorage
from dlt.common.typing import DictStrAny
from dlt.destinations.job_impl import DestinationJsonlLoadJob, DestinationParquetLoadJob
from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.exceptions import (
DestinationSchemaWillNotUpdate,
DestinationTerminalException,
DestinationTransientException,
LoadJobNotExistsException,
LoadJobTerminalException,
Expand All @@ -43,6 +49,7 @@
from dlt.destinations.job_impl import NewReferenceJob
from dlt.destinations.sql_jobs import SqlMergeJob
from dlt.destinations.type_mapping import TypeMapper
from dlt.pipeline.current import destination_state


class BigQueryTypeMapper(TypeMapper):
Expand Down Expand Up @@ -220,13 +227,41 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) ->
job = super().start_file_load(table, file_path, load_id)
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

if not job:
insert_api = table.get("x-insert-api", "default")
try:
job = BigQueryLoadJob(
FileStorage.get_file_name_from_file_path(file_path),
self._create_load_job(table, file_path),
self.config.http_timeout,
self.config.retry_deadline,
)
if insert_api == "streaming":
if table["write_disposition"] != "append":
raise DestinationTerminalException(
(
"BigQuery streaming insert can only be used with `append` write_disposition, while "
f'the given resource has `{table["write_disposition"]}`.'
)
)
if file_path.endswith(".jsonl"):
job_cls = DestinationJsonlLoadJob
elif file_path.endswith(".parquet"):
job_cls = DestinationParquetLoadJob # type: ignore
else:
raise ValueError(
f"Unsupported file type for BigQuery streaming inserts: {file_path}"
)

job = job_cls(
table,
file_path,
self.config, # type: ignore
self.schema,
destination_state(),
functools.partial(_streaming_load, self.sql_client),
[],
)
else:
job = BigQueryLoadJob(
FileStorage.get_file_name_from_file_path(file_path),
self._create_load_job(table, file_path),
self.config.http_timeout,
self.config.retry_deadline,
)
except api_core_exceptions.GoogleAPICallError as gace:
reason = BigQuerySqlClient._get_reason_from_errors(gace)
if reason == "notFound":
Expand All @@ -243,6 +278,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) ->
) from gace
else:
raise DestinationTransientException(gace) from gace

return job

def _get_table_update_sql(
Expand Down Expand Up @@ -328,9 +364,7 @@ def prepare_load_table(

def _get_column_def_sql(self, column: TColumnSchema, table_format: TTableFormat = None) -> str:
name = self.capabilities.escape_identifier(column["name"])
column_def_sql = (
f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}"
)
column_def_sql = f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}"
if column.get(ROUND_HALF_EVEN_HINT, False):
column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_EVEN')"
if column.get(ROUND_HALF_AWAY_FROM_ZERO_HINT, False):
Expand Down Expand Up @@ -425,3 +459,41 @@ def _from_db_type(
self, bq_t: str, precision: Optional[int], scale: Optional[int]
) -> TColumnType:
return self.type_mapper.from_db_type(bq_t, precision, scale)


def _streaming_load(
sql_client: SqlClientBase[BigQueryClient], items: List[Dict[Any, Any]], table: Dict[str, Any]
) -> None:
"""
Upload the given items into BigQuery table, using streaming API.
Streaming API is used for small amounts of data, with optimal
batch size equal to 500 rows.

Args:
sql_client (dlt.destinations.impl.bigquery.bigquery.BigQueryClient):
BigQuery client.
items (List[Dict[Any, Any]]): List of rows to upload.
table (Dict[Any, Any]): Table schema.
"""

def _should_retry(exc: api_core_exceptions.GoogleAPICallError) -> bool:
"""Predicate to decide if we need to retry the exception.

Args:
exc (google.api_core.exceptions.GoogleAPICallError):
Exception raised by the client.

Returns:
bool: True if the exception is retryable, False otherwise.
"""
reason = exc.errors[0]["reason"]
return reason in _RETRYABLE_REASONS

full_name = sql_client.make_qualified_table_name(table["name"], escape=False)

bq_client = sql_client._client
bq_client.insert_rows_json(
full_name,
items,
retry=retry.Retry(predicate=_should_retry, deadline=_DEFAULT_RETRY_DEADLINE),
)
18 changes: 17 additions & 1 deletion dlt/destinations/impl/bigquery/bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def bigquery_adapter(
round_half_even: TColumnNames = None,
table_description: Optional[str] = None,
table_expiration_datetime: Optional[str] = None,
insert_api: Optional[Literal["streaming", "default"]] = None,
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
) -> DltResource:
"""
Prepares data for loading into BigQuery.
Expand All @@ -56,6 +57,11 @@ def bigquery_adapter(
table_description (str, optional): A description for the BigQuery table.
table_expiration_datetime (str, optional): String representing the datetime when the BigQuery table expires.
This is always interpreted as UTC, BigQuery's default.
insert_api (Optional[Literal["streaming", "default"]]): The API to use for inserting data into BigQuery.
If "default" is chosen, the original SQL query mechanism is used.
If "streaming" is chosen, the streaming API (https://cloud.google.com/bigquery/docs/streaming-data-into-bigquery)
is used.
NOTE: due to BigQuery features, streaming insert is only available for `append` write_disposition.

Returns:
A `DltResource` object that is ready to be loaded into BigQuery.
Expand Down Expand Up @@ -134,7 +140,7 @@ def bigquery_adapter(
if not isinstance(table_expiration_datetime, str):
raise ValueError(
"`table_expiration_datetime` must be string representing the datetime when the"
" BigQuery table."
" BigQuery table will be deleted."
)
try:
parsed_table_expiration_datetime = parser.parse(table_expiration_datetime).replace(
Expand All @@ -144,6 +150,16 @@ def bigquery_adapter(
except ValueError as e:
raise ValueError(f"{table_expiration_datetime} could not be parsed!") from e

if insert_api is not None:
if insert_api == "streaming" and data.write_disposition != "append":
raise ValueError(
(
"BigQuery streaming insert can only be used with `append` write_disposition, while "
f"the given resource has `{data.write_disposition}`."
)
)
additional_table_hints |= {"x-insert-api": insert_api} # type: ignore[operator]

if column_hints or additional_table_hints:
resource.apply_hints(columns=column_hints, additional_table_hints=additional_table_hints)
else:
Expand Down
1 change: 1 addition & 0 deletions dlt/destinations/impl/bigquery/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class BigQueryClientConfiguration(DestinationClientDwhWithStagingConfiguration):
retry_deadline: float = (
60.0 # how long to retry the operation in case of error, the backoff 60 s.
)
batch_size: int = 500
rudolfix marked this conversation as resolved.
Show resolved Hide resolved

__config_gen_annotations__: ClassVar[List[str]] = ["location"]

Expand Down
4 changes: 3 additions & 1 deletion dlt/destinations/impl/bigquery/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class BigQueryDBApiCursorImpl(DBApiCursorImpl):
def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame:
if chunk_size is not None:
return super().df(chunk_size=chunk_size)
query_job: bigquery.QueryJob = self.native_cursor._query_job
query_job: bigquery.QueryJob = getattr(
self.native_cursor, "_query_job", self.native_cursor.query_job
)

try:
return query_job.to_dataframe(**kwargs)
Expand Down
127 changes: 8 additions & 119 deletions dlt/destinations/impl/destination/destination.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,29 @@
from abc import ABC, abstractmethod
from types import TracebackType
from typing import ClassVar, Dict, Optional, Type, Iterable, Iterable, cast, Dict, List
from copy import deepcopy
from types import TracebackType
from typing import ClassVar, Optional, Type, Iterable, cast, List

from dlt.common.destination.reference import LoadJob
from dlt.destinations.job_impl import EmptyLoadJob
from dlt.common.typing import TDataItems, AnyFun
from dlt.common import json
from dlt.pipeline.current import (
destination_state,
commit_load_package_state,
)
from dlt.common.typing import AnyFun
from dlt.pipeline.current import destination_state
from dlt.common.configuration import create_resolved_partial

from dlt.common.schema import Schema, TTableSchema, TSchemaTables
from dlt.common.schema.typing import TTableSchema
from dlt.common.storages import FileStorage
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import (
TLoadJobState,
LoadJob,
DoNothingJob,
JobClientBase,
)

from dlt.destinations.impl.destination import capabilities
from dlt.destinations.impl.destination.configuration import (
CustomDestinationClientConfiguration,
TDestinationCallable,
from dlt.destinations.impl.destination.configuration import CustomDestinationClientConfiguration
from dlt.destinations.job_impl import (
DestinationJsonlLoadJob,
DestinationParquetLoadJob,
)


class DestinationLoadJob(LoadJob, ABC):
def __init__(
self,
table: TTableSchema,
file_path: str,
config: CustomDestinationClientConfiguration,
schema: Schema,
destination_state: Dict[str, int],
destination_callable: TDestinationCallable,
skipped_columns: List[str],
) -> None:
super().__init__(FileStorage.get_file_name_from_file_path(file_path))
self._file_path = file_path
self._config = config
self._table = table
self._schema = schema
# we create pre_resolved callable here
self._callable = destination_callable
self._state: TLoadJobState = "running"
self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}"
self.skipped_columns = skipped_columns
try:
if self._config.batch_size == 0:
# on batch size zero we only call the callable with the filename
self.call_callable_with_items(self._file_path)
else:
current_index = destination_state.get(self._storage_id, 0)
for batch in self.run(current_index):
self.call_callable_with_items(batch)
current_index += len(batch)
destination_state[self._storage_id] = current_index

self._state = "completed"
except Exception as e:
self._state = "retry"
raise e
finally:
# save progress
commit_load_package_state()

@abstractmethod
def run(self, start_index: int) -> Iterable[TDataItems]:
pass

def call_callable_with_items(self, items: TDataItems) -> None:
if not items:
return
# call callable
self._callable(items, self._table)

def state(self) -> TLoadJobState:
return self._state

def exception(self) -> str:
raise NotImplementedError()


class DestinationParquetLoadJob(DestinationLoadJob):
def run(self, start_index: int) -> Iterable[TDataItems]:
# stream items
from dlt.common.libs.pyarrow import pyarrow

# guard against changed batch size after restart of loadjob
assert (
start_index % self._config.batch_size
) == 0, "Batch size was changed during processing of one load package"

# on record batches we cannot drop columns, we need to
# select the ones we want to keep
keep_columns = list(self._table["columns"].keys())
start_batch = start_index / self._config.batch_size
with pyarrow.parquet.ParquetFile(self._file_path) as reader:
for record_batch in reader.iter_batches(
batch_size=self._config.batch_size, columns=keep_columns
):
if start_batch > 0:
start_batch -= 1
continue
yield record_batch


class DestinationJsonlLoadJob(DestinationLoadJob):
def run(self, start_index: int) -> Iterable[TDataItems]:
current_batch: TDataItems = []

# stream items
with FileStorage.open_zipsafe_ro(self._file_path) as f:
encoded_json = json.typed_loads(f.read())

for item in encoded_json:
# find correct start position
if start_index > 0:
start_index -= 1
continue
# skip internal columns
for column in self.skipped_columns:
item.pop(column, None)
current_batch.append(item)
if len(current_batch) == self._config.batch_size:
yield current_batch
current_batch = []
yield current_batch


class DestinationClient(JobClientBase):
"""Sink Client"""

Expand Down
Loading
Loading