Skip to content

Commit

Permalink
feat: Add support for array and float32 SQL query params
Browse files Browse the repository at this point in the history
  • Loading branch information
jackdingilian committed Jan 17, 2025
1 parent e7ecfeb commit 23689d2
Show file tree
Hide file tree
Showing 8 changed files with 379 additions and 36 deletions.
2 changes: 2 additions & 0 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ async def execute_query(
will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions
from any retries that failed
google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error
google.cloud.bigtable.data.exceptions.ParameterTypeInferenceFailed: Raised if
a parameter is passed without an explicit type, and the type cannot be infered
"""
warnings.warn(
"ExecuteQuery is in preview and may change in the future.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
Tuple,
TYPE_CHECKING,
)

from google.api_core import retry as retries

from google.cloud.bigtable.data.execute_query._byte_cursor import _ByteCursor
Expand Down Expand Up @@ -116,7 +115,6 @@ def __init__(
exception_factory=_retry_exception_factory,
)
self._req_metadata = req_metadata

try:
self._register_instance_task = CrossSync.create_task(
self._client._register_instance,
Expand Down
17 changes: 14 additions & 3 deletions google/cloud/bigtable/data/execute_query/_parameters_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional
import datetime
from typing import Any, Dict, Optional

from google.api_core.datetime_helpers import DatetimeWithNanoseconds

from google.cloud.bigtable.data.exceptions import ParameterTypeInferenceFailed
from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType
from google.cloud.bigtable.data.execute_query.metadata import SqlType
from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType


def _format_execute_query_params(
Expand Down Expand Up @@ -48,7 +50,6 @@ def _format_execute_query_params(
parameter_types = parameter_types or {}

result_values = {}

for key, value in params.items():
user_provided_type = parameter_types.get(key)
try:
Expand Down Expand Up @@ -109,6 +110,16 @@ def _detect_type(value: ExecuteQueryValueType) -> SqlType.Type:
"Cannot infer type of None, please provide the type manually."
)

if isinstance(value, list):
raise ParameterTypeInferenceFailed(
"Cannot infer type of ARRAY parameters, please provide the type manually."
)

if isinstance(value, float):
raise ParameterTypeInferenceFailed(
"Cannot infer type of float, must specify either FLOAT32 or FLOAT64 type manually."
)

for field_type, type_dict in _TYPES_TO_TYPE_DICTS:
if isinstance(value, field_type):
return type_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SqlType.Bytes: "bytes_value",
SqlType.String: "string_value",
SqlType.Int64: "int_value",
SqlType.Float32: "float_value",
SqlType.Float64: "float_value",
SqlType.Bool: "bool_value",
SqlType.Timestamp: "timestamp_value",
Expand Down
46 changes: 30 additions & 16 deletions google/cloud/bigtable/data/execute_query/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,16 @@
"""

from collections import defaultdict
from typing import (
Optional,
List,
Dict,
Set,
Type,
Union,
Tuple,
Any,
)
import datetime
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union

from google.api_core.datetime_helpers import DatetimeWithNanoseconds
from google.protobuf import timestamp_pb2 # type: ignore
from google.type import date_pb2 # type: ignore

from google.cloud.bigtable.data.execute_query.values import _NamedList
from google.cloud.bigtable_v2 import ResultSetMetadata
from google.cloud.bigtable_v2 import Type as PBType
from google.type import date_pb2 # type: ignore
from google.protobuf import timestamp_pb2 # type: ignore
from google.api_core.datetime_helpers import DatetimeWithNanoseconds
import datetime


class SqlType:
Expand Down Expand Up @@ -127,6 +120,8 @@ class Array(Type):
def __init__(self, element_type: "SqlType.Type"):
if isinstance(element_type, SqlType.Array):
raise ValueError("Arrays of arrays are not supported.")
if isinstance(element_type, SqlType.Map):
raise ValueError("Arrays of Maps are not supported.")
self._element_type = element_type

@property
Expand All @@ -140,10 +135,21 @@ def from_pb_type(cls, type_pb: Optional[PBType] = None) -> "SqlType.Array":
return cls(_pb_type_to_metadata_type(type_pb.array_type.element_type))

def _to_value_pb_dict(self, value: Any):
raise NotImplementedError("Array is not supported as a query parameter")
if value is None:
return {}

return {
"array_value": {
"values": [
self.element_type._to_value_pb_dict(entry) for entry in value
]
}
}

def _to_type_pb_dict(self) -> Dict[str, Any]:
raise NotImplementedError("Array is not supported as a query parameter")
return {
"array_type": {"element_type": self.element_type._to_type_pb_dict()}
}

def __eq__(self, other):
return super().__eq__(other) and self.element_type == other.element_type
Expand Down Expand Up @@ -222,6 +228,13 @@ class Float64(Type):
value_pb_dict_field_name = "float_value"
type_field_name = "float64_type"

class Float32(Type):
"""Float32 SQL type."""

expected_type = float
value_pb_dict_field_name = "float_value"
type_field_name = "float32_type"

class Bool(Type):
"""Bool SQL type."""

Expand Down Expand Up @@ -376,6 +389,7 @@ def _pb_metadata_to_metadata_types(
"bytes_type": SqlType.Bytes,
"string_type": SqlType.String,
"int64_type": SqlType.Int64,
"float32_type": SqlType.Float32,
"float64_type": SqlType.Float64,
"bool_type": SqlType.Bool,
"timestamp_type": SqlType.Timestamp,
Expand Down
83 changes: 83 additions & 0 deletions tests/system/data/test_system_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@

import pytest
import asyncio
import datetime
import uuid
import os
from google.api_core import retry
from google.api_core.exceptions import ClientError

from google.cloud.bigtable.data.execute_query.metadata import SqlType
from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE
from google.cloud.environment_vars import BIGTABLE_EMULATOR
from google.type import date_pb2

from google.cloud.bigtable.data._cross_sync import CrossSync

Expand Down Expand Up @@ -1027,3 +1030,83 @@ async def test_execute_query_simple(self, client, table_id, instance_id):
row = rows[0]
assert row["a"] == 1
assert row["b"] == "foo"

@CrossSync.pytest
@pytest.mark.usefixtures("client")
@CrossSync.Retry(
predicate=retry.if_exception_type(ClientError), initial=1, maximum=5
)
async def test_execute_query_params(self, client, table_id, instance_id):
query = (
"SELECT @stringParam AS strCol, @bytesParam as bytesCol, @int64Param AS intCol, "
"@float32Param AS float32Col, @float64Param AS float64Col, @boolParam AS boolCol, "
"@tsParam AS tsCol, @dateParam AS dateCol, @byteArrayParam AS byteArrayCol, "
"@stringArrayParam AS stringArrayCol, @intArrayParam AS intArrayCol, "
"@float32ArrayParam AS float32ArrayCol, @float64ArrayParam AS float64ArrayCol, "
"@boolArrayParam AS boolArrayCol, @tsArrayParam AS tsArrayCol, "
"@dateArrayParam AS dateArrayCol"
)
parameters = {
"stringParam": "foo",
"bytesParam": b"bar",
"int64Param": 12,
"float32Param": 1.1,
"float64Param": 1.2,
"boolParam": True,
"tsParam": datetime.datetime.fromtimestamp(1000, tz=datetime.timezone.utc),
"dateParam": datetime.date(2025, 1, 16),
"byteArrayParam": [b"foo", b"bar", None],
"stringArrayParam": ["foo", "bar", None],
"intArrayParam": [1, None, 2],
"float32ArrayParam": [1.2, None, 1.3],
"float64ArrayParam": [1.4, None, 1.5],
"boolArrayParam": [None, False, True],
"tsArrayParam": [
datetime.datetime.fromtimestamp(1000, tz=datetime.timezone.utc),
datetime.datetime.fromtimestamp(2000, tz=datetime.timezone.utc),
None,
],
"dateArrayParam": [
datetime.date(2025, 1, 16),
datetime.date(2025, 1, 17),
None,
],
}
param_types = {
"float32Param": SqlType.Float32(),
"float64Param": SqlType.Float64(),
"byteArrayParam": SqlType.Array(SqlType.Bytes()),
"stringArrayParam": SqlType.Array(SqlType.String()),
"intArrayParam": SqlType.Array(SqlType.Int64()),
"float32ArrayParam": SqlType.Array(SqlType.Float32()),
"float64ArrayParam": SqlType.Array(SqlType.Float64()),
"boolArrayParam": SqlType.Array(SqlType.Bool()),
"tsArrayParam": SqlType.Array(SqlType.Timestamp()),
"dateArrayParam": SqlType.Array(SqlType.Date()),
}
result = await client.execute_query(
query, instance_id, parameters=parameters, parameter_types=param_types
)
rows = [r async for r in result]
assert len(rows) == 1
row = rows[0]
assert row["strCol"] == parameters["stringParam"]
assert row["bytesCol"] == parameters["bytesParam"]
assert row["intCol"] == parameters["int64Param"]
assert row["float32Col"] == pytest.approx(parameters["float32Param"])
assert row["float64Col"] == pytest.approx(parameters["float64Param"])
assert row["boolCol"] == parameters["boolParam"]
assert row["tsCol"] == parameters["tsParam"]
assert row["dateCol"] == date_pb2.Date(year=2025, month=1, day=16)
assert row["stringArrayCol"] == parameters["stringArrayParam"]
assert row["byteArrayCol"] == parameters["byteArrayParam"]
assert row["intArrayCol"] == parameters["intArrayParam"]
assert row["float32ArrayCol"] == pytest.approx(parameters["float32ArrayParam"])
assert row["float64ArrayCol"] == pytest.approx(parameters["float64ArrayParam"])
assert row["boolArrayCol"] == parameters["boolArrayParam"]
assert row["tsArrayCol"] == parameters["tsArrayParam"]
assert row["dateArrayCol"] == [
date_pb2.Date(year=2025, month=1, day=16),
date_pb2.Date(year=2025, month=1, day=17),
None,
]
74 changes: 74 additions & 0 deletions tests/system/data/test_system_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
# This file is automatically generated by CrossSync. Do not edit manually.

import pytest
import datetime
import uuid
import os
from google.api_core import retry
from google.api_core.exceptions import ClientError
from google.cloud.bigtable.data.execute_query.metadata import SqlType
from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE
from google.cloud.environment_vars import BIGTABLE_EMULATOR
from google.type import date_pb2
from google.cloud.bigtable.data._cross_sync import CrossSync
from . import TEST_FAMILY, TEST_FAMILY_2

Expand Down Expand Up @@ -838,3 +841,74 @@ def test_execute_query_simple(self, client, table_id, instance_id):
row = rows[0]
assert row["a"] == 1
assert row["b"] == "foo"

@pytest.mark.usefixtures("client")
@CrossSync._Sync_Impl.Retry(
predicate=retry.if_exception_type(ClientError), initial=1, maximum=5
)
def test_execute_query_params(self, client, table_id, instance_id):
query = "SELECT @stringParam AS strCol, @bytesParam as bytesCol, @int64Param AS intCol, @float32Param AS float32Col, @float64Param AS float64Col, @boolParam AS boolCol, @tsParam AS tsCol, @dateParam AS dateCol, @byteArrayParam AS byteArrayCol, @stringArrayParam AS stringArrayCol, @intArrayParam AS intArrayCol, @float32ArrayParam AS float32ArrayCol, @float64ArrayParam AS float64ArrayCol, @boolArrayParam AS boolArrayCol, @tsArrayParam AS tsArrayCol, @dateArrayParam AS dateArrayCol"
parameters = {
"stringParam": "foo",
"bytesParam": b"bar",
"int64Param": 12,
"float32Param": 1.1,
"float64Param": 1.2,
"boolParam": True,
"tsParam": datetime.datetime.fromtimestamp(1000, tz=datetime.timezone.utc),
"dateParam": datetime.date(2025, 1, 16),
"byteArrayParam": [b"foo", b"bar", None],
"stringArrayParam": ["foo", "bar", None],
"intArrayParam": [1, None, 2],
"float32ArrayParam": [1.2, None, 1.3],
"float64ArrayParam": [1.4, None, 1.5],
"boolArrayParam": [None, False, True],
"tsArrayParam": [
datetime.datetime.fromtimestamp(1000, tz=datetime.timezone.utc),
datetime.datetime.fromtimestamp(2000, tz=datetime.timezone.utc),
None,
],
"dateArrayParam": [
datetime.date(2025, 1, 16),
datetime.date(2025, 1, 17),
None,
],
}
param_types = {
"float32Param": SqlType.Float32(),
"float64Param": SqlType.Float64(),
"byteArrayParam": SqlType.Array(SqlType.Bytes()),
"stringArrayParam": SqlType.Array(SqlType.String()),
"intArrayParam": SqlType.Array(SqlType.Int64()),
"float32ArrayParam": SqlType.Array(SqlType.Float32()),
"float64ArrayParam": SqlType.Array(SqlType.Float64()),
"boolArrayParam": SqlType.Array(SqlType.Bool()),
"tsArrayParam": SqlType.Array(SqlType.Timestamp()),
"dateArrayParam": SqlType.Array(SqlType.Date()),
}
result = client.execute_query(
query, instance_id, parameters=parameters, parameter_types=param_types
)
rows = [r for r in result]
assert len(rows) == 1
row = rows[0]
assert row["strCol"] == parameters["stringParam"]
assert row["bytesCol"] == parameters["bytesParam"]
assert row["intCol"] == parameters["int64Param"]
assert row["float32Col"] == pytest.approx(parameters["float32Param"])
assert row["float64Col"] == pytest.approx(parameters["float64Param"])
assert row["boolCol"] == parameters["boolParam"]
assert row["tsCol"] == parameters["tsParam"]
assert row["dateCol"] == date_pb2.Date(year=2025, month=1, day=16)
assert row["stringArrayCol"] == parameters["stringArrayParam"]
assert row["byteArrayCol"] == parameters["byteArrayParam"]
assert row["intArrayCol"] == parameters["intArrayParam"]
assert row["float32ArrayCol"] == pytest.approx(parameters["float32ArrayParam"])
assert row["float64ArrayCol"] == pytest.approx(parameters["float64ArrayParam"])
assert row["boolArrayCol"] == parameters["boolArrayParam"]
assert row["tsArrayCol"] == parameters["tsArrayParam"]
assert row["dateArrayCol"] == [
date_pb2.Date(year=2025, month=1, day=16),
date_pb2.Date(year=2025, month=1, day=17),
None,
]
Loading

0 comments on commit 23689d2

Please sign in to comment.