Skip to content

Commit

Permalink
fix conflicts
Browse files Browse the repository at this point in the history
Signed-off-by: Labanya Mukhopadhyay <[email protected]>
  • Loading branch information
sfc-gh-lmukhopadhyay committed Jun 4, 2024
2 parents 8a6b313 + e7813c1 commit 1fc6794
Show file tree
Hide file tree
Showing 85 changed files with 2,405 additions and 386 deletions.
6 changes: 6 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@ snowpark-pandas:
- any-glob-to-any-file:
- src/snowflake/snowpark/modin/**
- tests/integ/modin/**

local-testing:
- changed-files:
- any-glob-to-any-file:
- src/snowflake/snowpark/mock/**
- tests/mock/**
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,30 @@

#### Improvements


#### Bug Fixes

- Fixed a bug where python stored procedure with table return type fails when run in a task.

### Snowpark Local Testing Updates

#### New Features

- Added support for the `strict` parameter when registering UDFs and Stored Procedures.

#### Bug Fixes

- Fixed a bug in convert_timezone that made the setting the source_timezone parameter return an error.
- Fixed a bug where creating DataFrame with empty data of type `DateType` raises `AttributeError`.
- Fixed a bug that table merge fails when update clause exists but no update takes place.

### Snowpark pandas API Updates

#### New Features

- Added partial support for `DataFrame.pct_change` and `Series.pct_change` without the `freq` and `limit` parameters.
- Added support for `Series.str.get`.
- Added support for `Series.dt.dayofweek`, `Series.dt.day_of_week`, `Series.dt.dayofyear`, and `Series.dt.day_of_year`.

#### Bug Fixes

Expand All @@ -20,6 +39,7 @@

- Added support for named aggregations in `DataFrame.aggregate` and `Series.aggregate` with `axis=0`.
- `pd.read_csv` reads using the native pandas CSV parser, then uploads data to snowflake using parquet. This enables most of the parameters supported by `read_csv` including date parsing and numeric conversions. Uploading via parquet is roughly twice as fast as uploading via CSV.
- Initial work to support an Index directly in Snowpark pandas. Currently, this class is a simple wrapper for a pandas index. Support for Index as a first-class component of Snowpark pandas is coming soon.

## 1.18.0 (2024-05-28)

Expand Down
8 changes: 4 additions & 4 deletions docs/source/modin/supported/series_dt_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ the method in the left column.
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``nanosecond`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``day_of_week`` | N | |
| ``day_of_week`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``week`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``weekofyear`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``dayofweek`` | N | |
| ``dayofweek`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``weekday`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``dayofyear`` | N | |
| ``dayofyear`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``day_of_year`` | N | |
| ``day_of_year`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``quarter`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
Expand Down
4 changes: 3 additions & 1 deletion docs/source/modin/supported/series_str_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ the method in the left column.
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``fullmatch`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``get`` | N | |
| ``get`` | P | ``N`` if the `i` parameter is set to a non-int |
| | | value. Also non-string data values such as list |
| | | and dict are not yet supported. |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``get_dummies`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
Expand Down
8 changes: 7 additions & 1 deletion src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def run_query(
case_sensitive: bool = True,
params: Optional[Sequence[Any]] = None,
num_statements: Optional[int] = None,
ignore_results: bool = False,
**kwargs,
) -> Union[Dict[str, Any], AsyncJob]:
try:
Expand Down Expand Up @@ -422,6 +423,8 @@ def run_query(
# have non-select statements, and it shouldn't fail if the user
# calls to_pandas() to execute the query.
if block:
if ignore_results:
return {"data": None, "sfqid": results_cursor.sfqid}
return self._to_data_or_iter(
results_cursor=results_cursor, to_pandas=to_pandas, to_iter=to_iter
)
Expand Down Expand Up @@ -541,6 +544,7 @@ def get_result_set(
data_type: _AsyncResultType = _AsyncResultType.ROW,
log_on_exception: bool = False,
case_sensitive: bool = True,
ignore_results: bool = False,
**kwargs,
) -> Tuple[
Dict[
Expand Down Expand Up @@ -592,6 +596,7 @@ def get_result_set(
case_sensitive=case_sensitive,
num_statements=len(plan.queries),
params=params,
ignore_results=ignore_results,
**kwargs,
)

Expand Down Expand Up @@ -620,6 +625,7 @@ def get_result_set(
log_on_exception=log_on_exception,
case_sensitive=case_sensitive,
params=query.params,
ignore_results=ignore_results,
**kwargs,
)
placeholders[query.query_id_place_holder] = (
Expand Down Expand Up @@ -656,7 +662,7 @@ def get_result_and_metadata(

def get_result_query_id(self, plan: SnowflakePlan, **kwargs) -> str:
# get the iterator such that the data is not fetched
result_set, _ = self.get_result_set(plan, to_iter=True, **kwargs)
result_set, _ = self.get_result_set(plan, ignore_results=True, **kwargs)
return result_set["sfqid"]

@_Decorator.wrap_exception
Expand Down
66 changes: 38 additions & 28 deletions src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,11 @@ def mock_to_date(
"""
import dateutil.parser

fmt = [fmt] * len(column) if not isinstance(fmt, ColumnEmulator) else fmt
if not isinstance(fmt, ColumnEmulator):
fmt = ColumnEmulator([fmt] * len(column), index=column.index)

def convert_date(row):
def convert_date(data, _fmt):
try:
_fmt = fmt[row.name]
data = row[0]
auto_detect = _fmt is None or _fmt.lower() == "auto"
date_format, _ = convert_snowflake_datetime_format(
_fmt, default_format="%Y-%m-%d"
Expand Down Expand Up @@ -399,7 +398,7 @@ def convert_date(row):
else:
SnowparkLocalTestingException.raise_from_error(exc)

res = column.to_frame().apply(convert_date, axis=1)
res = column.combine(fmt, convert_date)
res.sf_type = ColumnType(DateType(), column.sf_type.nullable)
return res

Expand Down Expand Up @@ -1700,42 +1699,53 @@ def mock_initcap(values: ColumnEmulator, delimiters: ColumnEmulator):

@patch("convert_timezone")
def mock_convert_timezone(
target_timezone: ColumnEmulator,
source_time: ColumnEmulator,
source_timezone: Optional[ColumnEmulator] = None,
*args: ColumnEmulator,
) -> ColumnEmulator:
"""Converts the given source_time to the target timezone.
For timezone information, refer to the `Snowflake SQL convert_timezone notes <https://docs.snowflake.com/en/sql-reference/functions/convert_timezone.html#usage-notes>`_
"""
import dateutil

is_ntz = source_time.sf_type.datatype.tz is TimestampTimeZone.NTZ
if source_timezone is not None and not is_ntz:
SnowparkLocalTestingException.raise_from_error(
ValueError(
# mock_convert_timezone matches the sql function call semantics.
# It has different parameters when called with 2 or 3 args.
# When called with two args, the third will be replaced with None.
if args[2] is None:
target_timezone, source_time, _ = args
source_timezone = pandas.Series([None] * len(source_time))
return_type = TimestampTimeZone.TZ
else:
source_timezone, target_timezone, source_time = args
return_type = TimestampTimeZone.NTZ
if source_time.sf_type.datatype.tz is not TimestampTimeZone.NTZ:
raise ValueError(
"[Local Testing] convert_timezone can only convert NTZ timestamps when source_timezone is specified."
)
)

# Using dateutil because it uses iana timezones while pytz would use Olson tzdb.
from_tz = None if source_timezone is None else dateutil.tz.gettz(source_timezone)
combined = pandas.concat(
[source_timezone, target_timezone, source_time], axis=1, ignore_index=True
)

if from_tz is not None:
timestamps = [ts.replace(tzinfo=from_tz) for ts in source_time]
return_type = TimestampTimeZone.NTZ
else:
timestamps = list(source_time)
return_type = TimestampTimeZone.TZ
def _convert(row):
source_timezone, target_timezone, source_time = row
if source_time is None:
return None

res = []
for tz, ts in zip(target_timezone, timestamps):
# Add local tz if info is missing
if ts.tzinfo is None:
ts = LocalTimezone.replace_tz(ts)
if source_timezone is not None:
# Using dateutil because it uses iana timezones while pytz would use Olson tzdb.
source_time = source_time.replace(tzinfo=dateutil.tz.gettz(source_timezone))

if source_time.tzinfo is None:
source_time = LocalTimezone.replace_tz(source_time)

result = source_time.astimezone(dateutil.tz.gettz(target_timezone))

if return_type == TimestampTimeZone.NTZ:
result = result.replace(tzinfo=None)

return result

# Convert all timestamps to the target tz
res.append(ts.astimezone(dateutil.tz.gettz(tz)))
res = combined.apply(_convert, axis=1)

return ColumnEmulator(
res,
Expand Down
26 changes: 19 additions & 7 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,11 +605,17 @@ def cleanup_imports():
# And these code would look like:
# res=input.apply(...)
# res.set_sf_type(ColumnType(exp.datatype, exp.nullable)) # fixes the drift and removes NaT

data = []
for _, row in function_input.iterrows():
if udf.strict and any([v is None for v in row]):
result = None
else:
result = remove_null_wrapper(udf_handler(*row))
data.append(result)

res = ColumnEmulator(
data=[
remove_null_wrapper(udf_handler(*row))
for _, row in function_input.iterrows()
],
data=data,
sf_type=ColumnType(exp.datatype, exp.nullable),
name=quote_name(
f"{exp.udf_name}({', '.join(input_data.columns)})".upper()
Expand Down Expand Up @@ -1352,11 +1358,15 @@ def outer_join(base_df):
# (2) A target row is selected to be both updated and deleted

inserted_rows = []
insert_clause_specified = (
update_clause_specified
) = delete_clause_specified = False
inserted_row_idx = set() # source_row_id
deleted_row_idx = set()
updated_row_idx = set()
for clause in source_plan.clauses:
if isinstance(clause, UpdateMergeExpression):
update_clause_specified = True
# Select rows to update
if clause.condition:
condition = calculate_expression(
Expand Down Expand Up @@ -1387,6 +1397,7 @@ def outer_join(base_df):
updated_row_idx.add(row[ROW_ID])

elif isinstance(clause, DeleteMergeExpression):
delete_clause_specified = True
# Select rows to delete
if clause.condition:
condition = calculate_expression(
Expand All @@ -1409,6 +1420,7 @@ def outer_join(base_df):
target = target[~matched]

elif isinstance(clause, InsertMergeExpression):
insert_clause_specified = True
# calculate unmatched rows in the source
matched = source.apply(tuple, 1).isin(
join_result[source.columns].apply(tuple, 1)
Expand Down Expand Up @@ -1493,11 +1505,11 @@ def outer_join(base_df):

# Generate metadata result
res = []
if inserted_rows:
if insert_clause_specified:
res.append(len(inserted_row_idx))
if updated_row_idx:
if update_clause_specified:
res.append(len(updated_row_idx))
if deleted_row_idx:
if delete_clause_specified:
res.append(len(deleted_row_idx))

return [Row(*res)]
Expand Down
5 changes: 5 additions & 0 deletions src/snowflake/snowpark/mock/_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ def __init__(
imports: Set[str],
execute_as: typing.Literal["caller", "owner"] = "owner",
anonymous_sp_sql: Optional[str] = None,
strict=False,
) -> None:
self.imports = imports
self.strict = strict
super().__init__(
func,
return_type,
Expand All @@ -66,6 +68,8 @@ def __call__(
statement_params: Optional[Dict[str, str]] = None,
) -> Any:
args, session = self._validate_call(args, session)
if self.strict and any([arg is None for arg in args]):
return None

# Unpack columns if passed
parsed_args = []
Expand Down Expand Up @@ -344,6 +348,7 @@ def _do_register_sp(
sproc_name,
sproc_imports,
execute_as=execute_as,
strict=strict,
)

self._registry[sproc_name] = sproc
Expand Down
26 changes: 16 additions & 10 deletions src/snowflake/snowpark/mock/_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from snowflake.snowpark.udf import UDFRegistration, UserDefinedFunction


class MockUserDefinedFunction(UserDefinedFunction):
def __init__(self, *args, strict=False, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.strict = strict


class MockUDFRegistration(UDFRegistration):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -121,13 +127,7 @@ def _do_register_udf(
raise ValueError("options replace and if_not_exists are incompatible")

if udf_name in self._registry and if_not_exists:
return UserDefinedFunction(
self._registry[udf_name],
return_type,
input_types,
udf_name,
packages=packages,
)
return self._registry[udf_name]

if udf_name in self._registry and not replace:
raise SnowparkSQLException(
Expand All @@ -151,17 +151,23 @@ def _do_register_udf(

if type(func) is tuple: # register from file
module_name = self._import_file(func[0], udf_name=udf_name)
self._registry[udf_name] = UserDefinedFunction(
self._registry[udf_name] = MockUserDefinedFunction(
(module_name, func[1]),
return_type,
input_types,
udf_name,
strict=strict,
packages=packages,
)
else:
# register from callable
self._registry[udf_name] = UserDefinedFunction(
func, return_type, input_types, udf_name, packages=packages
self._registry[udf_name] = MockUserDefinedFunction(
func,
return_type,
input_types,
udf_name,
strict=strict,
packages=packages,
)

return self._registry[udf_name]
Loading

0 comments on commit 1fc6794

Please sign in to comment.