Skip to content

Commit

Permalink
fix(biquery): enforce int and float dtypes [TCTC-9419] (#1784)
Browse files Browse the repository at this point in the history
* fix(biquery): enforce int and float dtypes [TCTC-9419]

Signed-off-by: Luka Peschke <[email protected]>

* fix ci

Signed-off-by: Luka Peschke <[email protected]>

---------

Signed-off-by: Luka Peschke <[email protected]>
  • Loading branch information
lukapeschke authored Oct 4, 2024
1 parent c19dfe7 commit 0b0ac9a
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 29 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ jobs:
ATHENA_ACCESS_KEY_ID: '${{ secrets.ATHENA_ACCESS_KEY_ID }}'
ATHENA_SECRET_ACCESS_KEY: '${{ secrets.ATHENA_SECRET_ACCESS_KEY }}'
ATHENA_REGION: '${{ secrets.ATHENA_REGION }}'
# GBQ
GOOGLE_BIG_QUERY_CREDENTIALS: '${{ secrets.GOOGLE_BIG_QUERY_CREDENTIALS }}'

- name: SonarCloud Scan
# Only executed for one of the tested python version and pandas version
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Fix

- Google BigQuery: If the dtype of a column in the `DataFrame` returned by `_retrive_data` is `object`,
it gets converted to `Int64` or `float64` when it is defined as a numeric dtype by Big Query.
- When testing connection, set timeout to 10s when checking if port is opened.

## [7.0.2] 2024-09-13
Expand Down
71 changes: 58 additions & 13 deletions tests/google_big_query/test_google_big_query.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import Any, Callable, Generator
import json
from os import environ
from typing import Any, Generator
from unittest.mock import patch

import numpy as np
import pandas
import pandas as pd
import pytest
import requests
from google.api_core.exceptions import NotFound
from google.cloud import bigquery
from google.cloud.bigquery import ArrayQueryParameter, Client, ScalarQueryParameter
from google.cloud.bigquery.job.query import QueryJob
from google.cloud.bigquery.table import RowIterator
from google.cloud.exceptions import Unauthorized
from google.oauth2.service_account import Credentials
Expand Down Expand Up @@ -178,16 +179,60 @@ def test_http_connect_on_invalid_token(
gbq_connector_with_jwt._get_bigquery_client()


@patch(
"google.cloud.bigquery.table.RowIterator.to_dataframe_iterable",
return_value=iter((pandas.DataFrame({"a": [1, 1], "b": [2, 2]}),)),
)
@patch("google.cloud.bigquery.job.query.QueryJob.result", return_value=RowIterator)
@patch("google.cloud.bigquery.Client.query", return_value=QueryJob)
@patch("google.cloud.bigquery.Client", autospec=True)
def test_execute(client: bigquery.Client, execute: Callable, result: pd.DataFrame, to_dataframe: Callable):
result = GoogleBigQueryConnector._execute_query(client, "SELECT 1 FROM my_table", [])
assert_frame_equal(pandas.DataFrame({"a": [1, 1], "b": [2, 2]}), result)
@pytest.fixture
def gbq_credentials() -> Any:
raw_creds = environ["GOOGLE_BIG_QUERY_CREDENTIALS"]
return json.loads(raw_creds)


@pytest.fixture
def gbq_connector(gbq_credentials: Any) -> GoogleBigQueryConnector:
return GoogleBigQueryConnector(name="gqb-test-connector", credentials=gbq_credentials)


@pytest.fixture
def gbq_datasource() -> GoogleBigQueryDataSource:
return GoogleBigQueryDataSource(name="coucou", query="SELECT 1 AS `my_col`;", domain="test-domain")


def test_get_df(gbq_connector: GoogleBigQueryConnector, gbq_datasource: GoogleBigQueryDataSource):
result = gbq_connector.get_df(gbq_datasource)
expected = pandas.DataFrame({"my_col": [1]})
assert_frame_equal(expected, result)


def test_get_df_with_variables(gbq_connector: GoogleBigQueryConnector, gbq_datasource: GoogleBigQueryDataSource):
gbq_datasource.parameters = {"name": "Superstrong beer"}
gbq_datasource.query = "SELECT name, price_per_l FROM `beers`.`beers_tiny` WHERE name = {{name}};"
result = gbq_connector.get_df(gbq_datasource)
expected = pandas.DataFrame({"name": ["Superstrong beer"], "price_per_l": [0.16]})
assert_frame_equal(expected, result)


def test_get_df_with_type_casts(gbq_connector: GoogleBigQueryConnector, gbq_datasource: GoogleBigQueryDataSource):
gbq_datasource.parameters = {"name": "Superstrong beer"}
gbq_datasource.query = """
WITH with_new_cols AS (
SELECT *,
CASE WHEN nullable_name IS NULL THEN NULL ELSE 1 END AS `nullable_int` ,
CASE WHEN nullable_name IS NULL THEN NULL ELSE 0.5 END AS `nullable_float`,
FROM `beers`.`beers_tiny` WHERE name = {{name}})
SELECT name, nullable_name, nullable_int, nullable_float FROM with_new_cols;
"""
result = gbq_connector.get_df(gbq_datasource)

expected = pandas.DataFrame(
{
"name": ["Superstrong beer"],
"nullable_name": [None],
# We should have correct dtypes, not "object"
"nullable_int": pd.Series([None]).astype("Int64"),
"nullable_float": pd.Series([None]).astype("float64"),
}
)
assert_frame_equal(expected, result)
assert result.dtypes["nullable_int"] == pd.Int64Dtype()
assert result.dtypes["nullable_float"] == np.float64


@patch(
Expand Down
51 changes: 35 additions & 16 deletions toucan_connectors/google_big_query/google_big_query_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from google.cloud.bigquery.dbapi import _helpers as bigquery_helpers
from google.cloud.bigquery.job import QueryJob
from google.oauth2.service_account import Credentials
from pandas.api.types import is_float_dtype, is_integer_dtype

class InvalidJWTToken(GoogleUnauthorized):
"""When the jwt-token is no longer valid (usualy from google as 401)"""
Expand Down Expand Up @@ -128,6 +129,29 @@ def _define_query_param(name: str, value: Any) -> BigQueryParam:
_SAMPLE_QUERY = "Sample BigQuery job"


_GBQ_FLOAT_TYPES = ("FLOAT", "FLOAT64", "NUMERIC", "BIGNUMERIC")
_GBQ_INT_TYPES = ("INTEGER", "INT64")


def _ensure_numeric_columns_dtypes(df: "pd.DataFrame", schema: "list[bigquery.SchemaField]") -> "pd.DataFrame":
"""Ensures that numeric columns have the right dtype.
In some cases (for example all-nulls columns), GBQ will set a numeric columns dtype to `object`,
even if the SQL type is known
"""
for col in schema:
if col.field_type in _GBQ_INT_TYPES:
if not is_integer_dtype(df.dtypes[col.name]):
# WARNING: casing matters here, as Int64 is not int64:
# https://pandas.pydata.org/pandas-docs/version/1.5/user_guide/integer_na.html
df[col.name] = df[col.name].astype("Int64")
elif col.field_type in _GBQ_FLOAT_TYPES:
if not is_float_dtype(df.dtypes[col.name]):
df[col.name] = df[col.name].astype("float64")

return df


class GoogleBigQueryConnector(ToucanConnector, DiscoverableConnector, data_source_model=GoogleBigQueryDataSource):
# for GoogleCredentials
credentials: GoogleCredentials | None = Field(
Expand Down Expand Up @@ -246,14 +270,11 @@ def _execute_query(client: "bigquery.Client", query: str, parameters: list) -> "
try:
start = timer()
query = GoogleBigQueryConnector._clean_query(query)
result_iterator: Iterable[pd.DataFrame] = (
client.query( # type:ignore[assignment]
query,
job_config=bigquery.QueryJobConfig(query_parameters=parameters),
)
.result()
.to_dataframe_iterable()
)
result = client.query(
query,
job_config=bigquery.QueryJobConfig(query_parameters=parameters),
).result()
result_iterator = result.to_dataframe_iterable()
end = timer()
_LOGGER.info(
f"[benchmark][google_big_query] - execute {end - start} seconds",
Expand All @@ -267,7 +288,8 @@ def _execute_query(client: "bigquery.Client", query: str, parameters: list) -> "
)

try:
return pd.concat((df for df in result_iterator), ignore_index=True)
df = pd.concat((df for df in result_iterator), ignore_index=True) # type:ignore[misc]
return _ensure_numeric_columns_dtypes(df, result.schema)
except ValueError as excp: # pragma: no cover
raise NoDataFoundException("No data found, please check your config again.") from excp
except TypeError as e:
Expand Down Expand Up @@ -298,7 +320,7 @@ def _bigquery_variable_transformer(variable: str):
def _bigquery_client_with_google_creds(self) -> "bigquery.Client":
try:
assert self.credentials is not None
credentials = GoogleBigQueryConnector._get_google_credentials(self.credentials, self.scopes)
credentials = self._get_google_credentials(self.credentials, self.scopes)
return self._connect(credentials)
except AssertionError as excp:
raise GoogleClientCreationError from excp
Expand All @@ -309,7 +331,7 @@ def _bigquery_client(self) -> "bigquery.Client":
try:
# We try to instantiate the bigquery.Client with the given jwt-token
session = CustomRequestSession(self.jwt_credentials.jwt_token)
client = GoogleBigQueryConnector._http_connect(http_session=session, project_id=self._get_project_id())
client = self._http_connect(http_session=session, project_id=self._get_project_id())
_LOGGER.debug("BigQuery client created using the provided JWT token")

return client
Expand Down Expand Up @@ -342,12 +364,9 @@ def _get_bigquery_client(self) -> "bigquery.Client":
def _retrieve_data(self, data_source: GoogleBigQueryDataSource) -> "pd.DataFrame":
_LOGGER.debug(f"Play request {data_source.query} with parameters {data_source.parameters}")

query, parameters = GoogleBigQueryConnector._prepare_query_and_parameters(
data_source.query, data_source.parameters
)

query, parameters = self._prepare_query_and_parameters(data_source.query, data_source.parameters)
client = self._get_bigquery_client()
result = GoogleBigQueryConnector._execute_query(client, query, parameters)
result = self._execute_query(client, query, parameters)

return result

Expand Down

0 comments on commit 0b0ac9a

Please sign in to comment.