Skip to content

Commit

Permalink
Can subtract two dates
Browse files Browse the repository at this point in the history
Signed-off-by: sfc-gh-mvashishtha <[email protected]>
  • Loading branch information
sfc-gh-mvashishtha committed Jul 24, 2024
1 parent f8491a1 commit 3e34c76
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 49 deletions.
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def binary_operator_extractor(
# to types.
# WRONG-- we still don't have datatypes here.
if isinstance(expr, Subtract) and isinstance(expr.left.datatype, TimestampType) and isinstance(expr.right.datatype, TimestampType):
return f'datediff("ns", {left_sql_expr}, {right_sql_expr})'
return f'datediff("ns", {right_sql_expr}, {left_sql_expr})'

# from .functions import datediff
# right_expression = Column._to_expr(other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
)
from snowflake.snowpark.types import TimestampType, TimedeltaType


class BinaryExpression(Expression):
Expand Down Expand Up @@ -81,6 +82,12 @@ class Add(BinaryArithmeticExpression):
class Subtract(BinaryArithmeticExpression):
sql_operator = "-"

def resolve_datatype(self, input_attributes):
self.children[0].resolve_datatype(input_attributes)
self.children[1].resolve_datatype(input_attributes)
if isinstance(self.children[0].datatype, TimestampType) and isinstance(self.children[1].datatype, TimestampType):
self.datatype = TimedeltaType()


class Multiply(BinaryArithmeticExpression):
sql_operator = "*"
Expand Down
18 changes: 18 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.COLUMN

def resolve_datatype(self, input_attributes):
# We already have a datatype. Nothing to do.
pass

class Star(Expression):
def __init__(
Expand Down Expand Up @@ -292,6 +295,17 @@ def resolve(self, input_attributes) -> Attribute:
raise SnowparkClientExceptionMessages.DF_CANNOT_RESOLVE_COLUMN_NAME(
self.name
)

def resolve_datatype(self, input_attributes):
normalized_col_name = snowflake.snowpark._internal.utils.quote_name(self.name)
cols = list(filter(lambda attr: attr.name == normalized_col_name, input_attributes))
if len(cols) == 1:
self.datatype = cols[0].datatype
else:
raise SnowparkClientExceptionMessages.DF_CANNOT_RESOLVE_COLUMN_NAME(
self.name
)


class Literal(Expression):
def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None:
Expand Down Expand Up @@ -517,6 +531,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.FUNCTION

def resolve_datatype(self, input_attributes):
# TODO: Column class should tell this class what its return type is.
self.datatype = self.return_type


class WithinGroup(Expression):
Expand Down
11 changes: 11 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,17 @@ def __init__(
) # will be replaced by new api calls if any operation.
self._placeholder_query = None

# try to append datatypes onto our projections
input_attributes = from_.snowflake_plan.attributes

if projection is None:
# TODO: formerly we had a "*", but having multiple datatypes
# in the star expression gets into sketchy semantic territory.
self.projection = input_attributes
else:
for each_projection in projection:
each_projection.resolve_datatype(input_attributes)

def __copy__(self):
new = SelectStatement(
projection=self.projection,
Expand Down
52 changes: 6 additions & 46 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,55 +370,15 @@ def attributes(self) -> List[Attribute]:
# first is 'select *' from another SelectStatement whose projection has the goods
if self.source_plan.projection is None:
return self.source_plan.from_.snowflake_plan.attributes
input_attributes = self.source_plan.from_.snowflake_plan.attributes
input_names = [c.name for c in input_attributes]
my_attributes = []
# We have real projections. resolve each of self.projection against self.source_plan.
#
# TODO:
# We want methods like resolve() on each of the expression objects so that they can
# recursively do this resolution on their own. We shouldn't have to unroll expression objects
# as we are doing here with projction.child.return_type.
from .analyzer import Alias
for projection in self.source_plan.projection:
if isinstance(projection, Attribute):
# Attribute already has a DataType.
my_attributes.append(projection)
elif isinstance(projection, Alias):
# get return type from function. but return type depends on input types.
# here have a hack for t_timestamp.
# TODO: add a method to functions called resolve() that will resolve
# the output types given the input schema.
# functions like sum() can initialize functions with a `resolver` attribute
# that tells how they should resolve types.

# copied from DataFrame._resolve
from .binary_expression import Subtract

if isinstance(projection.child, UnresolvedAttribute):
my_attributes.append(projection.child.resolve(input_attributes).with_name(projection.name))
elif (
isinstance(projection.child, Subtract) and
isinstance(projection.child.children[0], UnresolvedAttribute) and
isinstance(projection.child.children[1], UnresolvedAttribute) and
isinstance(projection.child.children[0].resolve(input_attributes).datatype, TimestampType) and
isinstance(projection.child.children[1].resolve(input_attributes).datatype, TimestampType)
):
my_attributes.append(Attribute(name=projection.name, datatype=TimedeltaType))
else:
breakpoint()
if not hasattr(projection.child, "return_type"):
breakpoint()
raise NotImplementedError(f'cannot handle projection.child')
my_attributes.append(Attribute(name=projection.name, datatype=projection.child.return_type))
elif isinstance(projection, UnresolvedAttribute):
my_attributes.append(projection.resolve(input_attributes))
else:
raise NotImplementedError(f'cannot handle projection type {type(projection)}')
# only return if we found all the attributes, including their types
# otherwise fall back to getting types from Snowflake.
# TODO: what if we don't know the name?
for each_projection in self.source_plan.projection:
if not hasattr(each_projection, 'name'):
raise NotImplementedError('cannot find name of projection')
my_attributes.append(Attribute(name=each_projection.name, datatype=each_projection.datatype))
return my_attributes

# otherwise, fall back to snowflake

output = analyze_attributes(self.schema_query, self.session)
# No simplifier case relies on this schema_query change to update SHOW TABLES to a nested sql friendly query.
Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/unary_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
# do not add additional complexity for alias
return {}

def resolve_datatype(self, input_attributes):
self.child.resolve_datatype(input_attributes)
self.datatype = self.child.datatype

class UnresolvedAlias(UnaryExpression, NamedExpression):
sql_operator = "AS"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def __init__(
super().__init__()
self.window_function = window_function
self.window_spec = window_spec
self.return_type = window_function.return_type

def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.window_function, self.window_spec)
Expand All @@ -152,6 +151,10 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
self.window_spec.cumulative_node_complexity,
)

def resolve_datatype(self, input_datatypes):
self.window_function.resolve_datatype(input_datatypes)
self.datatype = self.window_function.datatype

class RankRelatedFunctionExpression(Expression):
sql: str

Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,12 @@ def to_pandas(
),
**kwargs,
)
from snowflake.snowpark.types import TimedeltaType
import pandas
breakpoint()
for i, attribute in enumerate(self.schema):
if isinstance(attribute.datatype, TimedeltaType):
result.iloc[:, i] = result.iloc[:, i].apply(lambda v: pandas.Timedelta(v, "ns"))

# if the returned result is not a pandas dataframe, raise Exception
# this might happen when calling this method with non-select commands
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def count(e: ColumnOrName) -> Column:
if isinstance(c._expression, Star)
else builtin("count")(c._expression)
)
return_expression.return_type = LongType
return_expression._expression.return_type = LongType
return return_expression


Expand Down

0 comments on commit 3e34c76

Please sign in to comment.