Skip to content

Commit

Permalink
Merge branch 'main' into aalam-SNOW-1644950-add-explicit-option-for-u…
Browse files Browse the repository at this point in the history
…se-logical-type
  • Loading branch information
sfc-gh-aalam authored Aug 30, 2024
2 parents b4821ea + 4bcd987 commit 4db27a2
Show file tree
Hide file tree
Showing 54 changed files with 8,284 additions and 864 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
- Fixed a bug in query generation from set operations that allowed generation of duplicate queries when children have common subqueries.
- Fixed a bug in `session.get_session_stage` that referenced a non-existing stage after switching database or schema.
- Fixed a bug where calling `DataFrame.to_snowpark_pandas_dataframe` without explicitly initializing the Snowpark pandas plugin caused an error.
- Fixed a bug where using the `explode` function in dynamic table creation caused a SQL compilation error due to improper boolean type casting on the `outer` parameter.

### Snowpark Local Testing Updates

Expand Down Expand Up @@ -61,6 +62,7 @@
- support for binary arithmetic between two `Timedelta` values.
- support for lazy `TimedeltaIndex`.
- support for `pd.to_timedelta`.
- support for `GroupBy` aggregations `min`, `max`, `mean`, `idxmax`, `idxmin`, `std`, `sum`, `median`, `count`, `any`, `all`, `size`, `nunique`.
- Added support for index's arithmetic and comparison operators.
- Added support for `Series.dt.round`.
- Added documentation pages for `DatetimeIndex`.
Expand All @@ -78,12 +80,21 @@
#### Improvements

- Refactored `quoted_identifier_to_snowflake_type` to avoid making metadata queries if the types have been cached locally.
- Improved `pd.to_datetime` to handle all local input cases.

#### Bug Fixes

- Stopped ignoring nanoseconds in `pd.Timedelta` scalars.
- Fixed AssertionError in tree of binary operations.

#### Behavior Change

- When calling `DataFrame.set_index`, or setting `DataFrame.index` or `Series.index`, with a new index that does not match the current length of the `Series`/`DataFrame` object, a `ValueError` is no longer raised. When the `Series`/`DataFrame` object is longer than the new index, the `Series`/`DataFrame`'s new index is filled with `NaN` values for the "extra" elements. When the `Series`/`DataFrame` object is shorter than the new index, the extra values in the new index are ignored—`Series` and `DataFrame` stay the same length `n`, and use only the first `n` values of the new index.

#### Improvements

- Improve concat, join performance when operations are performed on series coming from the same dataframe by avoiding unnecessary joins.

## 1.21.0 (2024-08-19)

### Snowpark Python API Updates
Expand Down
10 changes: 8 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import TelemetryField
from snowflake.snowpark._internal.utils import quote_name
from snowflake.snowpark.types import _NumericType
from snowflake.snowpark.types import BooleanType, _NumericType

ARRAY_BIND_THRESHOLD = 512

Expand Down Expand Up @@ -605,7 +605,7 @@ def table_function_expression_extractor(
sql = named_arguments_function(
expr.func_name,
{
key: self.analyze(
key: self.to_sql_try_avoid_cast(
value, df_aliased_col_name_to_real_col_name, parse_local_name
)
for key, value in expr.args.items()
Expand Down Expand Up @@ -745,6 +745,12 @@ def to_sql_try_avoid_cast(
# otherwise process as normal
if isinstance(expr, Literal) and isinstance(expr.datatype, _NumericType):
return numeric_to_sql_without_cast(expr.value, expr.datatype)
elif (
isinstance(expr, Literal)
and isinstance(expr.datatype, BooleanType)
and isinstance(expr.value, bool)
):
return str(expr.value).upper()
else:
return self.analyze(
expr, df_aliased_col_name_to_real_col_name, parse_local_name
Expand Down
8 changes: 6 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,9 @@ def create_table_as_select_statement(
max_data_extension_time: Optional[int] = None,
change_tracking: Optional[bool] = None,
copy_grants: bool = False,
*,
use_scoped_temp_objects: bool = False,
is_generated: bool = False,
) -> str:
column_definition_sql = (
f"{LEFT_PARENTHESIS}{column_definition}{RIGHT_PARENTHESIS}"
Expand All @@ -877,8 +880,9 @@ def create_table_as_select_statement(
}
)
return (
f"{CREATE}{OR + REPLACE if replace else EMPTY_STRING} {table_type.upper()} {TABLE}"
f"{IF + NOT + EXISTS if not replace and not error else EMPTY_STRING} "
f"{CREATE}{OR + REPLACE if replace else EMPTY_STRING}"
f" {(get_temp_type_for_object(use_scoped_temp_objects, is_generated) if table_type.lower() in TEMPORARY_STRING_SET else table_type).upper()} "
f"{TABLE}{IF + NOT + EXISTS if not replace and not error else EMPTY_STRING} "
f"{table_name}{column_definition_sql}{cluster_by_clause}{options_statement}"
f"{COPY_GRANTS if copy_grants else EMPTY_STRING}{comment_sql} {AS}{project_statement([], child)}"
)
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,8 @@ def get_create_table_as_select_plan(child: SnowflakePlan, replace, error):
max_data_extension_time=max_data_extension_time,
change_tracking=change_tracking,
copy_grants=copy_grants,
use_scoped_temp_objects=use_scoped_temp_objects,
is_generated=is_generated,
),
child,
source_plan,
Expand Down
27 changes: 11 additions & 16 deletions src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from collections import defaultdict
from typing import List, Optional, Tuple

from sortedcontainers import SortedList

from snowflake.snowpark._internal.analyzer.analyzer_utils import (
drop_table_if_exists_statement,
)
Expand Down Expand Up @@ -201,11 +199,11 @@ def _find_node_to_breakdown(self, root: TreeNode) -> Optional[TreeNode]:
1. Traverse the plan tree and find the valid nodes for partitioning.
2. If no valid node is found, return None.
3. Keep valid nodes in a sorted list based on the complexity score.
4. Return the node with the highest complexity score.
3. Return the node with the highest complexity score.
"""
current_level = [root]
pipeline_breaker_list = SortedList(key=lambda x: x[0])
candidate_node = None
candidate_score = -1 # start with -1 since score is always > 0

while current_level:
next_level = []
Expand All @@ -215,23 +213,20 @@ def _find_node_to_breakdown(self, root: TreeNode) -> Optional[TreeNode]:
self._parent_map[child].add(node)
valid_to_breakdown, score = self._is_node_valid_to_breakdown(child)
if valid_to_breakdown:
# Append score and child to the pipeline breaker sorted list
# so that the valid child with the highest complexity score
# is at the end of the list.
pipeline_breaker_list.add((score, child))
# If the score for valid node is higher than the last candidate,
# update the candidate node and score.
if score > candidate_score:
candidate_score = score
candidate_node = child
else:
# don't traverse subtrees if parent is a valid candidate
next_level.append(child)

current_level = next_level

if not pipeline_breaker_list:
# Return None if no valid node is found for partitioning.
return None

# Get the node with the highest complexity score
_, child = pipeline_breaker_list.pop()
return child
# If no valid node is found, candidate_node will be None.
# Otherwise, return the node with the highest complexity score.
return candidate_node

def _get_partitioned_plan(self, root: TreeNode, child: TreeNode) -> SnowflakePlan:
"""This method takes cuts the child out from the root, creates a temp table plan for the
Expand Down
9 changes: 0 additions & 9 deletions src/snowflake/snowpark/_internal/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,6 @@ def update_resolvable_node(
elif isinstance(node, (SelectSnowflakePlan, SelectTableFunction)):
assert node.snowflake_plan is not None
update_resolvable_node(node.snowflake_plan, query_generator)
node.pre_actions = node.snowflake_plan.queries[:-1]
node.post_actions = node.snowflake_plan.post_actions
node._api_calls = node.snowflake_plan.api_calls
if isinstance(node, SelectSnowflakePlan):
node._query_params = []
for query in node._snowflake_plan.queries:
if query.params:
node._query_params.extend(query.params)

node.analyzer = query_generator

node.pre_actions = node._snowflake_plan.queries[:-1]
Expand Down
79 changes: 71 additions & 8 deletions src/snowflake/snowpark/modin/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,16 @@
timedelta_range,
)

import modin.pandas

# TODO: SNOW-851745 make sure add all Snowpark pandas API general functions
from modin.pandas import plotting # type: ignore[import]

from snowflake.snowpark.modin.pandas.dataframe import DataFrame
from snowflake.snowpark.modin.pandas.api.extensions import (
register_dataframe_accessor,
register_series_accessor,
)
from snowflake.snowpark.modin.pandas.dataframe import _DATAFRAME_EXTENSIONS_, DataFrame
from snowflake.snowpark.modin.pandas.general import (
concat,
crosstab,
Expand Down Expand Up @@ -140,15 +146,15 @@
read_xml,
to_pickle,
)
from snowflake.snowpark.modin.pandas.series import Series
from snowflake.snowpark.modin.pandas.series import _SERIES_EXTENSIONS_, Series
from snowflake.snowpark.modin.plugin._internal.session import SnowpandasSessionHolder
from snowflake.snowpark.modin.plugin._internal.telemetry import (
try_add_telemetry_to_attribute,
)

# The extensions assigned to this module
_PD_EXTENSIONS_: dict = {}

# base needs to be re-exported in order to properly override docstrings for BasePandasDataset
# moving this import higher prevents sphinx from building documentation (??)
from snowflake.snowpark.modin.pandas import base # isort: skip # noqa: E402,F401

import snowflake.snowpark.modin.plugin.extensions.pd_extensions as pd_extensions # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.pd_overrides # isort: skip # noqa: E402,F401
Expand All @@ -157,12 +163,71 @@
DatetimeIndex,
TimedeltaIndex,
)

# this must occur before overrides are applied
_attrs_defined_on_modin_base = set(dir(modin.pandas.base.BasePandasDataset))
_attrs_defined_on_series = set(
dir(Series)
) # TODO: SNOW-1063347 revisit when series.py is removed
_attrs_defined_on_dataframe = set(
dir(DataFrame)
) # TODO: SNOW-1063346 revisit when dataframe.py is removed

# base overrides occur before subclass overrides in case subclasses override a base method
import snowflake.snowpark.modin.plugin.extensions.base_extensions # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.base_overrides # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.dataframe_extensions # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.dataframe_overrides # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.series_extensions # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.series_overrides # isort: skip # noqa: E402,F401

# For any method defined on Series/DF, add telemetry to it if it meets all of the following conditions:
# 1. The method was defined directly on upstream BasePandasDataset (_attrs_defined_on_modin_base)
# 2. The method is not overridden by a child class (this will change)
# 3. The method is not overridden by an extensions module
# 4. The method name does not start with an _
#
# TODO: SNOW-1063347
# Since we still use the vendored version of Series and the overrides for the top-level
# namespace haven't been performed yet, we need to set properties on the vendored version
_base_telemetry_added_attrs = set()

_series_ext = _SERIES_EXTENSIONS_.copy()
for attr_name in dir(Series):
if (
attr_name in _attrs_defined_on_modin_base
and attr_name in _attrs_defined_on_series
and attr_name not in _series_ext
and not attr_name.startswith("_")
):
register_series_accessor(attr_name)(
try_add_telemetry_to_attribute(attr_name, getattr(Series, attr_name))
)
_base_telemetry_added_attrs.add(attr_name)

# TODO: SNOW-1063346
# Since we still use the vendored version of DataFrame and the overrides for the top-level
# namespace haven't been performed yet, we need to set properties on the vendored version
_dataframe_ext = _DATAFRAME_EXTENSIONS_.copy()
for attr_name in dir(DataFrame):
if (
attr_name in _attrs_defined_on_modin_base
and attr_name in _attrs_defined_on_dataframe
and attr_name not in _dataframe_ext
and not attr_name.startswith("_")
):
# If telemetry was already added via Series, register the override but don't re-wrap
# the method in the telemetry annotation. If we don't do this check, we will end up
# double-reporting telemetry on some methods.
original_attr = getattr(DataFrame, attr_name)
new_attr = (
original_attr
if attr_name in _base_telemetry_added_attrs
else try_add_telemetry_to_attribute(attr_name, original_attr)
)
register_dataframe_accessor(attr_name)(new_attr)
_base_telemetry_added_attrs.add(attr_name)


def __getattr__(name: str) -> Any:
"""
Expand Down Expand Up @@ -220,7 +285,6 @@ def __getattr__(name: str) -> Any:
"date_range",
"Index",
"MultiIndex",
"Series",
"bdate_range",
"period_range",
"DatetimeIndex",
Expand Down Expand Up @@ -318,8 +382,7 @@ def __getattr__(name: str) -> Any:
# Manually re-export the members of the pd_extensions namespace, which are not declared in __all__.
_EXTENSION_ATTRS = ["read_snowflake", "to_snowflake", "to_snowpark", "to_pandas"]
# We also need to re-export native_pd.offsets, since modin.pandas doesn't re-export it.
# snowflake.snowpark.pandas.base also needs to be re-exported to make docstring overrides for BasePandasDataset work.
_ADDITIONAL_ATTRS = ["offsets", "base"]
_ADDITIONAL_ATTRS = ["offsets"]

# This code should eventually be moved into the `snowflake.snowpark.modin.plugin` module instead.
# Currently, trying to do so would result in incorrect results because `snowflake.snowpark.modin.pandas`
Expand Down
16 changes: 0 additions & 16 deletions src/snowflake/snowpark/modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,14 +604,6 @@ def _to_series_list(self, index: pd.Index) -> list[pd.Series]:
return [pd.Series(index)]

def _set_index(self, new_index: Axes) -> None:
"""
Set the index for this DataFrame.
Parameters
----------
new_index : pandas.Index
The new index to set this.
"""
# TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
self._update_inplace(
new_query_compiler=self._query_compiler.set_index(
Expand Down Expand Up @@ -655,14 +647,6 @@ def set_axis(
return obj

def _get_index(self):
"""
Get the index for this DataFrame.
Returns
-------
pandas.Index
The union of all indexes across the partitions.
"""
# TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
from snowflake.snowpark.modin.plugin.extensions.index import Index

Expand Down
6 changes: 4 additions & 2 deletions src/snowflake/snowpark/modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import numpy as np
import pandas
from modin.pandas.accessor import CachedAccessor, SparseFrameAccessor
from modin.pandas.base import BasePandasDataset

# from . import _update_engine
from modin.pandas.iterator import PartitionIterator
Expand Down Expand Up @@ -73,7 +74,6 @@
from pandas.util._validators import validate_bool_kwarg

from snowflake.snowpark.modin import pandas as pd
from snowflake.snowpark.modin.pandas.base import _ATTRS_NO_LOOKUP, BasePandasDataset
from snowflake.snowpark.modin.pandas.groupby import (
DataFrameGroupBy,
validate_groupby_args,
Expand All @@ -91,12 +91,14 @@
replace_external_data_keys_with_empty_pandas_series,
replace_external_data_keys_with_query_compiler,
)
from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta
from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated
from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike
from snowflake.snowpark.modin.plugin.utils.error_message import (
ErrorMessage,
dataframe_not_implemented,
)
from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP
from snowflake.snowpark.modin.plugin.utils.warning_message import (
SET_DATAFRAME_ATTRIBUTE_WARNING,
WarningMessage,
Expand Down Expand Up @@ -136,7 +138,7 @@
],
apilink="pandas.DataFrame",
)
class DataFrame(BasePandasDataset):
class DataFrame(BasePandasDataset, metaclass=TelemetryMeta):
_pandas_class = pandas.DataFrame

def __init__(
Expand Down
Loading

0 comments on commit 4db27a2

Please sign in to comment.