Skip to content

Commit

Permalink
Can track types for getting schema, but can't use the types to custom…
Browse files Browse the repository at this point in the history
…ize the generated sql. Comments in new_types_demo

Signed-off-by: sfc-gh-mvashishtha <[email protected]>
  • Loading branch information
sfc-gh-mvashishtha committed Jul 23, 2024
1 parent c5c8004 commit f8491a1
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 3 deletions.
17 changes: 17 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union

import snowflake.snowpark
from snowflake.snowpark.column import TimestampType, TimedeltaType
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
alias_expression,
binary_arithmetic_expression,
Expand Down Expand Up @@ -45,6 +46,7 @@
from snowflake.snowpark._internal.analyzer.binary_expression import (
BinaryArithmeticExpression,
BinaryExpression,
Subtract
)
from snowflake.snowpark._internal.analyzer.binary_plan_node import Join, SetOperation
from snowflake.snowpark._internal.analyzer.datatype_mapper import (
Expand Down Expand Up @@ -688,7 +690,22 @@ def binary_operator_extractor(
expr.right, df_aliased_col_name_to_real_col_name, parse_local_name
)
if isinstance(expr, BinaryArithmeticExpression):
# TODO: it doesn't seem appropriate to rewrite the expression at this stage,
# but on the other hand Column and Expression themselves do not have access
# to types.
# WRONG-- we still don't have datatypes here.
if isinstance(expr, Subtract) and isinstance(expr.left.datatype, TimestampType) and isinstance(expr.right.datatype, TimestampType):
return f'datediff("ns", {left_sql_expr}, {right_sql_expr})'

# from .functions import datediff
# right_expression = Column._to_expr(other)
# if isinstance(self._expression.datatype, TimestampType) and isinstance(right_expression.datatype, TimestampType):
# result = datediff("ns", self._expression, right_expression)
# result.return_type = TimedeltaType
# return result
# return Column(Subtract(self._expression, right_expression))
return binary_arithmetic_expression(

expr.sql_operator,
left_sql_expr,
right_sql_expr,
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,16 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.COLUMN

def resolve(self, input_attributes) -> Attribute:
# copied from DataFrame._resolve
normalized_col_name = snowflake.snowpark._internal.utils.quote_name(self.name)
cols = list(filter(lambda attr: attr.name == normalized_col_name, input_attributes))
if len(cols) == 1:
return cols[0].with_name(self.name)
else:
raise SnowparkClientExceptionMessages.DF_CANNOT_RESOLVE_COLUMN_NAME(
self.name
)

class Literal(Expression):
def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None:
Expand Down
7 changes: 7 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,12 +456,14 @@ class SelectSnowflakePlan(Selectable):
"""Wrap a SnowflakePlan to a subclass of Selectable."""

def __init__(self, snowflake_plan: LogicalPlan, *, analyzer: "Analyzer") -> None:
# First, we get a snowflake_plan that is the SnowflakeValues object here.
super().__init__(analyzer)
self._snowflake_plan: SnowflakePlan = (
snowflake_plan
if isinstance(snowflake_plan, SnowflakePlan)
else analyzer.resolve(snowflake_plan)
)
# now we can look at self.snowflake_plan.attributes
self.expr_to_alias.update(self._snowflake_plan.expr_to_alias)
self.df_aliased_col_name_to_real_col_name.update(
self._snowflake_plan.df_aliased_col_name_to_real_col_name
Expand Down Expand Up @@ -502,6 +504,11 @@ def query_params(self) -> Optional[Sequence[Any]]:
@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return self.snowflake_plan.individual_node_complexity


# def attributes(self) -> List[Attribute]:
# # override the usual SnowflakePlan
# input_attributes = self.from_.snowflake_plan.attributes


class SelectStatement(Selectable):
Expand Down
60 changes: 60 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import uuid
from collections import defaultdict
from enum import Enum
from .expression import FunctionExpression, UnresolvedAttribute
from snowflake.snowpark._internal.utils import quote_name
from functools import cached_property
from snowflake.snowpark.column import TimestampType, TimedeltaType
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -360,6 +363,63 @@ def with_subqueries(self, subquery_plans: List["SnowflakePlan"]) -> "SnowflakePl

@cached_property
def attributes(self) -> List[Attribute]:
# first time we get here, self.source_plan is a SnowflakeValues here. don't want to touch attributes.
# second time we get here, self.source_plan is a SelectStatement selecting function calls and aliases from innner node.
from .select_statement import SelectStatement
if isinstance(self.source_plan, SelectStatement):
# first is 'select *' from another SelectStatement whose projection has the goods
if self.source_plan.projection is None:
return self.source_plan.from_.snowflake_plan.attributes
input_attributes = self.source_plan.from_.snowflake_plan.attributes
input_names = [c.name for c in input_attributes]
my_attributes = []
# We have real projections. resolve each of self.projection against self.source_plan.
#
# TODO:
# We want methods like resolve() on each of the expression objects so that they can
# recursively do this resolution on their own. We shouldn't have to unroll expression objects
# as we are doing here with projction.child.return_type.
from .analyzer import Alias
for projection in self.source_plan.projection:
if isinstance(projection, Attribute):
# Attribute already has a DataType.
my_attributes.append(projection)
elif isinstance(projection, Alias):
# get return type from function. but return type depends on input types.
# here have a hack for t_timestamp.
# TODO: add a method to functions called resolve() that will resolve
# the output types given the input schema.
# functions like sum() can initialize functions with a `resolver` attribute
# that tells how they should resolve types.

# copied from DataFrame._resolve
from .binary_expression import Subtract

if isinstance(projection.child, UnresolvedAttribute):
my_attributes.append(projection.child.resolve(input_attributes).with_name(projection.name))
elif (
isinstance(projection.child, Subtract) and
isinstance(projection.child.children[0], UnresolvedAttribute) and
isinstance(projection.child.children[1], UnresolvedAttribute) and
isinstance(projection.child.children[0].resolve(input_attributes).datatype, TimestampType) and
isinstance(projection.child.children[1].resolve(input_attributes).datatype, TimestampType)
):
my_attributes.append(Attribute(name=projection.name, datatype=TimedeltaType))
else:
breakpoint()
if not hasattr(projection.child, "return_type"):
breakpoint()
raise NotImplementedError(f'cannot handle projection.child')
my_attributes.append(Attribute(name=projection.name, datatype=projection.child.return_type))
elif isinstance(projection, UnresolvedAttribute):
my_attributes.append(projection.resolve(input_attributes))
else:
raise NotImplementedError(f'cannot handle projection type {type(projection)}')
# only return if we found all the attributes, including their types
# otherwise fall back to getting types from Snowflake.
return my_attributes


output = analyze_attributes(self.schema_query, self.session)
# No simplifier case relies on this schema_query change to update SHOW TABLES to a nested sql friendly query.
if not self.schema_query or not self.session.sql_simplifier_enabled:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
super().__init__()
self.window_function = window_function
self.window_spec = window_spec
self.return_type = window_function.return_type

def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.window_function, self.window_spec)
Expand All @@ -151,7 +152,6 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
self.window_spec.cumulative_node_complexity,
)


class RankRelatedFunctionExpression(Expression):
sql: str

Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
StringType,
TimestampTimeZone,
TimestampType,
TimedeltaType
)
from snowflake.snowpark.window import Window, WindowSpec

Expand Down
9 changes: 7 additions & 2 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@
StringType,
StructType,
TimestampType,
LongType
)
from snowflake.snowpark.udaf import UDAFRegistration, UserDefinedAggregateFunction
from snowflake.snowpark.udf import UDFRegistration, UserDefinedFunction
Expand Down Expand Up @@ -697,11 +698,13 @@ def count(e: ColumnOrName) -> Column:
<BLANKLINE>
"""
c = _to_col_if_str(e, "count")
return (
return_expression = (
builtin("count")(Literal(1))
if isinstance(c._expression, Star)
else builtin("count")(c._expression)
)
return_expression.return_type = LongType
return return_expression


def count_distinct(*cols: ColumnOrName) -> Column:
Expand Down Expand Up @@ -3181,11 +3184,13 @@ def to_timestamp(e: ColumnOrName, fmt: Optional["Column"] = None) -> Column:
[Row(ANS=datetime.datetime(1970, 1, 1, 0, 0, 20)), Row(ANS=datetime.datetime(1971, 1, 1, 0, 0)), Row(ANS=datetime.datetime(1971, 1, 1, 0, 0)), Row(ANS=datetime.datetime(1971, 1, 1, 0, 0))]
"""
c = _to_col_if_str(e, "to_timestamp")
return (
return_value = (
builtin("to_timestamp")(c, fmt)
if fmt is not None
else builtin("to_timestamp")(c)
)
return_value._expression.return_type = TimestampType()
return return_value


def to_timestamp_ntz(
Expand Down
41 changes: 41 additions & 0 deletions src/snowflake/snowpark/modin/pandas/new_types_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import modin.pandas as pd
import snowflake.snowpark.modin.plugin
import numpy as np
import pandas as native_pd
from snowflake.snowpark.session import Session; session = Session.builder.create()
import logging; logging.getLogger("snowflake.snowpark").setLevel(logging.DEBUG)


### Section 1. Timedelta
df = pd.DataFrame([[pd.Timestamp(year=2020,month=11,day=11,second=30), pd.Timestamp(year=2019,month=10,day=10,second=1)]])

# check we can print dataframe
print(df)

# The schema has the correct snowpark types.
print(df._query_compiler._modin_frame.ordered_dataframe._dataframe_ref.snowpark_dataframe.schema)

timedelta_result = df[0] - df[1]

# The timedelta type shows up as the last column in the schema!
print(timedelta_result._query_compiler._modin_frame.ordered_dataframe._dataframe_ref.snowpark_dataframe.schema)


# However, Snowflake still raises a type error because the expression types aren't available at the point where we generate the SQL,
# so we can't decide to use datediff instead of regular subtraction.
print(df[0] - df[1])

# adding timestamp to timedelta

# adding two timedelta

# timestamp + (timedelta + timedelta)

### Section 2. Interval

# Interval
# df = pd.DataFrame([pd.Interval(1, 3, closed='left'), pd.Interval(5, 7, closed='both')])
# print(df)
# dfp = df._to_pandas()
# print(dfp)
# print(list(dfp[0]))
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,8 @@ class GeometryType(DataType):

pass

This comment has been minimized.

Copy link
@sfc-gh-lspiegelberg

sfc-gh-lspiegelberg Jul 23, 2024

Contributor

I would probably go the route here of adding functions register_compound_user_type(name="TimeDelta", members={"days":LongType(), "seconds":LongType(), ...}) and unregister_compound_user_type(name="TimeDelta") here.

Then, within a session (for which the types should be alive), I would keep these around and track them separately. This would allow to simplify logic.

class TimedeltaType(DataType):
pass

class _PandasType(DataType):
pass
Expand Down
52 changes: 52 additions & 0 deletions type_tracking_notes
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
problem: we build an expression col('a') - col('b') that has no type at any level.

it only becomes typed once we select it into a Selectable (like "SELECT [...] FROM [...]), at which point it's not an expression but a Selectable.

something like a Selectable could deduce the new types and also rewrite the expressions into the correct form.

2a. Could do it in SelectStatement, which deals with Expression. but then we cannot make any sense of the myriad snowpark function calls

2b. Could do it in Snowpark DataFrame, which deals with Column.
- problem: dataframe.select also gets expressions. how can the dataframe even identify what it's selecting, let alone tell us how to translate the types?
chicken and egg: i'm trying to select f(g(h("A"))), but i can't tell how to invoke h without knowing the type of "A" itself.
How did we solve this problem in our prototype? we didn't. we never really successfully propagated the snowpark type up to the pandas layer,
and we were probably not tracking the snowpark type correctly.
We did the SQL translation in the pandas layer, and to convert to pandas, we did "don't count on our type tracking to tell us which columns consist of native_pd.Interval."
we guessed from the JSON values themselves that we were dealing with timedelta / interval rather than the correct type.

functions, which deal with Column objects, need to be able to tell snowpark what the result types are.

what I prototyped for Option 2 wasn't really what I wrote about for Option 2 in the design doc.

pd.DataFrame([[pd.Timestamp(year=2020,month=11,day=11,second=30), pd.Timestamp(year=2019,month=10,day=10,second=1)]])

makes something like

Select(
Column("__index__"),
Column(alias to "0",
FunctionExpression(
'to_timestamp',
UnresolvedAttribute("0"),
)
)
Column(alias to "1",
FunctionExpression(
'to_timestamp',
UnresolvedAttribute("1"),
)
)
)

currently we get the schema for that by generating the SQL and asking snowflake.

but if we knew the schema of what we're querying, shouldn't we be able to deduce the schema of the result without asking snowflake again?

on SelectStatement:

self.input_schema = self.input_data.schema

say I call SelectStatement.select(), then for each thing I'm selecting, I use its lazy type inference code to deduce its type based on the input types.


select_statement.select(h(f(g(A + B)), 3))

0 comments on commit f8491a1

Please sign in to comment.