Skip to content

Commit

Permalink
Refactor based on feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jrose committed Dec 5, 2024
1 parent 0b79825 commit f05c058
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 34 deletions.
6 changes: 3 additions & 3 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from snowflake.connector.constants import FIELD_ID_TO_NAME
from snowflake.connector.cursor import ResultMetadata
from snowflake.connector.options import installed_pandas, pandas
from snowflake.snowpark._internal.utils import quote_name
from snowflake.snowpark.types import (
LTZ,
NTZ,
Expand Down Expand Up @@ -157,9 +156,10 @@ def convert_metadata_to_sp_type(
return StructType(
[
StructField(
quote_name(field.name, keep_case=True),
field.name,
convert_metadata_to_sp_type(field, max_string_size),
nullable=field.is_nullable,
is_column=False,
)
for field in metadata.fields
],
Expand Down Expand Up @@ -292,7 +292,7 @@ def convert_sp_to_sf_type(datatype: DataType) -> str:
if isinstance(datatype, StructType):
if datatype.structured:
fields = ", ".join(
f"{field.raw_name} {convert_sp_to_sf_type(field.datatype)}"
f"{field.name} {convert_sp_to_sf_type(field.datatype)}"
for field in datatype.fields
)
return f"OBJECT({fields})"
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@
StringType,
TimestampTimeZone,
TimestampType,
ArrayType,
MapType,
StructType,
)
from snowflake.snowpark.window import Window, WindowSpec

Expand Down Expand Up @@ -916,6 +919,9 @@ def _cast(
if isinstance(to, str):
to = type_string_to_type_object(to)

if isinstance(to, (ArrayType, MapType, StructType)):
to = to._as_nested()

if self._ast is None:
_emit_ast = False

Expand Down
75 changes: 52 additions & 23 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,12 @@ def __init__(
def __repr__(self) -> str:
return f"ArrayType({repr(self.element_type) if self.element_type else ''})"

def _as_nested(self) -> "ArrayType":
element_type = self.element_type
if isinstance(element_type, (ArrayType, MapType, StructType)):
element_type = element_type._as_nested()
return ArrayType(element_type, self.structured)

def is_primitive(self):
return False

Expand Down Expand Up @@ -391,6 +397,12 @@ def __repr__(self) -> str:
def is_primitive(self):
return False

def _as_nested(self) -> "MapType":
value_type = self.value_type
if isinstance(value_type, (ArrayType, MapType, StructType)):
value_type = value_type._as_nested()
return MapType(self.key_type, value_type, self.structured)

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "MapType":
return MapType(
Expand Down Expand Up @@ -482,7 +494,6 @@ class ColumnIdentifier:
"""Represents a column identifier."""

def __init__(self, normalized_name: str) -> None:
self.raw_name = normalized_name
self.normalized_name = quote_name(normalized_name)
self._original_name = normalized_name

Expand Down Expand Up @@ -553,33 +564,41 @@ def __init__(
column_identifier: Union[ColumnIdentifier, str],
datatype: DataType,
nullable: bool = True,
is_column: bool = True,
) -> None:
self.column_identifier = (
ColumnIdentifier(column_identifier)
if isinstance(column_identifier, str)
else column_identifier
)
self.name = column_identifier
self.is_column = is_column
self.datatype = datatype
self.nullable = nullable

@property
def name(self) -> str:
"""Returns the column name."""
return self.column_identifier.name

@property
def raw_name(self) -> str:
return self.column_identifier.raw_name
return self.column_identifier.name if self.is_column else self._name

@name.setter
def name(self, n: str) -> None:
self.column_identifier = ColumnIdentifier(n)
def name(self, n: Union[ColumnIdentifier, str]) -> None:
if isinstance(n, ColumnIdentifier):
self._name = n.name
self.column_identifier = n
else:
self._name = n
self.column_identifier = ColumnIdentifier(n)

def _as_nested(self) -> "StructField":
datatype = self.datatype
if isinstance(datatype, (ArrayType, MapType, StructType)):
datatype = datatype._as_nested()
# Nested StructFields do not follow column naming conventions
return StructField(self._name, datatype, self.nullable, is_column=False)

def __repr__(self) -> str:
return f"StructField({self.name!r}, {repr(self.datatype)}, nullable={self.nullable})"

def __eq__(self, other):
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
return isinstance(other, self.__class__) and (
(self.name, self.is_column, self.datatype, self.nullable)
== (other.name, other.is_column, other.datatype, other.nullable)
)

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "StructField":
Expand Down Expand Up @@ -625,30 +644,40 @@ def __init__(
self, fields: Optional[List["StructField"]] = None, structured=False
) -> None:
self.structured = structured
if fields is None:
fields = []
self.fields = fields
self.fields = []
for field in fields:
self.add(field)

def add(
self,
field: Union[str, ColumnIdentifier, "StructField"],
datatype: Optional[DataType] = None,
nullable: Optional[bool] = True,
) -> "StructType":
if isinstance(field, StructField):
self.fields.append(field)
elif isinstance(field, (str, ColumnIdentifier)):
if isinstance(field, (str, ColumnIdentifier)):
if datatype is None:
raise ValueError(
"When field argument is str or ColumnIdentifier, datatype must not be None."
)
self.fields.append(StructField(field, datatype, nullable))
else:
field = StructField(field, datatype, nullable)
elif not isinstance(field, StructField):
__import__("pdb").set_trace()
raise ValueError(
f"field argument must be one of str, ColumnIdentifier or StructField. Got: '{type(field)}'"
)

# Nested data does not follow the same schema conventions as top level fields.
if isinstance(field.datatype, (ArrayType, MapType, StructType)):
field.datatype = field.datatype._as_nested()

self.fields.append(field)
return self

def _as_nested(self) -> "StructType":
return StructType(
[field._as_nested() for field in self.fields], self.structured
)

@classmethod
def _from_attributes(cls, attributes: list) -> "StructType":
return cls([StructField(a.name, a.datatype, a.nullable) for a in attributes])
Expand Down
16 changes: 8 additions & 8 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _create_test_dataframe(s):
StructType(
[
StructField("A", StringType(16777216), nullable=True),
StructField('"b"', DoubleType(), nullable=True),
StructField("b", DoubleType(), nullable=True),
],
structured=True,
),
Expand Down Expand Up @@ -524,27 +524,27 @@ def test_iceberg_nested_fields(
"NESTED_DATA",
StructType(
[
StructField('"camelCase"', StringType(), nullable=True),
StructField('"snake_case"', StringType(), nullable=True),
StructField('"PascalCase"', StringType(), nullable=True),
StructField("camelCase", StringType(), nullable=True),
StructField("snake_case", StringType(), nullable=True),
StructField("PascalCase", StringType(), nullable=True),
StructField(
'"nested_map"',
"nested_map",
MapType(
StringType(),
StructType(
[
StructField(
'"inner_camelCase"',
"inner_camelCase",
StringType(),
nullable=True,
),
StructField(
'"inner_snake_case"',
"inner_snake_case",
StringType(),
nullable=True,
),
StructField(
'"inner_PascalCase"',
"inner_PascalCase",
StringType(),
nullable=True,
),
Expand Down

0 comments on commit f05c058

Please sign in to comment.