Skip to content

Commit

Permalink
feat: Add support for array and float32 SQL query params
Browse files Browse the repository at this point in the history
  • Loading branch information
jackdingilian committed Jan 17, 2025
1 parent e7ecfeb commit 2dee92c
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
Tuple,
TYPE_CHECKING,
)

from google.api_core import retry as retries

from google.cloud.bigtable.data.execute_query._byte_cursor import _ByteCursor
Expand Down Expand Up @@ -116,7 +115,6 @@ def __init__(
exception_factory=_retry_exception_factory,
)
self._req_metadata = req_metadata

try:
self._register_instance_task = CrossSync.create_task(
self._client._register_instance,
Expand Down
17 changes: 14 additions & 3 deletions google/cloud/bigtable/data/execute_query/_parameters_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional
import datetime
from typing import Any, Dict, Optional

from google.api_core.datetime_helpers import DatetimeWithNanoseconds

from google.cloud.bigtable.data.exceptions import ParameterTypeInferenceFailed
from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType
from google.cloud.bigtable.data.execute_query.metadata import SqlType
from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType


def _format_execute_query_params(
Expand Down Expand Up @@ -48,7 +50,6 @@ def _format_execute_query_params(
parameter_types = parameter_types or {}

result_values = {}

for key, value in params.items():
user_provided_type = parameter_types.get(key)
try:
Expand Down Expand Up @@ -109,6 +110,16 @@ def _detect_type(value: ExecuteQueryValueType) -> SqlType.Type:
"Cannot infer type of None, please provide the type manually."
)

if isinstance(value, list):
raise ParameterTypeInferenceFailed(
"Cannot infer type of ARRAY parameters, please provide the type manually."
)

if isinstance(value, float):
raise ParameterTypeInferenceFailed(
"Cannot infer type of float, must specify either FLOAT32 or FLOAT64 type manually."
)

for field_type, type_dict in _TYPES_TO_TYPE_DICTS:
if isinstance(value, field_type):
return type_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SqlType.Bytes: "bytes_value",
SqlType.String: "string_value",
SqlType.Int64: "int_value",
SqlType.Float32: "float_value",
SqlType.Float64: "float_value",
SqlType.Bool: "bool_value",
SqlType.Timestamp: "timestamp_value",
Expand Down
46 changes: 30 additions & 16 deletions google/cloud/bigtable/data/execute_query/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,16 @@
"""

from collections import defaultdict
from typing import (
Optional,
List,
Dict,
Set,
Type,
Union,
Tuple,
Any,
)
import datetime
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union

from google.api_core.datetime_helpers import DatetimeWithNanoseconds
from google.protobuf import timestamp_pb2 # type: ignore
from google.type import date_pb2 # type: ignore

from google.cloud.bigtable.data.execute_query.values import _NamedList
from google.cloud.bigtable_v2 import ResultSetMetadata
from google.cloud.bigtable_v2 import Type as PBType
from google.type import date_pb2 # type: ignore
from google.protobuf import timestamp_pb2 # type: ignore
from google.api_core.datetime_helpers import DatetimeWithNanoseconds
import datetime


class SqlType:
Expand Down Expand Up @@ -127,6 +120,8 @@ class Array(Type):
def __init__(self, element_type: "SqlType.Type"):
if isinstance(element_type, SqlType.Array):
raise ValueError("Arrays of arrays are not supported.")
if isinstance(element_type, SqlType.Map):
raise ValueError("Arrays of Maps are not supported.")
self._element_type = element_type

@property
Expand All @@ -140,10 +135,21 @@ def from_pb_type(cls, type_pb: Optional[PBType] = None) -> "SqlType.Array":
return cls(_pb_type_to_metadata_type(type_pb.array_type.element_type))

def _to_value_pb_dict(self, value: Any):
raise NotImplementedError("Array is not supported as a query parameter")
if value is None:
return {}

return {
"array_value": {
"values": [
self.element_type._to_value_pb_dict(entry) for entry in value
]
}
}

def _to_type_pb_dict(self) -> Dict[str, Any]:
raise NotImplementedError("Array is not supported as a query parameter")
return {
"array_type": {"element_type": self.element_type._to_type_pb_dict()}
}

def __eq__(self, other):
return super().__eq__(other) and self.element_type == other.element_type
Expand Down Expand Up @@ -222,6 +228,13 @@ class Float64(Type):
value_pb_dict_field_name = "float_value"
type_field_name = "float64_type"

class Float32(Type):
"""Float32 SQL type."""

expected_type = float
value_pb_dict_field_name = "float_value"
type_field_name = "float32_type"

class Bool(Type):
"""Bool SQL type."""

Expand Down Expand Up @@ -376,6 +389,7 @@ def _pb_metadata_to_metadata_types(
"bytes_type": SqlType.Bytes,
"string_type": SqlType.String,
"int64_type": SqlType.Int64,
"float32_type": SqlType.Float32,
"float64_type": SqlType.Float64,
"bool_type": SqlType.Bool,
"timestamp_type": SqlType.Timestamp,
Expand Down
83 changes: 83 additions & 0 deletions tests/system/data/test_system_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@

import pytest
import asyncio
import datetime
import uuid
import os
from google.api_core import retry
from google.api_core.exceptions import ClientError

from google.cloud.bigtable.data.execute_query.metadata import SqlType
from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE
from google.cloud.environment_vars import BIGTABLE_EMULATOR
from google.type import date_pb2

from google.cloud.bigtable.data._cross_sync import CrossSync

Expand Down Expand Up @@ -1027,3 +1030,83 @@ async def test_execute_query_simple(self, client, table_id, instance_id):
row = rows[0]
assert row["a"] == 1
assert row["b"] == "foo"

@CrossSync.pytest
@pytest.mark.usefixtures("client")
@CrossSync.Retry(
predicate=retry.if_exception_type(ClientError), initial=1, maximum=5
)
async def test_execute_query_params(self, client, table_id, instance_id):
query = (
"SELECT @stringParam AS strCol, @bytesParam as bytesCol, @int64Param AS intCol, "
"@float32Param AS float32Col, @float64Param AS float64Col, @boolParam AS boolCol, "
"@tsParam AS tsCol, @dateParam AS dateCol, @byteArrayParam AS byteArrayCol, "
"@stringArrayParam AS stringArrayCol, @intArrayParam AS intArrayCol, "
"@float32ArrayParam AS float32ArrayCol, @float64ArrayParam AS float64ArrayCol, "
"@boolArrayParam AS boolArrayCol, @tsArrayParam AS tsArrayCol, "
"@dateArrayParam AS dateArrayCol"
)
parameters = {
"stringParam": "foo",
"bytesParam": b"bar",
"int64Param": 12,
"float32Param": 1.1,
"float64Param": 1.2,
"boolParam": True,
"tsParam": datetime.datetime.fromtimestamp(1000, tz=datetime.timezone.utc),
"dateParam": datetime.date(2025, 1, 16),
"byteArrayParam": [b"foo", b"bar", None],
"stringArrayParam": ["foo", "bar", None],
"intArrayParam": [1, None, 2],
"float32ArrayParam": [1.2, None, 1.3],
"float64ArrayParam": [1.4, None, 1.5],
"boolArrayParam": [None, False, True],
"tsArrayParam": [
datetime.datetime.fromtimestamp(1000, tz=datetime.timezone.utc),
datetime.datetime.fromtimestamp(2000, tz=datetime.timezone.utc),
None,
],
"dateArrayParam": [
datetime.date(2025, 1, 16),
datetime.date(2025, 1, 17),
None,
],
}
param_types = {
"float32Param": SqlType.Float32(),
"float64Param": SqlType.Float64(),
"byteArrayParam": SqlType.Array(SqlType.Bytes()),
"stringArrayParam": SqlType.Array(SqlType.String()),
"intArrayParam": SqlType.Array(SqlType.Int64()),
"float32ArrayParam": SqlType.Array(SqlType.Float32()),
"float64ArrayParam": SqlType.Array(SqlType.Float64()),
"boolArrayParam": SqlType.Array(SqlType.Bool()),
"tsArrayParam": SqlType.Array(SqlType.Timestamp()),
"dateArrayParam": SqlType.Array(SqlType.Date()),
}
result = await client.execute_query(
query, instance_id, parameters=parameters, parameter_types=param_types
)
rows = [r async for r in result]
assert len(rows) == 1
row = rows[0]
assert row["strCol"] == parameters["stringParam"]
assert row["bytesCol"] == parameters["bytesParam"]
assert row["intCol"] == parameters["int64Param"]
assert row["float32Col"] == pytest.approx(parameters["float32Param"])
assert row["float64Col"] == pytest.approx(parameters["float64Param"])
assert row["boolCol"] == parameters["boolParam"]
assert row["tsCol"] == parameters["tsParam"]
assert row["dateCol"] == date_pb2.Date(year=2025, month=1, day=16)
assert row["stringArrayCol"] == parameters["stringArrayParam"]
assert row["byteArrayCol"] == parameters["byteArrayParam"]
assert row["intArrayCol"] == parameters["intArrayParam"]
assert row["float32ArrayCol"] == pytest.approx(parameters["float32ArrayParam"])
assert row["float64ArrayCol"] == pytest.approx(parameters["float64ArrayParam"])
assert row["boolArrayCol"] == parameters["boolArrayParam"]
assert row["tsArrayCol"] == parameters["tsArrayParam"]
assert row["dateArrayCol"] == [
date_pb2.Date(year=2025, month=1, day=16),
date_pb2.Date(year=2025, month=1, day=17),
None,
]
Loading

0 comments on commit 2dee92c

Please sign in to comment.