diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a77e16ff8..0e1bbfa54 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -84,6 +84,8 @@ jobs: ATHENA_ACCESS_KEY_ID: '${{ secrets.ATHENA_ACCESS_KEY_ID }}' ATHENA_SECRET_ACCESS_KEY: '${{ secrets.ATHENA_SECRET_ACCESS_KEY }}' ATHENA_REGION: '${{ secrets.ATHENA_REGION }}' + # GBQ + GOOGLE_BIG_QUERY_CREDENTIALS: '${{ secrets.GOOGLE_BIG_QUERY_CREDENTIALS }}' - name: SonarCloud Scan # Only executed for one of the tested python version and pandas version diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bcbc582c..ea39bfc05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ### Fix +- Google BigQuery: If the dtype of a column in the `DataFrame` returned by `_retrive_data` is `object`, + it gets converted to `Int64` or `float64` when it is defined as a numeric dtype by Big Query. - When testing connection, set timeout to 10s when checking if port is opened. ## [7.0.2] 2024-09-13 diff --git a/tests/google_big_query/test_google_big_query.py b/tests/google_big_query/test_google_big_query.py index cf10b8a5d..7feab8d90 100644 --- a/tests/google_big_query/test_google_big_query.py +++ b/tests/google_big_query/test_google_big_query.py @@ -1,14 +1,15 @@ -from typing import Any, Callable, Generator +import json +from os import environ +from typing import Any, Generator from unittest.mock import patch +import numpy as np import pandas import pandas as pd import pytest import requests from google.api_core.exceptions import NotFound -from google.cloud import bigquery from google.cloud.bigquery import ArrayQueryParameter, Client, ScalarQueryParameter -from google.cloud.bigquery.job.query import QueryJob from google.cloud.bigquery.table import RowIterator from google.cloud.exceptions import Unauthorized from google.oauth2.service_account import Credentials @@ -178,16 +179,60 @@ def test_http_connect_on_invalid_token( gbq_connector_with_jwt._get_bigquery_client() -@patch( - "google.cloud.bigquery.table.RowIterator.to_dataframe_iterable", - return_value=iter((pandas.DataFrame({"a": [1, 1], "b": [2, 2]}),)), -) -@patch("google.cloud.bigquery.job.query.QueryJob.result", return_value=RowIterator) -@patch("google.cloud.bigquery.Client.query", return_value=QueryJob) -@patch("google.cloud.bigquery.Client", autospec=True) -def test_execute(client: bigquery.Client, execute: Callable, result: pd.DataFrame, to_dataframe: Callable): - result = GoogleBigQueryConnector._execute_query(client, "SELECT 1 FROM my_table", []) - assert_frame_equal(pandas.DataFrame({"a": [1, 1], "b": [2, 2]}), result) +@pytest.fixture +def gbq_credentials() -> Any: + raw_creds = environ["GOOGLE_BIG_QUERY_CREDENTIALS"] + return json.loads(raw_creds) + + +@pytest.fixture +def gbq_connector(gbq_credentials: Any) -> GoogleBigQueryConnector: + return GoogleBigQueryConnector(name="gqb-test-connector", credentials=gbq_credentials) + + +@pytest.fixture +def gbq_datasource() -> GoogleBigQueryDataSource: + return GoogleBigQueryDataSource(name="coucou", query="SELECT 1 AS `my_col`;", domain="test-domain") + + +def test_get_df(gbq_connector: GoogleBigQueryConnector, gbq_datasource: GoogleBigQueryDataSource): + result = gbq_connector.get_df(gbq_datasource) + expected = pandas.DataFrame({"my_col": [1]}) + assert_frame_equal(expected, result) + + +def test_get_df_with_variables(gbq_connector: GoogleBigQueryConnector, gbq_datasource: GoogleBigQueryDataSource): + gbq_datasource.parameters = {"name": "Superstrong beer"} + gbq_datasource.query = "SELECT name, price_per_l FROM `beers`.`beers_tiny` WHERE name = {{name}};" + result = gbq_connector.get_df(gbq_datasource) + expected = pandas.DataFrame({"name": ["Superstrong beer"], "price_per_l": [0.16]}) + assert_frame_equal(expected, result) + + +def test_get_df_with_type_casts(gbq_connector: GoogleBigQueryConnector, gbq_datasource: GoogleBigQueryDataSource): + gbq_datasource.parameters = {"name": "Superstrong beer"} + gbq_datasource.query = """ + WITH with_new_cols AS ( + SELECT *, + CASE WHEN nullable_name IS NULL THEN NULL ELSE 1 END AS `nullable_int` , + CASE WHEN nullable_name IS NULL THEN NULL ELSE 0.5 END AS `nullable_float`, + FROM `beers`.`beers_tiny` WHERE name = {{name}}) + SELECT name, nullable_name, nullable_int, nullable_float FROM with_new_cols; + """ + result = gbq_connector.get_df(gbq_datasource) + + expected = pandas.DataFrame( + { + "name": ["Superstrong beer"], + "nullable_name": [None], + # We should have correct dtypes, not "object" + "nullable_int": pd.Series([None]).astype("Int64"), + "nullable_float": pd.Series([None]).astype("float64"), + } + ) + assert_frame_equal(expected, result) + assert result.dtypes["nullable_int"] == pd.Int64Dtype() + assert result.dtypes["nullable_float"] == np.float64 @patch( diff --git a/toucan_connectors/google_big_query/google_big_query_connector.py b/toucan_connectors/google_big_query/google_big_query_connector.py index 0d3cdbfe2..c509ff61d 100644 --- a/toucan_connectors/google_big_query/google_big_query_connector.py +++ b/toucan_connectors/google_big_query/google_big_query_connector.py @@ -35,6 +35,7 @@ from google.cloud.bigquery.dbapi import _helpers as bigquery_helpers from google.cloud.bigquery.job import QueryJob from google.oauth2.service_account import Credentials + from pandas.api.types import is_float_dtype, is_integer_dtype class InvalidJWTToken(GoogleUnauthorized): """When the jwt-token is no longer valid (usualy from google as 401)""" @@ -128,6 +129,29 @@ def _define_query_param(name: str, value: Any) -> BigQueryParam: _SAMPLE_QUERY = "Sample BigQuery job" +_GBQ_FLOAT_TYPES = ("FLOAT", "FLOAT64", "NUMERIC", "BIGNUMERIC") +_GBQ_INT_TYPES = ("INTEGER", "INT64") + + +def _ensure_numeric_columns_dtypes(df: "pd.DataFrame", schema: "list[bigquery.SchemaField]") -> "pd.DataFrame": + """Ensures that numeric columns have the right dtype. + + In some cases (for example all-nulls columns), GBQ will set a numeric columns dtype to `object`, + even if the SQL type is known + """ + for col in schema: + if col.field_type in _GBQ_INT_TYPES: + if not is_integer_dtype(df.dtypes[col.name]): + # WARNING: casing matters here, as Int64 is not int64: + # https://pandas.pydata.org/pandas-docs/version/1.5/user_guide/integer_na.html + df[col.name] = df[col.name].astype("Int64") + elif col.field_type in _GBQ_FLOAT_TYPES: + if not is_float_dtype(df.dtypes[col.name]): + df[col.name] = df[col.name].astype("float64") + + return df + + class GoogleBigQueryConnector(ToucanConnector, DiscoverableConnector, data_source_model=GoogleBigQueryDataSource): # for GoogleCredentials credentials: GoogleCredentials | None = Field( @@ -246,14 +270,11 @@ def _execute_query(client: "bigquery.Client", query: str, parameters: list) -> " try: start = timer() query = GoogleBigQueryConnector._clean_query(query) - result_iterator: Iterable[pd.DataFrame] = ( - client.query( # type:ignore[assignment] - query, - job_config=bigquery.QueryJobConfig(query_parameters=parameters), - ) - .result() - .to_dataframe_iterable() - ) + result = client.query( + query, + job_config=bigquery.QueryJobConfig(query_parameters=parameters), + ).result() + result_iterator = result.to_dataframe_iterable() end = timer() _LOGGER.info( f"[benchmark][google_big_query] - execute {end - start} seconds", @@ -267,7 +288,8 @@ def _execute_query(client: "bigquery.Client", query: str, parameters: list) -> " ) try: - return pd.concat((df for df in result_iterator), ignore_index=True) + df = pd.concat((df for df in result_iterator), ignore_index=True) # type:ignore[misc] + return _ensure_numeric_columns_dtypes(df, result.schema) except ValueError as excp: # pragma: no cover raise NoDataFoundException("No data found, please check your config again.") from excp except TypeError as e: @@ -298,7 +320,7 @@ def _bigquery_variable_transformer(variable: str): def _bigquery_client_with_google_creds(self) -> "bigquery.Client": try: assert self.credentials is not None - credentials = GoogleBigQueryConnector._get_google_credentials(self.credentials, self.scopes) + credentials = self._get_google_credentials(self.credentials, self.scopes) return self._connect(credentials) except AssertionError as excp: raise GoogleClientCreationError from excp @@ -309,7 +331,7 @@ def _bigquery_client(self) -> "bigquery.Client": try: # We try to instantiate the bigquery.Client with the given jwt-token session = CustomRequestSession(self.jwt_credentials.jwt_token) - client = GoogleBigQueryConnector._http_connect(http_session=session, project_id=self._get_project_id()) + client = self._http_connect(http_session=session, project_id=self._get_project_id()) _LOGGER.debug("BigQuery client created using the provided JWT token") return client @@ -342,12 +364,9 @@ def _get_bigquery_client(self) -> "bigquery.Client": def _retrieve_data(self, data_source: GoogleBigQueryDataSource) -> "pd.DataFrame": _LOGGER.debug(f"Play request {data_source.query} with parameters {data_source.parameters}") - query, parameters = GoogleBigQueryConnector._prepare_query_and_parameters( - data_source.query, data_source.parameters - ) - + query, parameters = self._prepare_query_and_parameters(data_source.query, data_source.parameters) client = self._get_bigquery_client() - result = GoogleBigQueryConnector._execute_query(client, query, parameters) + result = self._execute_query(client, query, parameters) return result