diff --git a/.github/workflows/precommit.yml b/.github/workflows/precommit.yml index 37af2ad2815..33f831332c2 100644 --- a/.github/workflows/precommit.yml +++ b/.github/workflows/precommit.yml @@ -319,6 +319,66 @@ jobs: .tox/.coverage .tox/coverage.xml + test-local-testing: + name: Test Local Testing Module ${{ matrix.os.download_name }}-${{ matrix.python-version }} + needs: build + runs-on: ${{ matrix.os.image_name }} + strategy: + fail-fast: false + matrix: + os: + - image_name: macos-latest + download_name: macos # it includes doctest + python-version: ["3.8", "3.9", "3.10", "3.11"] + cloud-provider: [aws] + steps: + - name: Checkout Code + uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Display Python version + run: python -c "import sys; print(sys.version)" + - name: Decrypt parameters.py + shell: bash + run: .github/scripts/decrypt_parameters.sh + env: + PARAMETER_PASSWORD: ${{ secrets.PARAMETER_PASSWORD }} + CLOUD_PROVIDER: ${{ matrix.cloud-provider }} + - name: Download wheel(s) + uses: actions/download-artifact@v2 + with: + name: wheel + path: dist + - name: Show wheels downloaded + run: ls -lh dist + shell: bash + - name: Upgrade setuptools, pip and wheel + run: python -m pip install -U setuptools pip wheel + - name: Install tox + run: python -m pip install tox + - name: Run tests + run: python -m tox -e "py${PYTHON_VERSION/\./}-local" + env: + PYTHON_VERSION: ${{ matrix.python-version }} + cloud_provider: ${{ matrix.cloud-provider }} + PYTEST_ADDOPTS: --color=yes --tb=short + TOX_PARALLEL_NO_SPINNER: 1 + SNOWFLAKE_IS_PYTHON_RUNTIME_TEST: 1 + shell: bash + - name: Combine coverages + run: python -m tox -e coverage --skip-missing-interpreters false + shell: bash + env: + SNOWFLAKE_IS_PYTHON_RUNTIME_TEST: 1 + - uses: actions/upload-artifact@v2 + with: + name: coverage_${{ matrix.os.download_name }}-${{ matrix.python-version }}-local-testing + path: | + .tox/.coverage + .tox/coverage.xml + combine-coverage: if: ${{ success() || failure() }} name: Combine coverage diff --git a/CHANGELOG.md b/CHANGELOG.md index f489b24fc90..17196edba13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ## New Features - Added support for `RelationalGroupedDataframe.pivot()` to access `pivot` in the following pattern `Dataframe.group_by(...).pivot(...)`. +- Added experimental feature: Local Testing Mode. ### New Features diff --git a/setup.py b/setup.py index 9dcead7dd22..5bc64585fea 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ "snowflake.snowpark", "snowflake.snowpark._internal", "snowflake.snowpark._internal.analyzer", + "snowflake.snowpark.mock", ], package_dir={ "": "src", @@ -90,6 +91,10 @@ "cachetools", # used in UDF doctest "pytest-timeout", ], + "localtest": [ + "pandas", + "pyarrow", + ], }, classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 54d3cf200a6..ea23116c2c6 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -76,7 +76,13 @@ GroupingSet, GroupingSetsExpression, ) -from snowflake.snowpark._internal.analyzer.select_statement import Selectable +from snowflake.snowpark._internal.analyzer.select_statement import ( + Selectable, + SelectableEntity, + SelectSnowflakePlan, + SelectStatement, + SelectTableFunction, +) from snowflake.snowpark._internal.analyzer.snowflake_plan import ( SnowflakePlan, SnowflakePlanBuilder, @@ -1112,3 +1118,15 @@ def do_resolve_with_resolved_children( raise TypeError( f"Cannot resolve type logical_plan of {type(logical_plan).__name__} to a SnowflakePlan" ) + + def create_select_statement(self, *args, **kwargs): + return SelectStatement(*args, **kwargs) + + def create_selectable_entity(self, *args, **kwargs): + return SelectableEntity(*args, **kwargs) + + def create_select_snowflake_plan(self, *args, **kwargs): + return SelectSnowflakePlan(*args, **kwargs) + + def create_select_table_function(self, *args, **kwargs): + return SelectTableFunction(*args, **kwargs) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index bca93c76e27..2b7a11f4d02 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -23,6 +23,7 @@ from snowflake.snowpark._internal.analyzer.expression import Attribute from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type from snowflake.snowpark._internal.utils import ( + ALREADY_QUOTED, DOUBLE_QUOTE, EMPTY_STRING, TempObjectType, @@ -1342,6 +1343,10 @@ def quote_name_without_upper_casing(name: str) -> str: return DOUBLE_QUOTE + escape_quotes(name) + DOUBLE_QUOTE +def unquote_if_quoted(string): + return string[1:-1].replace('""', '"') if ALREADY_QUOTED.match(string) else string + + # Most integer types map to number(38,0) # https://docs.snowflake.com/en/sql-reference/ # data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 3349f01437b..5be1fe638df 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -22,10 +22,12 @@ def __init__( child: LogicalPlan, probability_fraction: Optional[float] = None, row_count: Optional[int] = None, + seed: Optional[int] = None, ) -> None: super().__init__(child) self.probability_fraction = probability_fraction self.row_count = row_count + self.seed = seed class Sort(UnaryNode): diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index ca5064bf13d..e99ee845608 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -139,6 +139,7 @@ stddev, to_char, ) +from snowflake.snowpark.mock.select_statement import MockSelectStatement from snowflake.snowpark.row import Row from snowflake.snowpark.table_function import ( TableFunctionCall, @@ -506,7 +507,7 @@ def __init__( ) -> None: self._session = session self._plan = self._session._analyzer.resolve(plan) - if isinstance(plan, SelectStatement): + if isinstance(plan, (SelectStatement, MockSelectStatement)): self._select_statement = plan plan.expr_to_alias.update(self._plan.expr_to_alias) plan.df_aliased_col_name_to_real_col_name.update( @@ -1072,8 +1073,8 @@ def select( if self._select_statement: if join_plan: return self._with_plan( - SelectStatement( - from_=SelectSnowflakePlan( + self._session._analyzer.create_select_statement( + from_=self._session._analyzer.create_select_snowflake_plan( join_plan, analyzer=self._session._analyzer ), analyzer=self._session._analyzer, @@ -2014,8 +2015,11 @@ def natural_join( None, ) if self._select_statement: - select_plan = SelectStatement( - from_=SelectSnowflakePlan(join_plan, analyzer=self._session._analyzer), + select_plan = self._session._analyzer.create_select_statement( + from_=self._session._analyzer.create_select_snowflake_plan( + join_plan, + analyzer=self._session._analyzer, + ), analyzer=self._session._analyzer, ) return self._with_plan(select_plan) @@ -2332,7 +2336,7 @@ def join_table_function( project_cols = [*old_cols, *alias_cols] if self._session.sql_simplifier_enabled: - select_plan = SelectStatement( + select_plan = self._session._analyzer.create_select_statement( from_=SelectTableFunction( func_expr, other_plan=self._plan, @@ -2463,8 +2467,8 @@ def _join_dataframes( ) if self._select_statement: return self._with_plan( - SelectStatement( - from_=SelectSnowflakePlan( + self._session._analyzer.create_select_statement( + from_=self._session._analyzer.create_select_snowflake_plan( join_logical_plan, analyzer=self._session._analyzer ), analyzer=self._session._analyzer, @@ -2493,9 +2497,10 @@ def _join_dataframes_internal( ) if self._select_statement: return self._with_plan( - SelectStatement( - from_=SelectSnowflakePlan( - join_logical_plan, analyzer=self._session._analyzer + self._session._analyzer.create_select_statement( + from_=self._session._analyzer.create_select_snowflake_plan( + join_logical_plan, + analyzer=self._session._analyzer, ), analyzer=self._session._analyzer, ) @@ -3305,8 +3310,8 @@ def sample( sample_plan = Sample(self._plan, probability_fraction=frac, row_count=n) if self._select_statement: return self._with_plan( - SelectStatement( - from_=SelectSnowflakePlan( + self._session._analyzer.create_select_statement( + from_=self._session._analyzer.create_select_snowflake_plan( sample_plan, analyzer=self._session._analyzer ), analyzer=self._session._analyzer, @@ -3625,22 +3630,27 @@ def cache_result( A :class:`Table` object that holds the cached result in a temporary table. All operations on this new DataFrame have no effect on the original. """ + from snowflake.snowpark.mock.connection import MockServerConnection + temp_table_name = f'{self._session.get_current_database()}.{self._session.get_current_schema()}."{random_name_for_temp_object(TempObjectType.TABLE)}"' - create_temp_table = self._session._plan_builder.create_temp_table( - temp_table_name, - self._plan, - use_scoped_temp_objects=self._session._use_scoped_temp_objects, - is_generated=True, - ) - self._session._conn.execute( - create_temp_table, - _statement_params=create_or_update_statement_params_with_query_tag( - statement_params or self._statement_params, - self._session.query_tag, - SKIP_LEVELS_TWO, - ), - ) + if isinstance(self._session._conn, MockServerConnection): + self.write.save_as_table(temp_table_name, create_temp_table=True) + else: + create_temp_table = self._session._analyzer.plan_builder.create_temp_table( + temp_table_name, + self._plan, + use_scoped_temp_objects=self._session._use_scoped_temp_objects, + is_generated=True, + ) + self._session._conn.execute( + create_temp_table, + _statement_params=create_or_update_statement_params_with_query_tag( + statement_params or self._statement_params, + self._session.query_tag, + SKIP_LEVELS_TWO, + ), + ) cached_df = self._session.table(temp_table_name) cached_df.is_cached = True return cached_df diff --git a/src/snowflake/snowpark/dataframe_na_functions.py b/src/snowflake/snowpark/dataframe_na_functions.py index 4f643970875..387a2e722d5 100644 --- a/src/snowflake/snowpark/dataframe_na_functions.py +++ b/src/snowflake/snowpark/dataframe_na_functions.py @@ -4,6 +4,7 @@ # import copy +import math import sys from logging import getLogger from typing import Dict, Optional, Union @@ -200,7 +201,7 @@ def drop( df_col_type_dict[normalized_col_name], (FloatType, DoubleType) ): # iff(col = 'NaN' or col is null, 0, 1) - is_na = iff((col == "NaN") | col.is_null(), 0, 1) + is_na = iff((col == math.nan) | col.is_null(), 0, 1) else: # iff(col is null, 0, 1) is_na = iff(col.is_null(), 0, 1) @@ -355,7 +356,7 @@ def fill( if isinstance(datatype, (FloatType, DoubleType)): # iff(col = 'NaN' or col is null, value, col) res_columns.append( - iff((col == "NaN") | col.is_null(), value, col).as_( + iff((col == math.nan) | col.is_null(), value, col).as_( col_name ) ) diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index b5777bd6c7c..254500030b7 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -400,9 +400,9 @@ def csv(self, path: str) -> DataFrame: if self._session.sql_simplifier_enabled: df = DataFrame( self._session, - SelectStatement( - from_=SelectSnowflakePlan( - self._session._plan_builder.read_file( + self._session._analyzer.create_select_statement( + from_=self._session._analyzer.create_select_snowflake_plan( + self._session._analyzer.plan_builder.read_file( path, self._file_type, self._cur_options, @@ -619,6 +619,13 @@ def _infer_schema_for_file_format( return new_schema, schema_to_cast, read_file_transformations, None def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: + from snowflake.snowpark.mock.connection import MockServerConnection + + if isinstance(self._session._conn, MockServerConnection): + raise NotImplementedError( + f"[Local Testing] Support for semi structured file {format} is not implemented." + ) + if self._user_schema: raise ValueError(f"Read {format} does not support user schema") self._file_path = path diff --git a/src/snowflake/snowpark/file_operation.py b/src/snowflake/snowpark/file_operation.py index 03388e4b672..c77fd9b70bb 100644 --- a/src/snowflake/snowpark/file_operation.py +++ b/src/snowflake/snowpark/file_operation.py @@ -128,7 +128,7 @@ def put( ) raise ne.with_traceback(tb) from None else: - plan = self._session._plan_builder.file_operation_plan( + plan = self._session._analyzer.plan_builder.file_operation_plan( "put", normalize_local_file(local_file_name), normalize_remote_file_or_dir(stage_location), diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index cb1a34c16c0..23d8880a1f4 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -2923,7 +2923,7 @@ def char(col: ColumnOrName) -> Column: return builtin("char")(c) -def to_char(c: ColumnOrName, format: Optional[ColumnOrLiteralStr] = None) -> Column: +def to_char(c: ColumnOrName, format: Optional[str] = None) -> Column: """Converts a Unicode code point (including 7-bit ASCII) into the character that matches the input Unicode. diff --git a/src/snowflake/snowpark/mock/__init__.py b/src/snowflake/snowpark/mock/__init__.py new file mode 100644 index 00000000000..1a67d313f9a --- /dev/null +++ b/src/snowflake/snowpark/mock/__init__.py @@ -0,0 +1,7 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from .functions import patch + +__all__ = ["patch"] diff --git a/src/snowflake/snowpark/mock/analyzer.py b/src/snowflake/snowpark/mock/analyzer.py new file mode 100644 index 00000000000..5cd2bde61d1 --- /dev/null +++ b/src/snowflake/snowpark/mock/analyzer.py @@ -0,0 +1,775 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from collections import Counter +from typing import Dict, List, Optional, Union + +import snowflake.snowpark +from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + alias_expression, + binary_arithmetic_expression, + block_expression, + case_when_expression, + cast_expression, + collate_expression, + delete_merge_statement, + flatten_expression, + function_expression, + in_expression, + insert_merge_statement, + like_expression, + list_agg, + named_arguments_function, + order_expression, + quote_name, + range_statement, + rank_related_function_expression, + regexp_expression, + specified_window_frame_expression, + subfield_expression, + subquery_expression, + table_function_partition_spec, + unary_expression, + update_merge_statement, + window_expression, + window_frame_boundary_expression, + window_spec_expression, + within_group_expression, +) +from snowflake.snowpark._internal.analyzer.binary_expression import ( + BinaryArithmeticExpression, + BinaryExpression, +) +from snowflake.snowpark._internal.analyzer.binary_plan_node import Join, SetOperation +from snowflake.snowpark._internal.analyzer.datatype_mapper import ( + str_to_sql, + to_sql, + to_sql_without_cast, +) +from snowflake.snowpark._internal.analyzer.expression import ( + Attribute, + CaseWhen, + Collate, + Expression, + FunctionExpression, + InExpression, + Like, + ListAgg, + Literal, + MultipleExpression, + NamedExpression, + RegExp, + ScalarSubquery, + SnowflakeUDF, + Star, + SubfieldInt, + SubfieldString, + UnresolvedAttribute, + WithinGroup, +) +from snowflake.snowpark._internal.analyzer.grouping_set import ( + GroupingSet, + GroupingSetsExpression, +) +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( + CopyIntoLocationNode, + CopyIntoTableNode, + Limit, + LogicalPlan, + Range, + SnowflakeCreateTable, + SnowflakeValues, + UnresolvedRelation, +) +from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder +from snowflake.snowpark._internal.analyzer.table_function import ( + FlattenFunction, + GeneratorTableFunction, + Lateral, + NamedArgumentsTableFunction, + PosArgumentsTableFunction, + TableFunctionExpression, + TableFunctionJoin, + TableFunctionPartitionSpecDefinition, + TableFunctionRelation, +) +from snowflake.snowpark._internal.analyzer.table_merge_expression import ( + DeleteMergeExpression, + InsertMergeExpression, + TableDelete, + TableMerge, + TableUpdate, + UpdateMergeExpression, +) +from snowflake.snowpark._internal.analyzer.unary_expression import ( + Alias, + Cast, + UnaryExpression, + UnresolvedAlias, +) +from snowflake.snowpark._internal.analyzer.unary_plan_node import ( + Aggregate, + CreateDynamicTableCommand, + CreateViewCommand, + Filter, + Pivot, + Project, + Sample, + Sort, + Unpivot, +) +from snowflake.snowpark._internal.analyzer.window_expression import ( + RankRelatedFunctionExpression, + SpecialFrameBoundary, + SpecifiedWindowFrame, + UnspecifiedFrame, + WindowExpression, + WindowSpecDefinition, +) +from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages +from snowflake.snowpark._internal.telemetry import TelemetryField +from snowflake.snowpark.mock.plan import MockExecutionPlan +from snowflake.snowpark.mock.plan_builder import MockSnowflakePlanBuilder +from snowflake.snowpark.mock.select_statement import ( + MockSelectable, + MockSelectableEntity, + MockSelectExecutionPlan, + MockSelectStatement, +) +from snowflake.snowpark.types import _NumericType + + +def serialize_expression(exp: Expression): + if isinstance(exp, Attribute): + return str(exp) + elif isinstance(exp, UnresolvedAttribute): + return str(exp) + else: + raise TypeError(f"{type(exp)} isn't supported yet in mocking.") + + +class MockAnalyzer: + def __init__(self, session: "snowflake.snowpark.session.Session") -> None: + self.session = session + self.plan_builder = MockSnowflakePlanBuilder(self.session) + self.generated_alias_maps = {} + self.subquery_plans = [] + self.alias_maps_to_use = None + + def analyze( + self, + expr: Union[Expression, NamedExpression], + expr_to_alias: Optional[Dict[str, str]] = None, + parse_local_name=False, + escape_column_name=False, + keep_alias=True, + ) -> Union[str, List[str]]: + """ + Args: + keep_alias: if true, return the column name as "aa as bb", else return the desired column name. + e.g., analyzing an expression sum(col('b')).as_("totB"), we want keep_alias to be true in the + sql simplifier process, which returns column name as sum('b') as 'totB', + so that it will detect column name change. + however, in the result calculation, we want to column name to be the output name, which is 'totB', + so we set keep_alias to False in the execution. + """ + if expr_to_alias is None: + expr_to_alias = {} + if isinstance(expr, GroupingSetsExpression): + raise NotImplementedError( + "[Local Testing] group by grouping sets is not implemented." + ) + + if isinstance(expr, Like): + return like_expression( + self.analyze(expr.expr, expr_to_alias, parse_local_name), + self.analyze(expr.pattern, expr_to_alias, parse_local_name), + ) + + if isinstance(expr, RegExp): + return regexp_expression( + self.analyze(expr.expr, expr_to_alias, parse_local_name), + self.analyze(expr.pattern, expr_to_alias, parse_local_name), + ) + + if isinstance(expr, Collate): + collation_spec = ( + expr.collation_spec.upper() if parse_local_name else expr.collation_spec + ) + return collate_expression( + self.analyze(expr.expr, expr_to_alias, parse_local_name), collation_spec + ) + + if isinstance(expr, (SubfieldString, SubfieldInt)): + field = expr.field + if parse_local_name and isinstance(field, str): + field = field.upper() + return subfield_expression( + self.analyze(expr.expr, expr_to_alias, parse_local_name), field + ) + + if isinstance(expr, CaseWhen): + return case_when_expression( + [ + ( + self.analyze(condition, expr_to_alias, parse_local_name), + self.analyze(value, expr_to_alias, parse_local_name), + ) + for condition, value in expr.branches + ], + self.analyze(expr.else_value, expr_to_alias, parse_local_name) + if expr.else_value + else "NULL", + ) + + if isinstance(expr, MultipleExpression): + return block_expression( + [ + self.analyze(expression, expr_to_alias, parse_local_name) + for expression in expr.expressions + ] + ) + + if isinstance(expr, InExpression): + return in_expression( + self.analyze(expr.columns, expr_to_alias, parse_local_name), + [ + self.analyze(expression, expr_to_alias, parse_local_name) + for expression in expr.values + ], + ) + + if isinstance(expr, GroupingSet): + raise NotImplementedError( + "[Local Testing] group by grouping sets is not implemented." + ) + + if isinstance(expr, WindowExpression): + return window_expression( + self.analyze( + expr.window_function, + parse_local_name=parse_local_name, + ), + self.analyze( + expr.window_spec, + parse_local_name=parse_local_name, + ), + ) + + if isinstance(expr, WindowSpecDefinition): + return window_spec_expression( + [ + self.analyze(x, parse_local_name=parse_local_name) + for x in expr.partition_spec + ], + [ + self.analyze(x, parse_local_name=parse_local_name) + for x in expr.order_spec + ], + self.analyze( + expr.frame_spec, + parse_local_name=parse_local_name, + ), + ) + + if isinstance(expr, SpecifiedWindowFrame): + return specified_window_frame_expression( + expr.frame_type.sql, + self.window_frame_boundary(self.to_sql_avoid_offset(expr.lower, {})), + self.window_frame_boundary(self.to_sql_avoid_offset(expr.upper, {})), + ) + + if isinstance(expr, UnspecifiedFrame): + return "" + if isinstance(expr, SpecialFrameBoundary): + return expr.sql + + if isinstance(expr, Literal): + sql = to_sql(expr.value, expr.datatype) + if parse_local_name: + sql = sql.upper() + return f"{sql}" + + if isinstance(expr, Attribute): + name = expr_to_alias.get(expr.expr_id, expr.name) + return quote_name(name) + + if isinstance(expr, UnresolvedAttribute): + if escape_column_name: + # TODO: ideally we should not escape here + return f"`{expr.name}`" + return expr.name + + if isinstance(expr, FunctionExpression): + if expr.api_call_source is not None: + self.session._conn._telemetry_client.send_function_usage_telemetry( + expr.api_call_source, TelemetryField.FUNC_CAT_USAGE.value + ) + func_name = expr.name.upper() + return function_expression( + func_name, + [self.to_sql_avoid_offset(c, expr_to_alias) for c in expr.children], + expr.is_distinct, + ) + + if isinstance(expr, Star): + if not expr.expressions: + return "*" + else: + return [self.analyze(e, expr_to_alias) for e in expr.expressions] + + if isinstance(expr, SnowflakeUDF): + if expr.api_call_source is not None: + self.session._conn._telemetry_client.send_function_usage_telemetry( + expr.api_call_source, TelemetryField.FUNC_CAT_USAGE.value + ) + func_name = expr.udf_name.upper() if parse_local_name else expr.udf_name + return function_expression( + func_name, + [ + self.analyze(x, expr_to_alias, parse_local_name) + for x in expr.children + ], + False, + ) + + if isinstance(expr, TableFunctionExpression): + if expr.api_call_source is not None: + self.session._conn._telemetry_client.send_function_usage_telemetry( + expr.api_call_source, TelemetryField.FUNC_CAT_USAGE.value + ) + return self.table_function_expression_extractor(expr, expr_to_alias) + + if isinstance(expr, TableFunctionPartitionSpecDefinition): + return table_function_partition_spec( + expr.over, + [ + self.analyze(x, expr_to_alias, parse_local_name) + for x in expr.partition_spec + ] + if expr.partition_spec + else [], + [ + self.analyze(x, expr_to_alias, parse_local_name) + for x in expr.order_spec + ] + if expr.order_spec + else [], + ) + + if isinstance(expr, UnaryExpression): + return self.unary_expression_extractor( + expr, + expr_to_alias, + parse_local_name, + keep_alias=keep_alias, + ) + + if isinstance(expr, SortOrder): + return order_expression( + self.analyze(expr.child, expr_to_alias, parse_local_name), + expr.direction.sql, + expr.null_ordering.sql, + ) + + if isinstance(expr, ScalarSubquery): + self.subquery_plans.append(expr.plan) + return subquery_expression(expr.plan.queries[-1].sql) + + if isinstance(expr, WithinGroup): + return within_group_expression( + self.analyze(expr.expr, expr_to_alias, parse_local_name), + [self.analyze(e, expr_to_alias) for e in expr.order_by_cols], + ) + + if isinstance(expr, BinaryExpression): + return self.binary_operator_extractor( + expr, + expr_to_alias, + parse_local_name, + escape_column_name=escape_column_name, + ) + + if isinstance(expr, InsertMergeExpression): + return insert_merge_statement( + self.analyze(expr.condition, expr_to_alias) if expr.condition else None, + [self.analyze(k, expr_to_alias) for k in expr.keys], + [self.analyze(v, expr_to_alias) for v in expr.values], + ) + + if isinstance(expr, UpdateMergeExpression): + return update_merge_statement( + self.analyze(expr.condition, expr_to_alias) if expr.condition else None, + { + self.analyze(k, expr_to_alias): self.analyze(v, expr_to_alias) + for k, v in expr.assignments.items() + }, + ) + + if isinstance(expr, DeleteMergeExpression): + return delete_merge_statement( + self.analyze(expr.condition, expr_to_alias) if expr.condition else None + ) + + if isinstance(expr, ListAgg): + return list_agg( + self.analyze(expr.col, expr_to_alias, parse_local_name), + str_to_sql(expr.delimiter), + expr.is_distinct, + ) + + if isinstance(expr, RankRelatedFunctionExpression): + return rank_related_function_expression( + expr.sql, + self.analyze(expr.expr, expr_to_alias, parse_local_name), + expr.offset, + self.analyze(expr.default, expr_to_alias, parse_local_name) + if expr.default + else None, + expr.ignore_nulls, + ) + + raise SnowparkClientExceptionMessages.PLAN_INVALID_TYPE( + str(expr) + ) # pragma: no cover + + def table_function_expression_extractor( + self, + expr: TableFunctionExpression, + expr_to_alias: Dict[str, str], + parse_local_name=False, + ) -> str: + if isinstance(expr, FlattenFunction): + return flatten_expression( + self.analyze(expr.input, expr_to_alias, parse_local_name), + expr.path, + expr.outer, + expr.recursive, + expr.mode, + ) + elif isinstance(expr, PosArgumentsTableFunction): + sql = function_expression( + expr.func_name, + [self.analyze(x, expr_to_alias, parse_local_name) for x in expr.args], + False, + ) + elif isinstance(expr, (NamedArgumentsTableFunction, GeneratorTableFunction)): + sql = named_arguments_function( + expr.func_name, + { + key: self.analyze(value, expr_to_alias, parse_local_name) + for key, value in expr.args.items() + }, + ) + else: # pragma: no cover + raise TypeError( + "A table function expression should be any of PosArgumentsTableFunction, " + "NamedArgumentsTableFunction, GeneratorTableFunction, or FlattenFunction." + ) + partition_spec_sql = ( + self.analyze(expr.partition_spec, expr_to_alias) + if expr.partition_spec + else "" + ) + return f"{sql} {partition_spec_sql}" + + def unary_expression_extractor( + self, + expr: UnaryExpression, + expr_to_alias: Dict[str, str], + parse_local_name=False, + keep_alias=True, + ) -> str: + if isinstance(expr, Alias): + quoted_name = quote_name(expr.name) + if isinstance(expr.child, Attribute): + expr_to_alias[expr.child.expr_id] = quoted_name + for k, v in expr_to_alias.items(): + if v == expr.child.name: + expr_to_alias[k] = quoted_name + alias_exp = alias_expression( + self.analyze(expr.child, expr_to_alias, parse_local_name), quoted_name + ) + + expr_str = alias_exp if keep_alias else expr.name or keep_alias + expr_str = expr_str.upper() if parse_local_name else expr_str + return expr_str + if isinstance(expr, UnresolvedAlias): + expr_str = self.analyze(expr.child, expr_to_alias, parse_local_name) + if parse_local_name: + expr_str = expr_str.upper() + return expr_str + elif isinstance(expr, Cast): + return cast_expression( + self.analyze(expr.child, expr_to_alias, parse_local_name), + expr.to, + expr.try_, + ) + else: + return unary_expression( + self.analyze(expr.child, expr_to_alias, parse_local_name), + expr.sql_operator, + expr.operator_first, + ) + + def binary_operator_extractor( + self, + expr: BinaryExpression, + expr_to_alias: Dict[str, str], + parse_local_name=False, + escape_column_name=False, + ) -> str: + operator = expr.sql_operator.lower() + if isinstance(expr, BinaryArithmeticExpression): + return binary_arithmetic_expression( + operator, + self.analyze( + expr.left, + expr_to_alias, + parse_local_name, + escape_column_name=escape_column_name, + ), + self.analyze( + expr.right, + expr_to_alias, + parse_local_name, + escape_column_name=escape_column_name, + ), + ) + else: + return function_expression( + operator, + [ + self.analyze( + expr.left, + expr_to_alias, + parse_local_name, + escape_column_name=escape_column_name, + ), + self.analyze( + expr.right, + expr_to_alias, + parse_local_name, + escape_column_name=escape_column_name, + ), + ], + False, + ) + + def grouping_extractor( + self, expr: GroupingSet, expr_to_alias: Dict[str, str] + ) -> str: + return self.analyze( + FunctionExpression( + expr.pretty_name.upper(), + [c.child if isinstance(c, Alias) else c for c in expr.children], + False, + ), + expr_to_alias, + ) + + def window_frame_boundary(self, offset: str) -> str: + try: + num = int(offset) + return window_frame_boundary_expression(str(abs(num)), num >= 0) + except Exception: + return offset + + def to_sql_avoid_offset( + self, expr: Expression, expr_to_alias: Dict[str, str], parse_local_name=False + ) -> str: + # if expression is a numeric literal, return the number without casting, + # otherwise process as normal + if isinstance(expr, Literal) and isinstance(expr.datatype, _NumericType): + return to_sql_without_cast(expr.value, expr.datatype) + else: + return self.analyze(expr, expr_to_alias, parse_local_name) + + def resolve( + self, logical_plan: LogicalPlan, expr_to_alias: Optional[Dict[str, str]] = None + ) -> MockExecutionPlan: + self.subquery_plans = [] + if expr_to_alias is None: + expr_to_alias = {} + result = self.do_resolve(logical_plan, expr_to_alias) + + if self.subquery_plans: + result = result.with_subqueries(self.subquery_plans) + + return result + + def do_resolve( + self, logical_plan: LogicalPlan, expr_to_alias: Dict[str, str] + ) -> MockExecutionPlan: + resolved_children = {} + expr_to_alias_maps = {} + for c in logical_plan.children: + _expr_to_alias = {} + resolved_children[c] = self.resolve(c, _expr_to_alias) + expr_to_alias_maps[c] = _expr_to_alias + + # get counts of expr_to_alias keys + counts = Counter() + for v in expr_to_alias_maps.values(): + counts.update(list(v.keys())) + + # Keep only non-shared expr_to_alias keys + # let (df1.join(df2)).join(df2.join(df3)).select(df2) report error + for v in expr_to_alias_maps.values(): + expr_to_alias.update({p: q for p, q in v.items() if counts[p] < 2}) + + return self.do_resolve_with_resolved_children( + logical_plan, resolved_children, expr_to_alias + ) + + def do_resolve_with_resolved_children( + self, + logical_plan: LogicalPlan, + resolved_children: Dict[LogicalPlan, SnowflakePlan], + expr_to_alias: Dict[str, str], + ) -> MockExecutionPlan: + if isinstance(logical_plan, MockExecutionPlan): + return logical_plan + if isinstance(logical_plan, TableFunctionJoin): + raise NotImplementedError( + "[Local Testing] Table function is currently not supported." + ) + + if isinstance(logical_plan, TableFunctionRelation): + raise NotImplementedError( + "[Local Testing] table function is not implemented." + ) + + if isinstance(logical_plan, Lateral): + raise NotImplementedError("[Local Testing] Lateral is not implemented.") + + if isinstance(logical_plan, Aggregate): + return MockExecutionPlan( + logical_plan, + self.session, + ) + + if isinstance(logical_plan, Project): + return logical_plan + + if isinstance(logical_plan, Filter): + return logical_plan + + # Add a sample stop to the plan being built + if isinstance(logical_plan, Sample): + return MockExecutionPlan(logical_plan, self.session) + + if isinstance(logical_plan, Join): + return MockExecutionPlan(logical_plan, self.session) + + if isinstance(logical_plan, Sort): + return self.plan_builder.sort( + list(map(self.analyze, logical_plan.order)), + resolved_children[logical_plan.child], + logical_plan, + ) + + if isinstance(logical_plan, SetOperation): + return self.plan_builder.set_operator( + resolved_children[logical_plan.left], + resolved_children[logical_plan.right], + logical_plan.sql, + logical_plan, + ) + + if isinstance(logical_plan, Range): + # schema of Range. Since this corresponds to the Snowflake column "id" + # (quoted lower-case) it's a little hard for users. So we switch it to + # the column name "ID" == id == Id + return self.plan_builder.query( + range_statement( + logical_plan.start, logical_plan.end, logical_plan.step, "id" + ), + logical_plan, + ) + + if isinstance(logical_plan, SnowflakeValues): + return MockExecutionPlan(logical_plan, self.session) + + if isinstance(logical_plan, UnresolvedRelation): + return MockExecutionPlan(logical_plan, self.session) + + if isinstance(logical_plan, SnowflakeCreateTable): + return MockExecutionPlan(logical_plan, self.session) + + if isinstance(logical_plan, Limit): + on_top_of_order_by = isinstance( + logical_plan.child, SnowflakePlan + ) and isinstance(logical_plan.child.source_plan, Sort) + return self.plan_builder.limit( + self.to_sql_avoid_offset(logical_plan.limit_expr, expr_to_alias), + self.to_sql_avoid_offset(logical_plan.offset_expr, expr_to_alias), + resolved_children[logical_plan.child], + on_top_of_order_by, + logical_plan, + ) + + if isinstance(logical_plan, Pivot): + raise NotImplementedError("[Local Testing] Pivot is not implemented.") + + if isinstance(logical_plan, Unpivot): + raise NotImplementedError( + "[Local Testing] DataFrame.unpivot is not currently supported." + ) + + if isinstance(logical_plan, CreateViewCommand): + return MockExecutionPlan(logical_plan, self.session) + + if isinstance(logical_plan, CopyIntoTableNode): + raise NotImplementedError( + "[Local Testing] Copy into table is currently not supported." + ) + + if isinstance(logical_plan, CopyIntoLocationNode): + return self.plan_builder.copy_into_location( + query=resolved_children[logical_plan.child], + stage_location=logical_plan.stage_location, + partition_by=self.analyze(logical_plan.partition_by, expr_to_alias) + if logical_plan.partition_by + else None, + file_format_name=logical_plan.file_format_name, + file_format_type=logical_plan.file_format_type, + format_type_options=logical_plan.format_type_options, + header=logical_plan.header, + **logical_plan.copy_options, + ) + + if isinstance(logical_plan, TableUpdate): + raise NotImplementedError( + "[Local Testing] Table update is not implemented." + ) + + if isinstance(logical_plan, TableDelete): + raise NotImplementedError( + "[Local Testing] Table delete is not implemented." + ) + + if isinstance(logical_plan, CreateDynamicTableCommand): + raise NotImplementedError( + "[Local Testing] Dynamic tables are currently not supported." + ) + + if isinstance(logical_plan, TableMerge): + raise NotImplementedError( + "[Local Testing] Table merge is currently not implemented." + ) + + if isinstance(logical_plan, MockSelectable): + return MockExecutionPlan(logical_plan, self.session) + + def create_select_statement(self, *args, **kwargs): + return MockSelectStatement(*args, **kwargs) + + def create_select_snowflake_plan(self, *args, **kwargs): + return MockSelectExecutionPlan(*args, **kwargs) + + def create_selectable_entity(self, *args, **kwargs): + return MockSelectableEntity(*args, **kwargs) diff --git a/src/snowflake/snowpark/mock/connection.py b/src/snowflake/snowpark/mock/connection.py new file mode 100644 index 00000000000..dd777b3c774 --- /dev/null +++ b/src/snowflake/snowpark/mock/connection.py @@ -0,0 +1,629 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import functools +import json +import os +import sys +import time +from copy import copy +from decimal import Decimal +from logging import getLogger +from typing import IO, Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from unittest.mock import Mock + +import snowflake.connector +from snowflake.connector.cursor import ResultMetadata, SnowflakeCursor +from snowflake.connector.errors import NotSupportedError, ProgrammingError +from snowflake.connector.network import ReauthenticationRequest +from snowflake.connector.options import pandas +from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + escape_quotes, + quote_name, + quote_name_without_upper_casing, + unquote_if_quoted, +) +from snowflake.snowpark._internal.analyzer.expression import Attribute +from snowflake.snowpark._internal.analyzer.snowflake_plan import ( + BatchInsertQuery, + SnowflakePlan, +) +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SaveMode +from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages +from snowflake.snowpark._internal.utils import ( + is_in_stored_procedure, + normalize_local_file, + parse_table_name, + result_set_to_rows, + unwrap_stage_location_single_quote, +) +from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType +from snowflake.snowpark.exceptions import SnowparkSQLException +from snowflake.snowpark.mock.plan import MockExecutionPlan, execute_mock_plan +from snowflake.snowpark.mock.snowflake_data_type import TableEmulator +from snowflake.snowpark.row import Row +from snowflake.snowpark.types import ( + ArrayType, + DecimalType, + MapType, + VariantType, + _IntegralType, +) + +logger = getLogger(__name__) + +# set `paramstyle` to qmark for batch insertion +snowflake.connector.paramstyle = "qmark" + +# parameters needed for usage tracking +PARAM_APPLICATION = "application" +PARAM_INTERNAL_APPLICATION_NAME = "internal_application_name" +PARAM_INTERNAL_APPLICATION_VERSION = "internal_application_version" + +# The module variable _CUSTOM_JSON_ENCODER is used to customize JSONEncoder when dumping json object +# into string, to use it, set: +# snowflake.snowpark.mock.connection._CUSTOM_JSON_ENCODER = +_CUSTOM_JSON_ENCODER = None + + +def _build_put_statement(*args, **kwargs): + raise NotImplementedError() + + +def _build_target_path(stage_location: str, dest_prefix: str = "") -> str: + qualified_stage_name = unwrap_stage_location_single_quote(stage_location) + dest_prefix_name = ( + dest_prefix + if not dest_prefix or dest_prefix.startswith("/") + else f"/{dest_prefix}" + ) + return f"{qualified_stage_name}{dest_prefix_name if dest_prefix_name else ''}" + + +class MockServerConnection: + class TabularEntityRegistry: + # Registry to store tables and views. + def __init__(self, conn: "MockServerConnection") -> None: + self.table_registry = {} + self.view_registry = {} + self.conn = conn + + def get_fully_qualified_name(self, name: Union[str, Iterable[str]]) -> str: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + if isinstance(name, str): + name = parse_table_name(name) + if len(name) == 1: + name = [current_schema] + name + if len(name) == 2: + name = [current_database] + name + return ".".join(quote_name(n) for n in name) + + def is_existing_table(self, name: Union[str, Iterable[str]]) -> bool: + qualified_name = self.get_fully_qualified_name(name) + return qualified_name in self.table_registry + + def is_existing_view(self, name: Union[str, Iterable[str]]) -> bool: + qualified_name = self.get_fully_qualified_name(name) + return qualified_name in self.view_registry + + def read_table(self, name: Union[str, Iterable[str]]) -> TableEmulator: + qualified_name = self.get_fully_qualified_name(name) + if qualified_name in self.table_registry: + return copy(self.table_registry[qualified_name]) + else: + raise SnowparkSQLException( + f"Object '{name}' does not exist or not authorized." + ) + + def write_table( + self, name: Union[str, Iterable[str]], table: TableEmulator, mode: SaveMode + ) -> Row: + name = self.get_fully_qualified_name(name) + table = copy(table) + if mode == SaveMode.APPEND: + # Fix append by index + if name in self.table_registry: + target_table = self.table_registry[name] + table.columns = target_table.columns + self.table_registry[name] = pandas.concat( + [target_table, table], ignore_index=True + ) + self.table_registry[name].sf_types = target_table.sf_types + else: + self.table_registry[name] = table + elif mode == SaveMode.IGNORE: + if name not in self.table_registry: + self.table_registry[name] = table + elif mode == SaveMode.OVERWRITE: + self.table_registry[name] = table + elif mode == SaveMode.ERROR_IF_EXISTS: + if name in self.table_registry: + raise SnowparkSQLException(f"Table {name} already exists") + else: + self.table_registry[name] = table + else: + raise ProgrammingError(f"Unrecognized mode: {mode}") + return [ + Row(status=f"Table {name} successfully created.") + ] # TODO: match message + + def drop_table(self, name: Union[str, Iterable[str]]) -> None: + name = self.get_fully_qualified_name(name) + if name in self.table_registry: + self.table_registry.pop(name) + + def create_or_replace_view( + self, execution_plan: MockExecutionPlan, name: Union[str, Iterable[str]] + ): + name = self.get_fully_qualified_name(name) + self.view_registry[name] = execution_plan + + def get_review(self, name: Union[str, Iterable[str]]) -> MockExecutionPlan: + name = self.get_fully_qualified_name(name) + if name in self.view_registry: + return self.view_registry[name] + raise SnowparkSQLException(f"View {name} does not exist") + + class _Decorator: + @classmethod + def wrap_exception(cls, func): + def wrap(*args, **kwargs): + try: + return func(*args, **kwargs) + except ReauthenticationRequest as ex: + raise SnowparkClientExceptionMessages.SERVER_SESSION_EXPIRED( + ex.cause + ) + except Exception as ex: + raise ex + + return wrap + + @classmethod + def log_msg_and_perf_telemetry(cls, msg): + def log_and_telemetry(func): + @functools.wraps(func) + def wrap(*args, **kwargs): + logger.debug(msg) + start_time = time.perf_counter() + result = func(*args, **kwargs) + end_time = time.perf_counter() + duration = end_time - start_time + sfqid = result["sfqid"] if result and "sfqid" in result else None + # If we don't have a query id, then its pretty useless to send perf telemetry + if sfqid: + args[0]._telemetry_client.send_upload_file_perf_telemetry( + func.__name__, duration, sfqid + ) + logger.debug(f"Finished in {duration:.4f} secs") + return result + + return wrap + + return log_and_telemetry + + def __init__(self) -> None: + self._conn = Mock() + self._cursor = Mock() + self.remove_query_listener = Mock() + self.add_query_listener = Mock() + self._telemetry_client = Mock() + self.entity_registry = MockServerConnection.TabularEntityRegistry(self) + self._conn._session_parameters = { + "ENABLE_ASYNC_QUERY_IN_PYTHON_STORED_PROCS": False, + "_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING": True, + "_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING": True, + } + + def _get_client_side_session_parameter(self, name: str, default_value: Any) -> Any: + # mock implementation + return ( + self._conn._session_parameters.get(name, default_value) + if self._conn._session_parameters + else default_value + ) + + def get_session_id(self) -> int: + return 1 + + def close(self) -> None: + if self._conn: + self._conn.close() + + def is_closed(self) -> bool: + return self._conn.is_closed() + + @_Decorator.wrap_exception + def _get_current_parameter(self, param: str, quoted: bool = True) -> Optional[str]: + name = getattr(self._conn, param) or self._get_string_datum( + f"SELECT CURRENT_{param.upper()}()" + ) + if param == "database": + return '"mock_database"' + if param == "schema": + return '"mock_schema"' + if param == "warehouse": + return '"mock_warehouse"' + return ( + (quote_name_without_upper_casing(name) if quoted else escape_quotes(name)) + if name + else None + ) + + def _get_string_datum(self, query: str) -> Optional[str]: + rows = result_set_to_rows(self.run_query(query)["data"]) + return rows[0][0] if len(rows) > 0 else None + + # @SnowflakePlan.Decorator.wrap_exception + # def get_result_attributes(self, query: str) -> List[Attribute]: + # return convert_result_meta_to_attribute(self._cursor.describe(query)) + + @_Decorator.log_msg_and_perf_telemetry("Uploading file to stage") + def upload_file( + self, + path: str, + stage_location: str, + dest_prefix: str = "", + parallel: int = 4, + compress_data: bool = True, + source_compression: str = "AUTO_DETECT", + overwrite: bool = False, + ) -> Optional[Dict[str, Any]]: + if is_in_stored_procedure(): # pragma: no cover + file_name = os.path.basename(path) + target_path = _build_target_path(stage_location, dest_prefix) + try: + # upload_stream directly consume stage path, so we don't need to normalize it + self._cursor.upload_stream( + open(path, "rb"), f"{target_path}/{file_name}" + ) + except ProgrammingError as pe: + tb = sys.exc_info()[2] + ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR( + pe + ) + raise ne.with_traceback(tb) from None + else: + uri = normalize_local_file(path) + return self.run_query( + _build_put_statement( + uri, + stage_location, + dest_prefix, + parallel, + compress_data, + source_compression, + overwrite, + ) + ) + + @_Decorator.log_msg_and_perf_telemetry("Uploading stream to stage") + def upload_stream( + self, + input_stream: IO[bytes], + stage_location: str, + dest_filename: str, + dest_prefix: str = "", + parallel: int = 4, + compress_data: bool = True, + source_compression: str = "AUTO_DETECT", + overwrite: bool = False, + is_in_udf: bool = False, + ) -> Optional[Dict[str, Any]]: + raise NotImplementedError( + "[Local Testing] PUT stream is currently not supported." + ) + + @_Decorator.wrap_exception + def run_query( + self, + query: str, + to_pandas: bool = False, + to_iter: bool = False, + is_ddl_on_temp_object: bool = False, + block: bool = True, + data_type: _AsyncResultType = _AsyncResultType.ROW, + async_job_plan: Optional[ + SnowflakePlan + ] = None, # this argument is currently only used by AsyncJob + **kwargs, + ) -> Union[Dict[str, Any], AsyncJob]: + raise NotImplementedError( + "[Local Testing] Running SQL queries is not supported." + ) + + def _to_data_or_iter( + self, + results_cursor: SnowflakeCursor, + to_pandas: bool = False, + to_iter: bool = False, + ) -> Dict[str, Any]: + if to_pandas: + try: + data_or_iter = ( + map( + functools.partial( + _fix_pandas_df_integer, results_cursor=results_cursor + ), + results_cursor.fetch_pandas_batches(), + ) + if to_iter + else _fix_pandas_df_integer( + results_cursor.fetch_pandas_all(), results_cursor + ) + ) + except NotSupportedError: + data_or_iter = ( + iter(results_cursor) if to_iter else results_cursor.fetchall() + ) + except KeyboardInterrupt: + raise + except BaseException as ex: + raise SnowparkClientExceptionMessages.SERVER_FAILED_FETCH_PANDAS( + str(ex) + ) + else: + data_or_iter = ( + iter(results_cursor) if to_iter else results_cursor.fetchall() + ) + + return {"data": data_or_iter, "sfqid": results_cursor.sfqid} + + def execute( + self, + plan: MockExecutionPlan, + to_pandas: bool = False, + to_iter: bool = False, + block: bool = True, + data_type: _AsyncResultType = _AsyncResultType.ROW, + case_sensitive: bool = True, + **kwargs, + ) -> Union[ + List[Row], "pandas.DataFrame", Iterator[Row], Iterator["pandas.DataFrame"] + ]: + if not case_sensitive: + raise NotImplementedError( + "[Local Testing] Case insensitive DataFrame.collect is currently not supported." + ) + if not block: + raise NotImplementedError( + "[Local Testing] Async jobs are currently not supported." + ) + + res = execute_mock_plan(plan) + if isinstance(res, TableEmulator): + # stringfy the variant type in the result df + for col in res.columns: + if isinstance( + res.sf_types[col].datatype, (ArrayType, MapType, VariantType) + ): + for row in range(len(res[col])): + if res[col][row] is not None: + res.loc[row, col] = json.dumps( + res[col][row], cls=_CUSTOM_JSON_ENCODER, indent=2 + ) + else: + # snowflake returns Python None instead of the str 'null' for DataType data + res.loc[row, col] = ( + "null" if row in res._null_rows_idxs_map[col] else None + ) + + # when setting output rows, snowpark python running against snowflake don't escape double quotes + # in column names. while in the local testing calculation, double quotes are preserved. + # to align with snowflake behavior, we unquote name here + columns = [unquote_if_quoted(col_name) for col_name in res.columns] + rows = [] + # TODO: SNOW-976145, move to index based approach to store col type mapping + # for now we only use the index based approach in aggregation functions + if res.sf_types_by_col_index: + keys = sorted(res.sf_types_by_col_index.keys()) + sf_types = [res.sf_types_by_col_index[key] for key in keys] + else: + sf_types = list(res.sf_types.values()) + for pdr in res.itertuples(index=False, name=None): + row = Row( + *[ + Decimal(str(v)) + if isinstance(sf_types[i].datatype, DecimalType) + and v is not None + else v + for i, v in enumerate(pdr) + ] + ) + row._fields = columns + rows.append(row) + elif isinstance(res, list): + rows = res + + if to_pandas: + pandas_df = pandas.DataFrame() + for col_name in res.columns: + pandas_df[unquote_if_quoted(col_name)] = res[col_name].tolist() + rows = _fix_pandas_df_integer(res) + + # the following implementation is just to make DataFrame.to_pandas_batches API workable + # in snowflake, large data result are split into multiple data chunks + # and sent back to the client, thus it makes sense to have the generator + # however, local testing is designed for local testing + # we do not mock the splitting into data chunks behavior + rows = [rows] if to_iter else rows + + if to_iter: + return iter(rows) + + return rows + + @SnowflakePlan.Decorator.wrap_exception + def get_result_set( + self, + plan: SnowflakePlan, + to_pandas: bool = False, + to_iter: bool = False, + block: bool = True, + data_type: _AsyncResultType = _AsyncResultType.ROW, + **kwargs, + ) -> Tuple[ + Dict[ + str, + Union[ + List[Any], + "pandas.DataFrame", + SnowflakeCursor, + Iterator["pandas.DataFrame"], + str, + ], + ], + List[ResultMetadata], + ]: + action_id = plan.session._generate_new_action_id() + + result, result_meta = None, None + try: + placeholders = {} + is_batch_insert = False + for q in plan.queries: + if isinstance(q, BatchInsertQuery): + is_batch_insert = True + break + # since batch insert does not support async execution (? in the query), we handle it separately here + if len(plan.queries) > 1 and not block and not is_batch_insert: + final_query = f"""EXECUTE IMMEDIATE $$ +DECLARE + res resultset; +BEGIN + {";".join(q.sql for q in plan.queries[:-1])}; + res := ({plan.queries[-1].sql}); + return table(res); +END; +$$""" + # In multiple queries scenario, we are unable to get the query id of former query, so we replace + # place holder with fucntion last_query_id() here + for q in plan.queries: + final_query = final_query.replace( + f"'{q.query_id_place_holder}'", "LAST_QUERY_ID()" + ) + + result = self.run_query( + final_query, + to_pandas, + to_iter, + is_ddl_on_temp_object=plan.queries[0].is_ddl_on_temp_object, + block=block, + data_type=data_type, + async_job_plan=plan, + **kwargs, + ) + + # since we will return a AsyncJob instance, result_meta is not needed, we will create result_meta in + # AsyncJob instance when needed + result_meta = None + if action_id < plan.session._last_canceled_id: + raise SnowparkClientExceptionMessages.SERVER_QUERY_IS_CANCELLED() + else: + for i, query in enumerate(plan.queries): + if isinstance(query, BatchInsertQuery): + self.run_batch_insert(query.sql, query.rows, **kwargs) + else: + is_last = i == len(plan.queries) - 1 and not block + final_query = query.sql + for holder, id_ in placeholders.items(): + final_query = final_query.replace(holder, id_) + result = self.run_query( + final_query, + to_pandas, + to_iter and (i == len(plan.queries) - 1), + is_ddl_on_temp_object=query.is_ddl_on_temp_object, + block=not is_last, + data_type=data_type, + async_job_plan=plan, + **kwargs, + ) + placeholders[query.query_id_place_holder] = ( + result["sfqid"] if not is_last else result.query_id + ) + result_meta = self._cursor.description + if action_id < plan.session._last_canceled_id: + raise SnowparkClientExceptionMessages.SERVER_QUERY_IS_CANCELLED() + finally: + # delete created tmp object + if block: + for action in plan.post_actions: + self.run_query( + action.sql, + is_ddl_on_temp_object=action.is_ddl_on_temp_object, + block=block, + **kwargs, + ) + + if result is None: + raise SnowparkClientExceptionMessages.SQL_LAST_QUERY_RETURN_RESULTSET() + + return result, result_meta + + def get_result_and_metadata( + self, plan: SnowflakePlan, **kwargs + ) -> Tuple[List[Row], List[Attribute]]: + res = execute_mock_plan(plan) + attrs = [ + Attribute( + name=quote_name(column_name.strip()), + datatype=res[column_name].sf_type, + ) + for column_name in res.columns.tolist() + ] + + rows = [ + Row(*[res.iloc[i, j] for j in range(len(attrs))]) for i in range(len(res)) + ] + return rows, attrs + + def get_result_query_id(self, plan: SnowflakePlan, **kwargs) -> str: + # get the iterator such that the data is not fetched + result_set, _ = self.get_result_set(plan, to_iter=True, **kwargs) + return result_set["sfqid"] + + +def _fix_pandas_df_integer(table_res: TableEmulator) -> "pandas.DataFrame": + pd_df = pandas.DataFrame() + for col_name in table_res.columns: + col_sf_type = table_res.sf_types[col_name] + pd_df_col_name = unquote_if_quoted(col_name) + if ( + isinstance(col_sf_type.datatype, DecimalType) + and col_sf_type.datatype.precision is not None + and col_sf_type.datatype.scale == 0 + and not str(table_res[col_name].dtype).startswith("int") + ): + # if decimal is set to default 38, we auto-detect the dtype, see the following code + # df = session.create_dataframe( + # data=[[decimal.Decimal(1)]], + # schema=StructType([StructField("d", DecimalType())]) + # ) + # df.to_pandas() # the returned df is of dtype int8, instead of dtype int64 + if col_sf_type.datatype.precision == 38: + pd_df[pd_df_col_name] = pandas.to_numeric( + table_res[col_name], downcast="integer" + ) + continue + + # this is to mock the behavior that precision is explicitly set to non-default value 38 + # optimize pd.DataFrame dtype of integer to align the behavior with live connection + if col_sf_type.datatype.precision <= 2: + pd_df[pd_df_col_name] = table_res[col_name].astype("int8") + elif col_sf_type.datatype.precision <= 4: + pd_df[pd_df_col_name] = table_res[col_name].astype("int16") + elif col_sf_type.datatype.precision <= 8: + pd_df[pd_df_col_name] = table_res[col_name].astype("int32") + else: + pd_df[pd_df_col_name] = table_res[col_name].astype("int64") + elif isinstance(col_sf_type.datatype, _IntegralType): + pd_df[pd_df_col_name] = pandas.to_numeric( + table_res[col_name].tolist(), downcast="integer" + ) + else: + pd_df[pd_df_col_name] = table_res[col_name].tolist() + + return pd_df diff --git a/src/snowflake/snowpark/mock/file_operation.py b/src/snowflake/snowpark/mock/file_operation.py new file mode 100644 index 00000000000..f4b64b0950c --- /dev/null +++ b/src/snowflake/snowpark/mock/file_operation.py @@ -0,0 +1,190 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import os +from collections import defaultdict +from functools import partial +from logging import getLogger +from typing import TYPE_CHECKING, Dict, List + +from snowflake.connector.options import pandas as pd +from snowflake.snowpark._internal.analyzer.expression import Attribute +from snowflake.snowpark.exceptions import SnowparkSQLException +from snowflake.snowpark.mock.snowflake_data_type import ( + ColumnEmulator, + ColumnType, + TableEmulator, +) +from snowflake.snowpark.mock.snowflake_to_pandas_converter import CONVERT_MAP +from snowflake.snowpark.types import DecimalType, StringType + +if TYPE_CHECKING: + from snowflake.snowpark.mock.analyzer import MockAnalyzer + + +_logger = getLogger(__name__) + +# this map keeps the information for remote file to local file, +# key is the remote 'stage/file', value is the local file path +_FILE_STAGE_MAP = {} + +# this map keeps the information the files each stage stores +# key is the remote 'stage', value is the set of file names stored in the stage +_STAGE_FILE_MAP = defaultdict(set) + +PUT_RESULT_KEYS = [ + "source", + "target", + "source_size", + "target_size", + "source_compression", + "target_compression", + "status", + "message", +] + +SUPPORTED_CSV_READ_OPTIONS = ( + "SKIP_HEADER", + "SKIP_BLANK_LINES", + "FIELD_DELIMITER", + "FIELD_OPTIONALLY_ENCLOSED_BY", +) + + +def put( + local_file_name: str, stage_location: str, auto_compress: bool = True +) -> TableEmulator: + """ + Put a file into in memory map, key being stage location and value being the local file path + """ + local_file_name = local_file_name[ + len("`file://") : -1 + ] # skip normalized prefix `file:// and suffix ` + file_name = os.path.basename(local_file_name) + remote_file_path = f"{stage_location}/{file_name}" + _FILE_STAGE_MAP[remote_file_path] = local_file_name + _STAGE_FILE_MAP[f"{stage_location}/"].add(local_file_name) + file_size = os.path.getsize(os.path.expanduser(local_file_name)) + + result_df = TableEmulator( + columns=PUT_RESULT_KEYS, + sf_types={ + "source": ColumnType(StringType(), True), + "target": ColumnType(StringType(), True), + "source_size": ColumnType(DecimalType(10, 0), True), + "target_size": ColumnType(DecimalType(10, 0), True), + "source_compression": ColumnType(StringType(), True), + "target_compression": ColumnType(StringType(), True), + "status": ColumnType(StringType(), True), + "message": ColumnType(StringType(), True), + }, + dtype=object, + ) + result_df.loc[len(result_df)] = { + k: v + for k, v in zip( + PUT_RESULT_KEYS, + [file_name, file_name, file_size, file_size, None, None, None, None], + ) + } + return result_df + + +def read_file( + stage_location, + format: str, + schema: List[Attribute], + analyzer: "MockAnalyzer", + options: Dict[str, str], +) -> TableEmulator: + try: + local_files = [_FILE_STAGE_MAP[stage_location]] + except KeyError: + if stage_location and stage_location[0] != "@": + raise SnowparkSQLException("SQL compilation error") + if stage_location in _STAGE_FILE_MAP.keys(): + # all files within the stage + local_files = list(_STAGE_FILE_MAP[stage_location]) + else: + raise SnowparkSQLException( + f"[Local Testing] file {stage_location} can not be found, please use " + "`session.file.put` to put the file to a local stage first." + ) + if format.lower() == "csv": + for option in options: + if option not in SUPPORTED_CSV_READ_OPTIONS: + _logger.warning( + f"[Local Testing] read file option {option} is not supported." + ) + skip_header = options.get("SKIP_HEADER", 0) + skip_blank_lines = options.get("SKIP_BLANK_LINES", False) + field_delimiter = options.get("FIELD_DELIMITER", ",") + field_optionally_enclosed_by = options.get("FIELD_OPTIONALLY_ENCLOSED_BY", None) + if ( + field_delimiter[0] + and field_delimiter[-1] == "'" + and len(field_delimiter) >= 2 + ): + # extract the field_delimiter as field_delimiter is normalized to be single quoted + # e.g. field_delimiter="'.'", we should remove the normalized single quotes to extract the single char "." + field_delimiter = field_delimiter[1:-1] + + # construct the returning dataframe + result_df = TableEmulator() + result_df_sf_types = {} + converters_dict = {} + for i in range(len(schema)): + column_name = analyzer.analyze(schema[i]) + column_series = ColumnEmulator(data=None, dtype=object, name=column_name) + column_series.sf_type = ColumnType(schema[i].datatype, schema[i].nullable) + result_df[column_name] = column_series + result_df_sf_types[column_name] = column_series.sf_type + if type(column_series.sf_type.datatype) not in CONVERT_MAP: + _logger.warning( + f"[Local Testing] Reading snowflake data type {type(column_series.sf_type.datatype)} is not supported. It will be treated as a raw string in the dataframe." + ) + continue + converter = CONVERT_MAP[type(column_series.sf_type.datatype)] + converters_dict[i] = ( + partial( + converter, + datatype=column_series.sf_type.datatype, + field_optionally_enclosed_by=field_optionally_enclosed_by, + ) + if field_optionally_enclosed_by + else partial(converter, datatype=column_series.sf_type.datatype) + ) + + for local_file in local_files: + # pre-read to check columns number + df = pd.read_csv( + local_file, + header=None, + skiprows=skip_header, + skip_blank_lines=skip_blank_lines, + delimiter=field_delimiter, + ) + df.dtype = object + if len(df.columns) != len(schema): + raise SnowparkSQLException( + f"Number of columns in file ({len(df.columns)}) does not match that of" + f" the corresponding table ({len(schema)})." + ) + + # read again with converters dict + df = pd.read_csv( + local_file, + header=None, + skiprows=skip_header, + skip_blank_lines=skip_blank_lines, + delimiter=field_delimiter, + dtype=object, + converters=converters_dict, + quoting=3, # QUOTE_NONE + ) + # set df columns to be result_df columns such that it can be concatenated + df.columns = result_df.columns + result_df = pd.concat([result_df, df], ignore_index=True) + result_df.sf_types = result_df_sf_types + return result_df + raise NotImplementedError(f"[Local Testing] File format {format} is not supported.") diff --git a/src/snowflake/snowpark/mock/functions.py b/src/snowflake/snowpark/mock/functions.py new file mode 100644 index 00000000000..13dfd595722 --- /dev/null +++ b/src/snowflake/snowpark/mock/functions.py @@ -0,0 +1,916 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import base64 +import binascii +import datetime +import json +import math +from decimal import Decimal +from functools import partial +from numbers import Real +from typing import Any, Callable, Optional, Union + +from snowflake.snowpark.exceptions import SnowparkSQLException +from snowflake.snowpark.mock.snowflake_data_type import ( + ColumnEmulator, + ColumnType, + TableEmulator, +) +from snowflake.snowpark.types import ( + ArrayType, + BinaryType, + BooleanType, + DateType, + DecimalType, + DoubleType, + LongType, + MapType, + NullType, + StringType, + TimestampType, + TimeType, + VariantType, + _NumericType, +) + +from .util import ( + convert_snowflake_datetime_format, + process_numeric_time, + process_string_time_with_fractional_seconds, +) + +RETURN_TYPE = Union[ColumnEmulator, TableEmulator] + +_MOCK_FUNCTION_IMPLEMENTATION_MAP = {} +# The module variable _CUSTOM_JSON_DECODER is used to custom JSONDecoder when decoing string, to use it, set: +# snowflake.snowpark.mock.functions._CUSTOM_JSON_DECODER = +_CUSTOM_JSON_DECODER = None + + +def _register_func_implementation( + snowpark_func: Union[str, Callable], func_implementation: Callable +): + try: + _MOCK_FUNCTION_IMPLEMENTATION_MAP[snowpark_func.__name__] = func_implementation + except AttributeError: + _MOCK_FUNCTION_IMPLEMENTATION_MAP[snowpark_func] = func_implementation + + +def _unregister_func_implementation(snowpark_func: Union[str, Callable]): + try: + try: + del _MOCK_FUNCTION_IMPLEMENTATION_MAP[snowpark_func.__name__] + except AttributeError: + del _MOCK_FUNCTION_IMPLEMENTATION_MAP[snowpark_func] + except KeyError: + pass + + +def patch(function): + def decorator(mocking_function): + _register_func_implementation(function, mocking_function) + + def wrapper(*args, **kwargs): + mocking_function(*args, **kwargs) + + return wrapper + + return decorator + + +@patch("min") +def mock_min(column: ColumnEmulator) -> ColumnEmulator: + if isinstance( + column.sf_type.datatype, _NumericType + ): # TODO: figure out where 5 is coming from + return ColumnEmulator(data=round(column.min(), 5), sf_type=column.sf_type) + res = ColumnEmulator(data=column.dropna().min(), sf_type=column.sf_type) + try: + if math.isnan(res[0]): + return ColumnEmulator(data=[None], sf_type=column.sf_type) + return ColumnEmulator(data=res, sf_type=column.sf_type) + except TypeError: # math.isnan throws TypeError if res[0] is not a number + return ColumnEmulator(data=res, sf_type=column.sf_type) + + +@patch("max") +def mock_max(column: ColumnEmulator) -> ColumnEmulator: + if isinstance(column.sf_type.datatype, _NumericType): + return ColumnEmulator(data=round(column.max(), 5), sf_type=column.sf_type) + res = ColumnEmulator(data=column.dropna().max(), sf_type=column.sf_type) + try: + if math.isnan(res[0]): + return ColumnEmulator(data=[None], sf_type=column.sf_type) + return ColumnEmulator(data=res, sf_type=column.sf_type) + except TypeError: + return ColumnEmulator(data=res, sf_type=column.sf_type) + + +@patch("sum") +def mock_sum(column: ColumnEmulator) -> ColumnEmulator: + all_item_is_none = True + res = 0 + for data in column: + if data is not None: + try: + if math.isnan(data): + continue + except TypeError: + pass + all_item_is_none = False + try: + res += float(data) + except ValueError: + raise SnowparkSQLException(f"Numeric value '{data}' is not recognized.") + if isinstance(column.sf_type.datatype, DecimalType): + p, s = column.sf_type.datatype.precision, column.sf_type.datatype.scale + new_type = DecimalType(min(38, p + 12), s) + else: + new_type = column.sf_type.datatype + return ( + ColumnEmulator( + data=[res], sf_type=ColumnType(new_type, column.sf_type.nullable) + ) + if not all_item_is_none + else ColumnEmulator( + data=[None], sf_type=ColumnType(new_type, column.sf_type.nullable) + ) + ) + + +@patch("avg") +def mock_avg(column: ColumnEmulator) -> ColumnEmulator: + all_item_is_none = True + ret = 0 + cnt = 0 + for data in column: + if data is not None: + all_item_is_none = False + ret += float(data) + cnt += 1 + + ret = ( + ColumnEmulator(data=[round((ret / cnt), 5)]) + if not all_item_is_none + else ColumnEmulator(data=[None]) + ) + ret.sf_type = column.sf_type + return ret + + +@patch("count") +def mock_count(column: Union[TableEmulator, ColumnEmulator]) -> ColumnEmulator: + if isinstance(column, ColumnEmulator): + count_column = column.count() + return ColumnEmulator(data=count_column, sf_type=ColumnType(LongType(), False)) + else: # TableEmulator + return ColumnEmulator(data=len(column), sf_type=ColumnType(LongType(), False)) + + +@patch("count_distinct") +def mock_count_distinct(*cols: ColumnEmulator) -> ColumnEmulator: + """ + Snowflake does not count rows that contain NULL values, in the mocking implementation + we iterate over each row and then each col to check if there exists NULL value, if the col is NULL, + we do not count that row. + """ + dict_data = {} + for i in range(len(cols)): + dict_data[f"temp_col_{i}"] = cols[i] + rows = len(cols[0]) + temp_table = TableEmulator(dict_data, index=[i for i in range(len(cols[0]))]) + temp_table = temp_table.reset_index() + to_drop_index = set() + for col in cols: + for i in range(rows): + if col[col.index[i]] is None: + to_drop_index.add(i) + break + temp_table = temp_table.drop(index=list(to_drop_index)) + temp_table = temp_table.drop_duplicates(subset=list(dict_data.keys())) + count_column = temp_table.count() + if isinstance(count_column, ColumnEmulator): + count_column.sf_type = ColumnType(LongType(), False) + return ColumnEmulator( + data=round(count_column, 5), sf_type=ColumnType(LongType(), False) + ) + + +@patch("median") +def mock_median(column: ColumnEmulator) -> ColumnEmulator: + if isinstance(column.sf_type.datatype, DecimalType): + return_type = DecimalType( + column.sf_type.datatype.precision + 3, column.sf_type.datatype.scale + 3 + ) + else: + return_type = column.sf_type.datatype + return ColumnEmulator( + data=round(column.median(), 5), + sf_type=ColumnType(return_type, column.sf_type.nullable), + ) + + +@patch("covar_pop") +def mock_covar_pop(column1: ColumnEmulator, column2: ColumnEmulator) -> ColumnEmulator: + non_nan_cnt = 0 + x_sum, y_sum, x_times_y_sum = 0, 0, 0 + for x, y in zip(column1, column2): + if x is not None and y is not None and not math.isnan(x) and not math.isnan(y): + non_nan_cnt += 1 + x_times_y_sum += x * y + x_sum += x + y_sum += y + data = (x_times_y_sum - x_sum * y_sum / non_nan_cnt) / non_nan_cnt + return ColumnEmulator( + data=data, + sf_type=ColumnType( + DoubleType(), column1.sf_type.nullable or column2.sf_type.nullable + ), + ) + + +@patch("listagg") +def mock_listagg(column: ColumnEmulator, delimiter: str, is_distinct: bool): + columns_data = ColumnEmulator(column.unique()) if is_distinct else column + # nit todo: returns a string that includes all the non-NULL input values, separated by the delimiter. + return ColumnEmulator( + data=delimiter.join([str(v) for v in columns_data.dropna()]), + sf_type=ColumnType(StringType(16777216), column.sf_type.nullable), + ) + + +@patch("to_date") +def mock_to_date( + column: ColumnEmulator, + fmt: str = None, + try_cast: bool = False, +): + """ + Converts an input expression to a date: + + [x] For a string expression, the result of converting the string to a date. + + [x] For a timestamp expression, the date from the timestamp. + + For a variant expression: + + [x] If the variant contains a string, a string conversion is performed. + + [ ] If the variant contains a date, the date value is preserved as is. + + [ ] If the variant contains a JSON null value, the output is NULL. + + [x] For NULL input, the output is NULL. + + [ ] For all other values, a conversion error is generated. + """ + res = [] + auto_detect = bool(not fmt) + + date_format, _, _ = convert_snowflake_datetime_format( + fmt, default_format="%Y-%m-%d" + ) + + for data in column: + if data is None: + res.append(None) + continue + try: + if auto_detect and data.isnumeric(): + res.append( + datetime.datetime.utcfromtimestamp( + process_numeric_time(data) + ).date() + ) + else: + res.append(datetime.datetime.strptime(data, date_format).date()) + except BaseException: + if try_cast: + res.append(None) + else: + raise + return ColumnEmulator( + data=res, sf_type=ColumnType(DateType(), column.sf_type.nullable) + ) + + +@patch("contains") +def mock_contains(expr1: ColumnEmulator, expr2: ColumnEmulator): + if isinstance(expr1, str) and isinstance(expr2, str): + return ColumnEmulator(data=[bool(str(expr2) in str(expr1))]) + if isinstance(expr1, ColumnEmulator) and isinstance(expr2, ColumnEmulator): + res = [bool(str(item2) in str(item1)) for item1, item2 in zip(expr1, expr2)] + elif isinstance(expr1, ColumnEmulator) and isinstance(expr2, str): + res = [bool(str(expr2) in str(item)) for item in expr1] + else: # expr1 is string, while expr2 is column + res = [bool(str(item) in str(expr1)) for item in expr2] + return ColumnEmulator( + data=res, sf_type=ColumnType(BooleanType(), expr1.sf_type.nullable) + ) + + +@patch("abs") +def mock_abs(expr): + if isinstance(expr, ColumnEmulator): + result = expr.abs() + result.sf_type = expr.sf_type + return result + else: + return abs(expr) + + +@patch("to_decimal") +def mock_to_decimal( + e: ColumnEmulator, + precision: Optional[int] = 38, + scale: Optional[int] = 0, + try_cast: bool = False, +): + """ + [x] For NULL input, the result is NULL. + + [ ] For fixed-point numbers: + + Numbers with different scales are converted by either adding zeros to the right (if the scale needs to be increased) or by reducing the number of fractional digits by rounding (if the scale needs to be decreased). + + Note that casts of fixed-point numbers to fixed-point numbers that increase scale might fail. + + [ ] For floating-point numbers: + + Numbers are converted if they are within the representable range, given the scale. + + The conversion between binary and decimal fractional numbers is not precise. This might result in loss of precision or out-of-range errors. + + Values of infinity and NaN (not-a-number) result in conversion errors. + + For floating-point input, omitting the mantissa or exponent is allowed and is interpreted as 0. Thus, E is parsed as 0. + + [ ] Strings are converted as decimal, integer, fractional, or floating-point numbers. + + [x] For fractional input, the precision is deduced as the number of digits after the point. + + For VARIANT input: + + [ ] If the variant contains a fixed-point or a floating-point numeric value, an appropriate numeric conversion is performed. + + [ ] If the variant contains a string, a string conversion is performed. + + [ ] If the variant contains a Boolean value, the result is 0 or 1 (for false and true, correspondingly). + + [ ] If the variant contains JSON null value, the output is NULL. + """ + res = [] + + for data in e: + if data is None: + res.append(data) + continue + try: + try: + float(data) + except ValueError: + raise SnowparkSQLException(f"Numeric value '{data}' is not recognized.") + + integer_part = round(float(data)) + integer_part_str = str(integer_part) + len_integer_part = ( + len(integer_part_str) - 1 + if integer_part_str[0] == "-" + else len(integer_part_str) + ) + if len_integer_part > precision: + raise SnowparkSQLException(f"Numeric value '{data}' is out of range") + remaining_decimal_len = min(precision - len(str(integer_part)), scale) + res.append(Decimal(str(round(float(data), remaining_decimal_len)))) + except BaseException: + if try_cast: + res.append(None) + else: + raise + + return ColumnEmulator( + data=res, + sf_type=ColumnType(DecimalType(precision, scale), nullable=e.sf_type.nullable), + ) + + +@patch("to_time") +def mock_to_time( + column: ColumnEmulator, + fmt: Optional[str] = None, + try_cast: bool = False, +): + """ + [ ] For string_expr, the result of converting the string to a time. + + [ ] For timestamp_expr, the time portion of the input value. + + [ ] For 'integer' (a string containing an integer), the integer is treated as a number of seconds, milliseconds, microseconds, or nanoseconds after the start of the Unix epoch. See the Usage Notes below. + + [ ] For this timestamp, the function gets the number of seconds after the start of the Unix epoch. The function performs a modulo operation to get the remainder from dividing this number by the number of seconds in a day (86400): number_of_seconds % 86400 + + """ + res = [] + + auto_detect = bool(not fmt) + + time_fmt, hour_delta, fractional_seconds = convert_snowflake_datetime_format( + fmt, default_format="%H:%M:%S" + ) + for data in column: + try: + if data is None: + res.append(None) + continue + if auto_detect and data.isnumeric(): + res.append( + datetime.datetime.utcfromtimestamp( + process_numeric_time(data) + ).time() + ) + else: + # handle seconds fraction + data_parts = data.split(".") + if len(data_parts) == 2: + # there is a part of seconds + seconds_part = data_parts[1] + # find the idx that the seconds part ends + idx = 0 + while seconds_part[idx].isdigit(): + idx += 1 + # truncate to precision + seconds_part = ( + seconds_part[: min(idx, fractional_seconds)] + + seconds_part[idx:] + ) + data = f"{data_parts[0]}.{seconds_part}" + res.append( + ( + datetime.datetime.strptime( + process_string_time_with_fractional_seconds( + data, fractional_seconds + ), + time_fmt, + ) + + datetime.timedelta(hours=hour_delta) + ).time() + ) + except BaseException: + if try_cast: + data.append(None) + else: + raise + + return ColumnEmulator( + data=res, sf_type=ColumnType(TimeType(), column.sf_type.nullable) + ) + + +@patch("to_timestamp") +def mock_to_timestamp( + column: ColumnEmulator, + fmt: Optional[str] = None, + try_cast: bool = False, +): + """ + [x] For NULL input, the result will be NULL. + + [ ] For string_expr: timestamp represented by a given string. If the string does not have a time component, midnight will be used. + + [ ] For date_expr: timestamp representing midnight of a given day will be used, according to the specific timestamp flavor (NTZ/LTZ/TZ) semantics. + + [ ] For timestamp_expr: a timestamp with possibly different flavor than the source timestamp. + + [ ] For numeric_expr: a timestamp representing the number of seconds (or fractions of a second) provided by the user. Note, that UTC time is always used to build the result. + + For variant_expr: + + [ ] If the variant contains JSON null value, the result will be NULL. + + [ ] If the variant contains a timestamp value of the same kind as the result, this value will be preserved as is. + + [ ] If the variant contains a timestamp value of the different kind, the conversion will be done in the same way as from timestamp_expr. + + [ ] If the variant contains a string, conversion from a string value will be performed (using automatic format). + + [ ] If the variant contains a number, conversion as if from numeric_expr will be performed. + + [ ] If conversion is not possible, an error is returned. + + If the format of the input parameter is a string that contains an integer: + + After the string is converted to an integer, the integer is treated as a number of seconds, milliseconds, microseconds, or nanoseconds after the start of the Unix epoch (1970-01-01 00:00:00.000000000 UTC). + + [ ] If the integer is less than 31536000000 (the number of milliseconds in a year), then the value is treated as a number of seconds. + + [ ] If the value is greater than or equal to 31536000000 and less than 31536000000000, then the value is treated as milliseconds. + + [ ] If the value is greater than or equal to 31536000000000 and less than 31536000000000000, then the value is treated as microseconds. + + [ ] If the value is greater than or equal to 31536000000000000, then the value is treated as nanoseconds. + """ + res = [] + auto_detect = bool(not fmt) + default_format = "%Y-%m-%d %H:%M:%S.%f" + ( + timestamp_format, + hour_delta, + fractional_seconds, + ) = convert_snowflake_datetime_format(fmt, default_format=default_format) + + for data in column: + try: + if data is None: + res.append(None) + continue + if auto_detect and ( + isinstance(data, int) or (isinstance(data, str) and data.isnumeric()) + ): + res.append( + datetime.datetime.utcfromtimestamp(process_numeric_time(data)) + ) + else: + # handle seconds fraction + try: + datetime_data = datetime.datetime.strptime( + process_string_time_with_fractional_seconds( + data, fractional_seconds + ), + timestamp_format, + ) + except ValueError: + # when creating df from pandas df, datetime doesn't come with microseconds + # leading to ValueError when using the default format + # but it's still a valid format to snowflake, so we use format code without microsecond to parse + if timestamp_format == default_format: + datetime_data = datetime.datetime.strptime( + process_string_time_with_fractional_seconds( + data, fractional_seconds + ), + "%Y-%m-%d %H:%M:%S", + ) + else: + raise + res.append(datetime_data + datetime.timedelta(hours=hour_delta)) + except BaseException: + if try_cast: + res.append(None) + else: + raise + + return ColumnEmulator( + data=res, + sf_type=ColumnType(TimestampType(), column.sf_type.nullable), + dtype=object, + ) + + +def try_convert(convert: Callable, try_cast: bool, val: Any): + if val is None: + return None + try: + return convert(val) + except BaseException: + if try_cast: + return None + else: + raise + + +@patch("to_char") +def mock_to_char( + column: ColumnEmulator, + fmt: Optional[str] = None, + try_cast: bool = False, +) -> ColumnEmulator: # TODO: support more input types + source_datatype = column.sf_type.datatype + + if isinstance(source_datatype, DateType): + date_format, _, _ = convert_snowflake_datetime_format( + fmt, default_format="%Y-%m-%d" + ) + func = partial( + try_convert, lambda x: datetime.datetime.strftime(x, date_format), try_cast + ) + elif isinstance(source_datatype, TimeType): + raise NotImplementedError( + "[Local Testing] Use TO_CHAR on Time data is not supported yet" + ) + elif isinstance(source_datatype, (DateType, TimeType, TimestampType)): + raise NotImplementedError( + "[Local Testing] Use TO_CHAR on Timestamp data is not supported yet" + ) + elif isinstance(source_datatype, _NumericType): + if fmt: + raise NotImplementedError( + "[Local Testing] Use format strings with Numeric types in TO_CHAR is not supported yet." + ) + func = partial(try_convert, lambda x: str(x), try_cast) + else: + func = partial(try_convert, lambda x: str(x), try_cast) + new_col = column.apply(func) + new_col.sf_type = ColumnType(StringType(), column.sf_type.nullable) + return new_col + + +@patch("to_double") +def mock_to_double( + column: ColumnEmulator, fmt: Optional[str] = None, try_cast: bool = False +) -> ColumnEmulator: + """ + [ ] Fixed-point numbers are converted to floating point; the conversion cannot fail, but might result in loss of precision. + + [ ] Strings are converted as decimal integer or fractional numbers, scientific notation and special values (nan, inf, infinity) are accepted. + + For VARIANT input: + + [ ] If the variant contains a fixed-point value, the numeric conversion will be performed. + + [ ] If the variant contains a floating-point value, the value will be preserved unchanged. + + [ ] If the variant contains a string, a string conversion will be performed. + + [ ] If the variant contains a Boolean value, the result will be 0 or 1 (for false and true, correspondingly). + + [ ] If the variant contains JSON null value, the output will be NULL. + + Note that conversion of decimal fractions to binary and back is not precise (i.e. printing of a floating-point number converted from decimal representation might produce a slightly diffe + """ + if fmt: + raise NotImplementedError( + "[Local Testing] Using format strings in to_double is not supported yet" + ) + if isinstance(column.sf_type.datatype, (_NumericType, StringType)): + res = column.apply(lambda x: try_convert(float, try_cast, x)) + res.sf_type = ColumnType(DoubleType(), column.sf_type.nullable) + return res + elif isinstance(column.sf_type.datatype, VariantType): + raise NotImplementedError("[Local Testing] Variant is not supported yet") + else: + raise NotImplementedError( + f"[Local Testing[ Invalid type {column.sf_type.datatype} for parameter 'TO_DOUBLE'" + ) + + +@patch("to_boolean") +def mock_to_boolean(column: ColumnEmulator, try_cast: bool = False) -> ColumnEmulator: + """ + [x] For a text expression, the string must be: + + 'true', 't', 'yes', 'y', 'on', '1' return TRUE. + + 'false', 'f', 'no', 'n', 'off', '0' return FALSE. + + All other strings return an error. + + Strings are case-insensitive. + + [x] For a numeric expression: + + 0 returns FALSE. + + All non-zero numeric values return TRUE. + + When converting from the FLOAT data type, non-numeric values, such as ‘NaN’ (not a number) and ‘INF’ (infinity), cause an error. + + + """ + if isinstance(column.sf_type, StringType): + + def convert_str_to_bool(x: Optional[str]): + if x is None: + return None + elif x.lower() in ("true", "t", "yes", "y", "on", "1"): + return True + elif x.lower() in ("false", "f", "no", "n", "off", "0"): + return False + raise SnowparkSQLException(f"Boolean value {x} is not recognized") + + new_col = column.apply(lambda x: try_convert(convert_str_to_bool, try_cast, x)) + new_col.sf_type = ColumnType(BooleanType(), column.sf_type.nullable) + return new_col + elif isinstance(column.sf_type, _NumericType): + + def convert_num_to_bool(x: Optional[Real]): + if x is None: + return None + elif math.isnan(x) or math.isinf(x): + raise SnowparkSQLException( + f"Invalid value {x} for parameter 'TO_BOOLEAN'" + ) + else: + return x != 0 + + new_col = column.apply(lambda x: try_convert(convert_num_to_bool, try_cast, x)) + new_col.sf_type = ColumnType(BooleanType(), column.sf_type.nullable) + return new_col + else: + raise SnowparkSQLException( + f"Invalid type {column.sf_type.datatype} for parameter 'TO_BOOLEAN'" + ) + + +@patch("to_binary") +def mock_to_binary( + column: ColumnEmulator, fmt: str = None, try_cast: bool = False +) -> ColumnEmulator: + """ + [x] TO_BINARY( [, ''] ) + [ ] TO_BINARY( ) + """ + if isinstance(column.sf_type.datatype, (StringType, NullType)): + fmt = fmt.upper() if fmt else "HEX" + if fmt == "HEX": + res = column.apply(lambda x: try_convert(binascii.unhexlify, try_cast, x)) + elif fmt == "BASE64": + res = column.apply(lambda x: try_convert(base64.b64decode, try_cast, x)) + elif fmt == "UTF-8": + res = column.apply( + lambda x: try_convert(lambda y: y.encode("utf-8"), try_cast, x) + ) + else: + raise SnowparkSQLException(f"Invalid binary format {fmt}") + res.sf_type = ColumnType(BinaryType(), column.sf_type.nullable) + return res + else: + raise SnowparkSQLException( + f"Invalid type {column.sf_type.datatype} for parameter 'TO_BINARY'" + ) + + +@patch("iff") +def mock_iff(condition: ColumnEmulator, expr1: ColumnEmulator, expr2: ColumnEmulator): + assert isinstance(condition.sf_type.datatype, BooleanType) + if ( + all(condition) + or all(~condition) + or ( + isinstance(expr1.sf_type.datatype, StringType) + and isinstance(expr2.sf_type.datatype, StringType) + ) + or expr1.sf_type.datatype == expr2.sf_type.datatype + or isinstance(expr1.sf_type.datatype, NullType) + or isinstance(expr2.sf_type.datatype, NullType) + ): + res = ColumnEmulator(data=[None] * len(condition), dtype=object) + if isinstance(expr1.sf_type.datatype, StringType) and isinstance( + expr2.sf_type.datatype, StringType + ): + l1 = expr1.sf_type.datatype.length or StringType._MAX_LENGTH + l2 = expr2.sf_type.datatype.length or StringType._MAX_LENGTH + sf_data_type = StringType(max(l1, l2)) + else: + sf_data_type = ( + expr1.sf_type.datatype + if any(condition) and not isinstance(expr1.sf_type.datatype, NullType) + else expr2.sf_type.datatype + ) + nullability = expr1.sf_type.nullable and expr2.sf_type.nullable + res.sf_type = ColumnType(sf_data_type, nullability) + res.where(condition, other=expr2, inplace=True) + res.where([not x for x in condition], other=expr1, inplace=True) + return res + else: + raise SnowparkSQLException( + f"[Local Testing] does not support coercion currently, iff expr1 and expr2 have conflicting data types: {expr1.sf_type} != {expr2.sf_type}" + ) + + +@patch("coalesce") +def mock_coalesce(*exprs): + import pandas + + if len(exprs) < 2: + raise SnowparkSQLException( + f"not enough arguments for function [COALESCE], got {len(exprs)}, expected at least two" + ) + res = pandas.Series( + exprs[0] + ) # workaround because sf_type is not inherited properly + for expr in exprs: + res = res.combine_first(expr) + return ColumnEmulator(data=res, sf_type=exprs[0].sf_type, dtype=object) + + +@patch("substring") +def mock_substring( + base_expr: ColumnEmulator, start_expr: ColumnEmulator, length_expr: ColumnEmulator +): + res = [ + x[y - 1 : y + z - 1] if x is not None else None + for x, y, z in zip(base_expr, start_expr, length_expr) + ] + res = ColumnEmulator( + res, sf_type=ColumnType(StringType(), base_expr.sf_type.nullable), dtype=object + ) + return res + + +@patch("startswith") +def mock_startswith(expr1: ColumnEmulator, expr2: ColumnEmulator): + res = [x.startswith(y) if x is not None else None for x, y in zip(expr1, expr2)] + res = ColumnEmulator( + res, sf_type=ColumnType(BooleanType(), expr1.sf_type.nullable), dtype=bool + ) + return res + + +@patch("endswith") +def mock_endswith(expr1: ColumnEmulator, expr2: ColumnEmulator): + res = [x.endswith(y) if x is not None else None for x, y in zip(expr1, expr2)] + res = ColumnEmulator( + res, sf_type=ColumnType(BooleanType(), expr1.sf_type.nullable), dtype=bool + ) + return res + + +@patch("row_number") +def mock_row_number(window: TableEmulator, row_idx: int): + return ColumnEmulator(data=[row_idx + 1], sf_type=ColumnType(LongType(), False)) + + +@patch("upper") +def mock_upper(expr: ColumnEmulator): + res = expr.apply(lambda x: x.upper()) + res.sf_type = ColumnType(StringType(), expr.sf_type.nullable) + return res + + +@patch("parse_json") +def mock_parse_json(expr: ColumnEmulator): + if isinstance(expr.sf_type.datatype, StringType): + res = expr.apply( + lambda x: try_convert( + partial(json.loads, cls=_CUSTOM_JSON_DECODER), False, x + ) + ) + else: + res = expr.copy() + res.sf_type = ColumnType(VariantType(), expr.sf_type.nullable) + return res + + +@patch("to_array") +def mock_to_array(expr: ColumnEmulator): + """ + [x] If the input is an ARRAY, or VARIANT containing an array value, the result is unchanged. + + [ ] For NULL or (TODO:) a JSON null input, returns NULL. + + [x] For any other value, the result is a single-element array containing this value. + """ + if isinstance(expr.sf_type.datatype, ArrayType): + res = expr.copy() + elif isinstance(expr.sf_type.datatype, VariantType): + res = expr.apply( + lambda x: try_convert(lambda y: y if isinstance(y, list) else [y], False, x) + ) + else: + res = expr.apply(lambda x: try_convert(lambda y: [y], False, x)) + res.sf_type = ColumnType(ArrayType(), expr.sf_type.nullable) + return res + + +@patch("to_object") +def mock_to_object(expr: ColumnEmulator): + """ + [x] For a VARIANT value containing an OBJECT, returns the OBJECT. + + [ ] For NULL input, or for (TODO:) a VARIANT value containing only JSON null, returns NULL. + + [x] For an OBJECT, returns the OBJECT itself. + + [x] For all other input values, reports an error. + """ + if isinstance(expr.sf_type.datatype, (MapType,)): + res = expr.copy() + elif isinstance(expr.sf_type.datatype, VariantType): + + def raise_exc(val): + raise SnowparkSQLException( + f"Invalid object of type {type(val)} passed to 'TO_OBJECT'" + ) + + res = expr.apply( + lambda x: try_convert( + lambda y: y if isinstance(y, dict) else raise_exc(y), False, x + ) + ) + else: + + def raise_exc(): + raise SnowparkSQLException( + f"Invalid type {type(expr.sf_type.datatype)} parameter 'TO_OBJECT'" + ) + + res = expr.apply(lambda x: try_convert(raise_exc, False, x)) + res.sf_type = ColumnType(MapType(), expr.sf_type.nullable) + return res + + +@patch("to_variant") +def mock_to_variant(expr: ColumnEmulator): + res = expr.copy() + res.sf_type = ColumnType(VariantType(), expr.sf_type.nullable) + return res diff --git a/src/snowflake/snowpark/mock/pandas_util.py b/src/snowflake/snowpark/mock/pandas_util.py new file mode 100644 index 00000000000..349cff496a1 --- /dev/null +++ b/src/snowflake/snowpark/mock/pandas_util.py @@ -0,0 +1,221 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import math +from typing import TYPE_CHECKING, Any, List, Tuple + +from snowflake.connector.options import pandas as pd +from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + quote_name_without_upper_casing, +) +from snowflake.snowpark._internal.type_utils import infer_type +from snowflake.snowpark.exceptions import SnowparkClientException +from snowflake.snowpark.table import Table +from snowflake.snowpark.types import ( + ArrayType, + BooleanType, + DecimalType, + DoubleType, + LongType, + MapType, + NullType, + StringType, + StructField, + StructType, + TimestampType, + VariantType, +) + +if TYPE_CHECKING: + from snowflake.snowpark import DataFrame, Session + + +def _extract_schema_and_data_from_pandas_df( + data: "pd.DataFrame", +) -> Tuple[StructType, List[List[Any]]]: + """ + infer column types from the pandas data + when running against snowflake, infer_schema (https://docs.snowflake.com/en/sql-reference/functions/infer_schema) + is used to infer schema. + + pandas type related doc: https://pandas.pydata.org/docs/user_guide/basics.html#dtypes + """ + import numpy + + # PANDAS_INTEGER_TYPES defined here to avoid module level referencing pandas lib + # as pandas is optional to snowpark-python + PANDAS_INTEGER_TYPES = ( + pd.Int8Dtype, + pd.Int16Dtype, + pd.Int32Dtype, + pd.Int64Dtype, + pd.UInt8Dtype, + pd.UInt16Dtype, + pd.UInt32Dtype, + pd.UInt64Dtype, + ) + + col_names = [ + quote_name_without_upper_casing(name) for name in data.columns.values.tolist() + ] + plain_data = [data.iloc[i].tolist() for i in range(data.shape[0])] + inferred_type_dict = ( + {} + ) # this map is to store types for columns in which data are of primitive python objects + for row_idx in range(data.shape[0]): + for col_idx in range(data.shape[1]): + if plain_data[row_idx][col_idx] is None: + continue + if isinstance(plain_data[row_idx][col_idx], (float, numpy.float_)): + # in pandas, a float is represented in type numpy.float64 + # which can not be inferred by snowpark python, we cast to built-in float type + if math.isnan(plain_data[row_idx][col_idx]): + # in snowflake, math.nan in a pandas DataFrame is treated as None + plain_data[row_idx][col_idx] = None + else: + # pandas PANDAS_INTEGER_TYPES (e.g. INT8Dtye) will also store data in the format of float64 + # here we use the col dtype info to convert data + plain_data[row_idx][col_idx] = ( + int(data.iloc[row_idx][col_idx]) + if isinstance(data.dtypes[col_idx], PANDAS_INTEGER_TYPES) + else float(str(data.iloc[row_idx][col_idx])) + ) + elif isinstance(plain_data[row_idx][col_idx], numpy.float32): + # convert str first and then to float to avoid precision drift as its stored in float32 format + plain_data[row_idx][col_idx] = float(str(plain_data[row_idx][col_idx])) + elif isinstance(plain_data[row_idx][col_idx], numpy.bool_): + plain_data[row_idx][col_idx] = bool(plain_data[row_idx][col_idx]) + elif isinstance( + plain_data[row_idx][col_idx], + (numpy.signedinteger, numpy.unsignedinteger), + ): + plain_data[row_idx][col_idx] = int(plain_data[row_idx][col_idx]) + elif isinstance(plain_data[row_idx][col_idx], pd.Timestamp): + if isinstance(data.dtypes[col_idx], pd.DatetimeTZDtype): + # this is to align with the current snowflake behavior that it + # apply the tz diff to time and then removes the tz information during ingestion + plain_data[row_idx][col_idx] = ( + plain_data[row_idx][col_idx] + .tz_convert("UTC") + .tz_localize(None) + .to_pydatetime() + ) + else: + # pandas.Timestamp.value gives nanoseconds + # snowpark will convert it to microseconds + plain_data[row_idx][col_idx] = int( + plain_data[row_idx][col_idx].value / 1000 + ) + elif isinstance(plain_data[row_idx][col_idx], pd.Timedelta): + # pandas.Timedetla.value gives nanoseconds + # snowflake keeps the unit of nanoarrow seconds + plain_data[row_idx][col_idx] = plain_data[row_idx][col_idx].value + elif isinstance(plain_data[row_idx][col_idx], pd.Interval): + + def convert_to_python_obj(obj): + if isinstance(obj, numpy.float_): + return float(obj) + elif isinstance(obj, numpy.int_): + return int(obj) + elif isinstance(obj, pd.Timestamp): + return int(obj.value / 1000) + else: + raise NotImplementedError( + f"[Local Testing] {type(obj)} within pandas.Interval is not supported." + ) + + plain_data[row_idx][col_idx] = { + "left": convert_to_python_obj(plain_data[row_idx][col_idx].left), + "right": convert_to_python_obj(plain_data[row_idx][col_idx].right), + } + elif isinstance(plain_data[row_idx][col_idx], str): + pass + elif isinstance(plain_data[row_idx][col_idx], pd.Period): + # snowflake returns the ordinal of a period object + plain_data[row_idx][col_idx] = plain_data[row_idx][col_idx].ordinal + else: + previous_inferred_type = inferred_type_dict.get(col_idx) + data_type = infer_type(plain_data[row_idx][col_idx]) + if isinstance(data_type, (MapType, ArrayType)): + # snowflake converts python dict/array to variant + data_type = VariantType() + if isinstance(data_type, DecimalType): + # we need to calculate the precision and scale + decimal_str = str(plain_data[row_idx][col_idx]) + decimal_parts = decimal_str.split(".") + integer_len = ( + len(decimal_str) + if len(decimal_parts) == 1 + else len(decimal_parts[0]) + ) + scale = 0 if len(decimal_parts) == 1 else len(decimal_parts[1]) + precision = integer_len + scale + if precision > 38: + raise SnowparkClientException( + f"[Local Testing] Column precision {precision} and scale {scale} are not supported." + ) + # handle integer and float separately + data_type = DecimalType(precision=precision, scale=scale) + if previous_inferred_type: + if isinstance(previous_inferred_type, NullType): + inferred_type_dict[col_idx] = data_type + if type(data_type) != type(previous_inferred_type): + raise SnowparkClientException( + f"[Local Testing] Detected type {type(data_type)} and type {type(previous_inferred_type)}" + f" in column, coercion is not currently supported" + ) + if isinstance(inferred_type_dict[col_idx], DecimalType): + inferred_type_dict[col_idx] = DecimalType( + precision=max( + previous_inferred_type.precision, data_type.precision + ), + scale=max(previous_inferred_type.scale, data_type.scale), + ) + else: + inferred_type_dict[col_idx] = data_type + + fields = [] + for idx, pandas_type in enumerate(data.dtypes): + if isinstance(pandas_type, pd.IntervalDtype): + data_type = VariantType() + elif isinstance(pandas_type, pd.DatetimeTZDtype): + data_type = TimestampType() + elif pandas_type.type == numpy.float64: + data_type = DoubleType() + elif isinstance(pandas_type, (pd.Float32Dtype, pd.Float64Dtype)): + data_type = DoubleType() + elif ( + pandas_type.type == numpy.int64 + or pandas_type.type == numpy.datetime64 + or pandas_type.type == numpy.timedelta64 + ): + data_type = LongType() + elif isinstance(pandas_type, PANDAS_INTEGER_TYPES): + data_type = LongType() + elif isinstance(pandas_type, pd.PeriodDtype): + data_type = LongType() + elif pandas_type.type == numpy.bool_: + data_type = BooleanType() + else: + data_type = inferred_type_dict.get(idx, StringType(length=16777216)) + # snowpark write_pandas will ignore the nullability of pd dataframe and set nullable to True + struct_field = StructField(col_names[idx], datatype=data_type, nullable=True) + fields.append(struct_field) + + return StructType(fields=fields), plain_data + + +def _convert_dataframe_to_table( + data: "DataFrame", table_name: str, session: "Session" +) -> Table: + """ + used by create_dataframe from a pandas dataframe to convert a mocking dataframe into a table + """ + df_select_statement, df_plan = data._select_statement, data._plan + table = Table(table_name, session) + # the original _select_statement & plan of Table is query table name + # replace the table._select_statement & plan with the df mocking one + table._select_statement, table._plan = df_select_statement, df_plan + table.write.save_as_table(table_name) + return table diff --git a/src/snowflake/snowpark/mock/plan.py b/src/snowflake/snowpark/mock/plan.py new file mode 100644 index 00000000000..e66d7773ef7 --- /dev/null +++ b/src/snowflake/snowpark/mock/plan.py @@ -0,0 +1,1597 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import importlib +import inspect +import math +import re +import typing +import uuid +from enum import Enum +from functools import cached_property, partial +from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Union +from unittest.mock import MagicMock + +from snowflake.snowpark._internal.analyzer.window_expression import ( + FirstValue, + Lag, + LastValue, + Lead, + RangeFrame, + RowFrame, + SpecifiedWindowFrame, + UnboundedFollowing, + UnboundedPreceding, + WindowExpression, +) +from snowflake.snowpark.mock.window_utils import ( + EntireWindowIndexer, + RowFrameIndexer, + is_rank_related_window_function, +) + +if TYPE_CHECKING: + from snowflake.snowpark.mock.analyzer import MockAnalyzer + +import snowflake.snowpark.mock.file_operation as mock_file_operation +from snowflake.connector.options import pandas as pd +from snowflake.snowpark import Column, Row +from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + EXCEPT, + INTERSECT, + UNION, + UNION_ALL, + quote_name, +) +from snowflake.snowpark._internal.analyzer.binary_expression import ( + Add, + And, + BinaryExpression, + Divide, + EqualNullSafe, + EqualTo, + GreaterThan, + GreaterThanOrEqual, + LessThan, + LessThanOrEqual, + Multiply, + NotEqualTo, + Or, + Pow, + Remainder, + Subtract, +) +from snowflake.snowpark._internal.analyzer.binary_plan_node import Join +from snowflake.snowpark._internal.analyzer.expression import ( + Attribute, + CaseWhen, + Expression, + FunctionExpression, + InExpression, + Like, + ListAgg, + Literal, + MultipleExpression, + RegExp, + ScalarSubquery, + Star, + SubfieldInt, + SubfieldString, + UnresolvedAttribute, +) +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( + LogicalPlan, + Range, + SnowflakeCreateTable, + SnowflakeValues, + UnresolvedRelation, +) +from snowflake.snowpark._internal.analyzer.sort_expression import ( + Ascending, + NullsFirst, + SortOrder, +) +from snowflake.snowpark._internal.analyzer.unary_expression import ( + Alias, + Cast, + IsNaN, + IsNotNull, + IsNull, + Not, + UnresolvedAlias, +) +from snowflake.snowpark._internal.analyzer.unary_plan_node import ( + Aggregate, + CreateViewCommand, + Sample, +) +from snowflake.snowpark._internal.type_utils import infer_type +from snowflake.snowpark._internal.utils import parse_table_name +from snowflake.snowpark.exceptions import SnowparkSQLException +from snowflake.snowpark.mock.functions import _MOCK_FUNCTION_IMPLEMENTATION_MAP +from snowflake.snowpark.mock.select_statement import ( + MockSelectable, + MockSelectableEntity, + MockSelectExecutionPlan, + MockSelectStatement, + MockSetStatement, +) +from snowflake.snowpark.mock.snowflake_data_type import ( + ColumnEmulator, + ColumnType, + TableEmulator, +) +from snowflake.snowpark.mock.util import convert_wildcard_to_regex, custom_comparator +from snowflake.snowpark.types import ( + ArrayType, + BinaryType, + BooleanType, + ByteType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + MapType, + NullType, + ShortType, + StringType, + TimestampType, + TimeType, + VariantType, + _NumericType, +) + + +class MockExecutionPlan(LogicalPlan): + def __init__( + self, + source_plan: LogicalPlan, + session, + *, + child: Optional["MockExecutionPlan"] = None, + expr_to_alias: Optional[Dict[uuid.UUID, str]] = None, + df_aliased_col_name_to_real_col_name: Optional[Dict[str, str]] = None, + ) -> NoReturn: + super().__init__() + self.source_plan = source_plan + self.session = session + mock_query = MagicMock() + mock_query.sql = "SELECT MOCK_TEST_FAKE_QUERY()" + self.queries = [mock_query] + self.child = child + self.expr_to_alias = expr_to_alias if expr_to_alias is not None else {} + self.df_aliased_col_name_to_real_col_name = ( + df_aliased_col_name_to_real_col_name or {} + ) + self.api_calls = [] + + @property + def attributes(self) -> List[Attribute]: + output = describe(self) + return output + + @cached_property + def output(self) -> List[Attribute]: + return [Attribute(a.name, a.datatype, a.nullable) for a in self.attributes] + + +class MockFileOperation(MockExecutionPlan): + class Operator(str, Enum): + PUT = "put" + READ_FILE = "read_file" + # others are not supported yet + + def __init__( + self, + session, + operator: Union[str, Operator], + *, + options: Dict[str, str], + local_file_name: Optional[str] = None, + stage_location: Optional[str] = None, + child: Optional["MockExecutionPlan"] = None, + source_plan: Optional[LogicalPlan] = None, + format: Optional[str] = None, + schema: Optional[List[Attribute]] = None, + ) -> None: + super().__init__(session=session, child=child, source_plan=source_plan) + self.operator = operator + self.local_file_name = local_file_name + self.stage_location = stage_location + self.api_calls = self.api_calls or [] + self.format = format + self.schema = schema + self.options = options + + +def handle_order_by_clause( + order_by: List[SortOrder], + result_df: TableEmulator, + analyzer: "MockAnalyzer", + expr_to_alias: Optional[Dict[str, str]], +) -> TableEmulator: + """Given an input dataframe `result_df` and a list of SortOrder expressions `order_by`, return the sorted dataframe.""" + sort_columns_array = [] + sort_orders_array = [] + null_first_last_array = [] + added_columns = [] + for exp in order_by: + exp_name = analyzer.analyze(exp.child, expr_to_alias) + if exp_name not in result_df.columns: + result_df[exp_name] = calculate_expression( + exp.child, result_df, analyzer, expr_to_alias + ) + added_columns.append(exp_name) + sort_columns_array.append(exp_name) + sort_orders_array.append(isinstance(exp.direction, Ascending)) + null_first_last_array.append( + isinstance(exp.null_ordering, NullsFirst) or exp.null_ordering == NullsFirst + ) + for column, ascending, null_first in reversed( + list(zip(sort_columns_array, sort_orders_array, null_first_last_array)) + ): + comparator = partial(custom_comparator, ascending, null_first) + result_df = result_df.sort_values(by=column, key=comparator) + result_df = result_df.drop(columns=added_columns) + return result_df + + +def handle_range_frame_indexing( + order_spec: List[SortOrder], + res_index: "pd.Index", + res: "pd.api.typing.DataFrameGroupBy", + analyzer: "MockAnalyzer", + expr_to_alias: Optional[Dict[str, str]], + unbounded_preceding: bool, + unbounded_following: bool, +) -> "pd.api.typing.RollingGroupby": + """Return a list of range between window frames based on the dataframe paritions `res` and the ORDER BY clause `order_spec`.""" + if order_spec: + windows = [] + for current_row, win in zip(res_index, res.rolling(EntireWindowIndexer())): + _win = handle_order_by_clause(order_spec, win, analyzer, expr_to_alias) + row_idx = list(_win.index).index(current_row) + start_idx = 0 if unbounded_preceding else row_idx + end_idx = len(_win) - 1 if unbounded_following else row_idx + + def search_boundary_idx(idx, delta, _win): + while 0 <= idx + delta < len(_win): + cur_expr = list( + calculate_expression( + exp.child, _win.iloc[idx], analyzer, expr_to_alias + ) + for exp in order_spec + ) + next_expr = list( + calculate_expression( + exp.child, _win.iloc[idx + delta], analyzer, expr_to_alias + ) + for exp in order_spec + ) + if not cur_expr == next_expr: + break + idx += delta + return idx + + start_idx = search_boundary_idx(start_idx, -1, _win) + end_idx = search_boundary_idx(end_idx, 1, _win) + windows.append(_win[start_idx : end_idx + 1]) + else: # If order by is not specified, just use the entire window + windows = res.rolling(EntireWindowIndexer()) + return windows + + +def execute_mock_plan( + plan: MockExecutionPlan, + expr_to_alias: Optional[Dict[str, str]] = None, +) -> Union[TableEmulator, List[Row]]: + import numpy as np + + if expr_to_alias is None: + expr_to_alias = {} + if isinstance(plan, (MockExecutionPlan, SnowflakePlan)): + source_plan = plan.source_plan + analyzer = plan.session._analyzer + else: + source_plan = plan + analyzer = plan.analyzer + + entity_registry = analyzer.session._conn.entity_registry + + if isinstance(source_plan, SnowflakeValues): + table = TableEmulator( + source_plan.data, + columns=[x.name for x in source_plan.output], + sf_types={ + x.name: ColumnType(x.datatype, x.nullable) for x in source_plan.output + }, + dtype=object, + ) + for column_name in table.columns: + sf_type = table.sf_types[column_name] + table[column_name].sf_type = table.sf_types[column_name] + if not isinstance(sf_type.datatype, _NumericType): + table[column_name].replace(np.nan, None, inplace=True) + return table + if isinstance(source_plan, MockSelectExecutionPlan): + return execute_mock_plan(source_plan.execution_plan, expr_to_alias) + if isinstance(source_plan, MockSelectStatement): + projection: Optional[List[Expression]] = source_plan.projection or [] + from_: Optional[MockSelectable] = source_plan.from_ + where: Optional[Expression] = source_plan.where + order_by: Optional[List[Expression]] = source_plan.order_by + limit_: Optional[int] = source_plan.limit_ + offset: Optional[int] = source_plan.offset + + from_df = execute_mock_plan(from_, expr_to_alias) + + result_df = TableEmulator() + + for exp in projection: + if isinstance(exp, Star): + for i in range(len(from_df.columns)): + result_df.insert(len(result_df.columns), str(i), from_df.iloc[:, i]) + result_df.columns = from_df.columns + result_df.sf_types = from_df.sf_types + result_df.sf_types_by_col_index = from_df.sf_types_by_col_index + elif ( + isinstance(exp, UnresolvedAlias) + and exp.child + and isinstance(exp.child, Star) + ): + for e in exp.child.expressions: + col_name = analyzer.analyze(e, expr_to_alias) + result_df[col_name] = calculate_expression( + e, from_df, analyzer, expr_to_alias + ) + else: + if isinstance(exp, Alias): + column_name = expr_to_alias.get(exp.expr_id, exp.name) + else: + column_name = analyzer.analyze( + exp, expr_to_alias, parse_local_name=True + ) + + column_series = calculate_expression( + exp, from_df, analyzer, expr_to_alias + ) + result_df[column_name] = column_series + + if isinstance(exp, (Alias)): + if isinstance(exp.child, Attribute): + quoted_name = quote_name(exp.name) + expr_to_alias[exp.child.expr_id] = quoted_name + for k, v in expr_to_alias.items(): + if v == exp.child.name: + expr_to_alias[k] = quoted_name + + if where: + condition = calculate_expression(where, result_df, analyzer, expr_to_alias) + result_df = result_df[condition] + + if order_by: + result_df = handle_order_by_clause( + order_by, result_df, analyzer, expr_to_alias + ) + + if limit_ is not None: + if offset is not None: + result_df = result_df.iloc[offset:] + result_df = result_df.head(n=limit_) + + return result_df + if isinstance(source_plan, MockSetStatement): + first_operand = source_plan.set_operands[0] + res_df = execute_mock_plan( + MockExecutionPlan( + first_operand.selectable, + source_plan.analyzer.session, + ), + expr_to_alias, + ) + for i in range(1, len(source_plan.set_operands)): + operand = source_plan.set_operands[i] + operator = operand.operator + cur_df = execute_mock_plan( + MockExecutionPlan(operand.selectable, source_plan.analyzer.session), + expr_to_alias, + ) + if len(res_df.columns) != len(cur_df.columns): + raise SnowparkSQLException( + f"SQL compilation error: invalid number of result columns for set operator input branches, expected {len(res_df.columns)}, got {len(cur_df.columns)} in branch {i + 1}" + ) + cur_df.columns = res_df.columns + if operator in (UNION, UNION_ALL): + res_df = pd.concat([res_df, cur_df], ignore_index=True) + res_df = ( + res_df.drop_duplicates().reset_index(drop=True) + if operator == UNION + else res_df + ) + res_df.sf_types = cur_df.sf_types + elif operator in (EXCEPT, INTERSECT): + # NaN == NaN evaluates to False in pandas, so we need to manually process rows that are all None/NaN + if ( + res_df.isnull().all(axis=1).where(lambda x: x).count() > 1 + ): # Dedup rows that are all None/NaN + res_df = res_df.drop(index=res_df.isnull().all(axis=1).index[1:]) + + any_null_rows_in_cur_df = cur_df.isnull().all(axis=1).any() + null_rows_in_res_df = res_df.isnull().all(axis=1) + if operator == INTERSECT: + res_df = res_df[ + (res_df.isin(cur_df.values.ravel()).all(axis=1)).values # IS IN + | ( + any_null_rows_in_cur_df & null_rows_in_res_df.values + ) # Rows that are all None/NaN in both sets + ] + elif operator == EXCEPT: + res_df = res_df[ + ~( + res_df.isin(cur_df.values.ravel()).all(axis=1) + ).values # NOT IS IN + | ( + ~any_null_rows_in_cur_df & null_rows_in_res_df.values + ) # Rows that are all None/NaN only in LEFT + ] + + # Compute drop duplicates + res_df = res_df.drop_duplicates() + else: + raise NotImplementedError( + f"[Local Testing] SetStatement operator {operator} is currently not implemented." + ) + return res_df + if isinstance(source_plan, MockSelectableEntity): + entity_name = source_plan.entity_name + if entity_registry.is_existing_table(entity_name): + return entity_registry.read_table(entity_name) + elif entity_registry.is_existing_view(entity_name): + execution_plan = entity_registry.get_review(entity_name) + res_df = execute_mock_plan(execution_plan) + return res_df + else: + db_schme_table = parse_table_name(entity_name) + raise SnowparkSQLException( + f"Object '{db_schme_table[0][1:-1]}.{db_schme_table[1][1:-1]}.{db_schme_table[2][1:-1]}' does not exist or not authorized." + ) + if isinstance(source_plan, Aggregate): + child_rf = execute_mock_plan(source_plan.child) + if ( + not source_plan.aggregate_expressions + and not source_plan.grouping_expressions + ): + return ( + TableEmulator(child_rf.iloc[0].to_frame().T, sf_types=child_rf.sf_types) + if len(child_rf) + else TableEmulator( + data=None, + dtype=object, + columns=child_rf.columns, + sf_types=child_rf.sf_types, + ) + ) + aggregate_columns = [ + plan.session._analyzer.analyze(exp, keep_alias=False) + for exp in source_plan.aggregate_expressions + ] + intermediate_mapped_column = [ + f"" for i in range(len(aggregate_columns)) + ] + for i in range(len(intermediate_mapped_column)): + agg_expr = source_plan.aggregate_expressions[i] + if isinstance(agg_expr, Alias): + if isinstance(agg_expr.child, Literal) and isinstance( + agg_expr.child.datatype, _NumericType + ): + child_rf.insert( + len(child_rf.columns), + intermediate_mapped_column[i], + ColumnEmulator( + data=[agg_expr.child.value] * len(child_rf), + sf_type=ColumnType( + agg_expr.child.datatype, agg_expr.child.nullable + ), + ), + ) + elif isinstance( + agg_expr.child, (ListAgg, FunctionExpression, BinaryExpression) + ): + # function expression will be evaluated later + child_rf.insert( + len(child_rf.columns), + intermediate_mapped_column[i], + ColumnEmulator( + data=[None] * len(child_rf), + dtype=object, + sf_type=None, # it will be set later when evaluating the function. + ), + ) + else: + raise NotImplementedError( + f"[Local Testing] Aggregate expression {type(agg_expr.child).__name__} is not implemented." + ) + elif isinstance(agg_expr, (Attribute, UnresolvedAlias)): + column_name = plan.session._analyzer.analyze(agg_expr) + try: + child_rf.insert( + len(child_rf.columns), + intermediate_mapped_column[i], + child_rf[column_name], + ) + except KeyError: + raise SnowparkSQLException( + f"[Local Testing] invalid identifier {column_name}" + ) + else: + raise NotImplementedError( + f"[Local Testing] Aggregate expression {type(agg_expr).__name__} is not implemented." + ) + + result_df_sf_Types = {} + result_df_sf_Types_by_col_idx = {} + + column_exps = [ + ( + plan.session._analyzer.analyze(exp), + bool(isinstance(exp, Literal)), + calculate_expression( + exp, child_rf, plan.session._analyzer, expr_to_alias + ).sf_type, + ) + for exp in source_plan.grouping_expressions + ] + for idx, (column_name, _, column_type) in enumerate(column_exps): + result_df_sf_Types[ + column_name + ] = column_type # TODO: fix this, this does not work + result_df_sf_Types_by_col_idx[idx] = column_type + # Aggregate may not have column_exps, which is allowed in the case of `Dataframe.agg`, in this case we pass + # lambda x: True as the `by` parameter + # also pandas group by takes None and nan as the same, so we use .astype to differentiate the two + by_column_expression = [] + try: + for exp in source_plan.grouping_expressions: + if isinstance(exp, Literal) and isinstance(exp.datatype, _NumericType): + col_name = f"" + by_column_expression.append(child_rf[col_name]) + else: + by_column_expression.append( + child_rf[plan.session._analyzer.analyze(exp)] + ) + except KeyError as e: + raise SnowparkSQLException( + f"This is not a valid group by expression due to exception {e!r}" + ) + + children_dfs = child_rf.groupby( + by=by_column_expression or (lambda x: True), sort=False, dropna=False + ) + # we first define the returning DataFrame with its column names + columns = [ + quote_name(plan.session._analyzer.analyze(exp, keep_alias=False)) + for exp in source_plan.aggregate_expressions + ] + intermediate_mapped_column = [str(i) for i in range(len(columns))] + result_df = TableEmulator(columns=intermediate_mapped_column, dtype=object) + data = [] + + def aggregate_by_groups(cur_group: TableEmulator): + values = [] + + if column_exps: + for idx, (expr, is_literal, _) in enumerate(column_exps): + if is_literal: + values.append(source_plan.grouping_expressions[idx].value) + elif not cur_group.empty: + values.append(cur_group.iloc[0][expr]) + + # the first len(column_exps) items of calculate_expression are the group_by column expressions, + # the remaining are the aggregation function expressions + for idx, exp in enumerate( + source_plan.aggregate_expressions[len(column_exps) :] + ): + cal_exp_res = calculate_expression( + exp, + cur_group, + plan.session._analyzer, + expr_to_alias, + ) + # and then append the calculated value + if isinstance(cal_exp_res, ColumnEmulator): + values.append(cal_exp_res.iat[0]) + result_df_sf_Types[ + columns[idx + len(column_exps)] + ] = result_df_sf_Types_by_col_idx[ + idx + len(column_exps) + ] = cal_exp_res.sf_type + else: + values.append(cal_exp_res) + result_df_sf_Types[ + columns[idx + len(column_exps)] + ] = result_df_sf_Types_by_col_idx[ + idx + len(column_exps) + ] = ColumnType( + infer_type(cal_exp_res), nullable=True + ) + data.append(values) + + if not children_dfs.indices: + aggregate_by_groups(child_rf) + else: + for _, indices in children_dfs.indices.items(): + # we construct row by row + cur_group = child_rf.iloc[indices] + # each row starts with group keys/column expressions, if there is no group keys/column expressions + # it means aggregation without group (Datagrame.agg) + aggregate_by_groups(cur_group) + + if len(data): + for col in range(len(data[0])): + series_data = ColumnEmulator( + data=[data[row][col] for row in range(len(data))], + dtype=object, + ) + result_df[intermediate_mapped_column[col]] = series_data + + result_df.sf_types = result_df_sf_Types + result_df.sf_types_by_col_index = result_df_sf_Types_by_col_idx + result_df.columns = columns + return result_df + if isinstance(source_plan, Range): + col = ColumnEmulator( + data=[ + num + for num in range( + source_plan.start, source_plan.end, int(source_plan.step) + ) + ], + sf_type=ColumnType(LongType(), False), + ) + result_df = TableEmulator( + col, + columns=['"ID"'], + sf_types={'"ID"': col.sf_type}, + dtype=object, + ) + return result_df + if isinstance(source_plan, Join): + L_expr_to_alias = {} + R_expr_to_alias = {} + left = execute_mock_plan(source_plan.left, L_expr_to_alias).reset_index( + drop=True + ) + right = execute_mock_plan(source_plan.right, R_expr_to_alias).reset_index( + drop=True + ) + # Processing ON clause + using_columns = getattr(source_plan.join_type, "using_columns", None) + on = using_columns + if isinstance(on, list): # USING a list of columns + if on: + on = [quote_name(x.upper()) for x in on] + else: + on = None + elif isinstance(on, Column): # ON a single column + on = on.name + elif isinstance( + on, BinaryExpression + ): # ON a condition, apply where to a Cartesian product + on = None + else: # ON clause not specified, SF returns a Cartesian product + on = None + + # Processing the join type + how = source_plan.join_type.sql + if how.startswith("USING "): + how = how[6:] + if how.startswith("NATURAL "): + how = how[8:] + if how == "LEFT OUTER": + how = "LEFT" + elif how == "RIGHT OUTER": + how = "RIGHT" + elif "FULL" in how: + how = "OUTER" + elif "SEMI" in how: + how = "INNER" + elif "ANTI" in how: + how = "CROSS" + + if ( + "NATURAL" in source_plan.join_type.sql and on is None + ): # natural joins use the list of common names as keys + on = left.columns.intersection(right.columns).values.tolist() + + if on is None: + how = "CROSS" + + result_df = left.merge( + right, + on=on, + how=how.lower(), + ) + + # Restore sf_types information after merging, there should be better way to do this + result_df.sf_types.update(left.sf_types) + result_df.sf_types.update(right.sf_types) + + if on: + result_df = result_df.reset_index(drop=True) + if isinstance(on, list): + # Reorder columns for JOINS with USING clause, where Snowflake puts the key columns to the left + reordered_cols = on + [ + col for col in result_df.columns.tolist() if col not in on + ] + result_df = result_df[reordered_cols] + + common_columns = set(L_expr_to_alias.keys()).intersection( + R_expr_to_alias.keys() + ) + new_expr_to_alias = { + k: v + for k, v in { + **L_expr_to_alias, + **R_expr_to_alias, + }.items() + if k not in common_columns + } + expr_to_alias.update(new_expr_to_alias) + + if source_plan.condition: + + def outer_join(base_df): + ret = base_df.apply(tuple, 1).isin( + result_df[condition][base_df.columns].apply(tuple, 1) + ) + ret.sf_type = ColumnType(BooleanType(), True) + return ret + + condition = calculate_expression( + source_plan.condition, result_df, analyzer, expr_to_alias + ) + sf_types = result_df.sf_types + if "SEMI" in source_plan.join_type.sql: # left semi + result_df = left[outer_join(left)].dropna() + elif "ANTI" in source_plan.join_type.sql: # left anti + result_df = left[~outer_join(left)].dropna() + elif "LEFT" in source_plan.join_type.sql: # left outer join + # rows from LEFT that did not get matched + unmatched_left = left[~outer_join(left)] + unmatched_left[right.columns] = None + result_df = pd.concat( + [result_df[condition], unmatched_left], ignore_index=True + ) + for right_column in right.columns.values: + ct = sf_types[right_column] + sf_types[right_column] = ColumnType(ct.datatype, True) + elif "RIGHT" in source_plan.join_type.sql: # right outer join + # rows from RIGHT that did not get matched + unmatched_right = right[~outer_join(right)] + unmatched_right[left.columns] = None + result_df = pd.concat( + [result_df[condition], unmatched_right], ignore_index=True + ) + for left_column in right.columns.values: + ct = sf_types[left_column] + sf_types[left_column] = ColumnType(ct.datatype, True) + elif "OUTER" in source_plan.join_type.sql: # full outer join + # rows from LEFT that did not get matched + unmatched_left = left[~outer_join(left)] + unmatched_left[right.columns] = None + # rows from RIGHT that did not get matched + unmatched_right = right[~outer_join(right)] + unmatched_right[left.columns] = None + result_df = pd.concat( + [result_df[condition], unmatched_left, unmatched_right], + ignore_index=True, + ) + for col_name, col_type in sf_types.items(): + sf_types[col_name] = ColumnType(col_type.datatype, True) + else: + result_df = result_df[condition] + result_df.sf_types = sf_types + + return result_df.where(result_df.notna(), None) # Swap np.nan with None + if isinstance(source_plan, MockFileOperation): + return execute_file_operation(source_plan, analyzer) + if isinstance(source_plan, SnowflakeCreateTable): + if source_plan.column_names is not None: + raise NotImplementedError( + "[Local Testing] Inserting data into table by matching column names is currently not supported." + ) + res_df = execute_mock_plan(source_plan.query) + return entity_registry.write_table( + source_plan.table_name, res_df, source_plan.mode + ) + if isinstance(source_plan, UnresolvedRelation): + entity_name = source_plan.name + if entity_registry.is_existing_table(entity_name): + return entity_registry.read_table(entity_name) + elif entity_registry.is_existing_view(entity_name): + execution_plan = entity_registry.get_review(entity_name) + res_df = execute_mock_plan(execution_plan) + return res_df + else: + db_schme_table = parse_table_name(entity_name) + raise SnowparkSQLException( + f"Object '{db_schme_table[0][1:-1]}.{db_schme_table[1][1:-1]}.{db_schme_table[2][1:-1]}' does not exist or not authorized." + ) + if isinstance(source_plan, Sample): + res_df = execute_mock_plan(source_plan.child) + + if source_plan.row_count and ( + source_plan.row_count < 0 or source_plan.row_count > 100000 + ): + raise SnowparkSQLException( + "parameter value out of range: size of fixed sample. Must be between 0 and 1,000,000." + ) + + return res_df.sample( + n=None + if source_plan.row_count is None + else min(source_plan.row_count, len(res_df)), + frac=source_plan.probability_fraction, + random_state=source_plan.seed, + ) + elif isinstance(source_plan, CreateViewCommand): + from_df = execute_mock_plan(source_plan.child, expr_to_alias) + view_name = source_plan.name + entity_registry.create_or_replace_view(source_plan.child, view_name) + return from_df + raise NotImplementedError( + f"[Local Testing] Mocking SnowflakePlan {type(source_plan).__name__} is not implemented." + ) + + +def describe(plan: MockExecutionPlan) -> List[Attribute]: + result = execute_mock_plan(plan) + ret = [] + for c in result.columns: + # Raising an exception here will cause infinite recursion + if isinstance(result[c].sf_type.datatype, NullType): + ret.append( + Attribute( + result[c].name if result[c].name else "NULL", StringType(), True + ) + ) + else: + data_type = result[c].sf_type.datatype + if isinstance(data_type, (ByteType, ShortType, IntegerType)): + data_type = LongType() + elif isinstance(data_type, FloatType): + data_type = DoubleType() + elif ( + isinstance(data_type, DecimalType) + and data_type.precision == 38 + and data_type.scale == 0 + ): + data_type = LongType() + elif isinstance(data_type, StringType): + data_type.length = ( + StringType._MAX_LENGTH + if data_type.length is None + else data_type.length + ) + + ret.append( + Attribute( + quote_name(result[c].name.strip()), + data_type, + result[c].sf_type.nullable, + ) + ) + return ret + + +def calculate_expression( + exp: Expression, + input_data: Union[TableEmulator, ColumnEmulator], + analyzer, + expr_to_alias: Dict[str, str], + *, + keep_literal: bool = False, +) -> ColumnEmulator: + """ + Returns the calculated expression evaluated based on input table/column + setting keep_literal to true returns Python datatype + setting keep_literal to false returns a ColumnEmulator wrapping the Python datatype of a Literal + """ + import numpy as np + + if isinstance(exp, Attribute): + try: + return input_data[expr_to_alias.get(exp.expr_id, exp.name)] + except KeyError: + # expr_id maps to the projected name, but input_data might still have the exp.name + # dealing with the KeyError here, this happens in case df.union(df) + # TODO: check SNOW-831880 for more context + return input_data[exp.name] + if isinstance(exp, (UnresolvedAttribute, Attribute)): + if exp.is_sql_text: + raise NotImplementedError( + "[Local Testing] SQL Text Expression is not supported." + ) + try: + return input_data[exp.name] + except KeyError: + raise SnowparkSQLException(f"[Local Testing] invalid identifier {exp.name}") + if isinstance(exp, (UnresolvedAlias, Alias)): + return calculate_expression(exp.child, input_data, analyzer, expr_to_alias) + if isinstance(exp, FunctionExpression): + + # Special case for count_distinct + if exp.name.lower() == "count" and exp.is_distinct: + func_name = "count_distinct" + else: + func_name = exp.name.lower() + + try: + original_func = getattr( + importlib.import_module("snowflake.snowpark.functions"), func_name + ) + except AttributeError: + raise NotImplementedError( + f"[Local Testing] Mocking function {func_name} is not supported." + ) + + signatures = inspect.signature(original_func) + spec = inspect.getfullargspec(original_func) + if func_name not in _MOCK_FUNCTION_IMPLEMENTATION_MAP: + raise NotImplementedError( + f"[Local Testing] Mocking function {func_name} is not implemented." + ) + to_pass_args = [] + type_hints = typing.get_type_hints(original_func) + for idx, key in enumerate(signatures.parameters): + type_hint = str(type_hints[key]) + keep_literal = "Column" not in type_hint + if key == spec.varargs: + to_pass_args.extend( + [ + calculate_expression( + c, + input_data, + analyzer, + expr_to_alias, + keep_literal=keep_literal, + ) + for c in exp.children[idx:] + ] + ) + else: + try: + to_pass_args.append( + calculate_expression( + exp.children[idx], + input_data, + analyzer, + expr_to_alias, + keep_literal=keep_literal, + ) + ) + except IndexError: + to_pass_args.append(None) + if func_name == "array_agg": + to_pass_args[-1] = exp.is_distinct + if func_name == "sum" and exp.is_distinct: + to_pass_args[0] = ColumnEmulator( + data=to_pass_args[0].unique(), sf_type=to_pass_args[0].sf_type + ) + return _MOCK_FUNCTION_IMPLEMENTATION_MAP[func_name](*to_pass_args) + if isinstance(exp, ListAgg): + lhs = calculate_expression(exp.col, input_data, analyzer, expr_to_alias) + lhs.sf_type = ColumnType(StringType(), exp.col.nullable) + return _MOCK_FUNCTION_IMPLEMENTATION_MAP["listagg"]( + lhs, + is_distinct=exp.is_distinct, + delimiter=exp.delimiter, + ) + if isinstance(exp, IsNull): + child_column = calculate_expression( + exp.child, input_data, analyzer, expr_to_alias + ) + return ColumnEmulator( + data=[bool(data is None) for data in child_column], + sf_type=ColumnType(BooleanType(), True), + ) + if isinstance(exp, IsNotNull): + child_column = calculate_expression( + exp.child, input_data, analyzer, expr_to_alias + ) + return ColumnEmulator( + data=[bool(data is not None) for data in child_column], + sf_type=ColumnType(BooleanType(), True), + ) + if isinstance(exp, IsNaN): + child_column = calculate_expression( + exp.child, input_data, analyzer, expr_to_alias + ) + res = [] + for data in child_column: + try: + res.append(math.isnan(data)) + except TypeError: + res.append(False) + return ColumnEmulator( + data=res, dtype=object, sf_type=ColumnType(BooleanType(), True) + ) + if isinstance(exp, Not): + child_column = calculate_expression( + exp.child, input_data, analyzer, expr_to_alias + ) + return ~child_column + if isinstance(exp, UnresolvedAttribute): + return analyzer.analyze(exp, expr_to_alias) + if isinstance(exp, Literal): + if not keep_literal: + if isinstance(exp.datatype, StringType): + # in live session, literal of string type will have size auto inferred + exp.datatype = StringType(len(exp.value)) + res = ColumnEmulator( + data=[exp.value for _ in range(len(input_data))], + sf_type=ColumnType(exp.datatype, False), + dtype=object, + ) + res.index = input_data.index + return res + return exp.value + if isinstance(exp, BinaryExpression): + left = calculate_expression(exp.left, input_data, analyzer, expr_to_alias) + right = calculate_expression(exp.right, input_data, analyzer, expr_to_alias) + # TODO: Address mixed type calculation here. For instance Snowflake allows to add a date to a number, but + # pandas doesn't allow. Type coercion will address it. + if isinstance(exp, Multiply): + new_column = left * right + elif isinstance(exp, Divide): + new_column = left / right + elif isinstance(exp, Add): + new_column = left + right + elif isinstance(exp, Subtract): + new_column = left - right + elif isinstance(exp, Remainder): + new_column = left % right + elif isinstance(exp, Pow): + new_column = left**right + elif isinstance(exp, EqualTo): + new_column = left == right + if left.hasnans and right.hasnans: + new_column[ + left.apply(lambda x: x is None) & right.apply(lambda x: x is None) + ] = True + new_column[ + left.apply(lambda x: x is not None and np.isnan(x)) + & right.apply(lambda x: x is not None and np.isnan(x)) + ] = True + # NaN == NaN evaluates to False in pandas, but True in Snowflake + elif isinstance(exp, NotEqualTo): + new_column = left != right + elif isinstance(exp, GreaterThanOrEqual): + new_column = left >= right + elif isinstance(exp, GreaterThan): + new_column = left > right + elif isinstance(exp, LessThanOrEqual): + new_column = left <= right + elif isinstance(exp, LessThan): + new_column = left < right + elif isinstance(exp, And): + new_column = ( + (left & right) + if isinstance(input_data, TableEmulator) or not input_data + else (left & right) & input_data + ) + elif isinstance(exp, Or): + new_column = ( + (left | right) + if isinstance(input_data, TableEmulator) or not input_data + else (left | right) & input_data + ) + elif isinstance(exp, EqualNullSafe): + new_column = ( + (left == right) + | (left.isna() & right.isna()) + | (left.isnull() & right.isnull()) + ) + else: + raise NotImplementedError( + f"[Local Testing] Binary expression {type(exp)} is not implemented." + ) + return new_column + if isinstance(exp, RegExp): + lhs = calculate_expression(exp.expr, input_data, analyzer, expr_to_alias) + raw_pattern = calculate_expression( + exp.pattern, input_data, analyzer, expr_to_alias + )[0] + pattern = f"^{raw_pattern}" if not raw_pattern.startswith("^") else raw_pattern + pattern = f"{pattern}$" if not pattern.endswith("$") else pattern + try: + re.compile(pattern) + except re.error: + raise SnowparkSQLException(f"Invalid regular expression {raw_pattern}") + result = lhs.str.match(pattern) + result.sf_type = ColumnType(BooleanType(), True) + return result + if isinstance(exp, Like): + lhs = calculate_expression(exp.expr, input_data, analyzer, expr_to_alias) + pattern = convert_wildcard_to_regex( + str( + calculate_expression(exp.pattern, input_data, analyzer, expr_to_alias)[ + 0 + ] + ) + ) + result = lhs.str.match(pattern) + result.sf_type = ColumnType(BooleanType(), True) + return result + if isinstance(exp, InExpression): + lhs = calculate_expression(exp.columns, input_data, analyzer, expr_to_alias) + res = ColumnEmulator([False] * len(lhs), dtype=object) + res.sf_type = ColumnType(BooleanType(), True) + for val in exp.values: + rhs = calculate_expression(val, input_data, analyzer, expr_to_alias) + if isinstance(lhs, ColumnEmulator): + if isinstance(rhs, ColumnEmulator): + res = res | lhs.isin(rhs) + elif isinstance(rhs, TableEmulator): + res = res | lhs.isin(rhs.iloc[:, 0]) + else: + raise NotImplementedError( + f"[Local Testing] IN expression does not support {type(rhs)} type on the right" + ) + else: + exists = lhs.apply(tuple, 1).isin(rhs.apply(tuple, 1)) + exists.sf_type = ColumnType(BooleanType(), False) + res = res | exists + return res + if isinstance(exp, ScalarSubquery): + return execute_mock_plan(exp.plan, expr_to_alias) + if isinstance(exp, MultipleExpression): + res = TableEmulator() + for e in exp.expressions: + res[analyzer.analyze(e, expr_to_alias)] = calculate_expression( + e, input_data, analyzer, expr_to_alias + ) + return res + if isinstance(exp, Cast): + column = calculate_expression(exp.child, input_data, analyzer, expr_to_alias) + if isinstance(exp.to, DateType): + return _MOCK_FUNCTION_IMPLEMENTATION_MAP["to_date"]( + column, try_cast=exp.try_ + ) + elif isinstance(exp.to, TimeType): + return _MOCK_FUNCTION_IMPLEMENTATION_MAP["to_time"]( + column, try_cast=exp.try_ + ) + elif isinstance(exp.to, TimestampType): + return _MOCK_FUNCTION_IMPLEMENTATION_MAP["to_timestamp"]( + column, try_cast=exp.try_ + ) + elif isinstance(exp.to, DecimalType): + return _MOCK_FUNCTION_IMPLEMENTATION_MAP["to_decimal"]( + column, + precision=exp.to.precision, + scale=exp.to.scale, + try_cast=exp.try_, + ) + elif isinstance(exp.to, IntegerType): + res = _MOCK_FUNCTION_IMPLEMENTATION_MAP["to_decimal"]( + column, try_cast=exp.try_ + ) + res.set_sf_type(ColumnType(IntegerType(), nullable=column.sf_type.nullable)) + return res + elif isinstance(exp.to, BinaryType): + return _MOCK_FUNCTION_IMPLEMENTATION_MAP["to_binary"]( + column, try_cast=exp.try_ + ) + elif isinstance(exp.to, StringType): + return _MOCK_FUNCTION_IMPLEMENTATION_MAP["to_char"]( + column, try_cast=exp.try_ + ) + elif isinstance(exp.to, DoubleType): + return _MOCK_FUNCTION_IMPLEMENTATION_MAP["to_double"]( + column, try_cast=exp.try_ + ) + elif isinstance(exp.to, MapType): + return _MOCK_FUNCTION_IMPLEMENTATION_MAP["to_object"](column) + elif isinstance(exp.to, ArrayType): + return _MOCK_FUNCTION_IMPLEMENTATION_MAP["to_array"](column) + elif isinstance(exp.to, VariantType): + return _MOCK_FUNCTION_IMPLEMENTATION_MAP["to_variant"](column) + else: + raise NotImplementedError( + f"[Local Testing] Cast to {exp.to} is not supported yet" + ) + if isinstance(exp, CaseWhen): + remaining = input_data + output_data = ColumnEmulator([None] * len(input_data)) + for case in exp.branches: + if len(remaining) == 0: + break + condition = calculate_expression( + case[0], input_data, analyzer, expr_to_alias + ) + value = calculate_expression(case[1], input_data, analyzer, expr_to_alias) + + true_index = remaining[condition].index + output_data[true_index] = value[true_index] + remaining = remaining[~remaining.index.isin(true_index)] + + if output_data.sf_type: + if ( + not isinstance(output_data.sf_type.datatype, NullType) + and output_data.sf_type != value.sf_type + ): + raise SnowparkSQLException( + f"CaseWhen expressions have conflicting data types: {output_data.sf_type} != {value.sf_type}" + ) + else: + output_data.sf_type = value.sf_type + + if len(remaining) > 0 and exp.else_value: + value = calculate_expression( + exp.else_value, remaining, analyzer, expr_to_alias + ) + output_data[remaining.index] = value[remaining.index] + if output_data.sf_type: + if ( + not isinstance(output_data.sf_type.datatype, NullType) + and output_data.sf_type.datatype != value.sf_type.datatype + ): + raise SnowparkSQLException( + f"CaseWhen expressions have conflicting data types: {output_data.sf_type.datatype} != {value.sf_type.datatype}" + ) + else: + output_data.sf_type = value.sf_type + return output_data + if isinstance(exp, WindowExpression): + window_function = exp.window_function + window_spec = exp.window_spec + + # Process order by clause + if window_spec.order_spec: + res = handle_order_by_clause( + window_spec.order_spec, input_data, analyzer, expr_to_alias + ) + elif is_rank_related_window_function(window_function): + raise SnowparkSQLException( + f"Window function type [{str(window_function)}] requires ORDER BY in window specification" + ) + else: + res = input_data + + res_index = res.index # List of row indexes of the result + + # Process partition_by clause + if window_spec.partition_spec: + res = res.groupby( + [exp.name for exp in window_spec.partition_spec], + sort=False, + as_index=False, + ) + res_index = [] + for r in res: + res_index += list(r[1].index) + + # Process window frame specification + # Reference: https://docs.snowflake.com/en/sql-reference/functions-analytic#window-frame-usage-notes + if not window_spec.frame_spec or not isinstance( + window_spec.frame_spec, SpecifiedWindowFrame + ): + if not is_rank_related_window_function(window_function): + windows = handle_range_frame_indexing( + window_spec.order_spec, + res_index, + res, + analyzer, + expr_to_alias, + True, + False, + ) + else: + indexer = EntireWindowIndexer() + res = res.rolling(indexer) + windows = [input_data.loc[w.index] for w in res] + + elif isinstance(window_spec.frame_spec.frame_type, RowFrame): + indexer = RowFrameIndexer(frame_spec=window_spec.frame_spec) + res = res.rolling(indexer) + windows = [w for w in res] + + elif isinstance(window_spec.frame_spec.frame_type, RangeFrame): + upper = window_spec.frame_spec.upper + lower = window_spec.frame_spec.lower + + if isinstance(upper, Literal) or isinstance(lower, Literal): + raise SnowparkSQLException( + "Range is not supported for sliding window frames." + ) + + windows = handle_range_frame_indexing( + window_spec.order_spec, + res_index, + res, + analyzer, + expr_to_alias, + isinstance(lower, UnboundedPreceding), + isinstance(upper, UnboundedFollowing), + ) + + # compute window function: + if isinstance(window_function, (FunctionExpression,)): + res_cols = [] + for current_row, w in zip(res_index, windows): + evaluated_children = [ + calculate_expression(c, w, analyzer, expr_to_alias) + for c in window_function.children + ] + try: + original_func = getattr( + importlib.import_module("snowflake.snowpark.functions"), + window_function.name.lower(), + ) + except AttributeError: + raise NotImplementedError( + f"[Local Testing] Mocking window function {window_function.name.lower()} is not supported." + ) + + signatures = inspect.signature(original_func) + spec = inspect.getfullargspec(original_func) + if window_function.name not in _MOCK_FUNCTION_IMPLEMENTATION_MAP: + raise NotImplementedError( + f"[Local Testing] Mocking window function {window_function.name} is not implemented." + ) + to_pass_args = [] + for idx, key in enumerate(signatures.parameters): + if key == spec.varargs: + to_pass_args.extend(evaluated_children[idx:]) + else: + try: + to_pass_args.append(evaluated_children[idx]) + except IndexError: + to_pass_args.append(None) + # Rank related function specific arguments + if window_function.name == "row_number": + to_pass_args.append(w) + row_idx = list(w.index).index( + current_row + ) # the row's 0-base index in the window + to_pass_args.append(row_idx) + res_cols.append( + _MOCK_FUNCTION_IMPLEMENTATION_MAP[window_function.name]( + *to_pass_args + ) + ) + res_col = pd.concat(res_cols) + res_col.index = res_index + if res_cols: + res_col.set_sf_type(res_cols[0].sf_type) + else: + res_col.set_sf_type(ColumnType(NullType(), True)) + return res_col.sort_index() + elif isinstance(window_function, (Lead, Lag)): + calculated_sf_type = None + offset = window_function.offset * ( + 1 if isinstance(window_function, Lead) else -1 + ) + ignore_nulls = window_function.ignore_nulls + res_cols = [] + for current_row, w in zip(res_index, windows): + row_idx = list(w.index).index( + current_row + ) # the row's 0-base index in the window + offset_idx = row_idx + offset + if offset_idx < 0 or offset_idx >= len(w): + sub_window_res = calculate_expression( + window_function.default, + w, + analyzer, + expr_to_alias, + ) + if not calculated_sf_type: + calculated_sf_type = sub_window_res.sf_type + elif calculated_sf_type.datatype != sub_window_res.sf_type.datatype: + if isinstance(calculated_sf_type.datatype, NullType): + calculated_sf_type = sub_window_res.sf_type + # the result calculated upon a windows can be None, this is still valid and we can keep + # the calculation + elif not isinstance(sub_window_res.sf_type.datatype, NullType): + raise SnowparkSQLException( + f"[Local Testing] Detected type {type(calculated_sf_type.datatype)} and type {type(sub_window_res.sf_type.datatype)}" + f" in column, coercion is not currently supported" + ) + res_cols.append(sub_window_res.iloc[0]) + elif not ignore_nulls or offset == 0: + sub_window_res = calculate_expression( + window_function.expr, + w.iloc[[offset_idx]], + analyzer, + expr_to_alias, + ) + # we use the whole frame to calculate the type + cur_windows_sf_type = calculate_expression( + window_function.expr, + w, + analyzer, + expr_to_alias, + ).sf_type + if not calculated_sf_type: + calculated_sf_type = cur_windows_sf_type + elif calculated_sf_type != cur_windows_sf_type and ( + not ( + isinstance(calculated_sf_type.datatype, StringType) + and isinstance(cur_windows_sf_type.datatype, StringType) + ) + ): + if isinstance(calculated_sf_type.datatype, NullType): + calculated_sf_type = sub_window_res.sf_type + # the result calculated upon a windows can be None, this is still valid and we can keep + # the calculation + elif not isinstance(sub_window_res.sf_type.datatype, NullType): + raise SnowparkSQLException( + f"[Local Testing] Detected type {type(calculated_sf_type.datatype)} and type {type(cur_windows_sf_type.datatype)}" + f" in column, coercion is not currently supported" + ) + res_cols.append(sub_window_res.iloc[0]) + else: + # skip rows where expr is NULL + delta = 1 if offset > 0 else -1 + cur_idx = row_idx + delta + cur_count = 0 + while 0 <= cur_idx < len(w): + target_expr = calculate_expression( + window_function.expr, + w.iloc[[cur_idx]], + analyzer, + expr_to_alias, + ).iloc[0] + if target_expr is not None: + cur_count += 1 + if cur_count == abs(offset): + break + cur_idx += delta + if cur_idx < 0 or cur_idx >= len(w): + res_cols.append( + calculate_expression( + window_function.default, + w, + analyzer, + expr_to_alias, + ).iloc[0] + ) + else: + res_cols.append(target_expr) + res_col = ColumnEmulator( + data=res_cols, dtype=object + ) # dtype=object prevents implicit converting None to Nan + res_col.index = res_index + res_col.sf_type = ( + calculated_sf_type + if calculated_sf_type + else ColumnType(NullType(), True) + ) + return res_col.sort_index() + elif isinstance(window_function, FirstValue): + ignore_nulls = window_function.ignore_nulls + res_cols = [] + for w in windows: + if not ignore_nulls: + res_cols.append( + calculate_expression( + window_function.expr, + w.iloc[[0]], + analyzer, + expr_to_alias, + ).iloc[0] + ) + else: + for cur_idx in range(len(w)): + target_expr = calculate_expression( + window_function.expr, + w.iloc[[cur_idx]], + analyzer, + expr_to_alias, + ).iloc[0] + if target_expr is not None: + res_cols.append(target_expr) + break + else: + res_cols.append(None) + res_col = ColumnEmulator( + data=res_cols, + dtype=object, + sf_type=calculate_expression( + window_function.expr, + input_data, + analyzer, + expr_to_alias, + ).sf_type, + ) # dtype=object prevents implicit converting None to Nan + res_col.index = res_index + return res_col.sort_index() + elif isinstance(window_function, LastValue): + ignore_nulls = window_function.ignore_nulls + res_cols = [] + for w in windows: + if not ignore_nulls: + res_cols.append( + calculate_expression( + window_function.expr, + w.iloc[[len(w) - 1]], + analyzer, + expr_to_alias, + ).iloc[0] + ) + else: + for cur_idx in range(len(w) - 1, -1, -1): + target_expr = calculate_expression( + window_function.expr, + w.iloc[[cur_idx]], + analyzer, + expr_to_alias, + ).iloc[0] + if target_expr is not None: + res_cols.append(target_expr) + break + else: + res_cols.append(None) + res_col = ColumnEmulator( + data=res_cols, + dtype=object, + sf_type=calculate_expression( + window_function.expr, + windows[0], + analyzer, + expr_to_alias, + ).sf_type, + ) # dtype=object prevents implicit converting None to Nan + res_col.index = res_index + return res_col.sort_index() + else: + raise NotImplementedError( + f"[Local Testing] Window Function {window_function} is not implemented." + ) + elif isinstance(exp, SubfieldString): + col = calculate_expression(exp.child, input_data, analyzer, expr_to_alias) + field = str(exp.field) + # in snowflake, two consecutive single quotes means escaping single quote + field = field.replace("''", "'") + col._null_rows_idxs = [ + index + for index in range(len(col)) + if col[index] is not None + and field in col[index] + and col[index][field] is None + ] + res = col.apply(lambda x: None if x is None or field not in x else x[field]) + res.set_sf_type(ColumnType(VariantType(), col.sf_type.nullable)) + return res + elif isinstance(exp, SubfieldInt): + col = calculate_expression(exp.child, input_data, analyzer, expr_to_alias) + res = col.apply(lambda x: None if x is None else x[exp.field]) + res.set_sf_type(ColumnType(VariantType(), col.sf_type.nullable)) + return res + raise NotImplementedError( + f"[Local Testing] Mocking Expression {type(exp).__name__} is not implemented." + ) + + +def execute_file_operation(source_plan: MockFileOperation, analyzer: "MockAnalyzer"): + if source_plan.operator == MockFileOperation.Operator.PUT: + return mock_file_operation.put( + source_plan.local_file_name, source_plan.stage_location + ) + if source_plan.operator == MockFileOperation.Operator.READ_FILE: + return mock_file_operation.read_file( + source_plan.stage_location, + source_plan.format, + source_plan.schema, + analyzer, + source_plan.options, + ) + raise NotImplementedError( + f"[Local Testing] File operation {source_plan.operator.value} is not implemented." + ) diff --git a/src/snowflake/snowpark/mock/plan_builder.py b/src/snowflake/snowpark/mock/plan_builder.py new file mode 100644 index 00000000000..3983670b535 --- /dev/null +++ b/src/snowflake/snowpark/mock/plan_builder.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Dict, List, Optional, Tuple + +from snowflake.snowpark._internal.analyzer.expression import Attribute +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlanBuilder +from snowflake.snowpark._internal.utils import is_single_quoted +from snowflake.snowpark.mock.plan import MockExecutionPlan, MockFileOperation + + +class MockSnowflakePlanBuilder(SnowflakePlanBuilder): + def create_temp_table(self, *args, **kwargs): + raise NotImplementedError( + "[Local Testing] DataFrame.cache_result is currently not implemented." + ) + + def read_file( + self, + path: str, + format: str, + options: Dict[str, str], + fully_qualified_schema: str, + schema: List[Attribute], + schema_to_cast: Optional[List[Tuple[str, str]]] = None, + transformations: Optional[List[str]] = None, + metadata_project: Optional[List[str]] = None, + ) -> MockExecutionPlan: + if format.upper() != "CSV": + raise NotImplementedError( + "[Local Testing] Reading non CSV data into dataframe is not currently supported." + ) + return MockExecutionPlan( + source_plan=MockFileOperation( + session=self.session, + operator=MockFileOperation.Operator.READ_FILE, + stage_location=path, + format=format, + schema=schema, + options=options, + ), + session=self.session, + ) + + def file_operation_plan( + self, command: str, file_name: str, stage_location: str, options: Dict[str, str] + ) -> MockExecutionPlan: + if options.get("auto_compress", False): + raise NotImplementedError( + "[Local Testing] PUT with auto_compress=True is currently not supported." + ) + if command == "get": + raise NotImplementedError("[Local Testing] GET is currently not supported.") + return MockExecutionPlan( + source_plan=MockFileOperation( + session=self.session, + operator=MockFileOperation.Operator(command), + local_file_name=file_name, + stage_location=stage_location[1:-1] + if is_single_quoted(stage_location) + else stage_location, + options=options, + ), + session=self.session, + ) diff --git a/src/snowflake/snowpark/mock/select_statement.py b/src/snowflake/snowpark/mock/select_statement.py new file mode 100644 index 00000000000..e4bf64ad13c --- /dev/null +++ b/src/snowflake/snowpark/mock/select_statement.py @@ -0,0 +1,473 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from abc import ABC +from copy import copy +from typing import TYPE_CHECKING, List, Optional, Union + +from snowflake.snowpark._internal.analyzer.select_statement import ( + ColumnChangeState, + ColumnStateDict, + Selectable, + SelectSnowflakePlan, + SelectStatement, + can_clause_dependent_columns_flatten, + can_projection_dependent_columns_be_flattened, + derive_column_states_from_subquery, + initiate_column_states, +) +from snowflake.snowpark.types import LongType + +if TYPE_CHECKING: + from snowflake.snowpark._internal.analyzer.analyzer import ( + Analyzer, + ) # pragma: no cover + +from snowflake.snowpark._internal.analyzer import analyzer_utils +from snowflake.snowpark._internal.analyzer.binary_expression import And +from snowflake.snowpark._internal.analyzer.expression import ( + COLUMN_DEPENDENCY_DOLLAR, + Attribute, + Expression, + Star, + derive_dependent_columns, +) +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan, Range +from snowflake.snowpark._internal.analyzer.unary_expression import UnresolvedAlias + +SET_UNION = analyzer_utils.UNION +SET_UNION_ALL = analyzer_utils.UNION_ALL +SET_INTERSECT = analyzer_utils.INTERSECT +SET_EXCEPT = analyzer_utils.EXCEPT + + +class MockSelectable(LogicalPlan, ABC): + """The parent abstract class of a DataFrame's logical plan. It can be converted to and from a SnowflakePlan.""" + + def __init__( + self, + analyzer: "Analyzer", + ) -> None: + super().__init__() + self.analyzer = analyzer + self.pre_actions = None + self.post_actions = None + self.flatten_disabled: bool = False + self._column_states: Optional[ColumnStateDict] = None + self._execution_plan: Optional[SnowflakePlan] = None + self._attributes = None + self.expr_to_alias = {} + self.df_aliased_col_name_to_real_col_name = {} + + @property + def execution_plan(self): + """Convert to a SnowflakePlan""" + from snowflake.snowpark.mock.plan import MockExecutionPlan + + if self._execution_plan is None: + self._execution_plan = MockExecutionPlan(self, self.analyzer.session) + return self._execution_plan + + @property + def attributes(self): + return self._attributes or self.execution_plan.attributes + + @property + def column_states(self) -> ColumnStateDict: + """A dictionary that contains the column states of a query. + Refer to class ColumnStateDict. + """ + if self._column_states is None: + self._column_states = initiate_column_states( + self.attributes, + self.analyzer, + {}, + ) + return self._column_states + + def to_subqueryable(self) -> "Selectable": + """Some queries can be used in a subquery. Some can't. For details, refer to class SelectSQL.""" + return self + + +class MockSetOperand: + def __init__(self, selectable: Selectable, operator: Optional[str] = None) -> None: + super().__init__() + self.selectable = selectable + self.operator = operator + + +class MockSetStatement(MockSelectable): + def __init__( + self, *set_operands: MockSetOperand, analyzer: Optional["Analyzer"] + ) -> None: + super().__init__(analyzer=analyzer) + self.set_operands = set_operands + for operand in set_operands: + if operand.selectable.pre_actions: + if not self.pre_actions: + self.pre_actions = [] + self.pre_actions.extend(operand.selectable.pre_actions) + if operand.selectable.post_actions: + if not self.post_actions: + self.post_actions = [] + self.post_actions.extend(operand.selectable.post_actions) + + @property + def sql_query(self) -> str: + sql = f"({self.set_operands[0].selectable.sql_query})" + for i in range(1, len(self.set_operands)): + sql = f"{sql}{self.set_operands[i].operator}({self.set_operands[i].selectable.sql_query})" + return sql + + @property + def schema_query(self) -> str: + """The first operand decide the column attributes of a query with set operations. + Refer to https://docs.snowflake.com/en/sql-reference/operators-query.html#general-usage-notes""" + return self.set_operands[0].selectable.schema_query + + @property + def column_states(self) -> Optional[ColumnStateDict]: + if not self._column_states: + self._column_states = initiate_column_states( + self.set_operands[0].selectable.column_states.projection, + self.analyzer, + {}, + ) + return self._column_states + + +class MockSelectExecutionPlan(MockSelectable): + """Wrap a SnowflakePlan to a subclass of Selectable.""" + + def __init__(self, snowflake_plan: LogicalPlan, *, analyzer: "Analyzer") -> None: + super().__init__(analyzer) + self._execution_plan = analyzer.resolve(snowflake_plan) + + if isinstance(snowflake_plan, Range): + self._attributes = [Attribute('"ID"', LongType(), False)] + + self.api_calls = [] + + +class MockSelectStatement(MockSelectable): + """The main logic plan to be used by a DataFrame. + It structurally has the parts of a query and uses the ColumnState to decide whether a query can be flattened.""" + + def __init__( + self, + *, + projection: Optional[List[Expression]] = None, + from_: Optional["MockSelectable"] = None, + where: Optional[Expression] = None, + order_by: Optional[List[Expression]] = None, + limit_: Optional[int] = None, + offset: Optional[int] = None, + analyzer: "Analyzer", + ) -> None: + super().__init__(analyzer) + self.projection: List[Expression] = projection or [Star([])] + self.from_: Optional["Selectable"] = from_ + self.where: Optional[Expression] = where + self.order_by: Optional[List[Expression]] = order_by + self.limit_: Optional[int] = limit_ + self.offset = offset + self.pre_actions = self.from_.pre_actions + self.post_actions = self.from_.post_actions + self._sql_query = None + self._schema_query = None + self._projection_in_str = None + self.api_calls = ( + self.from_.api_calls.copy() if self.from_.api_calls is not None else None + ) # will be replaced by new api calls if any operation. + + def __copy__(self): + new = MockSelectStatement( + projection=self.projection, + from_=self.from_, + where=self.where, + order_by=self.order_by, + limit_=self.limit_, + offset=self.offset, + analyzer=self.analyzer, + ) + # The following values will change if they're None in the newly copied one so reset their values here + # to avoid problems. + new._column_states = None + new.flatten_disabled = False # by default a SelectStatement can be flattened. + return new + + @property + def column_states(self) -> ColumnStateDict: + if self._column_states is None: + if not self.projection and not self.has_clause: + self._column_states = self.from_.column_states + elif isinstance(self.from_, MockSelectExecutionPlan): + self._column_states = initiate_column_states( + self.from_.attributes, self.analyzer, {} + ) + elif isinstance(self.from_, MockSelectStatement): + self._column_states = self.from_.column_states + else: + super().column_states # will assign value to self._column_states + return self._column_states + + @property + def has_clause_using_columns(self) -> bool: + return any( + ( + self.where is not None, + self.order_by is not None, + ) + ) + + @property + def has_clause(self) -> bool: + return self.has_clause_using_columns or self.limit_ is not None + + @property + def projection_in_str(self) -> str: + if not self._projection_in_str: + self._projection_in_str = ( + analyzer_utils.COMMA.join( + self.analyzer.analyze(x) for x in self.projection + ) + if self.projection + else analyzer_utils.STAR + ) + return self._projection_in_str + + def select(self, cols: List[Expression]) -> "SelectStatement": + """Build a new query. This SelectStatement will be the subquery of the new query. + Possibly flatten the new query and the subquery (self) to form a new flattened query. + """ + if ( + len(cols) == 1 + and isinstance(cols[0], UnresolvedAlias) + and isinstance(cols[0].child, Star) + and not cols[0].child.expressions + # df.select("*") doesn't have the child.expressions + # df.select(df["*"]) has the child.expressions + ): + new = copy(self) # it copies the api_calls + new._projection_in_str = self._projection_in_str + new._schema_query = self._schema_query + new._column_states = self._column_states + new.flatten_disabled = self.flatten_disabled + new._execution_plan = self._execution_plan + return new + final_projection = [] + disable_next_level_flatten = False + new_column_states = derive_column_states_from_subquery(cols, self) + if new_column_states is None: + can_be_flattened = False + disable_next_level_flatten = True + elif len(new_column_states.active_columns) != len(new_column_states.projection): + # There must be duplicate columns in the projection. + # We don't flatten when there are duplicate columns. + can_be_flattened = False + disable_next_level_flatten = True + elif self.flatten_disabled or self.has_clause_using_columns: + can_be_flattened = False + else: + can_be_flattened = True + subquery_column_states = self.column_states + for col, state in new_column_states.items(): + dependent_columns = state.dependent_columns + if dependent_columns == COLUMN_DEPENDENCY_DOLLAR: + can_be_flattened = False + break + subquery_state = subquery_column_states.get(col) + if state.change_state in ( + ColumnChangeState.CHANGED_EXP, + ColumnChangeState.NEW, + ): + can_be_flattened = can_projection_dependent_columns_be_flattened( + dependent_columns, subquery_column_states + ) + if not can_be_flattened: + break + final_projection.append(copy(state.expression)) + elif state.change_state == ColumnChangeState.UNCHANGED_EXP: + # query may change sequence of columns. If subquery has same-level reference, flattened sql may not work. + if ( + col not in subquery_column_states + or subquery_column_states[col].depend_on_same_level + ): + can_be_flattened = False + break + final_projection.append( + copy(subquery_column_states[col].expression) + ) # add subquery's expression for this column name + elif state.change_state == ColumnChangeState.DROPPED: + if ( + subquery_state.change_state == ColumnChangeState.NEW + and subquery_state.is_referenced_by_same_level_column + ): + can_be_flattened = False + break + else: # pragma: no cover + raise ValueError(f"Invalid column state {state}.") + if can_be_flattened: + new = copy(self) + new.projection = final_projection + new.from_ = self.from_ + new.pre_actions = new.from_.pre_actions + new.post_actions = new.from_.post_actions + else: + new = MockSelectStatement( + projection=cols, from_=self, analyzer=self.analyzer + ) + new.flatten_disabled = disable_next_level_flatten + new._column_states = derive_column_states_from_subquery( + new.projection, new.from_ + ) + # If new._column_states is None, when property `column_states` is called later, + # a query will be described and an error like "invalid identifier" will be thrown. + + return new + + def filter(self, col: Expression) -> "MockSelectStatement": + if self.flatten_disabled: + can_be_flattened = False + else: + dependent_columns = derive_dependent_columns(col) + can_be_flattened = can_clause_dependent_columns_flatten( + dependent_columns, self.column_states + ) + if can_be_flattened: + new = copy(self) + new.from_ = self.from_.to_subqueryable() + new.pre_actions = new.from_.pre_actions + new.post_actions = new.from_.post_actions + new.where = And(self.where, col) if self.where is not None else col + new._column_states = self._column_states + else: + new = MockSelectStatement( + from_=self.to_subqueryable(), where=col, analyzer=self.analyzer + ) + return new + + def sort(self, cols: List[Expression]) -> "SelectStatement": + if self.flatten_disabled: + can_be_flattened = False + else: + dependent_columns = derive_dependent_columns(*cols) + can_be_flattened = can_clause_dependent_columns_flatten( + dependent_columns, self.column_states + ) + if can_be_flattened: + new = copy(self) + new.from_ = self.from_.to_subqueryable() + new.pre_actions = new.from_.pre_actions + new.post_actions = new.from_.post_actions + new.order_by = cols + new._column_states = self._column_states + else: + new = MockSelectStatement( + from_=self.to_subqueryable(), order_by=cols, analyzer=self.analyzer + ) + return new + + def set_operator( + self, + *selectables: Union[ + SelectSnowflakePlan, + "SelectStatement", + ], + operator: str, + ) -> "SelectStatement": + if isinstance(self.from_, MockSetStatement) and not self.has_clause: + last_operator = self.from_.set_operands[-1].operator + if operator == last_operator: + existing_set_operands = self.from_.set_operands + set_operands = tuple( + MockSetOperand(x.to_subqueryable(), operator) for x in selectables + ) + elif operator == SET_INTERSECT: + # In Snowflake SQL, intersect has higher precedence than other set operators. + # So we need to put all operands before intersect into a single operand. + existing_set_operands = ( + MockSetOperand( + MockSetStatement( + *self.from_.set_operands, analyzer=self.analyzer + ) + ), + ) + sub_statement = MockSetStatement( + *( + MockSetOperand(x.to_subqueryable(), operator) + for x in selectables + ), + analyzer=self.analyzer, + ) + set_operands = ( + MockSetOperand(sub_statement.to_subqueryable(), operator), + ) + else: + existing_set_operands = self.from_.set_operands + sub_statement = MockSetStatement( + *( + MockSetOperand(x.to_subqueryable(), operator) + for x in selectables + ), + analyzer=self.analyzer, + ) + set_operands = ( + MockSetOperand(sub_statement.to_subqueryable(), operator), + ) + set_statement = MockSetStatement( + *existing_set_operands, *set_operands, analyzer=self.analyzer + ) + else: + set_operands = tuple( + MockSetOperand(x.to_subqueryable(), operator) for x in selectables + ) + set_statement = MockSetStatement( + MockSetOperand(self.to_subqueryable()), + *set_operands, + analyzer=self.analyzer, + ) + api_calls = self.api_calls.copy() + for s in selectables: + if s.api_calls: + api_calls.extend(s.api_calls) + set_statement.api_calls = api_calls + new = MockSelectStatement(analyzer=self.analyzer, from_=set_statement) + new._column_states = set_statement.column_states + return new + + def limit(self, n: int, *, offset: int = 0) -> "SelectStatement": + new = copy(self) + new.from_ = self.from_.to_subqueryable() + new.limit_ = min(self.limit_, n) if self.limit_ else n + new.offset = (self.offset + offset) if self.offset else offset + new._column_states = self._column_states + return new + + def to_subqueryable(self) -> "Selectable": + """When this SelectStatement's subquery is not subqueryable (can't be used in `from` clause of the sql), + convert it to subqueryable and create a new SelectStatement with from_ being the new subqueryable。 + An example is "show tables", which will be converted to a pre-action "show tables" and "select from result_scan(query_id_of_show_tables)". + """ + from_subqueryable = self.from_.to_subqueryable() + if self.from_ is not from_subqueryable: + new = copy(self) + new.pre_actions = from_subqueryable.pre_actions + new.post_actions = from_subqueryable.post_actions + new.from_ = from_subqueryable + new._column_states = self._column_states + return new + return self + + +class MockSelectableEntity(MockSelectable): + """Query from a table, view, or any other Snowflake objects. + Mainly used by session.table(). + """ + + def __init__(self, entity_name: str, *, analyzer: "Analyzer") -> None: + super().__init__(analyzer) + self.entity_name = entity_name + self.api_calls = [] diff --git a/src/snowflake/snowpark/mock/snowflake_data_type.py b/src/snowflake/snowpark/mock/snowflake_data_type.py new file mode 100644 index 00000000000..610e048a974 --- /dev/null +++ b/src/snowflake/snowpark/mock/snowflake_data_type.py @@ -0,0 +1,494 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Dict, NamedTuple, Optional, Union + +from snowflake.connector.options import installed_pandas, pandas as pd +from snowflake.snowpark.types import ( + BooleanType, + DataType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + _IntegralType, + _NumericType, +) + +# pandas is an optional requirement for local test, so make snowpark compatible with env where pandas +# not installed, here we redefine the base class to avoid ImportError +PandasDataframeType = object if not installed_pandas else pd.DataFrame +PandasSeriesType = object if not installed_pandas else pd.Series + + +class Operator: + def op(self, *operands): + pass + + +class Add(Operator): + def op(self, *operands): + if len(operands) == 1: + return type(operands[0]) + + +class Minus(Operator): + ... + + +class Multiply(Operator): + ... + + +class FunctionCall(Operator): + ... + + +""" +https://docs.snowflake.com/en/sql-reference/data-type-conversion +""" + +""" +@dataclasses.dataclass +class SnowDataTypeConversion: + from_type: Type[DataType] + to_type: Type[DataType] + castable: bool + coercible: bool + + +SNOW_DATA_TYPE_CONVERSION_LIST = [ + SnowDataTypeConversion(ArrayType, StringType, True, False), + SnowDataTypeConversion(ArrayType, VariantType, True, True), + SnowDataTypeConversion(BinaryType, StringType, True, False), + SnowDataTypeConversion(BinaryType, VariantType, True, False), + SnowDataTypeConversion(BooleanType, DecimalType, True, False), + SnowDataTypeConversion(BooleanType, StringType, True, True), + SnowDataTypeConversion(BooleanType, VariantType, True, True), + SnowDataTypeConversion(DateType, TimestampType, True, False), + SnowDataTypeConversion(DateType, StringType, True, True), + SnowDataTypeConversion(DateType, VariantType, True, False), + SnowDataTypeConversion(FloatType, BooleanType, True, True), + SnowDataTypeConversion(FloatType, DecimalType, True, True), + SnowDataTypeConversion(FloatType, StringType, True, True), + SnowDataTypeConversion(FloatType, VariantType, True, True), + SnowDataTypeConversion(GeographyType, VariantType, True, False), + # SnowDataTypeConversion(GeometryType, VariantType, True, False), # GeometryType isn't available yet. + SnowDataTypeConversion(DecimalType, BooleanType, True, True), + SnowDataTypeConversion(DecimalType, FloatType, True, True), + SnowDataTypeConversion(DecimalType, TimestampType, True, True), + SnowDataTypeConversion(DecimalType, StringType, True, True), + SnowDataTypeConversion(DecimalType, VariantType, True, True), + SnowDataTypeConversion(MapType, ArrayType, True, False), + SnowDataTypeConversion(MapType, StringType, True, False), + SnowDataTypeConversion(MapType, VariantType, True, True), + SnowDataTypeConversion(TimeType, StringType, True, True), + SnowDataTypeConversion(TimeType, VariantType, True, False), + SnowDataTypeConversion(TimestampType, DateType, True, True), + SnowDataTypeConversion(TimestampType, TimeType, True, True), + SnowDataTypeConversion(TimestampType, StringType, True, True), + SnowDataTypeConversion(TimestampType, VariantType, True, False), + SnowDataTypeConversion(StringType, BooleanType, True, True), + SnowDataTypeConversion(StringType, DateType, True, True), + SnowDataTypeConversion(StringType, FloatType, True, True), + SnowDataTypeConversion(StringType, DecimalType, True, True), + SnowDataTypeConversion(StringType, TimeType, True, True), + SnowDataTypeConversion(StringType, TimestampType, True, True), + SnowDataTypeConversion(StringType, VariantType, True, False), + SnowDataTypeConversion(VariantType, DateType, True, True), + SnowDataTypeConversion(VariantType, FloatType, True, True), + SnowDataTypeConversion(VariantType, GeographyType, True, False), + SnowDataTypeConversion(VariantType, DecimalType, True, True), + SnowDataTypeConversion(VariantType, MapType, True, True), + SnowDataTypeConversion(VariantType, TimeType, True, True), + SnowDataTypeConversion(VariantType, TimestampType, True, True), + SnowDataTypeConversion(VariantType, StringType, True, True), +] + + +SNOW_DATA_TYPE_CONVERSION_DICT = { + (x.from_type, x.to_type): x for x in SNOW_DATA_TYPE_CONVERSION_LIST +} +""" + + +class ColumnType(NamedTuple): + datatype: DataType + nullable: bool + + +def normalize_decimal(d: DecimalType): + if d.scale > d.precision or d.scale > 38 or d.scale < 0 or d.precision < 0: + raise ValueError( + f"Inferred data type DecimalType({d.precision}, {d.scale}) is invalid." + ) + d.precision = min(38, d.precision) + + +def normalize_output_sf_type(t: DataType) -> DataType: + if t == DecimalType(38, 0): + return LongType() + return t + + +def calculate_type(c1: ColumnType, c2: Optional[ColumnType], op: Union[str]): + """op, left, right decide what's next.""" + t1, t2 = c1.datatype, c2.datatype + nullable = c1.nullable or c2.nullable + decimal_types = (IntegerType, LongType, DecimalType) + if isinstance(t1, decimal_types) and isinstance(t2, decimal_types): + p1, s1 = get_number_precision_scale(t1) + p2, s2 = get_number_precision_scale(t2) + if op == "/": + division_min_scale = 6 + division_max_scale = 12 + l1 = p1 - s1 + res_scale = max(min(s1 + division_min_scale, division_max_scale), s1) + res_lead = l1 + s2 + res_precision = min(38, res_scale + res_lead) + result_type = normalize_output_sf_type( + DecimalType(res_precision, res_scale) + ) + return ColumnType(result_type, nullable) + elif op == "*": + multiplication_max_scale = 12 + l1 = p1 - s1 + l2 = p2 - s2 + result_scale = min(s1 + s2, max(multiplication_max_scale, max(s1, s2))) + result_precision = min(38, result_scale + l1 + l2) + result_type = DecimalType(result_precision, result_scale) + normalize_decimal(result_type) + result_type = normalize_output_sf_type(result_type) + return ColumnType(result_type, nullable) + elif op in ("+", "-"): + # widen the number with smaller scale + if s1 > s2: + gap = s1 - s2 + if p2 - s2 == 1: # special logic in Snowflake + gap = gap + 1 + p2 += gap + s2 += gap + elif s1 < s2: + gap = s2 - s1 + if p1 - s1 == 1: + gap = gap + 1 + p1 += gap + s1 += gap + result_type = normalize_output_sf_type( + DecimalType(min(38, max(p1, p2) + 1), max(s1, s2)) + ) + return ColumnType(result_type, nullable) + elif op == "%": + new_scale = max(s1, s2) + new_decimal = max(p1 - s1, p2 - s2) + new_decimal = new_decimal + new_scale + result_type = normalize_output_sf_type(DecimalType(new_decimal, new_scale)) + return ColumnType(result_type, nullable) + else: + return NotImplementedError( + f"Type inference for operator {op} is implemented." + ) + elif isinstance(t1, (FloatType, DoubleType)) or isinstance( + t2, (FloatType, DoubleType) + ): + return ColumnType(DoubleType(), nullable) + elif isinstance(t1, DateType) or isinstance(t2, DateType): + if isinstance(t2, DateType): + t1, t2 = t2, t1 + if t2 not in ( + IntegerType, + LongType, + DecimalType, + FloatType, + DoubleType, + ) or op not in ("+", "-"): + raise ValueError( + f"Result data type can't be calculated: (type1: {t1}, op: '{op}', type2: {t2})." + ) + return ColumnType(DateType(), nullable) + + raise TypeError( + f"Result data type can't be calculated: (type1: {t1}, op: '{op}', type2: {t2})." + ) + + +class TableEmulator(PandasDataframeType): + _metadata = ["sf_types", "sf_types_by_col_index", "_null_rows_idxs_map"] + + @property + def _constructor(self): + return TableEmulator + + @property + def _constructor_sliced(self): + return ColumnEmulator + + def __init__( + self, + *args, + sf_types: Optional[Dict[str, ColumnType]] = None, + sf_types_by_col_index: Optional[Dict[int, ColumnType]] = None, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.sf_types = {} if not sf_types else sf_types + # TODO: SNOW-976145, move to index based approach to store col type mapping + self.sf_types_by_col_index = ( + {} if not sf_types_by_col_index else sf_types_by_col_index + ) + self._null_rows_idxs_map = {} + + def __getitem__(self, item): + result = super().__getitem__(item) + if isinstance(result, ColumnEmulator): # pandas.Series + result.sf_type = self.sf_types.get(item) + elif isinstance(result, TableEmulator): # pandas.DataFrame + result.sf_types = self.sf_types + else: + # TODO: figure out what cases, it may can be removed + # list of columns + for ce in result: + ce.sf_type = self.sf_types.get(ce.name) + return result + + def __setitem__(self, key, value): + super().__setitem__(key, value) + if isinstance(value, ColumnEmulator): + self.sf_types[key] = value.sf_type + self._null_rows_idxs_map[key] = value._null_rows_idxs + + def sort_values(self, by, **kwargs): + result = super().sort_values(by, **kwargs) + result.sf_types = self.sf_types + return result + + +def get_number_precision_scale(t: DataType): + if isinstance(t, (IntegerType, LongType)): + return 38, 0 + if isinstance(t, DecimalType): + return t.precision, t.scale + return None, None + + +def add_date_and_number( + col1: "ColumnEmulator", col2: "ColumnEmulator" +) -> Optional["ColumnEmulator"]: + """If one column is DateType and another column is numeric, round and add the numeric to days""" + if isinstance(col2.sf_type.datatype, DateType): + col1, col2 = col2, col1 + if isinstance(col1.sf_type.datatype, DateType) and isinstance( + col2.sf_type.datatype, _NumericType + ): + result = pd.to_datetime(col1) + pd.to_timedelta(round(col2), unit="d") + result.sf_type = ColumnType( + DateType(), col1.sf_type.nullable or col2.sf_type.nullable + ) + return result + raise ValueError(f"Can't add {col1.sf_type.datatype} and {col2.sf_type.datatype}") + + +class ColumnEmulator(PandasSeriesType): + _metadata = ["sf_type", "_null_rows_idxs"] + + @property + def _constructor(self): + return ColumnEmulator + + @property + def _constructor_expanddim(self): + return TableEmulator + + def __init__(self, *args, **kwargs) -> None: + sf_type = kwargs.pop("sf_type", None) + super().__init__(*args, **kwargs) + self.sf_type: ColumnType = sf_type + # record which rows should be marked as null instead of None + # snowflake SubfieldString has this behavior + # suppose there are two Variant objects in table "v": 1. { "a": None } 2. None + # if we do sub-field v["a"], snowpark python return ['null', None] instead of [None, None] + # however during the calculation we want to keep using None, so we need extra data structure to store + # the information of null vs None + # check SNOW-960190 for more context + self._null_rows_idxs = [] + + def set_sf_type(self, value): + self.sf_type = value + + def __add__(self, other): + """TODO: needs to calculate date +""" + if isinstance(self.sf_type.datatype, DateType) or isinstance( + other.sf_type.datatype, DateType + ): + return add_date_and_number(self, other) + result = super().__add__(other) + if self.sf_type: + result.sf_type = calculate_type(self.sf_type, other.sf_type, op="+") + return result + + def __radd__(self, other): + if isinstance(self.sf_type.datatype, DateType) or isinstance( + other.sf_type.datatype, DateType + ): + return add_date_and_number(self, other) + result = super().__radd__(other) + result.sf_type = calculate_type(other.sf_type, self.sf_type, op="+") + return result + + def __sub__(self, other): + if isinstance(self.sf_type.datatype, DateType) and isinstance( + other.sf_type.datatype, _NumericType + ): + return add_date_and_number(self, -other) + result = super().__sub__(other) + result.sf_type = calculate_type(self.sf_type, other.sf_type, op="-") + return result + + def __rsub__(self, other): + result = super().__rsub__(other) + result.sf_type = calculate_type(other.sf_type, self.sf_type, op="-") + return result + + def __mul__(self, other): + result = super().__mul__(other) + result.sf_type = calculate_type(self.sf_type, other.sf_type, op="*") + return result + + def __rmul__(self, other): + result = super().__rmul__(other) + result.sf_type = calculate_type(other.sf_type, self.sf_type, op="*") + return result + + def __bool__(self): + result = super().__bool__() + result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + return result + + def __and__(self, other): + result = super().__and__(other) + result.sf_type = ColumnType(BooleanType(), True) + return result + + def __or__(self, other): + result = super().__or__(other) + result.sf_type = ColumnType(BooleanType(), True) + return result + + def __ne__(self, other): + result = super().__ne__(other) + result.sf_type = ColumnType(BooleanType(), True) + return result + + def __xor__(self, other): + result = super().__xor__(other) + result.sf_type = ColumnType(BooleanType(), True) + return result + + def __pow__(self, power): + result = super().__pow__(power) + result.sf_type = ColumnType( + DoubleType(), self.sf_type.nullable or power.sf_type.nullable + ) + return result + + def __ge__(self, other): + result = super().__ge__(other) + result.sf_type = ColumnType(BooleanType(), True) + return result + + def __gt__(self, other): + result = super().__gt__(other) + result.sf_type = ColumnType(BooleanType(), True) + return result + + def __invert__(self): + result = super().__invert__() + result.sf_type = ColumnType(BooleanType(), True) + return result + + def __le__(self, other): + result = super().__le__(other) + result.sf_type = ColumnType(BooleanType(), True) + return result + + def __lt__(self, other): + result = super().__lt__(other) + result.sf_type = ColumnType(BooleanType(), True) + return result + + def __eq__(self, other): + result = super().__eq__(other) + result.sf_type = ColumnType(BooleanType(), True) + return result + + def __neg__(self): + result = super().__neg__() + result.sf_type = self.sf_type + return result + + def __rand__(self, other): + result = super().__rand__(other) + result.sf_type = ColumnType(BooleanType(), True) + return result + + def __mod__(self, other): + result = super().__mod__(other) + result.sf_type = calculate_type(self.sf_type, other.sf_type, op="%") + return result + + def __rmod__(self, other): + result = super().__mod__(other) + result.sf_type = calculate_type(other.sf_type, self.sf_type, op="%") + return result + + def __ror__(self, other): + result = super().__ror__(other) + result.sf_type = ColumnType(BooleanType(), True) + return result + + def __round__(self, n=None): + result = super().__round__(n) + if isinstance(self.sf_type.datatype, (FloatType, DoubleType, _IntegralType)): + result.sf_type = self.sf_type + elif isinstance(self.sf_type.datatype, DecimalType): + scale = self.sf_type.datatype.scale + if scale <= n: + result.sf_type = self.sf_type + else: + result_scale = 0 if n <= 0 else n + result_precision = min(self.sf_type.datatype.precision + 1, 38) + result.sf_type = ColumnType( + DecimalType(result_precision, result_scale), self.sf_type.nullable + ) + return result + + def __rpow__(self, other): + result = super().__rpow__(other) + result.sf_type = ColumnType(DoubleType(), True) + return result + + def __rtruediv__(self, other): + return other.__truediv__(self) + + def __truediv__(self, other): + result = super().__truediv__(other) + sf_type = calculate_type(self.sf_type, other.sf_type, op="/") + if isinstance(sf_type.datatype, DecimalType): + result = result.astype("double").round(sf_type.datatype.scale) + elif isinstance(sf_type.datatype, (FloatType, DoubleType)): + result = result.astype("double").round(16) + result.sf_type = sf_type + + return result + + def isna(self): + result = super().isna() + result.sf_type = ColumnType(BooleanType(), True) + return result + + def isnull(self): + result = super().isnull() + result.sf_type = ColumnType(BooleanType(), True) + return result diff --git a/src/snowflake/snowpark/mock/snowflake_to_pandas_converter.py b/src/snowflake/snowpark/mock/snowflake_to_pandas_converter.py new file mode 100644 index 00000000000..4a7a7501854 --- /dev/null +++ b/src/snowflake/snowpark/mock/snowflake_to_pandas_converter.py @@ -0,0 +1,227 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +""" +The converter module is used to convert string into data in pandas dataframe complying with snowflake spec. +for example, when we call pandas.read_csv, we use the converter functions to validate, convert the data into python +objects according to snowflake datatype following the spec. Otherwise, pandas.read_csv takes data as raw string in +most cases. + +For full data type spec, please refer to https://docs.snowflake.com/en/sql-reference/data-types. +""" + +import datetime +from decimal import Decimal +from typing import Optional, Union + +from snowflake.snowpark.exceptions import SnowparkSQLException +from snowflake.snowpark.types import ( + BooleanType, + ByteType, + DataType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, + StringType, + TimestampType, + TimeType, +) + +TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S" +DATE_FORMAT = "%Y-%m-%d" +TIME_FORMAT = "%H:%M:%S" + + +def _integer_converter( + value: str, datatype: DataType, field_optionally_enclosed_by: str = None +) -> Optional[int]: + if value is None or value == "": + return None + if ( + field_optionally_enclosed_by + and len(value) >= 2 + and value[0] == field_optionally_enclosed_by + and value[-1] == field_optionally_enclosed_by + ): + value = value[1:-1] + try: + return int(value) + except ValueError: + raise SnowparkSQLException( + f"[Local Testing] Numeric value '{value}' is not recognized." + ) + + +def _fraction_converter( + value: str, datatype: DataType, field_optionally_enclosed_by: str = None +) -> Optional[float]: + if value is None or value == "": + return None + if ( + field_optionally_enclosed_by + and len(field_optionally_enclosed_by) >= 2 + and value[0] == field_optionally_enclosed_by + and value[-1] == field_optionally_enclosed_by + ): + value = value[1:-1] + try: + return float(value) + except ValueError: + raise SnowparkSQLException( + f"[Local Testing] Numeric value '{value}' is not recognized." + ) + + +def _decimal_converter( + value: str, datatype: DecimalType, field_optionally_enclosed_by: str = None +) -> Optional[Union[int, Decimal]]: + if value is None or value == "": + return None + if ( + field_optionally_enclosed_by + and len(field_optionally_enclosed_by) >= 2 + and value[0] == field_optionally_enclosed_by + and value[-1] == field_optionally_enclosed_by + ): + value = value[1:-1] + try: + precision = datatype.precision + scale = datatype.scale + integer_part = round(float(value)) + integer_part_str = str(integer_part) + len_integer_part = ( + len(integer_part_str) - 1 + if integer_part_str[0] == "-" + else len(integer_part_str) + ) + if len_integer_part > precision: + raise SnowparkSQLException(f"Numeric value '{value}' is out of range") + if scale == 0: + return integer_part + remaining_decimal_len = min(precision - len(str(integer_part)), scale) + return Decimal(str(round(float(value), remaining_decimal_len))) + except ValueError: + raise SnowparkSQLException( + f"[Local Testing] Numeric value '{value}' is not recognized." + ) + + +def _bool_converter( + value: str, datatype: DataType, field_optionally_enclosed_by: str = None +) -> Optional[bool]: + if value is None or value == "": + return None + if ( + field_optionally_enclosed_by + and len(field_optionally_enclosed_by) >= 2 + and value[0] == field_optionally_enclosed_by + and value[-1] == field_optionally_enclosed_by + ): + value = value[1:-1] + if value.lower() == "true": + return True + if value.lower() == "false": + return False + try: + float_value = float(value) + return bool(float_value != 0) + except TypeError: + raise SnowparkSQLException( + f"[Local Testing] Boolean value '{value}' is not recognized." + ) + + +def _string_converter( + value: str, datatype: DataType, field_optionally_enclosed_by: str = None +) -> Optional[str]: + if value is None or value == "": + return value + if ( + field_optionally_enclosed_by + and len(field_optionally_enclosed_by) >= 2 + and value[0] == field_optionally_enclosed_by + and value[-1] == field_optionally_enclosed_by + ): + value = value[1:-1] + return value + + +def _date_converter( + value: str, datatype: DataType, field_optionally_enclosed_by: str = None +) -> Optional[datetime.date]: + if value is None or value == "": + return None + if ( + field_optionally_enclosed_by + and len(field_optionally_enclosed_by) >= 2 + and value[0] == field_optionally_enclosed_by + and value[-1] == field_optionally_enclosed_by + ): + value = value[1:-1] + try: + return datetime.datetime.strptime(value, DATE_FORMAT).date() + except Exception as e: + raise SnowparkSQLException( + f"[Local Testing] DATE value '{value}' is not recognized due to error {e!r}." + ) + + +def _timestamp_converter( + value: str, datatype: DataType, field_optionally_enclosed_by: str = None +) -> Optional[datetime.datetime]: + if value is None or value == "": + return None + if ( + field_optionally_enclosed_by + and len(field_optionally_enclosed_by) >= 2 + and value[0] == field_optionally_enclosed_by + and value[-1] == field_optionally_enclosed_by + ): + value = value[1:-1] + try: + return datetime.datetime.strptime(value, TIMESTAMP_FORMAT) + except Exception as e: + raise SnowparkSQLException( + f"[Local Testing] TIMESTAMP value '{value}' is not recognized due to error {e!r}." + ) + + +def _time_converter( + value: str, datatype: DataType, field_optionally_enclosed_by: str = None +) -> Optional[datetime.time]: + if value is None or value == "": + return None + if ( + field_optionally_enclosed_by + and len(field_optionally_enclosed_by) >= 2 + and value[0] == field_optionally_enclosed_by + and value[-1] == field_optionally_enclosed_by + ): + value = value[1:-1] + try: + return datetime.datetime.strptime(value, TIME_FORMAT).time() + except Exception as e: + raise SnowparkSQLException( + f"[Local Testing] TIMESTAMP value '{value}' is not recognized due to error {e!r}." + ) + + +CONVERT_MAP = { + IntegerType: _integer_converter, + LongType: _integer_converter, + ByteType: _integer_converter, + ShortType: _integer_converter, + DoubleType: _fraction_converter, + FloatType: _fraction_converter, + DecimalType: _decimal_converter, + BooleanType: _bool_converter, + DateType: _date_converter, + TimeType: _time_converter, + TimestampType: _timestamp_converter, + StringType: _string_converter, +} diff --git a/src/snowflake/snowpark/mock/util.py b/src/snowflake/snowpark/mock/util.py new file mode 100644 index 00000000000..302a225a9e0 --- /dev/null +++ b/src/snowflake/snowpark/mock/util.py @@ -0,0 +1,166 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import math +from functools import cmp_to_key, partial +from typing import Any, Tuple + +from snowflake.connector.options import pandas as pd + +# placeholder map helps convert wildcard to reg. In practice, we convert wildcard to a middle string first, +# and then convert middle string to regex. See the following example: +# wildcard = "_." -> middle: "_" -> regex = ".\." +# placeholder string should not contain any special characters used in regex or wildcard +regex_special_characters_map = { + ".": "", + "\\": "", + "^": "", + "?": "", + "+": "", + "|": "", + "$": "", + "*": "", + "{": "", + "}": "", + "[": "", + "]": "", + "(": "", + ")": "", +} + +escape_regex_special_characters_map = { + regex_special_characters_map["."]: "\\.", + regex_special_characters_map["\\"]: "\\\\", + regex_special_characters_map["^"]: "\\^", + regex_special_characters_map["?"]: "\\?", + regex_special_characters_map["+"]: "\\+", + regex_special_characters_map["|"]: "\\|", + regex_special_characters_map["$"]: "\\$", + regex_special_characters_map["*"]: "\\*", + regex_special_characters_map["{"]: "\\{", + regex_special_characters_map["}"]: "\\}", + regex_special_characters_map["["]: "\\[", + regex_special_characters_map["]"]: "\\]", + regex_special_characters_map["("]: "\\(", + regex_special_characters_map[")"]: "\\)", +} + + +def convert_wildcard_to_regex(wildcard: str): + # convert regex in wildcard + for k, v in regex_special_characters_map.items(): + wildcard = wildcard.replace(k, v) + + # replace wildcard special character with regex + wildcard = wildcard.replace("_", ".") + wildcard = wildcard.replace("%", ".*") + + # escape regx in wildcard + for k, v in escape_regex_special_characters_map.items(): + wildcard = wildcard.replace(k, v) + + wildcard = f"^{wildcard}$" + return wildcard + + +def custom_comparator(ascend: bool, null_first: bool, pandas_series: "pd.Series"): + origin_array = pandas_series.values.tolist() + array_with_pos = list(zip([i for i in range(len(pandas_series))], origin_array)) + comparator = partial(array_custom_comparator, ascend, null_first) + array_with_pos.sort(key=cmp_to_key(comparator)) + new_pos = [0] * len(array_with_pos) + for i in range(len(array_with_pos)): + new_pos[array_with_pos[i][0]] = i + return new_pos + + +def array_custom_comparator(ascend: bool, null_first: bool, a: Any, b: Any): + value_a, value_b = a[1], b[1] + if value_a == value_b: + return 0 + if value_a is None: + return -1 if null_first else 1 + elif value_b is None: + return 1 if null_first else -1 + try: + if math.isnan(value_a) and math.isnan(value_b): + return 0 + elif math.isnan(value_a): + ret = 1 + elif math.isnan(value_b): + ret = -1 + else: + ret = -1 if value_a < value_b else 1 + except TypeError: + ret = -1 if value_a < value_b else 1 + return ret if ascend else -1 * ret + + +def convert_snowflake_datetime_format(format, default_format) -> Tuple[str, int, int]: + """ + unified processing of the time format + converting snowflake date/time/timestamp format into python datetime format + """ + + # if this is a PM time in 12-hour format, +12 hour + hour_delta = 12 if format is not None and "HH12" in format and "PM" in format else 0 + time_fmt = format.upper() if format else default_format + time_fmt = time_fmt.replace("YYYY", "%Y") + time_fmt = time_fmt.replace("MM", "%m") + time_fmt = time_fmt.replace("MON", "%b") + time_fmt = time_fmt.replace("DD", "%d") + time_fmt = time_fmt.replace("HH24", "%H") + time_fmt = time_fmt.replace("HH12", "%H") + time_fmt = time_fmt.replace("MI", "%M") + time_fmt = time_fmt.replace("SS", "%S") + time_fmt = time_fmt.replace("SS", "%S") + fractional_seconds = 9 + if format is not None and "FF" in format: + try: + ff_index = str(format).index("FF") + # handle precision string 'FF[0-9]' which could be like FF0, FF1, ..., FF9 + if str(format[ff_index + 2 : ff_index + 3]).isdigit(): + fractional_seconds = int(format[ff_index + 2 : ff_index + 3]) + # replace FF[0-9] with %f + time_fmt = time_fmt[:ff_index] + "%f" + time_fmt[ff_index + 3 :] + else: + time_fmt = time_fmt[:ff_index] + "%f" + time_fmt[ff_index + 2 :] + except ValueError: + # 'FF' is not in the fmt + pass + + return time_fmt, hour_delta, fractional_seconds + + +def process_numeric_time(time: str) -> int: + """ + deal with time of numeric values, convert the time into value that Python datetime accepts + spec here: https://docs.snowflake.com/en/sql-reference/functions/to_time#usage-notes + + """ + timestamp_values = int(time) + if 31536000000000 <= timestamp_values < 31536000000000: # milliseconds + timestamp_values = timestamp_values / 1000 + elif timestamp_values >= 31536000000000: + # nanoseconds + timestamp_values = timestamp_values / 1000000 + # timestamp_values < 31536000000 are treated as seconds + return int(timestamp_values) + + +def process_string_time_with_fractional_seconds(time: str, fractional_seconds) -> str: + # deal with the fractional seconds part of the input time str, apply precision and reconstruct the time string + ret = time + time_parts = ret.split(".") + if len(time_parts) == 2: + # there is a part of seconds + seconds_part = time_parts[1] + # find the idx that the seconds part ends + idx = 0 + while idx < len(seconds_part) and seconds_part[idx].isdigit(): + idx += 1 + # truncate to precision + seconds_part = seconds_part[: min(idx, fractional_seconds)] + seconds_part[idx:] + ret = f"{time_parts[0]}.{seconds_part}" + return ret diff --git a/src/snowflake/snowpark/mock/window_utils.py b/src/snowflake/snowpark/mock/window_utils.py new file mode 100644 index 00000000000..a5f27a54d5b --- /dev/null +++ b/src/snowflake/snowpark/mock/window_utils.py @@ -0,0 +1,84 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +try: + import numpy as np + from pandas.api.indexers import BaseIndexer +except ImportError: + # snowflake dataframe.py imports module that indirectly depends on this window_utils.py + # to avoid impacting the live session features which doesn't need pandas + # we ignore the error for now, there might be other better ways to workaround the issue + BaseIndexer = object + pass + +from snowflake.snowpark._internal.analyzer.expression import FunctionExpression, Literal +from snowflake.snowpark._internal.analyzer.window_expression import ( + CurrentRow, + FirstValue, + Lag, + LastValue, + Lead, + UnboundedFollowing, + UnboundedPreceding, +) + + +class EntireWindowIndexer(BaseIndexer): + def get_window_bounds(self, num_values, min_periods, center, closed, step): + start = np.empty(num_values, dtype=np.int64) + end = np.empty(num_values, dtype=np.int64) + for i in range(num_values): + start[i] = 0 + end[i] = num_values + + return start, end + + +class RowFrameIndexer(BaseIndexer): + def get_window_bounds(self, num_values, min_periods, center, closed, step): + start = np.empty(num_values, dtype=np.int64) + end = np.empty(num_values, dtype=np.int64) + + upper = self.frame_spec.upper + lower = self.frame_spec.lower + + for i in range(num_values): + if isinstance(lower, CurrentRow): + start[i] = i + elif isinstance(lower, UnboundedPreceding): + start[i] = 0 + else: + assert isinstance(lower, Literal) + start[i] = max(0, min(i + lower.value, num_values)) + + if isinstance(upper, CurrentRow): + end[i] = i + 1 # + 1 to include the right endpoint + elif isinstance(upper, UnboundedFollowing): + end[i] = num_values + else: + assert isinstance(upper, Literal) + end[i] = max( + 0, min(i + upper.value + 1, num_values) + ) # + 1 to include the right endpoint + + return start, end + + +# TODO: Add all rank related functions + +RANK_RELATED_FUNCTIONS = ( + Lead, + Lag, + LastValue, + FirstValue, +) + +RANK_RELATED_FUNCTION_NAMES = ("row_number",) + + +def is_rank_related_window_function(func): + return isinstance(func, RANK_RELATED_FUNCTIONS) or ( + isinstance(func, FunctionExpression) + and func.name in RANK_RELATED_FUNCTION_NAMES + ) diff --git a/src/snowflake/snowpark/relational_grouped_dataframe.py b/src/snowflake/snowpark/relational_grouped_dataframe.py index 873e60417e2..1e3dc36295e 100644 --- a/src/snowflake/snowpark/relational_grouped_dataframe.py +++ b/src/snowflake/snowpark/relational_grouped_dataframe.py @@ -18,10 +18,6 @@ GroupingSetsExpression, Rollup, ) -from snowflake.snowpark._internal.analyzer.select_statement import ( - SelectSnowflakePlan, - SelectStatement, -) from snowflake.snowpark._internal.analyzer.unary_expression import ( Alias, UnresolvedAlias, @@ -183,9 +179,9 @@ def _to_df(self, agg_exprs: List[Expression]) -> DataFrame: raise TypeError(f"Wrong group by type {self._group_type}") if self._df._select_statement: - group_plan = SelectStatement( - from_=SelectSnowflakePlan( - snowflake_plan=group_plan, analyzer=self._df._session._analyzer + group_plan = self._df._session._analyzer.create_select_statement( + from_=self._df._session._analyzer.create_select_snowflake_plan( + group_plan, analyzer=self._df._session._analyzer ), analyzer=self._df._session._analyzer, ) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 9c9eb714ad4..354689194e8 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -28,7 +28,6 @@ from snowflake.snowpark._internal.analyzer.datatype_mapper import str_to_sql from snowflake.snowpark._internal.analyzer.expression import Attribute from snowflake.snowpark._internal.analyzer.select_statement import ( - SelectSnowflakePlan, SelectSQL, SelectStatement, SelectTableFunction, @@ -124,6 +123,13 @@ to_timestamp, to_variant, ) +from snowflake.snowpark.mock.analyzer import MockAnalyzer +from snowflake.snowpark.mock.connection import MockServerConnection +from snowflake.snowpark.mock.pandas_util import ( + _convert_dataframe_to_table, + _extract_schema_and_data_from_pandas_df, +) +from snowflake.snowpark.mock.plan_builder import MockSnowflakePlanBuilder from snowflake.snowpark.query_history import QueryHistory from snowflake.snowpark.row import Row from snowflake.snowpark.stored_procedure import StoredProcedureRegistration @@ -313,7 +319,11 @@ def configs( def create(self) -> "Session": """Creates a new Session.""" - session = self._create_internal(self._options.get("connection")) + if self._options.get("local_testing", False): + session = Session(MockServerConnection(), self._options) + _add_session(session) + else: + session = self._create_internal(self._options.get("connection")) return session def getOrCreate(self) -> "Session": @@ -327,7 +337,8 @@ def getOrCreate(self) -> "Session": raise ex def _create_internal( - self, conn: Optional[SnowflakeConnection] = None + self, + conn: Optional[SnowflakeConnection] = None, ) -> "Session": # If no connection object and no connection parameter is provided, # we read from the default config file @@ -359,7 +370,9 @@ def __get__(self, obj, objtype=None): builder: SessionBuilder = SessionBuilder() def __init__( - self, conn: ServerConnection, options: Optional[Dict[str, Any]] = None + self, + conn: Union[ServerConnection, MockServerConnection], + options: Optional[Dict[str, Any]] = None, ) -> None: if len(_active_sessions) >= 1 and is_in_stored_procedure(): raise SnowparkClientExceptionMessages.DONT_CREATE_SESSION_IN_SP() @@ -381,7 +394,11 @@ def __init__( self._udtf_registration = UDTFRegistration(self) self._udaf_registration = UDAFRegistration(self) self._sp_registration = StoredProcedureRegistration(self) - self._plan_builder = SnowflakePlanBuilder(self) + self._plan_builder = ( + SnowflakePlanBuilder(self) + if isinstance(self._conn, ServerConnection) + else MockSnowflakePlanBuilder(self) + ) self._last_action_id = 0 self._last_canceled_id = 0 self._use_scoped_temp_objects: bool = ( @@ -391,7 +408,9 @@ def __init__( ) ) self._file = FileOperation(self) - self._analyzer = Analyzer(self) + self._analyzer = ( + Analyzer(self) if isinstance(conn, ServerConnection) else MockAnalyzer(self) + ) self._sql_simplifier_enabled: bool = ( self._conn._get_client_side_session_parameter( _PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING, True @@ -588,6 +607,10 @@ def add_import(self, path: str, import_path: Optional[str] = None) -> None: ``imports`` argument in :func:`functions.udf` or :meth:`session.udf.register() `. """ + if isinstance(self._conn, MockServerConnection): + raise NotImplementedError( + "[Local Testing] Stored procedures are not currently supported." + ) path, checksum, leading_path = self._resolve_import_path(path, import_path) self._import_paths[path] = (checksum, leading_path) @@ -850,6 +873,10 @@ def add_packages( to ensure the consistent experience of a UDF between your local environment and the Snowflake server. """ + if isinstance(self._conn, MockServerConnection): + raise NotImplementedError( + "[Local Testing] Add session packages is not currently supported." + ) self._resolve_packages( parse_positional_args_to_list(*packages), self._packages, @@ -1586,6 +1613,10 @@ def table_function( - :meth:`Session.generator`, which is used to instantiate a :class:`DataFrame` using Generator table function. Generator functions are not supported with :meth:`Session.table_function`. """ + if isinstance(self._conn, MockServerConnection): + raise NotImplementedError( + "[Local Testing] Table function is not currently supported." + ) func_expr = _create_table_function_expression( func_name, *func_arguments, **func_named_arguments ) @@ -1593,7 +1624,7 @@ def table_function( if self.sql_simplifier_enabled: d = DataFrame( self, - SelectStatement( + self._analyzer.create_select_statement( from_=SelectTableFunction(func_expr, analyzer=self._analyzer), analyzer=self._analyzer, ), @@ -1654,6 +1685,10 @@ def generator( Returns: A new :class:`DataFrame` with data from calling the generator table function. """ + if isinstance(self._conn, MockServerConnection): + raise NotImplementedError( + "[Local Testing] DataFrame.generator is currently not supported." + ) if not columns: raise ValueError("Columns cannot be empty for generator table function") named_args = {} @@ -1706,10 +1741,15 @@ def sql(self, query: str, params: Optional[Sequence[Any]] = None) -> DataFrame: >>> session.sql("select * from values (?, ?), (?, ?)", params=[1, "a", 2, "b"]).sort("column1").collect() [Row(COLUMN1=1, COLUMN2='a'), Row(COLUMN1=2, COLUMN2='b')] """ + if isinstance(self._conn, MockServerConnection): + raise NotImplementedError( + "[Local Testing] `Session.sql` is currently not supported." + ) + if self.sql_simplifier_enabled: d = DataFrame( self, - SelectStatement( + self._analyzer.create_select_statement( from_=SelectSQL(query, analyzer=self._analyzer, params=params), analyzer=self._analyzer, ), @@ -1717,7 +1757,9 @@ def sql(self, query: str, params: Optional[Sequence[Any]] = None) -> DataFrame: else: d = DataFrame( self, - self._plan_builder.query(query, source_plan=None, params=params), + self._analyzer.plan_builder.query( + query, source_plan=None, params=params + ), ) set_api_call_source(d, "Session.sql") return d @@ -1993,24 +2035,31 @@ def create_dataframe( # check to see if it is a Pandas DataFrame and if so, write that to a temp # table and return as a DataFrame + origin_data = data if installed_pandas and isinstance(data, pandas.DataFrame): - table_name = escape_quotes( + temp_table_name = escape_quotes( random_name_for_temp_object(TempObjectType.TABLE) ) - sf_database = self._conn._get_current_parameter("database", quoted=False) - sf_schema = self._conn._get_current_parameter("schema", quoted=False) - - t = self.write_pandas( - data, - table_name, - database=sf_database, - schema=sf_schema, - quote_identifiers=True, - auto_create_table=True, - table_type="temporary", - ) - set_api_call_source(t, "Session.create_dataframe[pandas]") - return t + if isinstance(self._conn, MockServerConnection): + schema, data = _extract_schema_and_data_from_pandas_df(data) + # we do not return here as live connection and keep using the data frame logic and compose table + else: + sf_database = self._conn._get_current_parameter( + "database", quoted=False + ) + sf_schema = self._conn._get_current_parameter("schema", quoted=False) + + t = self.write_pandas( + data, + temp_table_name, + database=sf_database, + schema=sf_schema, + quote_identifiers=True, + auto_create_table=True, + table_type="temporary", + ) + set_api_call_source(t, "Session.create_dataframe[pandas]") + return t # infer the schema based on the data names = None @@ -2175,8 +2224,8 @@ def convert_row_to_list( if self.sql_simplifier_enabled: df = DataFrame( self, - SelectStatement( - from_=SelectSnowflakePlan( + self._analyzer.create_select_statement( + from_=self._analyzer.create_select_snowflake_plan( SnowflakeValues(attrs, converted), analyzer=self._analyzer ), analyzer=self._analyzer, @@ -2187,6 +2236,14 @@ def convert_row_to_list( project_columns ) set_api_call_source(df, "Session.create_dataframe[values]") + + if ( + installed_pandas + and isinstance(origin_data, pandas.DataFrame) + and isinstance(self._conn, MockServerConnection) + ): + return _convert_dataframe_to_table(df, temp_table_name, self) + return df def range(self, start: int, end: Optional[int] = None, step: int = 1) -> DataFrame: @@ -2215,8 +2272,10 @@ def range(self, start: int, end: Optional[int] = None, step: int = 1) -> DataFra if self.sql_simplifier_enabled: df = DataFrame( self, - SelectStatement( - from_=SelectSnowflakePlan(range_plan, analyzer=self._analyzer), + self._analyzer.create_select_statement( + from_=self._analyzer.create_select_snowflake_plan( + range_plan, analyzer=self._analyzer + ), analyzer=self._analyzer, ), ) @@ -2241,6 +2300,10 @@ def create_async_job(self, query_id: str) -> AsyncJob: raise NotImplementedError( "Async query is not supported in stored procedure yet" ) + if isinstance(self._conn, MockServerConnection): + raise NotImplementedError( + "[Local Testing] Async query is currently not supported." + ) return AsyncJob(query_id, None, self) def get_current_account(self) -> Optional[str]: @@ -2386,6 +2449,8 @@ def udf(self) -> UDFRegistration: Returns a :class:`udf.UDFRegistration` object that you can use to register UDFs. See details of how to use this object in :class:`udf.UDFRegistration`. """ + if isinstance(self._conn, MockServerConnection): + raise NotImplementedError("[Local Testing] UDF is not currently supported.") return self._udf_registration @property @@ -2394,6 +2459,10 @@ def udtf(self) -> UDTFRegistration: Returns a :class:`udtf.UDTFRegistration` object that you can use to register UDTFs. See details of how to use this object in :class:`udtf.UDTFRegistration`. """ + if isinstance(self._conn, MockServerConnection): + raise NotImplementedError( + "[Local Testing] UDTF is not currently supported." + ) return self._udtf_registration @property @@ -2411,6 +2480,10 @@ def sproc(self) -> StoredProcedureRegistration: Returns a :class:`stored_procedure.StoredProcedureRegistration` object that you can use to register stored procedures. See details of how to use this object in :class:`stored_procedure.StoredProcedureRegistration`. """ + if isinstance(self, MockServerConnection): + raise NotImplementedError( + "[Local Testing] Stored procedures are not currently supported." + ) return self._sp_registration def _infer_is_return_table( @@ -2598,7 +2671,8 @@ def flatten( - :meth:`DataFrame.flatten`, which creates a new :class:`DataFrame` by exploding a VARIANT column of an existing :class:`DataFrame`. - :meth:`Session.table_function`, which can be used for any Snowflake table functions, including ``flatten``. """ - + if isinstance(self._conn, MockServerConnection): + raise NotImplementedError("[Local Testing] flatten is not implemented.") mode = mode.upper() if mode not in ("OBJECT", "ARRAY", "BOTH"): raise ValueError("mode must be one of ('OBJECT', 'ARRAY', 'BOTH')") diff --git a/src/snowflake/snowpark/table.py b/src/snowflake/snowpark/table.py index 7a151f4c7e3..8142153640a 100644 --- a/src/snowflake/snowpark/table.py +++ b/src/snowflake/snowpark/table.py @@ -4,14 +4,11 @@ # import sys +from logging import getLogger from typing import Dict, List, NamedTuple, Optional, Union, overload import snowflake.snowpark from snowflake.snowpark._internal.analyzer.binary_plan_node import create_join_type -from snowflake.snowpark._internal.analyzer.select_statement import ( - SelectableEntity, - SelectStatement, -) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import UnresolvedRelation from snowflake.snowpark._internal.analyzer.table_merge_expression import ( DeleteMergeExpression, @@ -21,6 +18,7 @@ TableUpdate, UpdateMergeExpression, ) +from snowflake.snowpark._internal.analyzer.unary_plan_node import Sample from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.telemetry import add_api_call, set_api_call_source from snowflake.snowpark._internal.type_utils import ColumnOrLiteral @@ -36,6 +34,8 @@ else: from collections.abc import Iterable +_logger = getLogger(__name__) + class UpdateResult(NamedTuple): """Result of updating rows in a :class:`Table`.""" @@ -274,8 +274,10 @@ def __init__( self.table_name: str = table_name #: The table name if self._session.sql_simplifier_enabled: - self._select_statement = SelectStatement( - from_=SelectableEntity(table_name, analyzer=session._analyzer), + self._select_statement = session._analyzer.create_select_statement( + from_=session._analyzer.create_selectable_entity( + table_name, analyzer=session._analyzer + ), analyzer=session._analyzer, ) # By default, the set the initial API call to say 'Table.__init__' since @@ -297,7 +299,7 @@ def sample( frac: Optional[float] = None, n: Optional[int] = None, *, - seed: Optional[float] = None, + seed: Optional[int] = None, sampling_method: Optional[str] = None, ) -> "DataFrame": """Samples rows based on either the number of rows to be returned or a percentage of rows to be returned. @@ -335,6 +337,26 @@ def sample( f"'sampling_method' value {sampling_method} must be None or one of 'BERNOULLI', 'ROW', 'SYSTEM', or 'BLOCK'." ) + from snowflake.snowpark.mock.connection import MockServerConnection + + if isinstance(self._session._conn, MockServerConnection): + if sampling_method in ("SYSTEM", "BLOCK"): + _logger.warning( + "[Local Testing] SYSTEM/BLOCK sampling is not supported for Local Testing, falling back to ROW sampling" + ) + + sample_plan = Sample( + self._plan, probability_fraction=frac, row_count=n, seed=seed + ) + return self._with_plan( + self._session._analyzer.create_select_statement( + from_=self._session._analyzer.create_select_snowflake_plan( + sample_plan, analyzer=self._session._analyzer + ), + analyzer=self._session._analyzer, + ) + ) + # The analyzer will generate a sql with subquery. So we build the sql directly without using the analyzer. sampling_method_text = sampling_method or "" frac_or_rowcount_text = str(frac * 100.0) if frac is not None else f"{n} ROWS" @@ -664,6 +686,12 @@ def drop_table(self) -> None: Note that subsequent operations such as :meth:`DataFrame.select`, :meth:`DataFrame.collect` on this ``Table`` instance and the derived DataFrame will raise errors because the underlying table in the Snowflake database no longer exists. """ - self._session.sql( - f"drop table {self.table_name}" - )._internal_collect_with_tag_no_telemetry() + from snowflake.snowpark.mock.connection import MockServerConnection + + if isinstance(self._session._conn, MockServerConnection): + # only mock connection has entity_registry + self._session._conn.entity_registry.drop_table(self.table_name) + else: + self._session.sql( + f"drop table {self.table_name}" + )._internal_collect_with_tag_no_telemetry() diff --git a/tests/conftest.py b/tests/conftest.py index 718f1aeb9b1..a337eb86346 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ def pytest_addoption(parser): parser.addoption("--disable_sql_simplifier", action="store_true", default=False) + parser.addoption("--local_testing_mode", action="store_true", default=False) def pytest_collection_modifyitems(items) -> None: @@ -40,3 +41,8 @@ def pytest_collection_modifyitems(items) -> None: def sql_simplifier_enabled(pytestconfig): disable_sql_simplifier = pytestconfig.getoption("disable_sql_simplifier") return not disable_sql_simplifier + + +@pytest.fixture(scope="session") +def local_testing_mode(pytestconfig): + return pytestconfig.getoption("local_testing_mode") diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index acf43f1c5e0..6e276fca656 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -11,6 +11,7 @@ import snowflake.connector from snowflake.snowpark import Session +from snowflake.snowpark.mock.connection import MockServerConnection from tests.parameters import CONNECTION_PARAMETERS from tests.utils import Utils @@ -67,25 +68,30 @@ def resources_path() -> str: @pytest.fixture(scope="session") -def connection(db_parameters): - _keys = [ - "user", - "password", - "host", - "port", - "database", - "account", - "protocol", - "role", - ] - with snowflake.connector.connect( - **{k: db_parameters[k] for k in _keys if k in db_parameters} - ) as con: - yield con +def connection(db_parameters, local_testing_mode): + if local_testing_mode: + yield MockServerConnection() + else: + _keys = [ + "user", + "password", + "host", + "port", + "database", + "account", + "protocol", + "role", + ] + with snowflake.connector.connect( + **{k: db_parameters[k] for k in _keys if k in db_parameters} + ) as con: + yield con @pytest.fixture(scope="session") -def is_sample_data_available(connection) -> bool: +def is_sample_data_available(connection, local_testing_mode) -> bool: + if local_testing_mode: + return False return ( len( connection.cursor() @@ -97,34 +103,46 @@ def is_sample_data_available(connection) -> bool: @pytest.fixture(scope="session", autouse=True) -def test_schema(connection) -> None: +def test_schema(connection, local_testing_mode) -> None: """Set up and tear down the test schema. This is automatically called per test session.""" - with connection.cursor() as cursor: - cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {TEST_SCHEMA}") - # This is needed for test_get_schema_database_works_after_use_role in test_session_suite - cursor.execute(f"GRANT ALL PRIVILEGES ON SCHEMA {TEST_SCHEMA} TO ROLE PUBLIC") + if local_testing_mode: yield - cursor.execute(f"DROP SCHEMA IF EXISTS {TEST_SCHEMA}") + else: + with connection.cursor() as cursor: + cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {TEST_SCHEMA}") + # This is needed for test_get_schema_database_works_after_use_role in test_session_suite + cursor.execute( + f"GRANT ALL PRIVILEGES ON SCHEMA {TEST_SCHEMA} TO ROLE PUBLIC" + ) + yield + cursor.execute(f"DROP SCHEMA IF EXISTS {TEST_SCHEMA}") @pytest.fixture(scope="module") -def session(db_parameters, resources_path, sql_simplifier_enabled): - session = Session.builder.configs(db_parameters).create() +def session(db_parameters, resources_path, sql_simplifier_enabled, local_testing_mode): + session = ( + Session.builder.configs(db_parameters) + .config("local_testing", local_testing_mode) + .create() + ) session.sql_simplifier_enabled = sql_simplifier_enabled yield session session.close() @pytest.fixture(scope="module") -def temp_schema(connection, session) -> None: +def temp_schema(connection, session, local_testing_mode) -> None: """Set up and tear down a temp schema for cross-schema test. This is automatically called per test module.""" temp_schema_name = Utils.get_fully_qualified_temp_schema(session) - with connection.cursor() as cursor: - cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {temp_schema_name}") - # This is needed for test_get_schema_database_works_after_use_role in test_session_suite - cursor.execute( - f"GRANT ALL PRIVILEGES ON SCHEMA {temp_schema_name} TO ROLE PUBLIC" - ) + if local_testing_mode: yield temp_schema_name - cursor.execute(f"DROP SCHEMA IF EXISTS {temp_schema_name}") + else: + with connection.cursor() as cursor: + cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {temp_schema_name}") + # This is needed for test_get_schema_database_works_after_use_role in test_session_suite + cursor.execute( + f"GRANT ALL PRIVILEGES ON SCHEMA {temp_schema_name} TO ROLE PUBLIC" + ) + yield temp_schema_name + cursor.execute(f"DROP SCHEMA IF EXISTS {temp_schema_name}") diff --git a/tests/integ/scala/test_async_job_suite.py b/tests/integ/scala/test_async_job_suite.py index cb5549d390b..97f93f98935 100644 --- a/tests/integ/scala/test_async_job_suite.py +++ b/tests/integ/scala/test_async_job_suite.py @@ -37,6 +37,12 @@ test_file_csv = "testCSV.csv" tmp_stage_name1 = Utils.random_stage_name() +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) + def test_async_collect_common(session): df = session.create_dataframe( @@ -366,6 +372,10 @@ def test_async_place_holder(session): @pytest.mark.skipif(not is_pandas_available, reason="Pandas is not available") +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Use of _execute_and_get_query_id", +) @pytest.mark.parametrize("create_async_job_from_query_id", [True, False]) def test_create_async_job(session, create_async_job_from_query_id): df = session.range(3) @@ -453,6 +463,10 @@ def test_get_query_from_async_job_negative(session, caplog): assert "result is empty" in caplog.text +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Use of _execute_and_get_query_id", +) @pytest.mark.parametrize("create_async_job_from_query_id", [True, False]) def test_async_job_to_df(session, create_async_job_from_query_id): df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) diff --git a/tests/integ/scala/test_column_suite.py b/tests/integ/scala/test_column_suite.py index 4c73827a064..0194f887e75 100644 --- a/tests/integ/scala/test_column_suite.py +++ b/tests/integ/scala/test_column_suite.py @@ -2,8 +2,9 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # - +import datetime import math +from decimal import Decimal import pytest @@ -16,10 +17,26 @@ SnowparkSQLUnexpectedAliasException, ) from snowflake.snowpark.functions import avg, col, in_, lit, parse_json, sql_expr, when -from snowflake.snowpark.types import StringType +from snowflake.snowpark.types import ( + BinaryType, + BooleanType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, + StringType, + StructField, + StructType, + TimestampType, + TimeType, +) from tests.utils import TestData, Utils +@pytest.mark.localtest def test_column_names_with_space(session): c1 = '"name with space"' c2 = '"name.with.dot"' @@ -33,6 +50,7 @@ def test_column_names_with_space(session): assert df.select(df[c2]).collect() == [Row("a")] +@pytest.mark.localtest def test_column_alias_and_case_insensitive_name(session): df = session.create_dataframe([1, 2]).to_df(["a"]) assert df.select(df["a"].as_("b")).schema.fields[0].name == "B" @@ -40,6 +58,7 @@ def test_column_alias_and_case_insensitive_name(session): assert df.select(df["a"].name("b")).schema.fields[0].name == "B" +@pytest.mark.localtest def test_column_alias_and_case_sensitive_name(session): df = session.create_dataframe([1, 2]).to_df(["a"]) assert df.select(df["a"].as_('"b"')).schema.fields[0].name == '"b"' @@ -47,6 +66,11 @@ def test_column_alias_and_case_sensitive_name(session): assert df.select(df["a"].name('"b"')).schema.fields[0].name == '"b"' +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_unary_operator(session): test_data1 = TestData.test_data1(session) # unary minus @@ -58,6 +82,7 @@ def test_unary_operator(session): ] +@pytest.mark.localtest def test_alias(session): test_data1 = TestData.test_data1(session) assert test_data1.select(test_data1["NUM"]).schema.fields[0].name == "NUM" @@ -75,6 +100,7 @@ def test_alias(session): ) +@pytest.mark.localtest def test_equal_and_not_equal(session): test_data1 = TestData.test_data1(session) assert test_data1.where(test_data1["BOOL"] == True).collect() == [ # noqa: E712 @@ -97,6 +123,7 @@ def test_equal_and_not_equal(session): ] +@pytest.mark.localtest def test_gt_and_lt(session): test_data1 = TestData.test_data1(session) assert test_data1.where(test_data1["NUM"] > 1).collect() == [Row(2, False, "b")] @@ -107,6 +134,7 @@ def test_gt_and_lt(session): assert test_data1.where(test_data1["NUM"] < lit(2)).collect() == [Row(1, True, "a")] +@pytest.mark.localtest def test_leq_and_geq(session): test_data1 = TestData.test_data1(session) assert test_data1.where(test_data1["NUM"] >= 2).collect() == [Row(2, False, "b")] @@ -125,8 +153,9 @@ def test_leq_and_geq(session): ] +@pytest.mark.localtest def test_null_safe_operators(session): - df = session.sql("select * from values(null, 1),(2, 2),(null, null) as T(a,b)") + df = session.create_dataframe([[None, 1], [2, 2], [None, None]], schema=["a", "b"]) assert df.select(df["A"].equal_null(df["B"])).collect() == [ Row(False), Row(True), @@ -134,9 +163,10 @@ def test_null_safe_operators(session): ] +@pytest.mark.localtest def test_nan_and_null(session): - df = session.sql( - "select * from values(1.1,1),(null,2),('NaN' :: Float,3) as T(a, b)" + df = session.create_dataframe( + [[1.1, 1], [None, 2], [math.nan, 3]], schema=["a", "b"] ) res = df.where(df["A"].equal_nan()).collect() assert len(res) == 1 @@ -150,9 +180,10 @@ def test_nan_and_null(session): assert res_row2[1] == 3 +@pytest.mark.localtest def test_and_or(session): - df = session.sql( - "select * from values(true,true),(true,false),(false,true),(false,false) as T(a, b)" + df = session.create_dataframe( + [[True, True], [True, False], [False, True], [False, False]], schema=["a", "b"] ) assert df.where(df["A"] & df["B"]).collect() == [Row(True, True)] assert df.where(df["A"] | df["B"]).collect() == [ @@ -162,8 +193,12 @@ def test_and_or(session): ] +@pytest.mark.xfail( + reason="Divide is expected to return decimal instead of float", + raises=AttributeError, +) def test_add_subtract_multiply_divide_mod_pow(session): - df = session.sql("select * from values(11, 13) as T(a, b)") + df = session.create_dataframe([[11, 13]], schema=["a", "b"]) assert df.select(df["A"] + df["B"]).collect() == [Row(24)] assert df.select(df["A"] - df["B"]).collect() == [Row(-2)] assert df.select(df["A"] * df["B"]).collect() == [Row(143)] @@ -186,15 +221,17 @@ def test_add_subtract_multiply_divide_mod_pow(session): assert res[0][0].to_eng_string() == "0.153846" +@pytest.mark.localtest def test_cast(session): test_data1 = TestData.test_data1(session) sc = test_data1.select(test_data1["NUM"].cast(StringType())).schema assert len(sc.fields) == 1 - assert sc.fields[0].column_identifier == '"CAST (""NUM"" AS STRING)"' + assert sc.fields[0].name == '"CAST (""NUM"" AS STRING)"' assert type(sc.fields[0].datatype) == StringType assert not sc.fields[0].nullable +@pytest.mark.localtest def test_order(session): null_data1 = TestData.null_data1(session) assert null_data1.sort(null_data1["A"].asc()).collect() == [ @@ -241,13 +278,19 @@ def test_order(session): ] +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_bitwise_operator(session): - df = session.sql("select * from values(1, 2) as T(a, b)") + df = session.create_dataframe([[1, 2]], schema=["a", "b"]) assert df.select(df["A"].bitand(df["B"])).collect() == [Row(0)] assert df.select(df["A"].bitor(df["B"])).collect() == [Row(3)] assert df.select(df["A"].bitxor(df["B"])).collect() == [Row(3)] +@pytest.mark.localtest def test_withcolumn_with_special_column_names(session): # Ensure that One and "One" are different column names Utils.check_answer( @@ -290,6 +333,7 @@ def test_withcolumn_with_special_column_names(session): ) +@pytest.mark.localtest def test_toDF_with_special_column_names(session): assert ( session.create_dataframe([[1]]).to_df(["ONE"]).schema @@ -317,6 +361,7 @@ def test_toDF_with_special_column_names(session): ) +@pytest.mark.localtest def test_column_resolution_with_different_kins_of_names(session): df = session.create_dataframe([[1]]).to_df(["One"]) assert df.select(df["one"]).collect() == [Row(1)] @@ -341,6 +386,7 @@ def test_column_resolution_with_different_kins_of_names(session): df.col('"ONE ONE"') +@pytest.mark.localtest def test_drop_columns_by_string(session): df = session.create_dataframe([[1, 2]]).to_df(["One", '"One"']) assert df.drop("one").schema.fields[0].name == '"One"' @@ -356,6 +402,7 @@ def test_drop_columns_by_string(session): assert "Cannot drop all columns" in str(ex_info) +@pytest.mark.localtest def test_drop_columns_by_column(session): df = session.create_dataframe([[1, 2]]).to_df(["One", '"One"']) assert df.drop(col("one")).schema.fields[0].name == '"One"' @@ -378,6 +425,11 @@ def test_drop_columns_by_column(session): assert df.drop(df2["one"]).schema.fields[0].name == '"One"' +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_fully_qualified_column_name(session): random_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) schema = "{}.{}".format( @@ -408,6 +460,7 @@ def test_fully_qualified_column_name(session): session._run_query(f"drop function if exists {schema}.{udf_name}(integer)") +@pytest.mark.localtest def test_column_names_with_quotes(session): df = session.create_dataframe([[1, 2, 3]]).to_df('col"', '"col"', '"""col"') assert df.select(col('col"')).collect() == [Row(1)] @@ -426,6 +479,7 @@ def test_column_names_with_quotes(session): assert "Invalid identifier" in str(ex_info) +@pytest.mark.localtest def test_column_constructors_col(session): df = session.create_dataframe([[1, 2, 3]]).to_df("col", '"col"', "col .") assert df.select(col("col")).collect() == [Row(1)] @@ -446,6 +500,7 @@ def test_column_constructors_col(session): assert "invalid identifier" in str(ex_info) +@pytest.mark.localtest def test_column_constructors_select(session): df = session.create_dataframe([[1, 2, 3]]).to_df("col", '"col"', "col .") assert df.select("col").collect() == [Row(1)] @@ -463,6 +518,11 @@ def test_column_constructors_select(session): assert "invalid identifier" in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_sql_expr_column(session): df = session.create_dataframe([[1, 2, 3]]).to_df("col", '"col"', "col .") assert df.select(sql_expr("col")).collect() == [Row(1)] @@ -489,16 +549,30 @@ def test_sql_expr_column(session): assert "syntax error" in str(ex_info) -def test_errors_for_aliased_columns(session): +@pytest.mark.localtest +def test_errors_for_aliased_columns(session, local_testing_mode): df = session.create_dataframe([[1]]).to_df("c") - with pytest.raises(SnowparkSQLUnexpectedAliasException) as ex_info: + # TODO: align exc experience between local testing and snowflake + exc = ( + SnowparkSQLUnexpectedAliasException + if not local_testing_mode + else SnowparkSQLException + ) + with pytest.raises(exc) as ex_info: df.select(col("a").as_("b") + 10).collect() - assert "You can only define aliases for the root" in str(ex_info) - with pytest.raises(SnowparkSQLUnexpectedAliasException) as ex_info: + if not local_testing_mode: + assert "You can only define aliases for the root" in str(ex_info) + else: + assert "invalid identifier" in str(ex_info) + with pytest.raises(exc) as ex_info: df.group_by(col("a")).agg(avg(col("a").as_("b"))).collect() - assert "You can only define aliases for the root" in str(ex_info) + if not local_testing_mode: + assert "You can only define aliases for the root" in str(ex_info) + else: + assert "invalid identifier" in str(ex_info) +@pytest.mark.localtest def test_like(session): assert TestData.string4(session).where(col("A").like(lit("%p%"))).collect() == [ Row("apple"), @@ -514,21 +588,31 @@ def test_like(session): assert TestData.string4(session).where(col("A").like("")).collect() == [] -def test_subfield(session): +@pytest.mark.localtest +def test_subfield(session, local_testing_mode): assert TestData.null_json1(session).select(col("v")["a"]).collect() == [ Row("null"), Row('"foo"'), Row(None), ] - assert TestData.array2(session).select(col("arr1")[0]).collect() == [ - Row("1"), - Row("6"), - ] - assert TestData.array2(session).select(parse_json(col("f"))[0]["a"]).collect() == [ - Row("1"), - Row("1"), - ] + if not local_testing_mode: + assert TestData.array2(session).select(col("arr1")[0]).collect() == [ + Row("1"), + Row("6"), + ] + assert TestData.array2(session).select( + parse_json(col("f"))[0]["a"] + ).collect() == [ + Row("1"), + Row("1"), + ] + else: + # TODO: function array_construct is not supported in local testing + # we use the array in variant2 for testing purpose + assert TestData.variant2(session).select( + col("src")["vehicle"][0]["extras"][1] + ).collect() == [Row('"paint protection"')] # Row name is not case-sensitive. field name is case-sensitive assert TestData.variant2(session).select( @@ -555,6 +639,7 @@ def test_subfield(session): ).collect() == [Row(None)] +@pytest.mark.localtest def test_regexp(session): assert TestData.string4(session).where(col("a").regexp(lit("ap.le"))).collect() == [ Row("apple") @@ -569,6 +654,11 @@ def test_regexp(session): assert "Invalid regular expression" in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.parametrize("spec", ["en_US-trim", "'en_US-trim'"]) def test_collate(session, spec): Utils.check_answer( @@ -577,12 +667,14 @@ def test_collate(session, spec): ) +@pytest.mark.localtest def test_get_column_name(session): assert TestData.integer1(session).col("a").getName() == '"A"' assert not (col("col") > 100).getName() -def test_when_case(session): +@pytest.mark.localtest +def test_when_case(session, local_testing_mode): assert TestData.null_data1(session).select( when(col("a").is_null(), lit(5)) .when(col("a") == 1, lit(6)) @@ -606,14 +698,17 @@ def test_when_case(session): TestData.null_data1(session).select( when(col("a").is_null(), lit("a")).when(col("a") == 1, lit(6)).as_("a") ).collect() - assert "Numeric value 'a' is not recognized" in str(ex_info) + if not local_testing_mode: + assert "Numeric value 'a' is not recognized" in str(ex_info) +@pytest.mark.localtest def test_lit_contains_single_quote(session): df = session.create_dataframe([[1, "'"], [2, "''"]]).to_df(["a", "b"]) assert df.where(col("b") == "'").collect() == [Row(1, "'")] +@pytest.mark.localtest def test_in_expression_1_in_with_constant_value_list(session): df = session.create_dataframe( [[1, "a", 1, 1], [2, "b", 2, 2], [3, "b", 33, 33]] @@ -645,6 +740,7 @@ def test_in_expression_1_in_with_constant_value_list(session): Utils.check_answer([Row(False), Row(False), Row(True)], df4, sort=False) +@pytest.mark.localtest def test_in_expression_2_in_with_subquery(session): df0 = session.create_dataframe([[1], [2], [5]]).to_df(["a"]) df = session.create_dataframe( @@ -668,6 +764,83 @@ def test_in_expression_2_in_with_subquery(session): Utils.check_answer(df4, [Row(False), Row(True), Row(True)]) +@pytest.mark.localtest +def test_in_expression_3_with_all_types(session, local_testing_mode): + schema = StructType( + [ + StructField("id", LongType()), + StructField("string", StringType()), + StructField("byte", BinaryType()), + StructField("short", ShortType()), + StructField("int", IntegerType()), + StructField("float", FloatType()), + StructField("double", DoubleType()), + StructField("decimal", DecimalType(10, 3)), + StructField("boolean", BooleanType()), + StructField("timestamp", TimestampType()), + StructField("date", DateType()), + StructField("time", TimeType()), + ] + ) + now = datetime.datetime.now() + utcnow = datetime.datetime.utcnow() + + first_row = [ + 1, + "one", + b"123", + 123, + 123, + 12.34, + 12.34, + Decimal("1.234"), + True, + now, + datetime.date(1989, 12, 7), + datetime.time(11, 11, 11), + ] + second_row = [ + 2, + "two", + b"456", + 456, + 456, + 45.67, + 45.67, + Decimal("4.567"), + False, + utcnow, + datetime.date(2018, 10, 31), + datetime.time(23, 23, 23), + ] + + df = session.create_dataframe([first_row, second_row], schema=schema) + if local_testing_mode: + # There seems to be a bug in live connection with timestamp precision + Utils.check_answer( + df.filter( + col("id").isin([1]) + & col("string").isin(["one"]) + & col("byte").isin([b"123"]) + & col("short").isin([123]) + & col("int").isin([123]) + & col("float").isin([12.34]) + & col("double").isin([12.34]) + & col("decimal").isin([Decimal("1.234")]) + & col("boolean").isin([True]) + & col("timestamp").isin([now]) + & col("date").isin([datetime.date(1989, 12, 7)]) + & col("time").isin([datetime.time(11, 11, 11)]) + ), + [first_row], + ) + Utils.check_answer(df.filter(col("timestamp").isin([utcnow])), [second_row]) + Utils.check_answer(df.filter(col("decimal").isin([Decimal("1.234")])), [first_row]) + Utils.check_answer(df.filter(col("id").isin([2])), [second_row]) + Utils.check_answer(df.filter(col("string").isin(["three"])), []) + + +@pytest.mark.localtest def test_in_expression_4_negative_test_to_input_column_in_value_list(session): df = session.create_dataframe( [[1, "a", 1, 1], [2, "b", 2, 2], [3, "b", 33, 33]] @@ -701,6 +874,7 @@ def test_in_expression_4_negative_test_to_input_column_in_value_list(session): ) +@pytest.mark.localtest def test_in_expression_5_negative_test_that_sub_query_has_multiple_columns(session): df = session.create_dataframe( [[1, "a", 1, 1], [2, "b", 2, 2], [3, "b", 33, 33]] @@ -712,14 +886,15 @@ def test_in_expression_5_negative_test_that_sub_query_has_multiple_columns(sessi assert "does not match the number of columns" in str(ex_info) +@pytest.mark.localtest def test_in_expression_6_multiple_columns_with_const_values(session): df = session.create_dataframe( - [[1, "a", 1, 1], [2, "b", 2, 2], [3, "b", 33, 33]] + [[1, "a", -1, 1], [2, "b", -2, 2], [3, "b", 33, 33]] ).to_df("a", "b", "c", "d") # filter without NOT df1 = df.filter(in_([col("a"), col("b")], [[1, "a"], [2, "b"], [3, "c"]])) - Utils.check_answer(df1, [Row(1, "a", 1, 1), Row(2, "b", 2, 2)]) + Utils.check_answer(df1, [Row(1, "a", -1, 1), Row(2, "b", -2, 2)]) # filter with NOT df2 = df.filter(~in_([col("a"), col("b")], [[1, "a"], [2, "b"], [3, "c"]])) @@ -727,17 +902,18 @@ def test_in_expression_6_multiple_columns_with_const_values(session): # select without NOT df3 = df.select( - in_([col("a"), col("c")], [[1, 1], [2, 2], [3, 3]]).as_("in_result") + in_([col("a"), col("c")], [[1, -1], [2, -2], [3, 3]]).as_("in_result") ) Utils.check_answer(df3, [Row(True), Row(True), Row(False)]) # select with NOT df4 = df.select( - ~in_([col("a"), col("c")], [[1, 1], [2, 2], [3, 3]]).as_("in_result") + ~in_([col("a"), col("c")], [[1, -1], [2, -2], [3, 3]]).as_("in_result") ) Utils.check_answer(df4, [Row(False), Row(False), Row(True)]) +@pytest.mark.localtest def test_in_expression_7_multiple_columns_with_sub_query(session): df0 = session.create_dataframe([[1, "a"], [2, "b"], [3, "c"]]).to_df("a", "b") df = session.create_dataframe( @@ -761,6 +937,7 @@ def test_in_expression_7_multiple_columns_with_sub_query(session): Utils.check_answer(df4, [Row(False), Row(False), Row(True)]) +@pytest.mark.localtest def test_in_expression_8_negative_test_to_input_column_in_value_list(session): df = session.create_dataframe( [[1, "a", 1, 1], [2, "b", 2, 2], [3, "b", 33, 33]] @@ -776,6 +953,7 @@ def test_in_expression_8_negative_test_to_input_column_in_value_list(session): ) +@pytest.mark.localtest def test_in_expression_9_negative_test_for_the_column_count_doesnt_match_the_value_list( session, ): @@ -794,6 +972,7 @@ def test_in_expression_9_negative_test_for_the_column_count_doesnt_match_the_val assert "does not match the number of columns" in str(ex_info) +@pytest.mark.localtest def test_in_expression_with_multiple_queries(session): from snowflake.snowpark._internal.analyzer import analyzer diff --git a/tests/integ/scala/test_complex_dataframe_suite.py b/tests/integ/scala/test_complex_dataframe_suite.py index 54f20914c85..9ef7cf0872f 100644 --- a/tests/integ/scala/test_complex_dataframe_suite.py +++ b/tests/integ/scala/test_complex_dataframe_suite.py @@ -17,6 +17,7 @@ from tests.utils import IS_IN_STORED_PROC_LOCALFS, TestFiles, Utils +@pytest.mark.localtest def test_combination_of_multiple_operators(session): df1 = session.create_dataframe([1, 2]).to_df("a") df2 = session.create_dataframe([[i, f"test{i}"] for i in [1, 2]]).to_df("a", "b") @@ -47,6 +48,7 @@ def test_combination_of_multiple_operators(session): ] +@pytest.mark.localtest def test_combination_of_multiple_operators_with_filters(session): df1 = session.create_dataframe([i for i in range(1, 11)]).to_df("a") df2 = session.create_dataframe([[i, f"test{i}"] for i in range(1, 11)]).to_df( @@ -78,6 +80,7 @@ def test_combination_of_multiple_operators_with_filters(session): assert df.collect() == [Row(i, f"test{i}") for i in range(1, 11)] +@pytest.mark.localtest def test_join_on_top_of_unions(session): df1 = session.create_dataframe([i for i in range(1, 6)]).to_df("a") df2 = session.create_dataframe([i for i in range(6, 11)]).to_df("a") @@ -92,6 +95,11 @@ def test_join_on_top_of_unions(session): assert res == [Row(i, f"test{i}") for i in range(1, 11)] +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="need resources") def test_combination_of_multiple_data_sources(session, resources_path): test_files = TestFiles(resources_path) diff --git a/tests/integ/scala/test_dataframe_aggregate_suite.py b/tests/integ/scala/test_dataframe_aggregate_suite.py index 4711344e505..f6e88918e7f 100644 --- a/tests/integ/scala/test_dataframe_aggregate_suite.py +++ b/tests/integ/scala/test_dataframe_aggregate_suite.py @@ -132,6 +132,7 @@ def test_pivot_on_join(session): ) +@pytest.mark.localtest def test_rel_grouped_dataframe_agg(session): df = ( session.create_dataframe([[1, "One"], [2, "Two"], [3, "Three"]]) @@ -156,6 +157,38 @@ def test_rel_grouped_dataframe_agg(session): ] +@pytest.mark.localtest +def test_group_by(session): + result = ( + TestData.nurse(session) + .group_by("medical_license") + .agg(count(col("*")).as_("count")) + .with_column("radio_license", lit(None)) + .select("medical_license", "radio_license", "count") + .union_all( + TestData.nurse(session) + .group_by("radio_license") + .agg(count(col("*")).as_("count")) + .with_column("medical_license", lit(None)) + .select("medical_license", "radio_license", "count") + ) + .sort(col("count")) + .collect() + ) + Utils.check_answer( + result, + [ + Row(None, "General", 1), + Row(None, "Amateur Extra", 1), + Row("RN", None, 2), + Row(None, "Technician", 2), + Row(None, None, 3), + Row("LVN", None, 5), + ], + sort=False, + ) + + def test_group_by_grouping_sets(session): result = ( TestData.nurse(session) @@ -257,6 +290,7 @@ def test_group_by_grouping_sets(session): ) +@pytest.mark.localtest def test_rel_grouped_dataframe_max(session): df1 = session.create_dataframe( [("a", 1, 11, "b"), ("b", 2, 22, "c"), ("a", 3, 33, "d"), ("b", 4, 44, "e")] @@ -275,6 +309,7 @@ def test_rel_grouped_dataframe_max(session): assert df1.group_by("key").agg([max("value1"), max("value2")]).collect() == expected +@pytest.mark.localtest def test_rel_grouped_dataframe_avg_mean(session): df1 = session.create_dataframe( [("a", 1, 11, "b"), ("b", 2, 22, "c"), ("a", 3, 33, "d"), ("b", 4, 44, "e")] @@ -303,6 +338,7 @@ def test_rel_grouped_dataframe_avg_mean(session): ) +@pytest.mark.localtest def test_rel_grouped_dataframe_median(session): df1 = session.create_dataframe( [ @@ -337,6 +373,7 @@ def test_rel_grouped_dataframe_median(session): ) +@pytest.mark.localtest def test_builtin_functions(session): df = session.create_dataframe([(1, 11), (2, 12), (1, 13)]).to_df(["a", "b"]) @@ -350,6 +387,7 @@ def test_builtin_functions(session): ] +@pytest.mark.localtest def test_non_empty_arg_functions(session): func_name = "avg" with pytest.raises(ValueError) as ex_info: @@ -392,6 +430,7 @@ def test_non_empty_arg_functions(session): ) +@pytest.mark.localtest def test_null_count(session): assert TestData.test_data3(session).group_by("a").agg( count(col("b")) @@ -419,6 +458,7 @@ def test_null_count(session): ).collect() == [Row(1, 1, 2)] +@pytest.mark.localtest def test_distinct(session): df = session.create_dataframe( [(1, "one", 1.0), (2, "one", 2.0), (2, "two", 1.0)] @@ -443,6 +483,7 @@ def test_distinct(session): assert df.filter(col("i") < 0).distinct().collect() == [] +@pytest.mark.localtest def test_distinct_and_joins(session): lhs = session.create_dataframe([(1, "one", 1.0), (2, "one", 2.0)]).to_df( "i", "s", '"i"' @@ -472,6 +513,7 @@ def test_distinct_and_joins(session): assert res == [Row("one", "one")] +@pytest.mark.localtest def test_groupBy(session): assert TestData.test_data2(session).group_by("a").agg(sum(col("b"))).collect() == [ Row(1, 3), @@ -523,6 +565,7 @@ def test_groupBy(session): ] +@pytest.mark.localtest def test_agg_should_be_order_preserving(session): df = ( session.range(2) @@ -539,6 +582,7 @@ def test_agg_should_be_order_preserving(session): assert df.collect() == [Row(0, 0, 1, 0), Row(1, 1, 1, 1)] +@pytest.mark.localtest def test_count(session): assert TestData.test_data2(session).agg( [count(col("a")), sum_distinct(col("a"))] @@ -661,6 +705,9 @@ def test_sn_null_moments(session): ) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="sql is not supported" +) def test_decimal_sum_over_window_should_work(session): assert session.sql( "select sum(a) over () from values (1.0), (2.0), (3.0) T(a)" @@ -670,12 +717,14 @@ def test_decimal_sum_over_window_should_work(session): ).collect() == [Row(2.0), Row(2.0), Row(2.0)] +@pytest.mark.localtest def test_aggregate_function_in_groupby(session): with pytest.raises(SnowparkSQLException) as ex_info: TestData.test_data4(session).group_by(sum(col('"KEY"'))).count().collect() assert "is not a valid group by expression" in str(ex_info) +@pytest.mark.localtest def test_ints_in_agg_exprs_are_taken_as_groupby_ordinal(session): assert TestData.test_data2(session).group_by(lit(3), lit(4)).agg( [lit(6), lit(7), sum(col("b"))] @@ -685,6 +734,12 @@ def test_ints_in_agg_exprs_are_taken_as_groupby_ordinal(session): [lit(6), col("b"), sum(col("b"))] ).collect() == [Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6)] + +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="sql is not supported" +) +def test_ints_in_agg_exprs_are_taken_as_groupby_ordinal_sql(session): + testdata2str = "(SELECT * FROM VALUES (1,1),(1,2),(2,1),(2,2),(3,1),(3,2) T(a, b) )" assert session.sql( f"SELECT 3, 4, SUM(b) FROM {testdata2str} GROUP BY 1, 2" @@ -695,7 +750,8 @@ def test_ints_in_agg_exprs_are_taken_as_groupby_ordinal(session): ).collect() == [Row(3, 4, 9)] -def test_distinct_and_unions(session): +@pytest.mark.localtest +def test_distinct_and_unions(session: object) -> object: lhs = session.create_dataframe([(1, "one", 1.0), (2, "one", 2.0)]).to_df( "i", "s", '"i"' ) @@ -722,6 +778,7 @@ def test_distinct_and_unions(session): assert res == [Row("one")] +@pytest.mark.localtest def test_distinct_and_unionall(session): lhs = session.create_dataframe([(1, "one", 1.0), (2, "one", 2.0)]).to_df( "i", "s", '"i"' @@ -800,14 +857,17 @@ def test_count_if(session): session.sql(f"SELECT COUNT_IF(x) FROM {temp_view_name}").collect() +@pytest.mark.localtest def test_agg_without_groups(session): assert TestData.test_data2(session).agg(sum(col("b"))).collect() == [Row(9)] +@pytest.mark.localtest def test_agg_without_groups_and_functions(session): assert TestData.test_data2(session).agg(lit(1)).collect() == [Row(1)] +@pytest.mark.localtest def test_null_average(session): assert TestData.test_data3(session).agg(avg(col("b"))).collect() == [Row(2.0)] @@ -820,6 +880,7 @@ def test_null_average(session): ).collect() == [Row(2.0, 2.0)] +@pytest.mark.localtest def test_zero_average(session): df = session.create_dataframe([[]]).to_df(["a"]) assert df.agg(avg(col("a"))).collect() == [Row(None)] @@ -829,6 +890,7 @@ def test_zero_average(session): ] +@pytest.mark.localtest def test_multiple_column_distinct_count(session): df1 = session.create_dataframe( [ @@ -855,6 +917,7 @@ def test_multiple_column_distinct_count(session): assert res == [Row("a", 2), Row("x", 1)] +@pytest.mark.localtest def test_zero_count(session): empty_table = session.create_dataframe([[]]).to_df(["a"]) assert empty_table.agg([count(col("a")), sum_distinct(col("a"))]).collect() == [ @@ -869,16 +932,19 @@ def test_zero_stddev(session): ).collect() == [Row(None, None, None)] +@pytest.mark.localtest def test_zero_sum(session): df = session.create_dataframe([[]]).to_df(["a"]) assert df.agg([sum(col("a"))]).collect() == [Row(None)] +@pytest.mark.localtest def test_zero_sum_distinct(session): df = session.create_dataframe([[]]).to_df(["a"]) assert df.agg([sum_distinct(col("a"))]).collect() == [Row(None)] +@pytest.mark.localtest def test_limit_and_aggregates(session): df = session.create_dataframe([("a", 1), ("b", 2), ("c", 1), ("d", 5)]).to_df( "id", "value" @@ -889,6 +955,7 @@ def test_limit_and_aggregates(session): ) +@pytest.mark.localtest def test_listagg(session): df = session.create_dataframe( [ @@ -907,6 +974,20 @@ def test_listagg(session): # result is unpredictable without within group assert len(result) == 4 + +def test_listagg_within_group(session): + df = session.create_dataframe( + [ + (2, 1, 35, "red", 99), + (7, 2, 24, "red", 99), + (7, 9, 77, "green", 99), + (8, 5, 11, "green", 99), + (8, 4, 14, "blue", 99), + (8, 3, 21, "red", 99), + (9, 9, 12, "orange", 99), + ], + schema=["v1", "v2", "length", "color", "unused"], + ) Utils.check_answer( df.group_by("color").agg(listagg("length", ",").within_group(df.length.asc())), [ diff --git a/tests/integ/scala/test_dataframe_copy_into.py b/tests/integ/scala/test_dataframe_copy_into.py index 329a539a705..b0c03acaaa8 100644 --- a/tests/integ/scala/test_dataframe_copy_into.py +++ b/tests/integ/scala/test_dataframe_copy_into.py @@ -69,34 +69,52 @@ def create_df_for_file_format( return df +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", raises=NotImplementedError +) + + @pytest.fixture(scope="module") -def tmp_stage_name1(session): +def tmp_stage_name1(session, local_testing_mode): stage_name = Utils.random_stage_name() - Utils.create_stage(session, stage_name) + if not local_testing_mode: + Utils.create_stage(session, stage_name) try: yield stage_name finally: - Utils.drop_stage(session, stage_name) + if not local_testing_mode: + Utils.drop_stage(session, stage_name) @pytest.fixture(scope="module") -def tmp_stage_name2(session): +def tmp_stage_name2(session, local_testing_mode): stage_name = Utils.random_stage_name() - Utils.create_stage(session, stage_name) + if not local_testing_mode: + Utils.create_stage(session, stage_name) try: yield stage_name finally: - Utils.drop_stage(session, stage_name) + if not local_testing_mode: + Utils.drop_stage(session, stage_name) @pytest.fixture(scope="module") def tmp_table_name(session): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) - Utils.create_table(session, table_name, "a Int, b String, c Double") + session.create_dataframe( + [], + StructType( + [ + StructField("a", IntegerType()), + StructField("b", StringType()), + StructField("c", DoubleType()), + ] + ), + ).write.save_as_table(table_name) try: yield table_name finally: - Utils.drop_table(session, table_name) + session.table(table_name).drop_table() @pytest.fixture(scope="module", autouse=True) @@ -641,6 +659,7 @@ def test_transormation_as_clause_no_effect(session, tmp_stage_name1): Utils.drop_table(session, table_name) +@pytest.mark.localtest def test_copy_with_wrong_dataframe(session): with pytest.raises(SnowparkDataframeException) as exec_info: session.table("a_table_name").copy_into_table("a_table_name") diff --git a/tests/integ/scala/test_dataframe_join_suite.py b/tests/integ/scala/test_dataframe_join_suite.py index fd47079d4a3..c15c31463ef 100644 --- a/tests/integ/scala/test_dataframe_join_suite.py +++ b/tests/integ/scala/test_dataframe_join_suite.py @@ -13,12 +13,21 @@ from snowflake.snowpark.exceptions import ( SnowparkJoinException, SnowparkSQLAmbiguousJoinException, + SnowparkSQLException, SnowparkSQLInvalidIdException, ) from snowflake.snowpark.functions import coalesce, col, count, is_null, lit +from snowflake.snowpark.types import ( + IntegerType, + StringType, + StructField, + StructType, + TimestampType, +) from tests.utils import Utils +@pytest.mark.localtest def test_join_using(session): df = session.create_dataframe([[i, str(i)] for i in range(1, 4)]).to_df( ["int", "str"] @@ -33,6 +42,7 @@ def test_join_using(session): ] +@pytest.mark.localtest def test_join_using_multiple_columns(session): df = session.create_dataframe([[i, i + 1, str(i)] for i in range(1, 4)]).to_df( ["int", "int2", "str"] @@ -49,6 +59,7 @@ def test_join_using_multiple_columns(session): ] +@pytest.mark.localtest def test_full_outer_join_followed_by_inner_join(session): a = session.create_dataframe([[1, 2], [2, 3]]).to_df(["a", "b"]) b = session.create_dataframe([[2, 5], [3, 4]]).to_df(["a", "c"]) @@ -59,6 +70,7 @@ def test_full_outer_join_followed_by_inner_join(session): assert abc.collect() == [Row(3, None, 4, 1)] +@pytest.mark.localtest def test_limit_with_join(session): df = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df( ["int", "int2", "str"] @@ -66,6 +78,7 @@ def test_limit_with_join(session): df2 = session.create_dataframe([[1, 1, "1"], [2, 3, "5"]]).to_df( ["int", "int2", "str"] ) + limit = 1310721 inner = ( df.limit(limit) @@ -75,6 +88,7 @@ def test_limit_with_join(session): assert inner.collect() == [Row(1)] +@pytest.mark.localtest def test_default_inner_join(session): df = session.create_dataframe([1, 2]).to_df(["a"]) df2 = session.create_dataframe([[i, f"test{i}"] for i in range(1, 3)]).to_df( @@ -91,6 +105,7 @@ def test_default_inner_join(session): ] +@pytest.mark.localtest def test_default_inner_join_using_column(session): df = session.create_dataframe([1, 2]).to_df(["a"]) df2 = session.create_dataframe([[i, f"test{i}"] for i in range(1, 3)]).to_df( @@ -101,6 +116,7 @@ def test_default_inner_join_using_column(session): assert df.join(df2, "a").filter(col("a") > 1).collect() == [Row(2, "test2")] +@pytest.mark.localtest def test_3_way_joins(session): df1 = session.create_dataframe([1, 2]).to_df(["a"]) df2 = session.create_dataframe([[i, f"test{i}"] for i in range(1, 3)]).to_df( @@ -115,6 +131,7 @@ def test_3_way_joins(session): assert res == [Row("test1", 1, "hello1"), Row("test2", 2, "hello2")] +@pytest.mark.localtest def test_default_inner_join_with_join_conditions(session): df1 = session.create_dataframe([[i, f"test{i}"] for i in range(1, 3)]).to_df( ["a", "b"] @@ -130,6 +147,7 @@ def test_default_inner_join_with_join_conditions(session): ] +@pytest.mark.localtest def test_join_with_multiple_conditions(session): df1 = session.create_dataframe([[i, f"test{i}"] for i in range(1, 3)]).to_df( ["a", "b"] @@ -153,82 +171,75 @@ def test_join_with_ambiguous_column_in_condidtion(session): assert "The reference to the column 'A' is ambiguous." in ex_info.value.message +@pytest.mark.localtest def test_join_using_multiple_columns_and_specifying_join_type(session): table_name1 = Utils.random_name_for_temp_object(TempObjectType.TABLE) table_name2 = Utils.random_name_for_temp_object(TempObjectType.TABLE) - try: - Utils.create_table(session, table_name1, "int int, int2 int, str string") - session.sql( - f"insert into {table_name1} values(1, 2, '1'),(3, 4, '3')" - ).collect() - Utils.create_table(session, table_name2, "int int, int2 int, str string") - session.sql( - f"insert into {table_name2} values(1, 3, '1'),(5, 6, '5')" - ).collect() + schema = StructType( + [ + StructField("int", IntegerType()), + StructField("int2", IntegerType()), + StructField("str", StringType()), + ] + ) + session.create_dataframe( + [[1, 2, "1"], [3, 4, "3"]], schema=schema + ).write.save_as_table(table_name1, table_type="temporary") + session.create_dataframe( + [[1, 3, "1"], [5, 6, "5"]], schema=schema + ).write.save_as_table(table_name2, table_type="temporary") - df = session.table(table_name1) - df2 = session.table(table_name2) + df = session.table(table_name1) + df2 = session.table(table_name2) - assert df.join(df2, ["int", "str"], "inner").collect() == [Row(1, "1", 2, 3)] + assert df.join(df2, ["int", "str"], "inner").collect() == [Row(1, "1", 2, 3)] - res = df.join(df2, ["int", "str"], "left").collect() - assert sorted(res, key=lambda x: x[0]) == [ - Row(1, "1", 2, 3), - Row(3, "3", 4, None), - ] + res = df.join(df2, ["int", "str"], "left").collect() + assert sorted(res, key=lambda x: x[0]) == [ + Row(1, "1", 2, 3), + Row(3, "3", 4, None), + ] - res = df.join(df2, ["int", "str"], "right").collect() - assert sorted(res, key=lambda x: x[0]) == [ - Row(1, "1", 2, 3), - Row(5, "5", None, 6), - ] + res = df.join(df2, ["int", "str"], "right").collect() + assert sorted(res, key=lambda x: x[0]) == [ + Row(1, "1", 2, 3), + Row(5, "5", None, 6), + ] - res = df.join(df2, ["int", "str"], "outer").collect() - res.sort(key=lambda x: x[0]) - assert res == [ - Row(1, "1", 2, 3), - Row(3, "3", 4, None), - Row(5, "5", None, 6), - ] + res = df.join(df2, ["int", "str"], "outer").collect() + res.sort(key=lambda x: x[0]) + assert res == [ + Row(1, "1", 2, 3), + Row(3, "3", 4, None), + Row(5, "5", None, 6), + ] - assert df.join(df2, ["int", "str"], "left_semi").collect() == [Row(1, 2, "1")] - assert df.join(df2, ["int", "str"], "semi").collect() == [Row(1, 2, "1")] + assert df.join(df2, ["int", "str"], "left_semi").collect() == [Row(1, 2, "1")] + assert df.join(df2, ["int", "str"], "semi").collect() == [Row(1, 2, "1")] - assert df.join(df2, ["int", "str"], "left_anti").collect() == [Row(3, 4, "3")] - assert df.join(df2, ["int", "str"], "anti").collect() == [Row(3, 4, "3")] - finally: - Utils.drop_table(session, table_name1) - Utils.drop_table(session, table_name2) + assert df.join(df2, ["int", "str"], "left_anti").collect() == [Row(3, 4, "3")] + assert df.join(df2, ["int", "str"], "anti").collect() == [Row(3, 4, "3")] +@pytest.mark.localtest def test_join_using_conditions_and_specifying_join_type(session): - table_name1 = Utils.random_name_for_temp_object(TempObjectType.TABLE) - table_name2 = Utils.random_name_for_temp_object(TempObjectType.TABLE) - - try: - Utils.create_table(session, table_name1, "a1 int, b1 int, str1 string") - session.sql( - f"insert into {table_name1} values(1, 2, '1'),(3, 4, '3')" - ).collect() - Utils.create_table(session, table_name2, "a2 int, b2 int, str2 string") - session.sql( - f"insert into {table_name2} values(1, 3, '1'),(5, 6, '5')" - ).collect() - df = session.table(table_name1) - df2 = session.table(table_name2) + df1 = session.create_dataframe( + [[1, 2, "1"], [3, 4, "3"]], schema=["a1", "b1", "str1"] + ) + df2 = session.create_dataframe( + [[1, 3, "1"], [5, 6, "5"]], schema=["a2", "b2", "str2"] + ) - join_cond = (df["a1"] == df2["a2"]) & (df["str1"] == df2["str2"]) + join_cond = (df1["a1"] == df2["a2"]) & (df1["str1"] == df2["str2"]) - Utils.check_answer(df.join(df2, join_cond, "left_semi"), [Row(1, 2, "1")]) - Utils.check_answer(df.join(df2, join_cond, "semi"), [Row(1, 2, "1")]) - Utils.check_answer(df.join(df2, join_cond, "left_anti"), [Row(3, 4, "3")]) - Utils.check_answer(df.join(df2, join_cond, "anti"), [Row(3, 4, "3")]) - finally: - Utils.drop_table(session, table_name1) - Utils.drop_table(session, table_name2) + Utils.check_answer(df1.join(df2, join_cond, "left_semi"), [Row(1, 2, "1")]) + Utils.check_answer(df1.join(df2, join_cond, "semi"), [Row(1, 2, "1")]) + Utils.check_answer(df1.join(df2, join_cond, "left_anti"), [Row(3, 4, "3")]) + Utils.check_answer(df1.join(df2, join_cond, "anti"), [Row(3, 4, "3")]) +@pytest.mark.localtest def test_natural_join(session): df = session.create_dataframe([1, 2]).to_df("a") df2 = session.create_dataframe([[i, f"test{i}"] for i in range(1, 3)]).to_df( @@ -237,6 +248,7 @@ def test_natural_join(session): Utils.check_answer(df.natural_join(df2), [Row(1, "test1"), Row(2, "test2")]) +@pytest.mark.localtest def test_natural_outer_join(session): df1 = session.create_dataframe([[1, "1"], [3, "3"]]).to_df("a", "b") df2 = session.create_dataframe([[1, "1"], [4, "4"]]).to_df("a", "c") @@ -252,6 +264,7 @@ def test_natural_outer_join(session): ) +@pytest.mark.localtest def test_cross_join(session): df1 = session.create_dataframe([[1, "1"], [3, "3"]]).to_df(["int", "str"]) df2 = session.create_dataframe([[2, "2"], [4, "4"]]).to_df(["int", "str"]) @@ -275,7 +288,10 @@ def test_cross_join(session): ] -def test_join_ambiguous_columns_with_specified_sources(session): +@pytest.mark.localtest +def test_join_ambiguous_columns_with_specified_sources( + session, +): df = session.create_dataframe([1, 2]).to_df(["a"]) df2 = session.create_dataframe([[i, f"test{i}"] for i in range(1, 3)]).to_df( ["a", "b"] @@ -317,7 +333,10 @@ def test_join_ambiguous_columns_without_specified_sources(session): ) -def test_join_expression_ambiguous_columns(session): +@pytest.mark.localtest +def test_join_expression_ambiguous_columns( + session, +): lhs = session.create_dataframe([[1, -1, "one"], [2, -2, "two"]]).to_df( ["intcol", "negcol", "lhscol"] ) @@ -361,7 +380,10 @@ def test_semi_join_expression_ambiguous_columns(session): assert "not present" in str(ex_info) -def test_semi_join_with_columns_from_LHS(session): +@pytest.mark.localtest +def test_semi_join_with_columns_from_LHS( + session, +): lhs = session.create_dataframe([[1, -1, "one"], [2, -2, "two"]]).to_df( ["intcol", "negcol", "lhscol"] ) @@ -420,7 +442,9 @@ def test_semi_join_with_columns_from_LHS(session): assert sorted(res, key=lambda x: x[0]) == [Row(1), Row(2)] -def test_using_joins(session): +@pytest.mark.localtest +@pytest.mark.parametrize("join_type", ["inner", "leftouter", "rightouter", "fullouter"]) +def test_using_joins(session, join_type, local_testing_mode): lhs = session.create_dataframe([[1, -1, "one"], [2, -2, "two"]]).to_df( ["intcol", "negcol", "lhscol"] ) @@ -428,31 +452,36 @@ def test_using_joins(session): ["intcol", "negcol", "rhscol"] ) - for join_type in ["inner", "leftouter", "rightouter", "full_outer"]: - res = lhs.join(rhs, ["intcol"], join_type).select("*").collect() - assert res == [ - Row(1, -1, "one", -10, "one"), - Row(2, -2, "two", -20, "two"), - ] + res = lhs.join(rhs, ["intcol"], join_type).select("*").collect() + assert res == [ + Row(1, -1, "one", -10, "one"), + Row(2, -2, "two", -20, "two"), + ] - res = lhs.join(rhs, ["intcol"], join_type).collect() - assert res == [ - Row(1, -1, "one", -10, "one"), - Row(2, -2, "two", -20, "two"), - ] + res = lhs.join(rhs, ["intcol"], join_type).collect() + assert res == [ + Row(1, -1, "one", -10, "one"), + Row(2, -2, "two", -20, "two"), + ] + if local_testing_mode: + # TODO: [local testing] align error experience + with pytest.raises(SnowparkSQLException) as ex_info: + lhs.join(rhs, ["intcol"], join_type).select("negcol").collect() + assert 'invalid identifier "NEGCOL"' in ex_info.value.message + else: with pytest.raises(SnowparkSQLAmbiguousJoinException) as ex_info: lhs.join(rhs, ["intcol"], join_type).select("negcol").collect() assert "reference to the column 'NEGCOL' is ambiguous" in ex_info.value.message - res = lhs.join(rhs, ["intcol"], join_type).select("intcol").collect() - assert res == [Row(1), Row(2)] - res = ( - lhs.join(rhs, ["intcol"], join_type) - .select(lhs["negcol"], rhs["negcol"]) - .collect() - ) - assert sorted(res, key=lambda x: -x[0]) == [Row(-1, -10), Row(-2, -20)] + res = lhs.join(rhs, ["intcol"], join_type).select("intcol").collect() + assert res == [Row(1), Row(2)] + res = ( + lhs.join(rhs, ["intcol"], join_type) + .select(lhs["negcol"], rhs["negcol"]) + .collect() + ) + assert sorted(res, key=lambda x: -x[0]) == [Row(-1, -10), Row(-2, -20)] def test_columns_with_and_without_quotes(session): @@ -486,7 +515,10 @@ def test_columns_with_and_without_quotes(session): assert "reference to the column 'INTCOL' is ambiguous." in ex_info.value.message -def test_aliases_multiple_levels_deep(session): +@pytest.mark.localtest +def test_aliases_multiple_levels_deep( + session, +): lhs = session.create_dataframe([[1, -1, "one"], [2, -2, "two"]]).to_df( ["intcol", "negcol", "lhscol"] ) @@ -577,126 +609,119 @@ def test_negative_test_for_self_join_with_conditions(session): Utils.drop_table(session, table_name1) +@pytest.mark.localtest def test_clone_can_help_these_self_joins(session): table_name1 = Utils.random_name_for_temp_object(TempObjectType.TABLE) - try: - Utils.create_table(session, table_name1, "c1 int, c2 int") - session.sql(f"insert into {table_name1} values(1, 2), (2, 3)").collect() - df = session.table(table_name1) - cloned_df = copy.copy(df) - - # inner self join - assert df.join(cloned_df, df["c1"] == cloned_df["c2"]).collect() == [ - Row(2, 3, 1, 2) - ] + schema = StructType( + [StructField("c1", IntegerType()), StructField("c2", IntegerType())] + ) + session.create_dataframe([[1, 2], [2, 3]], schema=schema).write.save_as_table( + table_name1, table_type="temporary" + ) + df = session.table(table_name1) + cloned_df = copy.copy(df) - # left self join - res = df.join(cloned_df, df["c1"] == cloned_df["c2"], "left").collect() - res.sort(key=lambda x: x[0]) - assert res == [Row(1, 2, None, None), Row(2, 3, 1, 2)] + # inner self join + assert df.join(cloned_df, df["c1"] == cloned_df["c2"]).collect() == [ + Row(2, 3, 1, 2) + ] - # right self join - res = df.join(cloned_df, df["c1"] == cloned_df["c2"], "right").collect() - res.sort(key=lambda x: x[0] or 0) - assert res == [Row(None, None, 2, 3), Row(2, 3, 1, 2)] + # left self join + res = df.join(cloned_df, df["c1"] == cloned_df["c2"], "left").collect() + res.sort(key=lambda x: x[0]) + assert res == [Row(1, 2, None, None), Row(2, 3, 1, 2)] - # outer self join - res = df.join(cloned_df, df["c1"] == cloned_df["c2"], "outer").collect() - res.sort(key=lambda x: x[0] or 0) - assert res == [ - Row(None, None, 2, 3), - Row(1, 2, None, None), - Row(2, 3, 1, 2), - ] + # right self join + res = df.join(cloned_df, df["c1"] == cloned_df["c2"], "right").collect() + res.sort(key=lambda x: x[0] or 0) + assert res == [Row(None, None, 2, 3), Row(2, 3, 1, 2)] - finally: - Utils.drop_table(session, table_name1) + # outer self join + res = df.join(cloned_df, df["c1"] == cloned_df["c2"], "outer").collect() + res.sort(key=lambda x: x[0] or 0) + assert res == [ + Row(None, None, 2, 3), + Row(1, 2, None, None), + Row(2, 3, 1, 2), + ] +@pytest.mark.localtest def test_natural_cross_joins(session): - table_name1 = Utils.random_name_for_temp_object(TempObjectType.TABLE) - try: - Utils.create_table(session, table_name1, "c1 int, c2 int") - session.sql(f"insert into {table_name1} values(1, 2), (2, 3)").collect() - df = session.table(table_name1) - df2 = df # Another reference of "df" - cloned_df = copy.copy(df) + df1 = session.create_dataframe([[1, 2], [2, 3]], schema=["c1", "c2"]) + df2 = df1 # Another reference of "df" + cloned_df1 = copy.copy(df1) - # "natural join" supports self join - assert df.natural_join(df2).collect() == [Row(1, 2), Row(2, 3)] - assert df.natural_join(cloned_df).collect() == [Row(1, 2), Row(2, 3)] - - # "cross join" supports self join - res = df.cross_join(df2).collect() - res.sort(key=lambda x: x[0]) - assert res == [ - Row(1, 2, 1, 2), - Row(1, 2, 2, 3), - Row(2, 3, 1, 2), - Row(2, 3, 2, 3), - ] + # "natural join" supports self join + assert df1.natural_join(df2).collect() == [Row(1, 2), Row(2, 3)] + assert df1.natural_join(cloned_df1).collect() == [Row(1, 2), Row(2, 3)] - res = df.cross_join(df2).collect() - res.sort(key=lambda x: x[0]) - assert res == [ - Row(1, 2, 1, 2), - Row(1, 2, 2, 3), - Row(2, 3, 1, 2), - Row(2, 3, 2, 3), - ] + # "cross join" supports self join + res = df1.cross_join(df2).collect() + res.sort(key=lambda x: x[0]) + assert res == [ + Row(1, 2, 1, 2), + Row(1, 2, 2, 3), + Row(2, 3, 1, 2), + Row(2, 3, 2, 3), + ] - finally: - Utils.drop_table(session, table_name1) + res = df1.cross_join(df2).collect() + res.sort(key=lambda x: x[0]) + assert res == [ + Row(1, 2, 1, 2), + Row(1, 2, 2, 3), + Row(2, 3, 1, 2), + Row(2, 3, 2, 3), + ] +@pytest.mark.localtest def test_clone_with_join_dataframe(session): table_name1 = Utils.random_name_for_temp_object(TempObjectType.TABLE) - try: - Utils.create_table(session, table_name1, "c1 int, c2 int") - session.sql(f"insert into {table_name1} values(1, 2), (2, 3)").collect() - df = session.table(table_name1) + session.create_dataframe([[1, 2], [2, 3]], schema=["c1", "c2"]).write.save_as_table( + table_name1, table_type="temporary" + ) - assert df.collect() == [Row(1, 2), Row(2, 3)] + df = session.table(table_name1) - cloned_df = copy.copy(df) - # Cloned DF has the same conent with original DF - assert cloned_df.collect() == [Row(1, 2), Row(2, 3)] + assert df.collect() == [Row(1, 2), Row(2, 3)] - join_df = df.join(cloned_df, df["c1"] == cloned_df["c2"]) - assert join_df.collect() == [Row(2, 3, 1, 2)] - # Cloned join DF - cloned_join_df = copy.copy(join_df) - assert cloned_join_df.collect() == [Row(2, 3, 1, 2)] + cloned_df = copy.copy(df) + # Cloned DF has the same conent with original DF + assert cloned_df.collect() == [Row(1, 2), Row(2, 3)] - finally: - Utils.drop_table(session, table_name1) + join_df = df.join(cloned_df, df["c1"] == cloned_df["c2"]) + assert join_df.collect() == [Row(2, 3, 1, 2)] + # Cloned join DF + cloned_join_df = copy.copy(join_df) + assert cloned_join_df.collect() == [Row(2, 3, 1, 2)] +@pytest.mark.localtest def test_join_of_join(session): table_name1 = Utils.random_name_for_temp_object(TempObjectType.TABLE) - try: - Utils.create_table(session, table_name1, "c1 int, c2 int") - session.sql(f"insert into {table_name1} values(1, 1), (2, 2)").collect() - df_l = session.table(table_name1) - df_r = copy.copy(df_l) - df_j = df_l.join(df_r, df_l["c1"] == df_r["c1"]) - - assert df_j.collect() == [Row(1, 1, 1, 1), Row(2, 2, 2, 2)] + session.create_dataframe([[1, 1], [2, 2]], schema=["c1", "c2"]).write.save_as_table( + table_name1, table_type="temporary" + ) + df_l = session.table(table_name1) + df_r = copy.copy(df_l) + df_j = df_l.join(df_r, df_l["c1"] == df_r["c1"]) - df_j_clone = copy.copy(df_j) - # Because of duplicate column name rename, we have to get a name. - col_name = df_j.schema.fields[0].name - df_j_j = df_j.join(df_j_clone, df_j[col_name] == df_j_clone[col_name]) + assert df_j.collect() == [Row(1, 1, 1, 1), Row(2, 2, 2, 2)] - assert df_j_j.collect() == [ - Row(1, 1, 1, 1, 1, 1, 1, 1), - Row(2, 2, 2, 2, 2, 2, 2, 2), - ] + df_j_clone = copy.copy(df_j) + # Because of duplicate column name rename, we have to get a name. + col_name = df_j.schema.fields[0].name + df_j_j = df_j.join(df_j_clone, df_j[col_name] == df_j_clone[col_name]) - finally: - Utils.drop_table(session, table_name1) + assert df_j_j.collect() == [ + Row(1, 1, 1, 1, 1, 1, 1, 1), + Row(2, 2, 2, 2, 2, 2, 2, 2), + ] +# TODO: [Local Testing] Fix simplifier copy def test_negative_test_join_of_join(session): table_name1 = Utils.random_name_for_temp_object(TempObjectType.TABLE) try: @@ -715,67 +740,60 @@ def test_negative_test_join_of_join(session): Utils.drop_table(session, table_name1) -def test_drop_on_join(session): +def test_drop_on_join( + session, +): # TODO: [Local Testing] Fix drop table_name_1 = Utils.random_name_for_temp_object(TempObjectType.TABLE) table_name_2 = Utils.random_name_for_temp_object(TempObjectType.TABLE) - try: - session.create_dataframe([[1, "a", True], [2, "b", False]]).to_df( - "a", "b", "c" - ).write.save_as_table(table_name_1) - session.create_dataframe([[3, "a", True], [4, "b", False]]).to_df( - "a", "b", "c" - ).write.save_as_table(table_name_2) - df1 = session.table(table_name_1) - df2 = session.table(table_name_2) - df3 = df1.join(df2, df1["c"] == df2["c"]).drop(df1["a"], df2["b"], df1["c"]) - Utils.check_answer(df3, [Row("a", 3, True), Row("b", 4, False)]) - df4 = df3.drop(df2["c"], df1["b"], col("other")) - Utils.check_answer(df4, [Row(3), Row(4)]) - finally: - Utils.drop_table(session, table_name_1) - Utils.drop_table(session, table_name_2) - -def test_drop_on_self_join(session): + session.create_dataframe([[1, "a", True], [2, "b", False]]).to_df( + "a", "b", "c" + ).write.save_as_table(table_name_1, table_type="temporary") + session.create_dataframe([[3, "a", True], [4, "b", False]]).to_df( + "a", "b", "c" + ).write.save_as_table(table_name_2, table_type="temporary") + df1 = session.table(table_name_1) + df2 = session.table(table_name_2) + df3 = df1.join(df2, df1["c"] == df2["c"]).drop(df1["a"], df2["b"], df1["c"]) + Utils.check_answer(df3, [Row("a", 3, True), Row("b", 4, False)]) + df4 = df3.drop(df2["c"], df1["b"], col("other")) + Utils.check_answer(df4, [Row(3), Row(4)]) + + +def test_drop_on_self_join(session): # TODO: Fix drop table_name_1 = Utils.random_name_for_temp_object(TempObjectType.TABLE) - try: - session.create_dataframe([[1, "a", True], [2, "b", False]]).to_df( - "a", "b", "c" - ).write.save_as_table(table_name_1) - df1 = session.table(table_name_1) - df2 = copy.copy(df1) - df3 = df1.join(df2, df1["c"] == df2["c"]).drop(df1["a"], df2["b"], df1["c"]) - Utils.check_answer(df3, [Row("a", 1, True), Row("b", 2, False)]) - df4 = df3.drop(df2["c"], df1["b"], col("other")) - Utils.check_answer(df4, [Row(1), Row(2)]) - finally: - Utils.drop_table(session, table_name_1) - - -def test_with_column_on_join(session): + session.create_dataframe([[1, "a", True], [2, "b", False]]).to_df( + "a", "b", "c" + ).write.save_as_table(table_name_1, table_type="temporary") + df1 = session.table(table_name_1) + df2 = copy.copy(df1) + df3 = df1.join(df2, df1["c"] == df2["c"]).drop(df1["a"], df2["b"], df1["c"]) + Utils.check_answer(df3, [Row("a", 1, True), Row("b", 2, False)]) + df4 = df3.drop(df2["c"], df1["b"], col("other")) + Utils.check_answer(df4, [Row(1), Row(2)]) + + +def test_with_column_on_join(session): # TODO: Fix drop table_name_1 = Utils.random_name_for_temp_object(TempObjectType.TABLE) table_name_2 = Utils.random_name_for_temp_object(TempObjectType.TABLE) - try: - session.create_dataframe([[1, "a", True], [2, "b", False]]).to_df( - "a", "b", "c" - ).write.save_as_table(table_name_1) - session.create_dataframe([[3, "a", True], [4, "b", False]]).to_df( - "a", "b", "c" - ).write.save_as_table(table_name_2) - df1 = session.table(table_name_1) - df2 = session.table(table_name_2) - Utils.check_answer( - df1.join(df2, df1["c"] == df2["c"]) - .drop(df1["b"], df2["b"], df1["c"]) - .with_column("newColumn", df1["a"] + df2["a"]), - [Row(1, 3, True, 4), Row(2, 4, False, 6)], - ) - finally: - Utils.drop_table(session, table_name_1) - Utils.drop_table(session, table_name_2) + session.create_dataframe([[1, "a", True], [2, "b", False]]).to_df( + "a", "b", "c" + ).write.save_as_table(table_name_1, table_type="temporary") + session.create_dataframe([[3, "a", True], [4, "b", False]]).to_df( + "a", "b", "c" + ).write.save_as_table(table_name_2, table_type="temporary") + df1 = session.table(table_name_1) + df2 = session.table(table_name_2) + Utils.check_answer( + df1.join(df2, df1["c"] == df2["c"]) + .drop(df1["b"], df2["b"], df1["c"]) + .with_column("newColumn", df1["a"] + df2["a"]), + [Row(1, 3, True, 4), Row(2, 4, False, 6)], + ) -def test_process_outer_join_results_using_the_non_nullable_columns_in_the_join_outpu( +@pytest.mark.localtest +def test_process_outer_join_results_using_the_non_nullable_columns_in_the_join_output( session, ): df1 = session.create_dataframe([(0, 0), (1, 0), (2, 0), (3, 0), (4, 0)]).to_df( @@ -801,6 +819,7 @@ def test_process_outer_join_results_using_the_non_nullable_columns_in_the_join_o ) +@pytest.mark.localtest def test_outer_join_conversion(session): df = session.create_dataframe([(1, 2, "1"), (3, 4, "3")]).to_df( ["int", "int2", "str"] @@ -842,7 +861,11 @@ def test_outer_join_conversion(session): assert left_join_2_inner == [Row(1, 2, "1", 1, 3, "1")] -def test_dont_throw_analysis_exception_in_check_cartesian(session): +@pytest.mark.localtest +def test_dont_throw_analysis_exception_in_check_cartesian( + session, +): + # Can't this be a unit test """Don't throw Analysis Exception in CheckCartesianProduct when join condition is false or null""" df = session.range(10).to_df(["id"]) dfNull = session.range(10).select(lit(None).as_("b")) @@ -853,18 +876,30 @@ def test_dont_throw_analysis_exception_in_check_cartesian(session): dfOne.join(dfTwo, col("a") == col("b"), "left").collect() +@pytest.mark.localtest def test_name_alias_on_multiple_join(session): table_trips = Utils.random_name_for_temp_object(TempObjectType.TABLE) table_stations = Utils.random_name_for_temp_object(TempObjectType.TABLE) try: - session.sql( - f"create or replace temp table {table_trips} (starttime timestamp, " - f"start_station_id int, end_station_id int)" - ).collect() - session.sql( - f"create or replace temp table {table_stations} " - f"(station_id int, station_name string)" - ).collect() + session.create_dataframe( + [], + schema=StructType( + [ + StructField("starttime", TimestampType()), + StructField("start_station_id", IntegerType()), + StructField("end_station_id", IntegerType()), + ] + ), + ).write.save_as_table(table_trips, table_type="temporary") + session.create_dataframe( + [], + schema=StructType( + [ + StructField("station_id", IntegerType()), + StructField("station_name", StringType()), + ] + ), + ).write.save_as_table(table_stations, table_type="temporary") df_trips = session.table(table_trips) df_start_stations = session.table(table_stations) @@ -949,6 +984,7 @@ def test_report_error_when_refer_common_col(session): assert "The reference to the column 'C' is ambiguous." in ex_info.value.message +@pytest.mark.localtest def test_select_all_on_join_result(session): df_left = session.create_dataframe([[1, 2]]).to_df("a", "b") df_right = session.create_dataframe([[3, 4]]).to_df("c", "d") @@ -994,6 +1030,7 @@ def test_select_all_on_join_result(session): ) +@pytest.mark.localtest def test_select_left_right_on_join_result(session): df_left = session.create_dataframe([[1, 2]]).to_df("a", "b") df_right = session.create_dataframe([[3, 4]]).to_df("c", "d") @@ -1020,6 +1057,7 @@ def test_select_left_right_on_join_result(session): ) +@pytest.mark.localtest def test_select_left_right_combination_on_join_result(session): df_left = session.create_dataframe([[1, 2]]).to_df("a", "b") df_right = session.create_dataframe([[3, 4]]).to_df("c", "d") @@ -1077,7 +1115,10 @@ def test_select_left_right_combination_on_join_result(session): ) -def test_select_columns_on_join_result_with_conflict_name(session): +@pytest.mark.localtest +def test_select_columns_on_join_result_with_conflict_name( + session, +): df_left = session.create_dataframe([[1, 2]]).to_df("a", "b") df_right = session.create_dataframe([[3, 4]]).to_df("a", "d") df = df_left.join(df_right) @@ -1118,7 +1159,9 @@ def test_select_columns_on_join_result_with_conflict_name(session): assert df4.collect() == [Row(3, 4, 1)] -def test_nested_join_diamond_shape_error(session): +def test_nested_join_diamond_shape_error( + session, +): # TODO: local testing match error behavior """This is supposed to work but currently we don't handle it correctly. We should fix this with a good design.""" df1 = session.create_dataframe([[1]], schema=["a"]) df2 = session.create_dataframe([[1]], schema=["a"]) @@ -1134,6 +1177,7 @@ def test_nested_join_diamond_shape_error(session): df5.collect() +@pytest.mark.localtest def test_nested_join_diamond_shape_workaround(session): df1 = session.create_dataframe([[1]], schema=["a"]) df2 = session.create_dataframe([[1]], schema=["a"]) diff --git a/tests/integ/scala/test_dataframe_range_suite.py b/tests/integ/scala/test_dataframe_range_suite.py index 048cd2cf24f..bc45bc7aaf4 100644 --- a/tests/integ/scala/test_dataframe_range_suite.py +++ b/tests/integ/scala/test_dataframe_range_suite.py @@ -12,26 +12,30 @@ from snowflake.snowpark.functions import col, count, sum as sum_ +@pytest.mark.localtest def test_range(session): assert session.range(5).collect() == [Row(i) for i in range(5)] assert session.range(3, 5).collect() == [Row(i) for i in range(3, 5)] assert session.range(3, 10, 2).collect() == [Row(i) for i in range(3, 10, 2)] +@pytest.mark.localtest def test_negative_test(session): with pytest.raises(ValueError) as ex_info: session.range(-3, 5, 0) assert "The step for range() cannot be 0." in str(ex_info) +@pytest.mark.localtest def test_empty_result_and_negative_start_end_step(session): - assert session.range(3, 5, -1).count() == 0 - assert session.range(-3, -5, 1).count() == 0 + assert session.range(3, 5, -1).collect() == [] + assert session.range(-3, -5, 1).collect() == [] assert session.range(-3, -10, -2).collect() == [Row(i) for i in range(-3, -10, -2)] assert session.range(10, 3, -3).collect() == [Row(i) for i in range(10, 3, -3)] +@pytest.mark.localtest def test_range_api(session): res3 = session.range(1, -2).select("id") assert res3.count() == 0 @@ -61,10 +65,7 @@ def test_range_api(session): assert res16.count() == 500 -def test_range_test(session): - assert len(session.range(3).select("id").collect()) == 3 - - +@pytest.mark.localtest def test_range_with_randomized_parameters(session): MAX_NUM_STEPS = 10 * 1000 MAX_VALUE = 2**31 - 1 @@ -99,10 +100,11 @@ def random_bound(): assert res[0][1] == expected_sum +@pytest.mark.localtest def test_range_with_max_and_min(session): MAX_VALUE = 0x7FFFFFFFFFFFFFFF MIN_VALUE = -0x8000000000000000 start = MAX_VALUE - 3 end = MIN_VALUE + 2 - assert session.range(start, end, 1).count() == 0 - assert session.range(start, start, 1).count() == 0 + assert session.range(start, end, 1).collect() == [] + assert session.range(start, start, 1).collect() == [] diff --git a/tests/integ/scala/test_dataframe_reader_suite.py b/tests/integ/scala/test_dataframe_reader_suite.py index 1d0e4eadde7..e3cfcfc3940 100644 --- a/tests/integ/scala/test_dataframe_reader_suite.py +++ b/tests/integ/scala/test_dataframe_reader_suite.py @@ -27,9 +27,11 @@ ) from snowflake.snowpark.functions import col, lit, sql_expr from snowflake.snowpark.types import ( + BooleanType, DateType, DecimalType, DoubleType, + FloatType, IntegerType, LongType, StringType, @@ -42,6 +44,7 @@ from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils test_file_csv = "testCSV.csv" +test_file_cvs_various_data = "testCSVvariousData.csv" test_file2_csv = "test2CSV.csv" test_file_csv_colon = "testCSVcolon.csv" test_file_csv_header = "testCSVheader.csv" @@ -117,13 +120,20 @@ def get_df_from_reader_and_file_format(reader, file_format): @pytest.fixture(scope="module", autouse=True) -def setup(session, resources_path): +def setup(session, resources_path, local_testing_mode): test_files = TestFiles(resources_path) - Utils.create_stage(session, tmp_stage_name1, is_temporary=True) - Utils.create_stage(session, tmp_stage_name2, is_temporary=True) + if not local_testing_mode: + Utils.create_stage(session, tmp_stage_name1, is_temporary=True) + Utils.create_stage(session, tmp_stage_name2, is_temporary=True) Utils.upload_to_stage( session, "@" + tmp_stage_name1, test_files.test_file_csv, compress=False ) + Utils.upload_to_stage( + session, + "@" + tmp_stage_name1, + test_files.test_file_csv_various_data, + compress=False, + ) Utils.upload_to_stage( session, "@" + tmp_stage_name1, @@ -196,10 +206,12 @@ def setup(session, resources_path): yield # tear down the resources after yield (pytest fixture feature) # https://docs.pytest.org/en/6.2.x/fixture.html#yield-fixtures-recommended - session.sql(f"DROP STAGE IF EXISTS {tmp_stage_name1}").collect() - session.sql(f"DROP STAGE IF EXISTS {tmp_stage_name2}").collect() + if not local_testing_mode: + session.sql(f"DROP STAGE IF EXISTS {tmp_stage_name1}").collect() + session.sql(f"DROP STAGE IF EXISTS {tmp_stage_name2}").collect() +@pytest.mark.localtest @pytest.mark.parametrize("mode", ["select", "copy"]) def test_read_csv(session, mode): reader = get_reader(session, mode) @@ -229,6 +241,85 @@ def test_read_csv(session, mode): df2.collect() assert "Numeric value 'one' is not recognized" in ex_info.value.message + cvs_schema = StructType( + [ + StructField("a", IntegerType()), + StructField("b", LongType()), + StructField("c", StringType()), + StructField("d", DoubleType()), + StructField("e", DecimalType(scale=0)), + StructField("f", DecimalType(scale=2)), + StructField("g", DecimalType(precision=2)), + StructField("h", DecimalType(precision=10, scale=3)), + StructField("i", FloatType()), + StructField("j", BooleanType()), + StructField("k", DateType()), + StructField("l", TimestampType()), + StructField("m", TimeType()), + ] + ) + df3 = reader.schema(cvs_schema).csv( + f"@{tmp_stage_name1}/{test_file_cvs_various_data}" + ) + res = df3.collect() + res.sort(key=lambda x: x[0]) + assert res == [ + Row( + 1, + 234, + "one", + 1.2, + 12, + Decimal("12.35"), + -12, + Decimal("12.346"), + 56.78, + True, + datetime.date(2023, 6, 6), + datetime.datetime(2023, 6, 6, 12, 34, 56), + datetime.time(12, 34, 56), + ), + Row( + 2, + 567, + "two", + 2.2, + 57, + Decimal("56.79"), + -57, + Decimal("56.787"), + 89.01, + False, + datetime.date(2023, 6, 6), + datetime.datetime(2023, 6, 6, 12, 34, 56), + datetime.time(12, 34, 56), + ), + ] + + cvs_schema = StructType( + [ + StructField("a", IntegerType()), + StructField("b", LongType()), + StructField("c", StringType()), + StructField("d", DoubleType()), + StructField("e", DecimalType(scale=0)), + StructField("f", DecimalType(scale=2)), + StructField("g", DecimalType(precision=1)), + StructField("h", DecimalType(precision=10, scale=3)), + StructField("i", FloatType()), + StructField("j", BooleanType()), + StructField("k", DateType()), + StructField("l", TimestampType()), + StructField("m", TimeType()), + ] + ) + df3 = reader.schema(cvs_schema).csv( + f"@{tmp_stage_name1}/{test_file_cvs_various_data}" + ) + with pytest.raises(SnowparkSQLException) as ex_info: + df3.collect() + assert "is out of range" in str(ex_info) + @pytest.mark.parametrize("mode", ["select", "copy"]) @pytest.mark.parametrize("parse_header", [True, False]) @@ -317,6 +408,7 @@ def test_save_as_table_work_with_df_created_from_read(session): Utils.drop_table(session, xml_table_name) +@pytest.mark.localtest def test_read_csv_with_more_operations(session): test_file_on_stage = f"@{tmp_stage_name1}/{test_file_csv}" df1 = session.read.schema(user_schema).csv(test_file_on_stage).filter(col("a") < 2) @@ -364,6 +456,7 @@ def test_read_csv_with_more_operations(session): ] +@pytest.mark.localtest @pytest.mark.parametrize("mode", ["select", "copy"]) def test_read_csv_with_format_type_options(session, mode): test_file_colon = f"@{tmp_stage_name1}/{test_file_csv_colon}" @@ -424,11 +517,13 @@ def test_read_csv_with_format_type_options(session, mode): ] +@pytest.mark.localtest @pytest.mark.parametrize("mode", ["select", "copy"]) -def test_to_read_files_from_stage(session, resources_path, mode): +def test_to_read_files_from_stage(session, resources_path, mode, local_testing_mode): data_files_stage = Utils.random_stage_name() - Utils.create_stage(session, data_files_stage, is_temporary=True) test_files = TestFiles(resources_path) + if not local_testing_mode: + Utils.create_stage(session, data_files_stage, is_temporary=True) Utils.upload_to_stage( session, "@" + data_files_stage, test_files.test_file_csv, False ) @@ -453,7 +548,8 @@ def test_to_read_files_from_stage(session, resources_path, mode): Row(4, "four", 4.4), ] finally: - session.sql(f"DROP STAGE IF EXISTS {data_files_stage}") + if not local_testing_mode: + session.sql(f"DROP STAGE IF EXISTS {data_files_stage}") @pytest.mark.xfail(reason="SNOW-575700 flaky test", strict=False) @@ -492,6 +588,7 @@ def test_for_all_csv_compression_keywords(session, temp_schema, mode): session.sql(f"drop file format {format_name}") +@pytest.mark.localtest @pytest.mark.parametrize("mode", ["select", "copy"]) def test_read_csv_with_special_chars_in_format_type_options(session, mode): schema1 = StructType( @@ -900,6 +997,10 @@ def test_read_xml_with_no_schema(session, mode): ] +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="on_error is not supported", +) def test_copy(session): test_file_on_stage = f"@{tmp_stage_name1}/{test_file_csv}" @@ -940,6 +1041,9 @@ def test_copy(session): assert df2.collect() == [] +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="force is not supported." +) def test_copy_option_force(session): test_file_on_stage = f"@{tmp_stage_name1}/{test_file_csv}" @@ -983,6 +1087,10 @@ def test_copy_option_force(session): ).collect() +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="on_error is not supported.", +) def test_read_file_on_error_continue_on_csv(session, db_parameters, resources_path): broken_file = f"@{tmp_stage_name1}/{test_broken_csv}" @@ -998,6 +1106,10 @@ def test_read_file_on_error_continue_on_csv(session, db_parameters, resources_pa assert res == [Row(1, "one", 1.1), Row(3, "three", 3.3)] +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="on_error is not supported.", +) def test_read_file_on_error_continue_on_avro(session): broken_file = f"@{tmp_stage_name1}/{test_file_avro}" @@ -1034,6 +1146,9 @@ def test_select_and_copy_on_non_csv_format_have_same_result_schema(session): assert c.column_identifier.quoted_name == f.column_identifier.quoted_name +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="pattern is not supported" +) @pytest.mark.parametrize("mode", ["select", "copy"]) def test_pattern(session, mode): assert ( @@ -1047,6 +1162,9 @@ def test_pattern(session, mode): ) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="sql is not supported." +) def test_read_staged_file_no_commit(session): path = f"@{tmp_stage_name1}/{test_file_csv}" @@ -1065,6 +1183,10 @@ def test_read_staged_file_no_commit(session): assert not Utils.is_active_transaction(session) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="local test does not have queries", +) def test_read_csv_with_sql_simplifier(session): if session.sql_simplifier_enabled is False: pytest.skip("Applicable only when sql simplifier is enabled") diff --git a/tests/integ/scala/test_dataframe_set_operations_suite.py b/tests/integ/scala/test_dataframe_set_operations_suite.py index 23b7cb113c0..b9f2c120dab 100644 --- a/tests/integ/scala/test_dataframe_set_operations_suite.py +++ b/tests/integ/scala/test_dataframe_set_operations_suite.py @@ -16,6 +16,7 @@ from tests.utils import TestData, Utils +@pytest.mark.localtest def test_union_with_filters(session): """Tests union queries with a filter added""" @@ -49,6 +50,7 @@ def check(new_col: Column, cfilter: Column, result: List[Row]): check(lit(2).cast(IntegerType()), col("c") != 2, list()) +@pytest.mark.localtest def test_union_all_with_filters(session): """Tests union queries with a filter added""" @@ -82,6 +84,7 @@ def check(new_col: Column, cfilter: Column, result: List[Row]): check(lit(2).cast(IntegerType()), col("c") != 2, list()) +@pytest.mark.localtest def test_except(session): lower_case_data = TestData.lower_case_data(session) upper_case_data = TestData.upper_case_data(session) @@ -122,6 +125,7 @@ def test_except(session): Utils.check_answer(all_nulls.filter(lit(0) == 1).except_(all_nulls), []) +@pytest.mark.localtest def test_except_between_two_projects_without_references_used_in_filter(session): df = session.create_dataframe(((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5))).to_df( "a", "b", "c" @@ -132,6 +136,7 @@ def test_except_between_two_projects_without_references_used_in_filter(session): Utils.check_answer(df1.select("b").except_(df2.select("c")), Row(2)) +@pytest.mark.localtest def test_union_unionall_unionbyname_unionallbyname_in_one_case(session): df1 = session.create_dataframe([(1, 2, 3)]).to_df("a", "b", "c") df2 = session.create_dataframe([(3, 1, 2)]).to_df("c", "a", "b") @@ -148,6 +153,7 @@ def test_union_unionall_unionbyname_unionallbyname_in_one_case(session): Utils.check_answer(df1.union_all_by_name(df3), [Row(1, 2, 3), Row(3, 1, 2)]) +@pytest.mark.localtest def test_nondeterministic_expressions_should_not_be_pushed_down(session): df1 = session.create_dataframe([(i,) for i in range(1, 21)]).to_df("i") df2 = session.create_dataframe([(i,) for i in range(1, 11)]).to_df("i") @@ -165,6 +171,7 @@ def test_nondeterministic_expressions_should_not_be_pushed_down(session): Utils.check_answer(except_.collect(), except_.collect()) +@pytest.mark.localtest def test_union_all(session): td4 = TestData.test_data4(session) union_df = td4.union(td4).union(td4).union(td4).union(td4) @@ -177,6 +184,7 @@ def test_union_all(session): assert res == [Row(1, 25250)] +@pytest.mark.localtest def test_union_by_name(session): df1 = session.create_dataframe([(1, 2, 3)]).to_df("a", "b", "c") df2 = session.create_dataframe([(3, 1, 2)]).to_df("c", "a", "b") @@ -197,6 +205,7 @@ def test_union_by_name(session): df1.union_by_name(df2) +@pytest.mark.localtest def test_unionall_by_name(session): df1 = session.create_dataframe([(1, 2, 3)]).to_df("a", "b", "c") df2 = session.create_dataframe([(3, 1, 2)]).to_df("c", "a", "b") @@ -217,6 +226,7 @@ def test_unionall_by_name(session): df1.union_all_by_name(df2) +@pytest.mark.localtest def test_union_by_quoted_name(session): df1 = session.create_dataframe([(1, 2, 3)]).to_df('"a"', "a", "c") df2 = session.create_dataframe([(3, 1, 2)]).to_df("c", '"a"', "a") @@ -232,6 +242,7 @@ def test_union_by_quoted_name(session): df1.union_by_name(df2) +@pytest.mark.localtest def test_unionall_by_quoted_name(session): df1 = session.create_dataframe([(1, 2, 3)]).to_df('"a"', "a", "c") df2 = session.create_dataframe([(3, 1, 2)]).to_df("c", '"a"', "a") @@ -247,6 +258,7 @@ def test_unionall_by_quoted_name(session): df1.union_by_name(df2) +@pytest.mark.localtest def test_intersect_nullability(session): non_nullable_ints = session.create_dataframe([[1], [3]]).to_df("a") null_ints = TestData.null_ints(session) @@ -280,6 +292,7 @@ def test_intersect_nullability(session): assert all(not i.nullable for i in df4.schema.fields) +@pytest.mark.localtest def test_performing_set_ops_on_non_native_types(session): dates = session.create_dataframe( [ @@ -303,6 +316,7 @@ def test_performing_set_ops_on_non_native_types(session): dates.except_(widen_typed_rows).collect() +@pytest.mark.localtest def test_union_by_name_check_name_duplication(session): c0 = "ab" c1 = "AB" @@ -319,6 +333,7 @@ def test_union_by_name_check_name_duplication(session): df1.union_by_name(df2) +@pytest.mark.localtest def test_unionall_by_name_check_name_duplication(session): c0 = "ab" c1 = "AB" @@ -335,6 +350,7 @@ def test_unionall_by_name_check_name_duplication(session): df1.union_all_by_name(df2) +@pytest.mark.localtest def test_intersect(session): lcd = TestData.lower_case_data(session) res = lcd.intersect(lcd).collect() @@ -362,6 +378,7 @@ def test_intersect(session): assert res == [Row("id", 1), Row("id1", 1), Row("id1", 2)] +@pytest.mark.localtest def test_project_should_not_be_pushed_down_through_intersect_or_except(session): df1 = session.create_dataframe([[i] for i in range(1, 101)]).to_df("i") df2 = session.create_dataframe([[i] for i in range(1, 31)]).to_df("i") @@ -370,8 +387,9 @@ def test_project_should_not_be_pushed_down_through_intersect_or_except(session): assert df1.except_(df2).count() == 70 +@pytest.mark.localtest def test_except_nullability(session): - non_nullable_ints = session.create_dataframe(((11,), (3,))).to_df("a") + non_nullable_ints = session.create_dataframe(((11,), (3,))).to_df(["a"]) for attribute in non_nullable_ints.schema._to_attributes(): assert not attribute.nullable @@ -397,12 +415,18 @@ def test_except_nullability(session): assert not attribute.nullable +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_except_distinct_sql_compliance(session): df_left = session.create_dataframe([(1,), (2,), (2,), (3,), (3,), (4,)]).to_df("id") df_right = session.create_dataframe([(1,), (3,)]).to_df("id") Utils.check_answer(df_left.except_(df_right), [Row(2), Row(4)]) +@pytest.mark.localtest def test_mix_set_operator(session): df1 = session.create_dataframe([1]).to_df("a") df2 = session.create_dataframe([2]).to_df("a") diff --git a/tests/integ/scala/test_dataframe_suite.py b/tests/integ/scala/test_dataframe_suite.py index 9dc77762848..dc2da3e2538 100644 --- a/tests/integ/scala/test_dataframe_suite.py +++ b/tests/integ/scala/test_dataframe_suite.py @@ -68,17 +68,28 @@ SAMPLING_DEVIATION = 0.4 -def test_null_data_in_tables(session): +@pytest.mark.localtest +def test_null_data_in_tables(session, local_testing_mode): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) try: - Utils.create_table(session, table_name, "num int") - session.sql(f"insert into {table_name} values(null),(null),(null)").collect() + if not local_testing_mode: + Utils.create_table(session, table_name, "num int") + session.sql( + f"insert into {table_name} values(null),(null),(null)" + ).collect() + else: + session.create_dataframe( + [[None], [None], [None]], + schema=StructType([StructField("num", IntegerType())]), + ).write.save_as_table(table_name) res = session.table(table_name).collect() assert res == [Row(None), Row(None), Row(None)] finally: - Utils.drop_table(session, table_name) + if not local_testing_mode: + Utils.drop_table(session, table_name) +@pytest.mark.localtest def test_null_data_in_local_relation_with_filters(session): df = session.create_dataframe([[1, None], [2, "NotNull"], [3, None]]).to_df( ["a", "b"] @@ -101,6 +112,7 @@ def test_null_data_in_local_relation_with_filters(session): ] +@pytest.mark.localtest def test_project_null_values(session): """Tests projecting null values onto different columns in a dataframe""" df = session.create_dataframe([1, 2]).to_df("a").with_column("b", lit(None)) @@ -136,17 +148,43 @@ def test_bulk_insert_from_collected_result(session): Utils.drop_table(session, table_name_copied) -def test_write_null_data_to_table(session): +@pytest.mark.localtest +def test_write_null_data_to_table(session, local_testing_mode): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) df = session.create_dataframe([(1, None), (2, None), (3, None)]).to_df("a", "b") try: df.write.save_as_table(table_name) Utils.check_answer(session.table(table_name), df, True) finally: - Utils.drop_table(session, table_name) + if not local_testing_mode: + Utils.drop_table(session, table_name) -def test_create_or_replace_view_with_null_data(session): +@pytest.mark.localtest +def test_view_should_be_updated(session, local_testing_mode): + """Assert views should reflect changes if the underlying data is updated.""" + table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) + view_name = Utils.random_name_for_temp_object(TempObjectType.VIEW) + df = session.create_dataframe([(1, 2), (3, 4)], schema=["a", "b"]) + try: + df.write.save_as_table(table_name, table_type="temporary") + session.table(table_name).select( + sum_(col("a")), sum_(col("b")) + ).create_or_replace_view(view_name) + Utils.check_answer(session.table(view_name), [Row(4, 6)]) + + session.create_dataframe( + [(5, 6), (7, 8)], schema=["a", "b"] + ).write.save_as_table(table_name, mode="append") + Utils.check_answer(session.table(view_name), [Row(16, 20)]) + finally: + if not local_testing_mode: + Utils.drop_table(session, table_name) + Utils.drop_view(session, view_name) + + +@pytest.mark.localtest +def test_create_or_replace_view_with_null_data(session, local_testing_mode): df = session.create_dataframe([[1, None], [2, "NotNull"], [3, None]]).to_df( ["a", "b"] ) @@ -154,13 +192,15 @@ def test_create_or_replace_view_with_null_data(session): try: df.create_or_replace_view(view_name) - res = session.sql(f"select * from {view_name}").collect() + res = session.table(view_name).collect() res.sort(key=lambda x: x[0]) assert res == [Row(1, None), Row(2, "NotNull"), Row(3, None)] finally: - Utils.drop_view(session, view_name) + if not local_testing_mode: + Utils.drop_view(session, view_name) +@pytest.mark.localtest def test_adjust_column_width_of_show(session): df = session.create_dataframe([[1, None], [2, "NotNull"]]).to_df("a", "b") # run show(), make sure no error is reported @@ -179,6 +219,7 @@ def test_adjust_column_width_of_show(session): ) +@pytest.mark.localtest def test_show_with_null_data(session): df = session.create_dataframe([[1, None], [2, "NotNull"]]).to_df("a", "b") # run show(), make sure no error is reported @@ -197,6 +238,7 @@ def test_show_with_null_data(session): ) +@pytest.mark.localtest def test_show_multi_lines_row(session): df = session.create_dataframe( [ @@ -221,6 +263,9 @@ def test_show_multi_lines_row(session): ) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="sql use is not supported" +) def test_show(session): TestData.test_data1(session).show() @@ -255,19 +300,23 @@ def test_show(session): ) +@pytest.mark.localtest def test_cache_result(session): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) - session._run_query(f"create temp table {table_name} (num int)") - session._run_query(f"insert into {table_name} values(1),(2)") + session.create_dataframe([[1], [2]], schema=["num"]).write.save_as_table(table_name) df = session.table(table_name) Utils.check_answer(df, [Row(1), Row(2)]) - session._run_query(f"insert into {table_name} values (3)") + session.create_dataframe([[3]], schema=["num"]).write.save_as_table( + table_name, mode="append" + ) Utils.check_answer(df, [Row(1), Row(2), Row(3)]) df1 = df.cache_result() - session._run_query(f"insert into {table_name} values (4)") + session.create_dataframe([[4]], schema=["num"]).write.save_as_table( + table_name, mode="append" + ) Utils.check_answer(df1, [Row(1), Row(2), Row(3)]) Utils.check_answer(df, [Row(1), Row(2), Row(3), Row(4)]) @@ -280,7 +329,7 @@ def test_cache_result(session): df4 = df1.cache_result() Utils.check_answer(df4, [Row(1), Row(2), Row(3)]) - session._run_query(f"drop table {table_name}") + session.table(table_name).drop_table() Utils.check_answer(df1, [Row(1), Row(2), Row(3)]) Utils.check_answer(df2, [Row(3)]) @@ -312,8 +361,9 @@ def test_cache_result_with_show(session): session._run_query(f"drop table {table_name1}") +@pytest.mark.localtest def test_drop_cache_result_try_finally(session): - df = session.sql("select 1 as a, 2 as b") + df = session.create_dataframe([[1, 2]], schema=["a", "b"]) cached = df.cache_result() try: df_after_cached = cached.select("a") @@ -335,8 +385,9 @@ def test_drop_cache_result_try_finally(session): df_after_cached.collect() +@pytest.mark.localtest def test_drop_cache_result_context_manager(session): - df = session.sql("select 1 as a, 2 as b") + df = session.create_dataframe([[1, 2]], schema=["a", "b"]) with df.cache_result() as cached: df_after_cached = cached.select("a") df_after_cached.collect() @@ -439,6 +490,9 @@ def test_non_select_query_composition_self_unionall(session): Utils.drop_table(session, table_name) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="sql use is not supported" +) def test_only_use_result_scan_when_composing_queries(session): df = session.sql("show tables") assert len(df._plan.queries) == 1 @@ -449,6 +503,9 @@ def test_only_use_result_scan_when_composing_queries(session): assert "RESULT_SCAN" in df2._plan.queries[-1].sql +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="sql use is not supported" +) def test_joins_on_result_scan(session): df1 = session.sql("show tables").select(['"name"', '"kind"']) df2 = session.sql("show tables").select(['"name"', '"rows"']) @@ -757,6 +814,7 @@ def test_df_stat_crosstab_max_column_test(session): assert res_4[0]["A"] == 1 and res_4[0]["CAST(1 AS NUMBER(38,0))"] == 1001 +@pytest.mark.localtest def test_select_star(session): double2 = TestData.double2(session) expected = TestData.double2(session).collect() @@ -764,6 +822,7 @@ def test_select_star(session): assert double2.select(double2.col("*")).collect() == expected +@pytest.mark.localtest def test_first(session): assert TestData.integer1(session).first() == Row(1) assert TestData.null_data1(session).first() == Row(None) @@ -783,6 +842,7 @@ def test_first(session): assert sorted(res, key=lambda x: x[0]) == [Row(1), Row(2), Row(3)] +@pytest.mark.localtest @pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="Large result") def test_sample_with_row_count(session): """Tests sample using n (row count)""" @@ -799,6 +859,7 @@ def test_sample_with_row_count(session): assert len(df.sample(n=row_count + 10).collect()) == row_count +@pytest.mark.localtest @pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="Large result") def test_sample_with_frac(session): """Tests sample using frac""" @@ -820,6 +881,7 @@ def test_sample_with_frac(session): assert len(df.sample(frac=1.0).collect()) == row_count +@pytest.mark.localtest def test_sample_with_seed(session): row_count = 10000 temp_table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) @@ -862,9 +924,10 @@ def test_sample_with_sampling_method(session): ) assert len(df.sample(frac=1.0, sampling_method="BLOCK").collect()) == row_count finally: - Utils.drop_table(session, temp_table_name) + df.drop_table() +@pytest.mark.localtest def test_sample_negative(session): """Tests negative test cases for sample""" row_count = 10000 @@ -889,6 +952,7 @@ def test_sample_negative(session): table.sample(frac=0.1, sampling_method="InvalidValue") +@pytest.mark.localtest def test_sample_on_join(session): """Tests running sample on a join statement""" row_count = 10000 @@ -905,6 +969,7 @@ def test_sample_on_join(session): ) +@pytest.mark.localtest @pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="Large result") def test_sample_on_union(session): """Tests running sample on union statements""" @@ -930,6 +995,7 @@ def test_sample_on_union(session): ) +@pytest.mark.localtest def test_toDf(session): # to_df(*str) with 1 column df1 = session.create_dataframe([1, 2, 3]).to_df("a") @@ -939,6 +1005,17 @@ def test_toDf(session): and df1.schema.fields[0].name == "A" ) df1.show() + assert ( + df1._show_string() + == """ +------- +|"A" | +------- +|1 | +|2 | +|3 | +-------\n""".lstrip() + ) # to_df([str]) with 1 column df2 = session.create_dataframe([1, 2, 3]).to_df(["a"]) assert ( @@ -947,6 +1024,17 @@ def test_toDf(session): and df2.schema.fields[0].name == "A" ) df2.show() + assert ( + df2._show_string() + == """ +------- +|"A" | +------- +|1 | +|2 | +|3 | +-------\n""".lstrip() + ) # to_df(*str) with 2 columns df3 = session.create_dataframe([(1, None), (2, "NotNull"), (3, None)]).to_df( @@ -975,6 +1063,7 @@ def test_toDf(session): assert df6.schema.fields[0].name == "A" and df6.schema.fields[-1].name == "C" +@pytest.mark.localtest def test_toDF_negative_test(session): values = session.create_dataframe([[1, None], [2, "NotNull"], [3, None]]) @@ -1001,6 +1090,7 @@ def test_toDF_negative_test(session): assert "The number of columns doesn't match" in ex_info.value.args[0] +@pytest.mark.localtest def test_sort(session): df = session.create_dataframe( [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3)] @@ -1046,6 +1136,7 @@ def test_sort(session): assert "sort() needs at least one sort expression" in ex_info.value.args[0] +@pytest.mark.localtest def test_select(session): df = session.create_dataframe([(1, "a", 10), (2, "b", 20), (3, "c", 30)]).to_df( ["a", "b", "c"] @@ -1111,6 +1202,7 @@ def test_select_negative_select(session): assert "SQL compilation error" in str(ex_info) +@pytest.mark.localtest def test_drop_and_dropcolumns(session): df = session.create_dataframe([(1, "a", 10), (2, "b", 20), (3, "c", 30)]).to_df( ["a", "b", "c"] @@ -1163,6 +1255,7 @@ def test_drop_and_dropcolumns(session): assert "Cannot drop all column" in str(ex_info) +@pytest.mark.localtest def test_dataframe_agg(session): df = session.create_dataframe([(1, "One"), (2, "Two"), (3, "Three")]).to_df( "empid", "name" @@ -1266,6 +1359,7 @@ def test_rollup(session): ) +@pytest.mark.localtest def test_groupby(session): df = session.create_dataframe( [ @@ -1597,6 +1691,7 @@ def test_createDataFrame_with_given_schema(session): Utils.check_answer(result, data, sort=False) +@pytest.mark.localtest def test_createDataFrame_with_given_schema_time(session): schema = StructType( [ @@ -1660,6 +1755,9 @@ def test_createDataFrame_with_given_schema_timestamp(session): Utils.check_answer(df, expected, sort=False) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="sql use is not supported" +) @pytest.mark.skipif(IS_IN_STORED_PROC, reason="need to support PUT/GET command") def test_show_collect_with_misc_commands(session, resources_path, tmpdir): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) @@ -1758,7 +1856,8 @@ def test_createDataFrame_with_given_schema_array_map_variant(session): Utils.check_answer(df, expected, sort=False) -def test_variant_in_array_and_map(session): +@pytest.mark.localtest +def test_variant_in_array_and_map(session, local_testing_mode): schema = StructType( [StructField("array", ArrayType(None)), StructField("map", MapType(None, None))] ) @@ -1767,6 +1866,7 @@ def test_variant_in_array_and_map(session): Utils.check_answer(df, [Row('[\n 1,\n "\\"\'"\n]', '{\n "a": "\\"\'"\n}')]) +@pytest.mark.localtest def test_escaped_character(session): df = session.create_dataframe(["'", "\\", "\n"]).to_df("a") res = df.collect() @@ -1777,36 +1877,39 @@ def test_escaped_character(session): IS_IN_STORED_PROC, reason="creating new sessions within stored proc is not supported", ) -def test_create_or_replace_temporary_view(session, db_parameters): +@pytest.mark.localtest +def test_create_or_replace_temporary_view(session, db_parameters, local_testing_mode): view_name = Utils.random_name_for_temp_object(TempObjectType.VIEW) view_name1 = f'"{view_name}%^11"' view_name2 = f'"{view_name}"' - try: - df = session.create_dataframe([1, 2, 3]).to_df("a") - df.create_or_replace_temp_view(view_name) - res = session.table(view_name).collect() - res.sort(key=lambda x: x[0]) - assert res == [Row(1), Row(2), Row(3)] - - # test replace - df2 = session.create_dataframe(["a", "b", "c"]).to_df("b") - df2.create_or_replace_temp_view(view_name) - res = session.table(view_name).collect() - assert res == [Row("a"), Row("b"), Row("c")] - - # view name has special char - df.create_or_replace_temp_view(view_name1) - res = session.table(view_name1).collect() - res.sort(key=lambda x: x[0]) - assert res == [Row(1), Row(2), Row(3)] - - # view name has quote - df.create_or_replace_temp_view(view_name2) - res = session.table(view_name2).collect() - res.sort(key=lambda x: x[0]) - assert res == [Row(1), Row(2), Row(3)] - + df = session.create_dataframe([1, 2, 3]).to_df("a") + df.create_or_replace_temp_view(view_name) + res = session.table(view_name).collect() + res.sort(key=lambda x: x[0]) + assert res == [Row(1), Row(2), Row(3)] + + # test replace + df2 = session.create_dataframe(["a", "b", "c"]).to_df("b") + df2.create_or_replace_temp_view(view_name) + res = session.table(view_name).collect() + assert res == [Row("a"), Row("b"), Row("c")] + + # view name has special char + df.create_or_replace_temp_view(view_name1) + res = session.table(view_name1).collect() + res.sort(key=lambda x: x[0]) + assert res == [Row(1), Row(2), Row(3)] + + # view name has quote + df.create_or_replace_temp_view(view_name2) + res = session.table(view_name2).collect() + res.sort(key=lambda x: x[0]) + assert res == [Row(1), Row(2), Row(3)] + + if ( + not local_testing_mode + ): # Having multiple sessions are not supported, Local Testing doesn't maintain states across sessions # Get a second session object session2 = Session.builder.configs(db_parameters).create() session2.sql_simplifier_enabled = session.sql_simplifier_enabled @@ -1814,13 +1917,10 @@ def test_create_or_replace_temporary_view(session, db_parameters): assert session is not session2 with pytest.raises(SnowparkSQLException) as ex_info: session2.table(view_name).collect() - assert "does not exist or not authorized" in str(ex_info) - finally: - Utils.drop_view(session, view_name) - Utils.drop_view(session, view_name1) - Utils.drop_view(session, view_name2) + assert "does not exist or not authorized" in str(ex_info) +@pytest.mark.localtest def test_createDataFrame_with_schema_inference(session): df1 = session.create_dataframe([1, 2, 3]).to_df("int") Utils.check_answer(df1, [Row(1), Row(2), Row(3)]) @@ -1836,6 +1936,7 @@ def test_createDataFrame_with_schema_inference(session): Utils.check_answer(df2, [Row(True, "a"), Row(False, "b")], False) +@pytest.mark.localtest def test_create_nullable_dataframe_with_schema_inference(session): df = session.create_dataframe([(1, 1, None), (2, 3, True)]).to_df("a", "b", "c") assert ( @@ -1846,6 +1947,7 @@ def test_create_nullable_dataframe_with_schema_inference(session): Utils.check_answer(df, [Row(1, 1, None), Row(2, 3, True)]) +@pytest.mark.localtest def test_schema_inference_binary_type(session): df = session.create_dataframe( [ @@ -1860,12 +1962,16 @@ def test_schema_inference_binary_type(session): ) -def test_primitive_array(session): +@pytest.mark.localtest +def test_primitive_array(session, local_testing_mode): schema = StructType([StructField("arr", ArrayType(None))]) df = session.create_dataframe([Row([1])], schema) Utils.check_answer(df, Row("[\n 1\n]")) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="sql use is not supported" +) def test_time_date_and_timestamp_test(session): assert str(session.sql("select '00:00:00' :: Time").collect()[0][0]) == "00:00:00" assert ( @@ -1875,7 +1981,8 @@ def test_time_date_and_timestamp_test(session): assert str(session.sql("select '1970-1-1' :: Date").collect()[0][0]) == "1970-01-01" -def test_quoted_column_names(session): +@pytest.mark.localtest +def test_quoted_column_names(session, local_testing_mode): normalName = "NORMAL_NAME" lowerCaseName = '"lower_case"' quoteStart = '"""quote_start"' @@ -1885,13 +1992,26 @@ def test_quoted_column_names(session): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) try: - Utils.create_table( - session, - table_name, - f"{normalName} int, {lowerCaseName} int, {quoteStart} int," - f"{quoteEnd} int, {quoteMiddle} int, {quoteAllCases} int", - ) - session.sql(f"insert into {table_name} values(1, 2, 3, 4, 5, 6)").collect() + if not local_testing_mode: + Utils.create_table( + session, + table_name, + f"{normalName} int, {lowerCaseName} int, {quoteStart} int," + f"{quoteEnd} int, {quoteMiddle} int, {quoteAllCases} int", + ) + session.sql(f"insert into {table_name} values(1, 2, 3, 4, 5, 6)").collect() + else: + session.create_dataframe( + [[1, 2, 3, 4, 5, 6]], + schema=[ + normalName, + lowerCaseName, + quoteStart, + quoteEnd, + quoteMiddle, + quoteAllCases, + ], + ).write.save_as_table(table_name) # test select() df1 = session.table(table_name).select( @@ -1964,10 +2084,12 @@ def test_quoted_column_names(session): assert df4.collect() == [Row(1)] finally: - Utils.drop_table(session, table_name) + if not local_testing_mode: + Utils.drop_table(session, table_name) -def test_column_names_without_surrounding_quote(session): +@pytest.mark.localtest +def test_column_names_without_surrounding_quote(session, local_testing_mode): normalName = "NORMAL_NAME" lowerCaseName = '"lower_case"' quoteStart = '"""quote_start"' @@ -1977,13 +2099,26 @@ def test_column_names_without_surrounding_quote(session): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) try: - Utils.create_table( - session, - table_name, - f"{normalName} int, {lowerCaseName} int, {quoteStart} int," - f"{quoteEnd} int, {quoteMiddle} int, {quoteAllCases} int", - ) - session.sql(f"insert into {table_name} values(1, 2, 3, 4, 5, 6)").collect() + if not local_testing_mode: + Utils.create_table( + session, + table_name, + f"{normalName} int, {lowerCaseName} int, {quoteStart} int," + f"{quoteEnd} int, {quoteMiddle} int, {quoteAllCases} int", + ) + session.sql(f"insert into {table_name} values(1, 2, 3, 4, 5, 6)").collect() + else: + session.create_dataframe( + [[1, 2, 3, 4, 5, 6]], + schema=[ + normalName, + lowerCaseName, + quoteStart, + quoteEnd, + quoteMiddle, + quoteAllCases, + ], + ).write.save_as_table(table_name) quoteStart2 = '"quote_start' quoteEnd2 = 'quote_end"' @@ -2001,9 +2136,11 @@ def test_column_names_without_surrounding_quote(session): assert df1.collect() == [Row(3, 4, 5)] finally: - Utils.drop_table(session, table_name) + if not local_testing_mode: + Utils.drop_table(session, table_name) +@pytest.mark.localtest def test_negative_test_for_user_input_invalid_quoted_name(session): df = session.create_dataframe([1, 2, 3]).to_df("a") with pytest.raises(SnowparkPlanException) as ex_info: @@ -2011,12 +2148,18 @@ def test_negative_test_for_user_input_invalid_quoted_name(session): assert "Invalid identifier" in str(ex_info) -def test_clone_with_union_dataframe(session): +@pytest.mark.localtest +def test_clone_with_union_dataframe(session, local_testing_mode): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) try: - Utils.create_table(session, table_name, "c1 int, c2 int") + if not local_testing_mode: + Utils.create_table(session, table_name, "c1 int, c2 int") + session.sql(f"insert into {table_name} values(1, 1),(2, 2)").collect() + else: + session.create_dataframe( + [[1, 1], [2, 2]], schema=["c1", "c2"] + ).write.save_as_table(table_name) - session.sql(f"insert into {table_name} values(1, 1),(2, 2)").collect() df = session.table(table_name) union_df = df.union(df) @@ -2025,15 +2168,22 @@ def test_clone_with_union_dataframe(session): res.sort(key=lambda x: x[0]) assert res == [Row(1, 1), Row(2, 2)] finally: - Utils.drop_table(session, table_name) + if not local_testing_mode: + Utils.drop_table(session, table_name) -def test_clone_with_unionall_dataframe(session): +@pytest.mark.localtest +def test_clone_with_unionall_dataframe(session, local_testing_mode): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) try: - Utils.create_table(session, table_name, "c1 int, c2 int") + if not local_testing_mode: + Utils.create_table(session, table_name, "c1 int, c2 int") + session.sql(f"insert into {table_name} values(1, 1),(2, 2)").collect() + else: + session.create_dataframe( + [[1, 1], [2, 2]], schema=["c1", "c2"] + ).write.save_as_table(table_name) - session.sql(f"insert into {table_name} values(1, 1),(2, 2)").collect() df = session.table(table_name) union_df = df.union_all(df) @@ -2042,9 +2192,11 @@ def test_clone_with_unionall_dataframe(session): res.sort(key=lambda x: x[0]) assert res == [Row(1, 1), Row(1, 1), Row(2, 2), Row(2, 2)] finally: - Utils.drop_table(session, table_name) + if not local_testing_mode: + Utils.drop_table(session, table_name) +@pytest.mark.localtest def test_dataframe_show_with_new_line(session): df = session.create_dataframe( ["line1\nline1.1\n", "line2", "\n", "line4", "\n\n", None] @@ -2100,6 +2252,7 @@ def test_dataframe_show_with_new_line(session): ) +@pytest.mark.localtest def test_negative_test_to_input_invalid_table_name_for_saveAsTable(session): df = session.create_dataframe([(1, None), (2, "NotNull"), (3, None)]).to_df( "a", "b" @@ -2109,6 +2262,7 @@ def test_negative_test_to_input_invalid_table_name_for_saveAsTable(session): assert re.compile("The object name .* is invalid.").match(ex_info.value.message) +@pytest.mark.localtest def test_negative_test_to_input_invalid_view_name_for_createOrReplaceView(session): df = session.create_dataframe([[2, "NotNull"]]).to_df(["a", "b"]) with pytest.raises(SnowparkInvalidObjectNameException) as ex_info: @@ -2116,6 +2270,7 @@ def test_negative_test_to_input_invalid_view_name_for_createOrReplaceView(sessio assert re.compile("The object name .* is invalid.").match(ex_info.value.message) +@pytest.mark.localtest def test_toDF_with_array_schema(session): df = session.create_dataframe([[1, "a"]]).to_df("a", "b") schema = df.schema @@ -2124,6 +2279,7 @@ def test_toDF_with_array_schema(session): assert schema.fields[1].name == "B" +@pytest.mark.localtest def test_sort_with_array_arg(session): df = session.create_dataframe([(1, 1, 1), (2, 0, 4), (1, 2, 3)]).to_df( "col1", "col2", "col3" @@ -2132,28 +2288,33 @@ def test_sort_with_array_arg(session): Utils.check_answer(df_sorted, [Row(1, 2, 3), Row(1, 1, 1), Row(2, 0, 4)], False) +@pytest.mark.localtest def test_select_with_array_args(session): df = session.create_dataframe([[1, 2]]).to_df("col1", "col2") df_selected = df.select(df.col("col1"), lit("abc"), df.col("col1") + df.col("col2")) Utils.check_answer(df_selected, Row(1, "abc", 3)) +@pytest.mark.localtest def test_select_string_with_array_args(session): df = session.create_dataframe([[1, 2, 3]]).to_df("col1", "col2", "col3") df_selected = df.select(["col1", "col2"]) Utils.check_answer(df_selected, [Row(1, 2)]) +@pytest.mark.localtest def test_drop_string_with_array_args(session): df = session.create_dataframe([[1, 2, 3]]).to_df("col1", "col2", "col3") Utils.check_answer(df.drop(["col3"]), [Row(1, 2)]) +@pytest.mark.localtest def test_drop_with_array_args(session): df = session.create_dataframe([[1, 2, 3]]).to_df("col1", "col2", "col3") Utils.check_answer(df.drop([df["col3"]]), [Row(1, 2)]) +@pytest.mark.localtest def test_agg_with_array_args(session): df = session.create_dataframe([[1, 2], [4, 5]]).to_df("col1", "col2") Utils.check_answer(df.agg([max(col("col1")), mean(col("col2"))]), [Row(4, 3.5)]) @@ -2225,6 +2386,7 @@ def test_rollup_string_with_array_args(session): ) +@pytest.mark.localtest def test_groupby_with_array_args(session): df = session.create_dataframe( [ @@ -2251,6 +2413,7 @@ def test_groupby_with_array_args(session): ) +@pytest.mark.localtest def test_groupby_string_with_array_args(session): df = session.create_dataframe( [ @@ -2277,6 +2440,7 @@ def test_groupby_string_with_array_args(session): ) +@pytest.mark.localtest def test_rename_basic(session): df = session.create_dataframe([[1, 2]], schema=["a", "b"]) df2 = df.with_column_renamed("b", "b1") @@ -2341,7 +2505,8 @@ def test_rename_to_df_and_joined_dataframe(session): Utils.check_answer(df5, [Row(1, 2, 1, 2)]) -def test_rename_negative_test(session): +@pytest.mark.localtest +def test_rename_negative_test(session, local_testing_mode): df = session.create_dataframe([[1, 2]], schema=["a", "b"]) # rename un-qualified column @@ -2359,13 +2524,14 @@ def test_rename_negative_test(session): in str(exec_info) ) - df2 = session.sql("select 1 as A, 2 as A, 3 as A") - with pytest.raises(SnowparkColumnException) as col_exec_info: - df2.rename("A", "B") - assert ( - 'Unable to rename the column "A" as "B" because this DataFrame has 3 columns named "A".' - in str(col_exec_info) - ) + if not local_testing_mode: + df2 = session.sql("select 1 as A, 2 as A, 3 as A") + with pytest.raises(SnowparkColumnException) as col_exec_info: + df2.rename("A", "B") + assert ( + 'Unable to rename the column "A" as "B" because this DataFrame has 3 columns named "A".' + in str(col_exec_info) + ) # If single parameter, it has to be dict with pytest.raises(ValueError) as exec_info: @@ -2419,6 +2585,7 @@ def test_with_columns_keep_order(session): ) +@pytest.mark.localtest def test_with_columns_input_doesnt_match_each_other(session): df = session.create_dataframe([Row(1, 2, 3)]).to_df(["a", "b", "c"]) with pytest.raises(ValueError) as ex_info: @@ -2429,6 +2596,7 @@ def test_with_columns_input_doesnt_match_each_other(session): ) +@pytest.mark.localtest def test_with_columns_replace_existing(session): df = session.create_dataframe([Row(1, 2, 3)]).to_df(["a", "b", "c"]) replaced = df.with_columns(["b", "d"], [lit(5), lit(6)]) @@ -2449,6 +2617,7 @@ def test_with_columns_replace_existing(session): ) +@pytest.mark.localtest def test_drop_duplicates(session): df = session.create_dataframe( [[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4]], @@ -2494,6 +2663,7 @@ def test_drop_duplicates(session): assert "The DataFrame does not contain the column named e." in str(exec_info) +@pytest.mark.localtest def test_consecutively_drop_duplicates(session): df = session.create_dataframe( [[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4]], @@ -2511,36 +2681,58 @@ def test_consecutively_drop_duplicates(session): assert row1 in [Row(1, 1, 1, 1), Row(1, 1, 1, 2), Row(1, 1, 2, 3), Row(1, 2, 3, 4)] -def test_dropna(session): +@pytest.mark.local +def test_dropna(session, local_testing_mode): + Utils.check_answer( - TestData.double3(session).na.drop(thresh=1, subset=["a"]), + TestData.double3(session, local_testing_mode).na.drop(thresh=1, subset=["a"]), [Row(1.0, 1), Row(4.0, None)], ) - res = TestData.double3(session).na.drop(thresh=1, subset=["a", "b"]).collect() + res = ( + TestData.double3(session, local_testing_mode) + .na.drop(thresh=1, subset=["a", "b"]) + .collect() + ) assert res[0] == Row(1.0, 1) assert math.isnan(res[1][0]) assert res[1][1] == 2 assert res[2] == Row(None, 3) assert res[3] == Row(4.0, None) - assert TestData.double3(session).na.drop(thresh=0, subset=["a"]).count() == 6 - assert TestData.double3(session).na.drop(thresh=3, subset=["a", "b"]).count() == 0 - assert TestData.double3(session).na.drop(thresh=1, subset=[]).count() == 6 + assert ( + TestData.double3(session, local_testing_mode) + .na.drop(thresh=0, subset=["a"]) + .count() + == 6 + ) + assert ( + TestData.double3(session, local_testing_mode) + .na.drop(thresh=3, subset=["a", "b"]) + .count() + == 0 + ) + assert ( + TestData.double3(session, local_testing_mode) + .na.drop(thresh=1, subset=[]) + .count() + == 6 + ) # wrong column name with pytest.raises(SnowparkColumnException) as ex_info: - TestData.double3(session).na.drop(thresh=1, subset=["c"]) + TestData.double3(session, local_testing_mode).na.drop(thresh=1, subset=["c"]) assert "The DataFrame does not contain the column named" in str(ex_info) with pytest.raises(ValueError) as exc_info: - TestData.double3(session).na.drop(how="bad") + TestData.double3(session, local_testing_mode).na.drop(how="bad") assert "how ('bad') should be 'any' or 'all'" in str(exc_info) -def test_fillna(session): +@pytest.mark.localtest +def test_fillna(session, local_testing_mode): Utils.check_answer( - TestData.null_data3(session).na.fill( + TestData.null_data3(session, local_testing_mode).na.fill( {"flo": 12.3, "int": 11, "boo": False, "str": "f"} ), [ @@ -2554,7 +2746,7 @@ def test_fillna(session): sort=False, ) Utils.check_answer( - TestData.null_data3(session).na.fill( + TestData.null_data3(session, local_testing_mode).na.fill( {"flo": 22.3, "int": 22, "boo": False, "str": "f"} ), [ @@ -2569,7 +2761,7 @@ def test_fillna(session): ) # wrong type Utils.check_answer( - TestData.null_data3(session).na.fill( + TestData.null_data3(session, local_testing_mode).na.fill( {"flo": 12.3, "int": "11", "boo": False, "str": 1} ), [ @@ -2584,14 +2776,15 @@ def test_fillna(session): ) # wrong column name with pytest.raises(SnowparkColumnException) as ex_info: - TestData.null_data3(session).na.fill({"wrong": 11}) + TestData.null_data3(session, local_testing_mode).na.fill({"wrong": 11}) assert "The DataFrame does not contain the column named" in str(ex_info) -def test_replace(session): +@pytest.mark.localtest +def test_replace(session, local_testing_mode): res = ( - TestData.null_data3(session) - .na.replace({2: 300, 1: 200}, subset=["flo"]) + TestData.null_data3(session, local_testing_mode) + .na.replace({2: 300.0, 1: 200.0}, subset=["flo"]) .collect() ) assert res[0] == Row(200.0, 1, True, "a") @@ -2607,7 +2800,9 @@ def test_replace(session): # replace null res = ( - TestData.null_data3(session).na.replace({None: True}, subset=["boo"]).collect() + TestData.null_data3(session, local_testing_mode) + .na.replace({None: True}, subset=["boo"]) + .collect() ) assert res[0] == Row(1.0, 1, True, "a") assert math.isnan(res[1][0]) @@ -2622,21 +2817,25 @@ def test_replace(session): # replace NaN Utils.check_answer( - TestData.null_data3(session).na.replace({float("nan"): 11}, subset=["flo"]), + TestData.null_data3(session, local_testing_mode).na.replace( + {float("nan"): 11.0}, subset=["flo"] + ), [ Row(1.0, 1, True, "a"), - Row(11, 2, None, "b"), + Row(11.0, 2, None, "b"), Row(None, 3, False, None), Row(4.0, None, None, "d"), Row(None, None, None, None), - Row(11, None, None, None), + Row(11.0, None, None, None), ], sort=False, ) # incompatible type (skip that replacement and do nothing) res = ( - TestData.null_data3(session).na.replace({None: "aa"}, subset=["flo"]).collect() + TestData.null_data3(session, local_testing_mode) + .na.replace({None: "aa"}, subset=["flo"]) + .collect() ) assert res[0] == Row(1.0, 1, True, "a") assert math.isnan(res[1][0]) @@ -2651,7 +2850,9 @@ def test_replace(session): # replace NaN with None Utils.check_answer( - TestData.null_data3(session).na.replace({float("nan"): None}, subset=["flo"]), + TestData.null_data3(session, local_testing_mode).na.replace( + {float("nan"): None}, subset=["flo"] + ), [ Row(1.0, 1, True, "a"), Row(None, 2, None, "b"), @@ -2685,6 +2886,7 @@ def test_explain(session): assert "Logical Execution Plan" not in explain_string +@pytest.mark.localtest def test_to_local_iterator(session): df = session.create_dataframe([1, 2, 3]).toDF("a") iterator = df.to_local_iterator() @@ -2736,6 +2938,7 @@ def check_random_split_result(weights, seed=None): check_random_split_result([0.11111, 0.6666, 1.3]) +@pytest.mark.localtest def test_random_split_negative(session): df1 = session.range(10) @@ -2752,6 +2955,7 @@ def test_random_split_negative(session): assert "weights must be positive numbers" in str(ex_info) +@pytest.mark.localtest def test_to_df(session): df = session.create_dataframe( [[1], [3], [5], [7], [9]], diff --git a/tests/integ/scala/test_dataframe_writer_suite.py b/tests/integ/scala/test_dataframe_writer_suite.py index 03bd6e1c7c3..fb68108797b 100644 --- a/tests/integ/scala/test_dataframe_writer_suite.py +++ b/tests/integ/scala/test_dataframe_writer_suite.py @@ -19,9 +19,22 @@ from tests.utils import TestFiles, Utils +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_write_with_target_column_name_order(session): table_name = Utils.random_table_name() - Utils.create_table(session, table_name, "a int, b int", is_temporary=True) + session.create_dataframe( + [], + schema=StructType( + [ + StructField("a", IntegerType()), + StructField("b", IntegerType()), + ] + ), + ).write.save_as_table(table_name, table_type="temporary") try: df1 = session.create_dataframe([[1, 2]], schema=["b", "a"]) @@ -48,7 +61,7 @@ def test_write_with_target_column_name_order(session): df1.write.saveAsTable(table_name, mode="append", column_order="name") Utils.check_answer(session.table(table_name), [Row(1, 2)]) finally: - Utils.drop_table(session, table_name) + session.table(table_name).drop_table() # column name and table name with special characters special_table_name = '"test table name"' @@ -65,6 +78,9 @@ def test_write_with_target_column_name_order(session): Utils.drop_table(session, special_table_name) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="Testing SQL-only feature" +) def test_write_with_target_table_autoincrement( session, ): # Scala doesn't support this yet. @@ -82,9 +98,19 @@ def test_write_with_target_table_autoincrement( Utils.drop_table(session, table_name) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_negative_write_with_target_column_name_order(session): table_name = Utils.random_table_name() - Utils.create_table(session, table_name, "a int, b int", is_temporary=True) + session.create_dataframe( + [], + schema=StructType( + [StructField("a", IntegerType()), StructField("b", IntegerType())] + ), + ).write.save_as_table(table_name, table_type="temporary") try: df1 = session.create_dataframe([[1, 2]], schema=["a", "c"]) # The "columnOrder = name" needs the DataFrame has the same column name set @@ -104,14 +130,24 @@ def test_negative_write_with_target_column_name_order(session): table_type="temp", ) finally: - Utils.drop_table(session, table_name) + session.table(table_name).drop_table() +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_write_with_target_column_name_order_all_kinds_of_dataframes( session, resources_path ): table_name = Utils.random_table_name() - Utils.create_table(session, table_name, "a int, b int", is_temporary=True) + session.create_dataframe( + [], + schema=StructType( + [StructField("a", IntegerType()), StructField("b", IntegerType())] + ), + ).write.save_as_table(table_name, table_type="temporary") try: df1 = session.create_dataframe([[1, 2]], schema=["b", "a"]) # DataFrame.cache_result() @@ -140,7 +176,7 @@ def test_write_with_target_column_name_order_all_kinds_of_dataframes( for row in rows: assert row["B"] == 1 and row["A"] == 2 finally: - Utils.drop_table(session, table_name) + session.table(table_name).drop_table() # show tables # Create a DataFrame from SQL `show tables` and then filter on it not supported yet. Enable the following test after it's supported. diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index ab1a4d52fd4..2feddc4ac63 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -5,6 +5,8 @@ # Many of the tests have been moved to unit/scala/test_datattype_suite.py from decimal import Decimal +import pytest + from snowflake.snowpark import Row from snowflake.snowpark.functions import lit from snowflake.snowpark.types import ( @@ -33,6 +35,11 @@ from tests.utils import Utils +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_verify_datatypes_reference(session): schema = StructType( [ @@ -108,6 +115,7 @@ def test_verify_datatypes_reference(session): Utils.is_schema_same(df.schema, expected_schema, case_sensitive=False) +@pytest.mark.localtest def test_verify_datatypes_reference2(session): d1 = DecimalType(2, 1) d2 = DecimalType(2, 1) @@ -126,6 +134,11 @@ def test_verify_datatypes_reference2(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_dtypes(session): schema = StructType( [ diff --git a/tests/integ/scala/test_file_operation_suite.py b/tests/integ/scala/test_file_operation_suite.py index fa44388cdda..5aeb9faf4c8 100644 --- a/tests/integ/scala/test_file_operation_suite.py +++ b/tests/integ/scala/test_file_operation_suite.py @@ -16,6 +16,12 @@ ) from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) + def random_alphanumeric_name(): return "".join( @@ -75,16 +81,18 @@ def path4(temp_source_directory): @pytest.fixture(scope="module") -def temp_stage(session, resources_path): +def temp_stage(session, resources_path, local_testing_mode): tmp_stage_name = Utils.random_stage_name() test_files = TestFiles(resources_path) - Utils.create_stage(session, tmp_stage_name, is_temporary=True) + if not local_testing_mode: + Utils.create_stage(session, tmp_stage_name, is_temporary=True) Utils.upload_to_stage( session, tmp_stage_name, test_files.test_file_parquet, compress=False ) yield tmp_stage_name - Utils.drop_stage(session, tmp_stage_name) + if not local_testing_mode: + Utils.drop_stage(session, tmp_stage_name) def test_put_with_one_file(session, temp_stage, path1, path2, path3): diff --git a/tests/integ/scala/test_function_suite.py b/tests/integ/scala/test_function_suite.py index e1bce923eba..5d68f40569d 100644 --- a/tests/integ/scala/test_function_suite.py +++ b/tests/integ/scala/test_function_suite.py @@ -198,6 +198,7 @@ from tests.utils import IS_IN_STORED_PROC, TestData, Utils +@pytest.mark.localtest def test_col(session): test_data1 = TestData.test_data1(session) Utils.check_answer(test_data1.select(col("bool")), [Row(True), Row(False)]) @@ -208,6 +209,7 @@ def test_col(session): Utils.check_answer(test_data1.select(col("num")), [Row(1), Row(2)]) +@pytest.mark.localtest def test_lit(session): res = TestData.test_data1(session).select(lit(1)).collect() assert res == [Row(1), Row(1)] @@ -397,6 +399,7 @@ def test_variance(session): ) +@pytest.mark.localtest def test_coalesce(session): Utils.check_answer( TestData.null_data2(session).select(coalesce(col("A"), col("B"), col("C"))), @@ -459,6 +462,7 @@ def test_sqrt(session): ) +@pytest.mark.localtest def test_abs(session): Utils.check_answer( TestData.number2(session).select(abs(col("X"))), [Row(1), Row(0), Row(5)], False @@ -538,6 +542,7 @@ def test_builtin_function(session): ) +@pytest.mark.localtest def test_sub_string(session): Utils.check_answer( TestData.string1(session).select(substring(col("A"), lit(2), lit(4))), @@ -1051,6 +1056,7 @@ def test_split(session): ) +@pytest.mark.localtest def test_contains(session): Utils.check_answer( TestData.string4(session).select(contains(col("a"), lit("app"))), @@ -1065,6 +1071,7 @@ def test_contains(session): ) +@pytest.mark.localtest @pytest.mark.parametrize("col_a", ["a", col("a")]) def test_startswith(session, col_a): Utils.check_answer( @@ -1074,6 +1081,7 @@ def test_startswith(session, col_a): ) +@pytest.mark.localtest @pytest.mark.parametrize("col_a", ["a", col("a")]) def test_endswith(session, col_a): Utils.check_answer( @@ -1184,6 +1192,7 @@ def test_json_extract_path_text(session): ) +@pytest.mark.localtest def test_parse_json(session): null_json1 = TestData.null_json1(session) Utils.check_answer( @@ -2339,6 +2348,7 @@ def test_time_from_parts(session): ) +@pytest.mark.localtest def test_columns_from_timestamp_parts(): func_name = "test _columns_from_timestamp_parts" y, m, d = _columns_from_timestamp_parts(func_name, "year", "month", 8) @@ -2357,11 +2367,13 @@ def test_columns_from_timestamp_parts(): assert s._expression.value == 17 +@pytest.mark.localtest def test_columns_from_timestamp_parts_negative(): with pytest.raises(ValueError, match="Incorrect number of args passed"): _columns_from_timestamp_parts("neg test", "year", "month") +@pytest.mark.localtest def test_timestamp_from_parts_internal(): func_name = "test _timestamp_from_parts_internal" date_expr, time_expr = _timestamp_from_parts_internal(func_name, "date", "time") @@ -2418,6 +2430,7 @@ def test_timestamp_from_parts_internal(): assert s._expression.name == '"S"' +@pytest.mark.localtest def test_timestamp_from_parts_internal_negative(): func_name = "negative test" with pytest.raises(ValueError, match="expected 2 or 6 required arguments"): @@ -2884,7 +2897,8 @@ def test_approx_percentile_combine(session, col_a, col_b): ) -def test_iff(session): +@pytest.mark.localtest +def test_iff(session, local_testing_mode): df = session.create_dataframe( [(True, 2, 2, 4), (False, 12, 12, 14), (True, 22, 23, 24)], schema=["a", "b", "c", "d"], @@ -2900,12 +2914,13 @@ def test_iff(session): sort=False, ) - # accept sql expression - Utils.check_answer( - df.select("b", "c", "d", iff("b = c", col("b"), col("d"))), - [Row(2, 2, 4, 2), Row(12, 12, 14, 12), Row(22, 23, 24, 24)], - sort=False, - ) + if not local_testing_mode: + # accept sql expression + Utils.check_answer( + df.select("b", "c", "d", iff("b = c", col("b"), col("d"))), + [Row(2, 2, 4, 2), Row(12, 12, 14, 12), Row(22, 23, 24, 24)], + sort=False, + ) def test_cume_dist(session): @@ -2926,14 +2941,15 @@ def test_dense_rank(session): ) +@pytest.mark.localtest @pytest.mark.parametrize("col_z", ["Z", col("Z")]) -def test_lag(session, col_z): +def test_lag(session, col_z, local_testing_mode): Utils.check_answer( TestData.xyz(session).select( lag(col_z, 1, 0).over(Window.partition_by(col("X")).order_by(col("X"))) ), [Row(0), Row(10), Row(1), Row(0), Row(1)], - sort=False, + sort=local_testing_mode, ) Utils.check_answer( @@ -2941,7 +2957,7 @@ def test_lag(session, col_z): lag(col_z, 1).over(Window.partition_by(col("X")).order_by(col("X"))) ), [Row(None), Row(10), Row(1), Row(None), Row(1)], - sort=False, + sort=local_testing_mode, ) Utils.check_answer( @@ -2949,18 +2965,19 @@ def test_lag(session, col_z): lag(col_z).over(Window.partition_by(col("X")).order_by(col("X"))) ), [Row(None), Row(10), Row(1), Row(None), Row(1)], - sort=False, + sort=local_testing_mode, ) +@pytest.mark.localtest @pytest.mark.parametrize("col_z", ["Z", col("Z")]) -def test_lead(session, col_z): +def test_lead(session, col_z, local_testing_mode): Utils.check_answer( TestData.xyz(session).select( lead(col_z, 1, 0).over(Window.partition_by(col("X")).order_by(col("X"))) ), [Row(1), Row(3), Row(0), Row(3), Row(0)], - sort=False, + sort=local_testing_mode, ) Utils.check_answer( @@ -2968,7 +2985,7 @@ def test_lead(session, col_z): lead(col_z, 1).over(Window.partition_by(col("X")).order_by(col("X"))) ), [Row(1), Row(3), Row(None), Row(3), Row(None)], - sort=False, + sort=local_testing_mode, ) Utils.check_answer( @@ -2976,10 +2993,11 @@ def test_lead(session, col_z): lead(col_z).over(Window.partition_by(col("X")).order_by(col("X"))) ), [Row(1), Row(3), Row(None), Row(3), Row(None)], - sort=False, + sort=local_testing_mode, ) +@pytest.mark.localtest @pytest.mark.parametrize("col_z", ["Z", col("Z")]) def test_last_value(session, col_z): Utils.check_answer( @@ -2991,6 +3009,7 @@ def test_last_value(session, col_z): ) +@pytest.mark.localtest @pytest.mark.parametrize("col_z", ["Z", col("Z")]) def test_first_value(session, col_z): Utils.check_answer( diff --git a/tests/integ/scala/test_large_dataframe_suite.py b/tests/integ/scala/test_large_dataframe_suite.py index 80d4b71818f..46323e469cf 100644 --- a/tests/integ/scala/test_large_dataframe_suite.py +++ b/tests/integ/scala/test_large_dataframe_suite.py @@ -91,6 +91,9 @@ def test_limit_on_order_by(session, is_sample_data_available): assert int(e1[0]) < int(e2[0]) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="Testing SQL generation" +) @pytest.mark.parametrize("use_scoped_temp_objects", [True, False]) def test_create_dataframe_for_large_values_check_plan(session, use_scoped_temp_objects): origin_use_scoped_temp_objects_setting = session._use_scoped_temp_objects @@ -120,6 +123,7 @@ def check_plan(df, data): session._use_scoped_temp_objects = origin_use_scoped_temp_objects_setting +@pytest.mark.localtest def test_create_dataframe_for_large_values_basic_types(session): schema = StructType( [ @@ -182,6 +186,7 @@ def test_create_dataframe_for_large_values_basic_types(session): assert df.sort("id").collect() == large_data +# TODO: enable for local testing after emulating sf data types def test_create_dataframe_for_large_values_array_map_variant(session): schema = StructType( [ diff --git a/tests/integ/scala/test_literal_suite.py b/tests/integ/scala/test_literal_suite.py index d8eab11cae3..f20990021cf 100644 --- a/tests/integ/scala/test_literal_suite.py +++ b/tests/integ/scala/test_literal_suite.py @@ -6,6 +6,8 @@ import json from decimal import Decimal +import pytest + from snowflake.snowpark import Column, Row from snowflake.snowpark._internal.analyzer.expression import Literal from snowflake.snowpark._internal.utils import PythonObjJSONEncoder @@ -22,6 +24,10 @@ ) from tests.utils import Utils +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", raises=NotImplementedError +) + def test_literal_basic_types(session): df = ( diff --git a/tests/integ/scala/test_permanent_udf_suite.py b/tests/integ/scala/test_permanent_udf_suite.py index bb77a94284a..f55b6975a4a 100644 --- a/tests/integ/scala/test_permanent_udf_suite.py +++ b/tests/integ/scala/test_permanent_udf_suite.py @@ -19,7 +19,12 @@ from snowflake.snowpark.functions import call_udf, col from tests.utils import TempObjectType, TestFiles, Utils -pytestmark = pytest.mark.udf +pytestmark = [ + pytest.mark.udf, + pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + ), +] @pytest.fixture(scope="module") diff --git a/tests/integ/scala/test_query_tag_suite.py b/tests/integ/scala/test_query_tag_suite.py index 36c7478d0ac..bdb9beb6833 100644 --- a/tests/integ/scala/test_query_tag_suite.py +++ b/tests/integ/scala/test_query_tag_suite.py @@ -13,6 +13,10 @@ from snowflake.snowpark._internal.utils import TempObjectType from tests.utils import Utils +pytestmark = pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="usage of sql" +) + @pytest.mark.parametrize( "query_tag", diff --git a/tests/integ/scala/test_result_attributes_suite.py b/tests/integ/scala/test_result_attributes_suite.py index 4798ada8394..78b7e607874 100644 --- a/tests/integ/scala/test_result_attributes_suite.py +++ b/tests/integ/scala/test_result_attributes_suite.py @@ -25,6 +25,11 @@ ) from tests.utils import IS_IN_STORED_PROC, Utils +pytestmark = pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Testing session._get_result_attributes", +) + def get_table_attributes(session: Session, name: str) -> List[Attribute]: return session._get_result_attributes(f"select * from {name}") diff --git a/tests/integ/scala/test_result_schema_suite.py b/tests/integ/scala/test_result_schema_suite.py index 29fd66bddec..26599d7e607 100644 --- a/tests/integ/scala/test_result_schema_suite.py +++ b/tests/integ/scala/test_result_schema_suite.py @@ -22,6 +22,11 @@ tmp_full_types_table_name2 = Utils.random_name_for_temp_object(TempObjectType.TABLE) +pytestmark = pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="usage of sql" +) + + @pytest.fixture(scope="module", autouse=True) def setup(session): Utils.create_stage(session, tmp_stage_name, is_temporary=True) diff --git a/tests/integ/scala/test_session_suite.py b/tests/integ/scala/test_session_suite.py index 8848e6bb364..47d745b272a 100644 --- a/tests/integ/scala/test_session_suite.py +++ b/tests/integ/scala/test_session_suite.py @@ -26,6 +26,10 @@ from tests.utils import IS_IN_STORED_PROC, IS_IN_STORED_PROC_LOCALFS, Utils +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Testing session parameters", +) @pytest.mark.skipif( IS_IN_STORED_PROC, reason="creating new session is not allowed in stored proc" ) @@ -43,6 +47,10 @@ def test_invalid_configs(session, db_parameters): assert "Incorrect username or password was specified" in str(ex_info) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Testing database specific operations", +) @pytest.mark.skipif(IS_IN_STORED_PROC, reason="db_parameters is not available") def test_current_database_and_schema(session, db_parameters): database = quote_name(db_parameters["database"]) @@ -64,6 +72,7 @@ def test_current_database_and_schema(session, db_parameters): session._run_query(f"use schema {schema}") +@pytest.mark.localtest def test_quote_all_database_and_schema_names(session): def is_quoted(name: str) -> bool: return name[0] == '"' and name[-1] == '"' @@ -72,6 +81,7 @@ def is_quoted(name: str) -> bool: assert is_quoted(session.get_current_schema()) +@pytest.mark.localtest def test_create_dataframe_sequence(session): df = session.create_dataframe([[1, "one", 1.0], [2, "two", 2.0]]) assert [field.name for field in df.schema.fields] == ["_1", "_2", "_3"] @@ -87,6 +97,7 @@ def test_create_dataframe_sequence(session): assert df.collect() == [Row("one"), Row("two")] +@pytest.mark.localtest def test_create_dataframe_namedtuple(session): class P1(NamedTuple): a: int @@ -101,6 +112,10 @@ class P1(NamedTuple): # and the public role has the privilege to access the current database and # schema of the current role @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Not enough privilege to run this test") +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Testing database specific operations", +) def test_get_schema_database_works_after_use_role(session): current_role = session._conn._get_string_datum("select current_role()") try: @@ -116,6 +131,10 @@ def test_get_schema_database_works_after_use_role(session): @pytest.mark.skipif( IS_IN_STORED_PROC, reason="creating new session is not allowed in stored proc" ) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Testing database specific operations", +) def test_negative_test_for_missing_required_parameter_schema( db_parameters, sql_simplifier_enabled ): @@ -130,12 +149,17 @@ def test_negative_test_for_missing_required_parameter_schema( @pytest.mark.skipif(IS_IN_STORED_PROC, reason="client is regression test specific") +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Testing database specific operations", +) def test_select_current_client(session): current_client = session.sql("select current_client()")._show_string(10) assert get_application_name() in current_client assert get_version() in current_client +@pytest.mark.localtest def test_negative_test_to_invalid_table_name(session): with pytest.raises(SnowparkInvalidObjectNameException) as ex_info: session.table("negative.test.invalid.table.name") @@ -144,7 +168,8 @@ def test_negative_test_to_invalid_table_name(session): ) -def test_create_dataframe_from_seq_none(session): +@pytest.mark.localtest +def test_create_dataframe_from_seq_none(session, local_testing_mode): assert session.create_dataframe([None, 1]).to_df("int").collect() == [ Row(None), Row(1), @@ -155,6 +180,7 @@ def test_create_dataframe_from_seq_none(session): ] +# should be enabled after emulating snowflake types def test_create_dataframe_from_array(session): data = [Row(1, "a"), Row(2, "b")] schema = StructType( @@ -173,6 +199,10 @@ def test_create_dataframe_from_array(session): @pytest.mark.skipif( IS_IN_STORED_PROC, reason="creating new session is not allowed in stored proc" ) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Testing database specific operations", +) def test_dataframe_created_before_session_close_are_not_usable_after_closing_session( session, db_parameters ): @@ -190,18 +220,19 @@ def test_dataframe_created_before_session_close_are_not_usable_after_closing_ses assert ex_info.value.error_code == "1404" +@pytest.mark.localtest def test_load_table_from_array_multipart_identifier(session): name = Utils.random_name_for_temp_object(TempObjectType.TABLE) - try: - Utils.create_table(session, name, "col int") - db = session.get_current_database() - sc = session.get_current_schema() - multipart = [db, sc, name] - assert len(session.table(multipart).schema.fields) == 1 - finally: - Utils.drop_table(session, name) + session.create_dataframe( + [], schema=StructType([StructField("col", IntegerType())]) + ).write.save_as_table(name, table_type="temporary") + db = session.get_current_database() + sc = session.get_current_schema() + multipart = [db, sc, name] + assert len(session.table(multipart).schema.fields) == 1 +@pytest.mark.localtest def test_session_info(session): session_info = session._session_info assert get_version() in session_info @@ -213,6 +244,10 @@ def test_session_info(session): @pytest.mark.skipif( IS_IN_STORED_PROC, reason="creating new session is not allowed in stored proc" ) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Testing database specific operations", +) def test_dataframe_close_session(session, db_parameters): new_session = Session.builder.configs(db_parameters).create() new_session.sql_simplifier_enabled = session.sql_simplifier_enabled @@ -232,6 +267,10 @@ def test_dataframe_close_session(session, db_parameters): @pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="Large result") +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Testing database specific operations", +) def test_create_temp_table_no_commit(session): # test large local relation session.sql("begin").collect() diff --git a/tests/integ/scala/test_snowflake_plan_suite.py b/tests/integ/scala/test_snowflake_plan_suite.py index 1490c4c5bc0..873f6227f90 100644 --- a/tests/integ/scala/test_snowflake_plan_suite.py +++ b/tests/integ/scala/test_snowflake_plan_suite.py @@ -2,6 +2,8 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +import pytest + from snowflake.snowpark import Row from snowflake.snowpark._internal.analyzer.analyzer_utils import schema_value_statement from snowflake.snowpark._internal.analyzer.expression import Attribute @@ -11,6 +13,11 @@ from snowflake.snowpark.types import IntegerType, LongType from tests.utils import Utils +pytestmark = pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="build plan not suitable for local testing", +) + def test_single_query(session): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) diff --git a/tests/integ/scala/test_sql_suite.py b/tests/integ/scala/test_sql_suite.py index aa77c53435b..2e6206e1936 100644 --- a/tests/integ/scala/test_sql_suite.py +++ b/tests/integ/scala/test_sql_suite.py @@ -13,6 +13,12 @@ from snowflake.snowpark.types import LongType, StructField, StructType from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) + @pytest.mark.skipif( IS_IN_STORED_PROC, diff --git a/tests/integ/scala/test_table_function_suite.py b/tests/integ/scala/test_table_function_suite.py index 05f336118c5..bd191426976 100644 --- a/tests/integ/scala/test_table_function_suite.py +++ b/tests/integ/scala/test_table_function_suite.py @@ -2,12 +2,18 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +import pytest + from snowflake.snowpark import Row from snowflake.snowpark._internal.utils import TempObjectType from snowflake.snowpark.functions import array_agg, col, lit, parse_json from snowflake.snowpark.types import StructField, StructType, VariantType from tests.utils import Utils +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", raises=NotImplementedError +) + def test_dataframe_join_table_function(session): df = session.create_dataframe(["[1,2]", "[3,4]"], schema=["a"]) diff --git a/tests/integ/scala/test_table_suite.py b/tests/integ/scala/test_table_suite.py index 3298de1b3a4..5fe4b6deba8 100644 --- a/tests/integ/scala/test_table_suite.py +++ b/tests/integ/scala/test_table_suite.py @@ -14,20 +14,29 @@ ArrayType, GeographyType, GeometryType, + IntegerType, MapType, StringType, + StructField, + StructType, + TimeType, VariantType, ) from tests.utils import Utils @pytest.fixture(scope="function") -def table_name_1(session: Session): +def table_name_1(session: Session, local_testing_mode: bool): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) - Utils.create_table(session, table_name, "num int") - session._run_query(f"insert into {table_name} values (1), (2), (3)") + if not local_testing_mode: + Utils.create_table(session, table_name, "num int") + session._run_query(f"insert into {table_name} values (1), (2), (3)") + else: + session.create_dataframe( + [[1], [2], [3]], schema=StructType([StructField("num", IntegerType())]) + ).write.save_as_table(table_name) yield table_name - Utils.drop_table(session, table_name) + session.table(table_name).drop_table() @pytest.fixture(scope="function") @@ -36,7 +45,7 @@ def table_name_4(session: Session): Utils.create_table(session, table_name, "num int") session._run_query(f"insert into {table_name} values (1), (2), (3)") yield table_name - Utils.drop_table(session, table_name) + session.table(table_name).drop_table() @pytest.fixture(scope="function") @@ -54,31 +63,43 @@ def semi_structured_table(session: Session): ) session._run_query(query) yield table_name - Utils.drop_table(session, table_name) + session.table(table_name).drop_table() @pytest.fixture(scope="function") -def temp_table_name(session: Session, temp_schema: str): +def temp_table_name(session: Session, temp_schema: str, local_testing_mode: bool): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) table_name_with_schema = f"{temp_schema}.{table_name}" - Utils.create_table(session, table_name_with_schema, "str string") - session._run_query(f"insert into {table_name_with_schema} values ('abc')") + if not local_testing_mode: + Utils.create_table(session, table_name_with_schema, "str string") + session._run_query(f"insert into {table_name_with_schema} values ('abc')") + else: + session.create_dataframe( + [["abc"]], schema=StructType([StructField("str", StringType())]) + ).write.saveAsTable(table_name_with_schema) yield table_name - Utils.drop_table(session, table_name_with_schema) + session.table(table_name_with_schema).drop_table() @pytest.fixture(scope="function") -def table_with_time(session: Session): +def table_with_time(session: Session, local_testing_mode: bool): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) - Utils.create_table(session, table_name, "time time") - session._run_query( - f"insert into {table_name} select to_time(a) from values('09:15:29')," - f"('09:15:29.99999999') as T(a)" - ) + if not local_testing_mode: + Utils.create_table(session, table_name, "time time") + session._run_query( + f"insert into {table_name} select to_time(a) from values('09:15:29')," + f"('09:15:29.99999999') as T(a)" + ) + else: + session.create_dataframe( + [[datetime.time(9, 15, 29)], [datetime.time(9, 15, 29, 999999)]], + schema=StructType([StructField("time", TimeType())]), + ).write.saveAsTable(table_name) yield table_name - Utils.drop_table(session, table_name) + session.table(table_name).drop_table() +@pytest.mark.localtest def test_read_snowflake_table(session, table_name_1): df = session.table(table_name_1) Utils.check_answer(df, [Row(1), Row(2), Row(3)]) @@ -94,7 +115,8 @@ def test_read_snowflake_table(session, table_name_1): Utils.check_answer(df3, [Row(1), Row(2), Row(3)]) -def test_save_as_snowflake_table(session, table_name_1): +@pytest.mark.localtest +def test_save_as_snowflake_table(session, table_name_1, local_testing_mode): df = session.table(table_name_1) assert df.collect() == [Row(1), Row(2), Row(3)] table_name_2 = Utils.random_name_for_temp_object(TempObjectType.TABLE) @@ -131,23 +153,12 @@ def test_save_as_snowflake_table(session, table_name_1): with pytest.raises(SnowparkSQLException): df.write.mode("errorifexists").save_as_table(table_name_2) finally: - Utils.drop_table(session, table_name_2) - Utils.drop_table(session, table_name_3) + session.table(table_name_2).drop_table() + session.table(table_name_3).drop_table() -@pytest.mark.skip( - "Python doesn't have non-string argument for mode. Scala has this test but python doesn't need to." -) -def test_save_as_snowflake_table_string_argument(table_name_4): - """ - Scala's `DataFrameWriter.mode()` accepts both enum values of SaveMode and str. - It's conventional that python uses str. - `test_save_as_snowflake_table` already tests the string argument. This test will be the same as - `test_save_as_snowflake_table` if ported from Scala so it's omitted. - """ - - -def test_multipart_identifier(session, table_name_1): +@pytest.mark.localtest +def test_multipart_identifier(session, table_name_1, local_testing_mode): name1 = table_name_1 name2 = session.get_current_schema() + "." + name1 name3 = session.get_current_database() + "." + name2 @@ -165,37 +176,41 @@ def test_multipart_identifier(session, table_name_1): try: assert session.table(name4).collect() == expected finally: - Utils.drop_table(session, name4) - + session.table(name4).drop_table() session.table(name1).write.mode("Overwrite").save_as_table(name5) try: assert session.table(name4).collect() == expected finally: - Utils.drop_table(session, name5) + session.table(name5).drop_table() session.table(name1).write.mode("Overwrite").save_as_table(name6) try: assert session.table(name6).collect() == expected finally: - Utils.drop_table(session, name5) + session.table(name6).drop_table() -def test_write_table_to_different_schema(session, temp_schema, table_name_1): +@pytest.mark.localtest +def test_write_table_to_different_schema( + session, temp_schema, table_name_1, local_testing_mode +): name1 = table_name_1 name2 = temp_schema + "." + name1 session.table(name1).write.save_as_table(name2) try: assert session.table(name2).collect() == [Row(1), Row(2), Row(3)] finally: - Utils.drop_table(session, name2) + session.table(name2).drop_table() +@pytest.mark.localtest def test_read_from_different_schema(session, temp_schema, temp_table_name): table_from_different_schema = f"{temp_schema}.{temp_table_name}" df = session.table(table_from_different_schema) Utils.check_answer(df, [Row("abc")]) +@pytest.mark.localtest def test_quotes_upper_and_lower_case_name(session, table_name_1): tested_table_names = [ '"' + table_name_1 + '"', @@ -206,6 +221,7 @@ def test_quotes_upper_and_lower_case_name(session, table_name_1): Utils.check_answer(session.table(table_name), [Row(1), Row(2), Row(3)]) +# TODO: enable for local testing after emulating snowflake data types def test_table_with_semi_structured_types(session, semi_structured_table): df = session.table(semi_structured_table) types = [s.datatype for s in df.schema.fields] @@ -240,6 +256,7 @@ def test_table_with_semi_structured_types(session, semi_structured_table): ) +# TODO: enable for local testing. Emulate data type def test_table_with_time_type(session, table_with_time): df = session.table(table_with_time) # snowflake time has accuracy to 0.99999999. Python has accuracy to 0.999999. @@ -250,7 +267,8 @@ def test_table_with_time_type(session, table_with_time): ) -def test_consistent_table_name_behaviors(session): +@pytest.mark.localtest +def test_consistent_table_name_behaviors(session, local_testing_mode): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) db = session.get_current_database() sc = session.get_current_schema() @@ -267,11 +285,11 @@ def test_consistent_table_name_behaviors(session): for tn in table_names: Utils.check_answer(session.table(tn), [Row(1), Row(2), Row(3)]) finally: - Utils.drop_table(session, table_name) + session.table(table_name).drop_table() for tn in table_names: df.write.mode("Overwrite").save_as_table(tn) try: Utils.check_answer(session.table(table_name), [Row(1), Row(2), Row(3)]) finally: - Utils.drop_table(session, table_name) + session.table(table_name).drop_table() diff --git a/tests/integ/scala/test_udf_suite.py b/tests/integ/scala/test_udf_suite.py index d1d97ccf5ac..8cb93752379 100644 --- a/tests/integ/scala/test_udf_suite.py +++ b/tests/integ/scala/test_udf_suite.py @@ -35,7 +35,12 @@ ) from tests.utils import TestData, TestFiles, Utils -pytestmark = pytest.mark.udf +pytestmark = [ + pytest.mark.udf, + pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", raises=NotImplementedError + ), +] tmp_stage_name = Utils.random_stage_name() tmp_table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) diff --git a/tests/integ/scala/test_udtf_suite.py b/tests/integ/scala/test_udtf_suite.py index b54f11d6c4e..48eaa517c8a 100644 --- a/tests/integ/scala/test_udtf_suite.py +++ b/tests/integ/scala/test_udtf_suite.py @@ -37,6 +37,12 @@ wordcount_table_name = Utils.random_table_name() +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) + @pytest.fixture(scope="module", autouse=True) def setup_data(session): diff --git a/tests/integ/scala/test_update_delete_merge_suite.py b/tests/integ/scala/test_update_delete_merge_suite.py index 916eac0771d..e803c1b550f 100644 --- a/tests/integ/scala/test_update_delete_merge_suite.py +++ b/tests/integ/scala/test_update_delete_merge_suite.py @@ -34,6 +34,11 @@ table_name3 = Utils.random_name_for_temp_object(TempObjectType.TABLE) +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", raises=NotImplementedError +) + + def test_update_rows_in_table(session): TestData.test_data2(session).write.save_as_table( table_name, mode="overwrite", table_type="temporary" diff --git a/tests/integ/scala/test_view_suite.py b/tests/integ/scala/test_view_suite.py index 32a8c2c5430..f14ebf39fb8 100644 --- a/tests/integ/scala/test_view_suite.py +++ b/tests/integ/scala/test_view_suite.py @@ -11,22 +11,29 @@ from snowflake.snowpark._internal.utils import TempObjectType, quote_name from snowflake.snowpark.exceptions import SnowparkCreateViewException from snowflake.snowpark.functions import col, sql_expr, sum -from snowflake.snowpark.types import LongType +from snowflake.snowpark.types import ( + DecimalType, + IntegerType, + LongType, + StructField, + StructType, +) from tests.utils import TestData, Utils -def test_create_view(session): +@pytest.mark.localtest +def test_create_view(session, local_testing_mode): view_name = Utils.random_name_for_temp_object(TempObjectType.VIEW) try: TestData.integer1(session).create_or_replace_view(view_name) - res = session.sql(f"select * from {view_name}").collect() + res = session.table(view_name).collect() # don't sort assert res == [Row(1), Row(2), Row(3)] # Test replace TestData.double1(session).create_or_replace_view(view_name) - res = session.sql(f"select * from {view_name}").collect() + res = session.table(view_name).collect() # don't sort assert res == [ Row(Decimal("1.111")), @@ -34,21 +41,27 @@ def test_create_view(session): Row(Decimal("3.333")), ] finally: - Utils.drop_view(session, view_name) + if not local_testing_mode: + Utils.drop_view(session, view_name) -def test_view_name_with_special_character(session): +@pytest.mark.localtest +def test_view_name_with_special_character(session, local_testing_mode): view_name = Utils.random_name_for_temp_object(TempObjectType.VIEW) try: TestData.column_has_special_char(session).create_or_replace_view(view_name) - res = session.sql(f"select * from {quote_name(view_name)}").collect() + res = session.table(quote_name(view_name)).collect() # don't sort assert res == [Row(1, 2), Row(3, 4)] finally: - Utils.drop_view(session, view_name) + if not local_testing_mode: + Utils.drop_view(session, view_name) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="sql is not supported" +) def test_view_with_with_sql_statement(session): view_name = Utils.random_name_for_temp_object(TempObjectType.VIEW) try: @@ -62,13 +75,17 @@ def test_view_with_with_sql_statement(session): Utils.drop_view(session, view_name) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="sql use is not supported" +) def test_only_works_on_select(session): view_name = Utils.random_name_for_temp_object(TempObjectType.VIEW) with pytest.raises(SnowparkCreateViewException): session.sql("show tables").create_or_replace_view(view_name) -def test_consistent_view_name_behaviors(session): +@pytest.mark.localtest +def test_consistent_view_name_behaviors(session, local_testing_mode): view_name = Utils.random_name_for_temp_object(TempObjectType.VIEW) sc = session.get_current_schema() db = session.get_current_database() @@ -82,82 +99,80 @@ def test_consistent_view_name_behaviors(session): res = session.table(view_name).collect() res.sort(key=lambda x: x[0]) assert res == [Row(1), Row(2), Row(3)] - Utils.drop_view(session, view_name) df.create_or_replace_view(name_parts) res = session.table(view_name).collect() res.sort(key=lambda x: x[0]) assert res == [Row(1), Row(2), Row(3)] - Utils.drop_view(session, view_name) df.create_or_replace_view([sc, view_name]) res = session.table(view_name).collect() res.sort(key=lambda x: x[0]) assert res == [Row(1), Row(2), Row(3)] - Utils.drop_view(session, view_name) df.create_or_replace_view([view_name]) res = session.table(view_name).collect() res.sort(key=lambda x: x[0]) assert res == [Row(1), Row(2), Row(3)] - Utils.drop_view(session, view_name) df.create_or_replace_view(f"{db}.{sc}.{view_name}") res = session.table(view_name).collect() res.sort(key=lambda x: x[0]) assert res == [Row(1), Row(2), Row(3)] - Utils.drop_view(session, view_name) # create temp view df.create_or_replace_temp_view(view_name) res = session.table(view_name).collect() res.sort(key=lambda x: x[0]) assert res == [Row(1), Row(2), Row(3)] - Utils.drop_view(session, view_name) df.create_or_replace_temp_view(name_parts) res = session.table(view_name).collect() res.sort(key=lambda x: x[0]) assert res == [Row(1), Row(2), Row(3)] - Utils.drop_view(session, view_name) df.create_or_replace_temp_view([sc, view_name]) res = session.table(view_name).collect() res.sort(key=lambda x: x[0]) assert res == [Row(1), Row(2), Row(3)] - Utils.drop_view(session, view_name) df.create_or_replace_temp_view([view_name]) res = session.table(view_name).collect() res.sort(key=lambda x: x[0]) assert res == [Row(1), Row(2), Row(3)] - Utils.drop_view(session, view_name) df.create_or_replace_temp_view(f"{db}.{sc}.{view_name}") res = session.table(view_name).collect() res.sort(key=lambda x: x[0]) assert res == [Row(1), Row(2), Row(3)] - Utils.drop_view(session, view_name) finally: - Utils.drop_view(session, view_name) + if not local_testing_mode: + Utils.drop_view(session, view_name) -def test_create_temp_view_on_functions(session): +@pytest.mark.localtest +def test_create_temp_view_on_functions(session, local_testing_mode): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) view_name = Utils.random_name_for_temp_object(TempObjectType.VIEW) try: - Utils.create_table(session, table_name, "id int, val int") + session.create_dataframe( + [], + schema=StructType( + [StructField("id", IntegerType()), StructField("val", IntegerType())] + ), + ).write.save_as_table(table_name) t = session.table(table_name) - a = t.group_by(col("id")).agg(sql_expr("max(val)")) - a.create_or_replace_temp_view(view_name) - schema = session.table(view_name).schema - assert len(schema.fields) == 2 - assert schema.fields[0].datatype == LongType() - assert schema.fields[0].name == "ID" - assert schema.fields[1].datatype == LongType() - assert schema.fields[1].name == '"MAX(VAL)"' + if not local_testing_mode: # Use of sql_expr is not supported in local testing + a = t.group_by(col("id")).agg(sql_expr("max(val)")) + a.create_or_replace_temp_view(view_name) + schema = session.table(view_name).schema + assert len(schema.fields) == 2 + assert schema.fields[0].datatype == LongType() + assert schema.fields[0].name == "ID" + assert schema.fields[1].datatype == LongType() + assert schema.fields[1].name == '"MAX(VAL)"' a2 = t.group_by(col("id")).agg(sum(col("val"))) a2.create_or_replace_temp_view(view_name) @@ -171,12 +186,13 @@ def test_create_temp_view_on_functions(session): a3 = t.group_by(col("id")).agg(sum(col("val")) + 1) a3.create_or_replace_temp_view(view_name) schema2 = session.table(view_name).schema - assert len(schema.fields) == 2 + assert len(schema2.fields) == 2 assert schema2.fields[0].datatype == LongType() assert schema2.fields[0].name == "ID" - assert schema2.fields[1].datatype == LongType() + assert schema2.fields[1].datatype in (LongType(), DecimalType(38, 0)) assert schema2.fields[1].name == '"ADD(SUM(VAL), LITERAL())"' finally: Utils.drop_table(session, table_name) - Utils.drop_view(session, view_name) + if not local_testing_mode: + Utils.drop_view(session, view_name) diff --git a/tests/integ/scala/test_window_frame_suite.py b/tests/integ/scala/test_window_frame_suite.py index 4dc15497b5d..5cdef539afe 100644 --- a/tests/integ/scala/test_window_frame_suite.py +++ b/tests/integ/scala/test_window_frame_suite.py @@ -24,6 +24,7 @@ from tests.utils import Utils +@pytest.mark.localtest def test_lead_lag_with_positive_offset(session): df = session.create_dataframe( [(1, "1"), (2, "2"), (1, "3"), (2, "4")], schema=["key", "value"] @@ -35,6 +36,7 @@ def test_lead_lag_with_positive_offset(session): ) +@pytest.mark.localtest def test_reverse_lead_lag_with_positive_offset(session): df = session.create_dataframe( [(1, "1"), (2, "2"), (1, "3"), (2, "4")], schema=["key", "value"] @@ -46,6 +48,7 @@ def test_reverse_lead_lag_with_positive_offset(session): ) +@pytest.mark.localtest def test_lead_lag_with_negative_offset(session): df = session.create_dataframe( [(1, "1"), (2, "2"), (1, "3"), (2, "4")], schema=["key", "value"] @@ -57,6 +60,7 @@ def test_lead_lag_with_negative_offset(session): ) +@pytest.mark.localtest def test_reverse_lead_lag_with_negative_offset(session): df = session.create_dataframe( [(1, "1"), (2, "2"), (1, "3"), (2, "4")], schema=["key", "value"] @@ -68,6 +72,7 @@ def test_reverse_lead_lag_with_negative_offset(session): ) +@pytest.mark.localtest @pytest.mark.parametrize("default", [None, "10"]) def test_lead_lag_with_default_value(session, default): df = session.create_dataframe( @@ -92,6 +97,7 @@ def test_lead_lag_with_default_value(session, default): ) +@pytest.mark.localtest def test_lead_lag_with_ignore_or_respect_nulls(session): df = session.create_dataframe( [(1, 5), (2, 4), (3, None), (4, 2), (5, None), (6, None), (7, 6)], @@ -118,6 +124,7 @@ def test_lead_lag_with_ignore_or_respect_nulls(session): ) +@pytest.mark.localtest def test_first_last_value_with_ignore_or_respect_nulls(session): df = session.create_dataframe( [(1, None), (2, 4), (3, None), (4, 2), (5, None), (6, 6), (7, None)], @@ -144,6 +151,7 @@ def test_first_last_value_with_ignore_or_respect_nulls(session): ) +@pytest.mark.localtest def test_unbounded_rows_range_between_with_aggregation(session): df = session.create_dataframe( [("one", 1), ("two", 2), ("one", 3), ("two", 4)] @@ -167,11 +175,20 @@ def test_unbounded_rows_range_between_with_aggregation(session): ) +@pytest.mark.localtest def test_rows_between_boundary(session): # This test is different from scala as `int` in Python is unbounded df = session.create_dataframe( - [(1, "1"), (1, "1"), (sys.maxsize, "1"), (3, "2"), (2, "1"), (sys.maxsize, "2")] + [ + (1, "1"), + (1, "1"), + (sys.maxsize, "1"), + (3, "2"), + (2, "1"), + (sys.maxsize, "2"), + ] ).to_df("key", "value") + Utils.check_answer( df.select( "key", @@ -226,8 +243,9 @@ def test_rows_between_boundary(session): ) -def test_range_between_should_accept_at_most_one_order_by_expression_when_unbounded( - session, +@pytest.mark.localtest +def test_range_between_should_accept_at_most_one_order_by_expression_when_bounded( + session, local_testing_mode ): df = session.create_dataframe([(1, 1)]).to_df("key", "value") window = Window.order_by("key", "value") @@ -247,20 +265,30 @@ def test_range_between_should_accept_at_most_one_order_by_expression_when_unboun df.select( min_("key").over(window.range_between(Window.unboundedPreceding, 1)) ).collect() - assert "Cumulative window frame unsupported for function MIN" in str(ex_info) + if not local_testing_mode: + assert "Cumulative window frame unsupported for function MIN" in str( + ex_info + ) with pytest.raises(SnowparkSQLException) as ex_info: df.select( min_("key").over(window.range_between(-1, Window.unboundedFollowing)) ).collect() - assert "Cumulative window frame unsupported for function MIN" in str(ex_info) + if not local_testing_mode: + assert "Cumulative window frame unsupported for function MIN" in str( + ex_info + ) with pytest.raises(SnowparkSQLException) as ex_info: df.select(min_("key").over(window.range_between(-1, 1))).collect() - assert "Sliding window frame unsupported for function MIN" in str(ex_info) + if not local_testing_mode: + assert "Sliding window frame unsupported for function MIN" in str(ex_info) -def test_range_between_should_accept_numeric_values_only_when_bounded(session): +@pytest.mark.localtest +def test_range_between_should_accept_non_numeric_values_only_when_unbounded( + session, local_testing_mode +): df = session.create_dataframe(["non_numeric"]).to_df("value") window = Window.order_by("value") Utils.check_answer( @@ -279,19 +307,28 @@ def test_range_between_should_accept_numeric_values_only_when_bounded(session): df.select( min_("value").over(window.range_between(Window.unboundedPreceding, 1)) ).collect() - assert "Cumulative window frame unsupported for function MIN" in str(ex_info) + if not local_testing_mode: + assert "Cumulative window frame unsupported for function MIN" in str( + ex_info + ) with pytest.raises(SnowparkSQLException) as ex_info: df.select( min_("value").over(window.range_between(-1, Window.unboundedFollowing)) ).collect() - assert "Cumulative window frame unsupported for function MIN" in str(ex_info) + if not local_testing_mode: + assert "Cumulative window frame unsupported for function MIN" in str( + ex_info + ) with pytest.raises(SnowparkSQLException) as ex_info: df.select(min_("value").over(window.range_between(-1, 1))).collect() - assert "Sliding window frame unsupported for function MIN" in str(ex_info) + if not local_testing_mode: + assert "Sliding window frame unsupported for function MIN" in str(ex_info) +# [Local Testing PuPr] TODO: enable for local testing when we align precision. +# In avg, the output column has 3 more decimal digits than NUMBER(38, 0) def test_sliding_rows_between_with_aggregation(session): df = session.create_dataframe( [(1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")] @@ -309,10 +346,13 @@ def test_sliding_rows_between_with_aggregation(session): ) +# [Local Testing PuPr] TODO: enable for local testing when we align precision. +# In avg, the output column has 3 more decimal digits than NUMBER(38, 0) def test_reverse_sliding_rows_between_with_aggregation(session): df = session.create_dataframe( [(1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")] ).to_df("key", "value") + window = ( Window.partition_by("value").order_by(col("key").desc()).rows_between(-1, 2) ) @@ -326,3 +366,25 @@ def test_reverse_sliding_rows_between_with_aggregation(session): Row(2, Decimal("2.000")), ], ) + + +@pytest.mark.localtest +def test_range_between_should_include_rows_equal_to_current_row(session): + df1 = session.create_dataframe( + [("b", 10), ("a", 10), ("a", 10), ("d", 15), ("e", 20), ("f", 20)], + schema=["c1", "c2"], + ) + win = Window.order_by(col("c2").asc(), col("c1").desc()).range_between( + -sys.maxsize, 0 + ) + Utils.check_answer( + df1.select(col("c1"), col("c2"), (sum_(col("c2")).over(win)).alias("win_sum")), + [ + Row(C1="b", C2=10, WIN_SUM=10), + Row(C1="a", C2=10, WIN_SUM=30), + Row(C1="a", C2=10, WIN_SUM=30), + Row(C1="d", C2=15, WIN_SUM=45), + Row(C1="e", C2=20, WIN_SUM=85), + Row(C1="f", C2=20, WIN_SUM=65), + ], + ) diff --git a/tests/integ/scala/test_window_spec_suite.py b/tests/integ/scala/test_window_spec_suite.py index 713c621c502..5973c0966bb 100644 --- a/tests/integ/scala/test_window_spec_suite.py +++ b/tests/integ/scala/test_window_spec_suite.py @@ -42,7 +42,9 @@ from tests.utils import TestData, Utils -def test_partition_by_order_by_rows_between(session): +# [Local Testing PuPr] TODO: enable for local testing when we align precision. +# In avg, the output column has 3 more decimal digits than NUMBER(38, 0) +def test_partition_by_order_by_rows_between(session, local_testing_mode): df = session.create_dataframe( [(1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")] ).to_df("key", "value") @@ -68,10 +70,11 @@ def test_partition_by_order_by_rows_between(session): Row(1, Decimal("1.666")), Row(1, Decimal("1.333")), ], - sort=False, + sort=local_testing_mode, ) +@pytest.mark.localtest def test_range_between(session): df = session.create_dataframe(["non_numeric"]).to_df("value") window = Window.order_by("value") @@ -95,6 +98,7 @@ def test_range_between(session): ) +# [Local Testing GA] TODO: enable for local testing def test_window_function_with_aggregates(session): df = session.create_dataframe( [("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)] @@ -108,6 +112,7 @@ def test_window_function_with_aggregates(session): ) +# [Local Testing GA] TODO: Align error behavior with live connection def test_window_function_inside_where_and_having_clauses(session): with pytest.raises(SnowparkSQLException) as ex_info: TestData.test_data2(session).select("a").where( @@ -145,6 +150,7 @@ def test_window_function_inside_where_and_having_clauses(session): assert "outside of SELECT, QUALIFY, and ORDER BY clauses" in str(ex_info) +@pytest.mark.localtest def test_reuse_window_partition_by(session): df = session.create_dataframe([(1, "1"), (2, "2"), (1, "1"), (2, "2")]).to_df( "key", "value" @@ -157,6 +163,7 @@ def test_reuse_window_partition_by(session): ) +@pytest.mark.localtest def test_reuse_window_order_by(session): df = session.create_dataframe([(1, "1"), (2, "2"), (1, "1"), (2, "2")]).to_df( "key", "value" @@ -197,12 +204,11 @@ def test_rank_functions_in_unspecific_window(session): ) +@pytest.mark.localtest def test_empty_over_spec(session): df = session.create_dataframe([("a", 1), ("a", 1), ("a", 2), ("b", 2)]).to_df( "key", "value" ) - view_name = Utils.random_name_for_temp_object(TempObjectType.VIEW) - df.create_or_replace_temp_view(view_name) Utils.check_answer( df.select("key", "value", sum_("value").over(), avg("value").over()), [ @@ -212,9 +218,12 @@ def test_empty_over_spec(session): Row("b", 2, 6, 1.5), ], ) + view_name = Utils.random_name_for_temp_object(TempObjectType.VIEW) + df.create_or_replace_temp_view(view_name) + Utils.check_answer( - session.sql( - f"select key, value, sum(value) over(), avg(value) over() from {view_name}" + session.table(view_name).select( + "key", "value", sum_("value").over(), avg("value").over() ), [ Row("a", 1, 6, 1.5), @@ -225,6 +234,7 @@ def test_empty_over_spec(session): ) +@pytest.mark.localtest def test_null_inputs(session): df = session.create_dataframe( [("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)] @@ -247,6 +257,7 @@ def test_null_inputs(session): ) +@pytest.mark.localtest def test_window_function_should_fail_if_order_by_clause_is_not_specified(session): df = session.create_dataframe([(1, "1"), (2, "2"), (1, "2"), (2, "2")]).to_df( "key", "value" @@ -376,6 +387,7 @@ def test_covar_samp_var_samp_stddev_samp_functions_in_specific_window(session): ) +@pytest.mark.localtest def test_aggregation_function_on_invalid_column(session): df = session.create_dataframe([(1, "1")]).to_df("key", "value") with pytest.raises(SnowparkSQLException) as ex_info: @@ -428,17 +440,20 @@ def test_skewness_and_kurtosis_functions_in_window(session): ) +@pytest.mark.localtest def test_window_functions_in_multiple_selects(session): df = session.create_dataframe( [("S1", "P1", 100), ("S1", "P1", 700), ("S2", "P1", 200), ("S2", "P2", 300)] ).to_df("sno", "pno", "qty") w1 = Window.partition_by("sno") w2 = Window.partition_by("sno", "pno") + select = df.select( "sno", "pno", "qty", sum_("qty").over(w2).alias("sum_qty_2") ).select( "sno", "pno", "qty", col("sum_qty_2"), sum_("qty").over(w1).alias("sum_qty_1") ) + Utils.check_answer( select, [ diff --git a/tests/integ/test_bind_variable.py b/tests/integ/test_bind_variable.py index f1bc9c5945f..fa18676bce9 100644 --- a/tests/integ/test_bind_variable.py +++ b/tests/integ/test_bind_variable.py @@ -29,6 +29,12 @@ except ImportError: is_pandas_available = False +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) + def test_basic_query(session): df1 = session.sql("select * from values (?, ?), (?, ?)", params=[1, "a", 2, "b"]) diff --git a/tests/integ/test_column.py b/tests/integ/test_column.py index e65b14a2911..692827d8559 100644 --- a/tests/integ/test_column.py +++ b/tests/integ/test_column.py @@ -14,6 +14,7 @@ from tests.utils import TestData, Utils +@pytest.mark.localtest def test_column_constructors_subscriptable(session): df = session.create_dataframe([[1, 2, 3]]).to_df("col", '"col"', "col .") assert df.select(df["col"]).collect() == [Row(1)] @@ -31,6 +32,7 @@ def test_column_constructors_subscriptable(session): assert "The DataFrame does not contain the column" in str(ex_info) +@pytest.mark.localtest def test_between(session): df = session.create_dataframe([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]).to_df( ["a", "b"] @@ -46,6 +48,7 @@ def test_between(session): ) +@pytest.mark.localtest def test_try_cast(session): df = session.create_dataframe([["2018-01-01"]], schema=["a"]) cast_res = df.select(df["a"].cast("date")).collect() @@ -53,16 +56,22 @@ def test_try_cast(session): assert cast_res[0][0] == try_cast_res[0][0] == datetime.date(2018, 1, 1) -def test_try_cast_work_cast_not_work(session): +@pytest.mark.localtest +def test_try_cast_work_cast_not_work(session, local_testing_mode): df = session.create_dataframe([["aaa"]], schema=["a"]) - with pytest.raises(SnowparkSQLException) as execinfo: + with pytest.raises( + ValueError if local_testing_mode else SnowparkSQLException + ) as execinfo: df.select(df["a"].cast("date")).collect() - assert "Date 'aaa' is not recognized" in str(execinfo) + if not local_testing_mode: + assert "Date 'aaa' is not recognized" in str(execinfo) + Utils.check_answer( df.select(df["a"].try_cast("date")), [Row(None)] ) # try_cast doesn't throw exception +@pytest.mark.localtest def test_cast_try_cast_negative(session): df = session.create_dataframe([["aaa"]], schema=["a"]) with pytest.raises(ValueError) as execinfo: @@ -73,6 +82,11 @@ def test_cast_try_cast_negative(session): assert "'wrong_type' is not a supported type" in str(execinfo) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.parametrize("number_word", ["decimal", "number", "numeric"]) def test_cast_decimal(session, number_word): df = session.create_dataframe([[5.2354]], schema=["a"]) @@ -81,18 +95,21 @@ def test_cast_decimal(session, number_word): ) +@pytest.mark.localtest def test_cast_map_type(session): df = session.create_dataframe([['{"key": "1"}']], schema=["a"]) result = df.select(parse_json(df["a"]).cast("object")).collect() assert json.loads(result[0][0]) == {"key": "1"} +@pytest.mark.localtest def test_cast_array_type(session): df = session.create_dataframe([["[1,2,3]"]], schema=["a"]) result = df.select(parse_json(df["a"]).cast("array")).collect() assert json.loads(result[0][0]) == [1, 2, 3] +@pytest.mark.localtest def test_startswith(session): Utils.check_answer( TestData.string4(session).select(col("a").startswith(lit("a"))), @@ -101,6 +118,7 @@ def test_startswith(session): ) +@pytest.mark.localtest def test_endswith(session): Utils.check_answer( TestData.string4(session).select(col("a").endswith(lit("ana"))), @@ -109,6 +127,7 @@ def test_endswith(session): ) +@pytest.mark.localtest def test_substring(session): Utils.check_answer( TestData.string4(session).select( @@ -119,6 +138,7 @@ def test_substring(session): ) +@pytest.mark.localtest def test_contains(session): Utils.check_answer( TestData.string4(session).filter(col("a").contains(lit("e"))), @@ -127,6 +147,7 @@ def test_contains(session): ) +@pytest.mark.localtest def test_when_accept_literal_value(session): assert TestData.null_data1(session).select( when(col("a").is_null(), 5).when(col("a") == 1, 6).otherwise(7).as_("a") @@ -141,6 +162,7 @@ def test_when_accept_literal_value(session): ).collect() == [Row(5), Row(None), Row(6), Row(None), Row(5)] +@pytest.mark.localtest def test_logical_operator_raise_error(session): df = session.create_dataframe([[1, 2]], schema=["a", "b"]) with pytest.raises(TypeError) as execinfo: @@ -148,6 +170,11 @@ def test_logical_operator_raise_error(session): assert "Cannot convert a Column object into bool" in str(execinfo) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_when_accept_sql_expr(session): assert TestData.null_data1(session).select( when("a is NULL", 5).when("a = 1", 6).otherwise(7).as_("a") diff --git a/tests/integ/test_column_names.py b/tests/integ/test_column_names.py index b87e969203d..ba3cf75ef69 100644 --- a/tests/integ/test_column_names.py +++ b/tests/integ/test_column_names.py @@ -24,17 +24,23 @@ upper, when, ) +from snowflake.snowpark.mock.connection import MockServerConnection from tests.utils import Utils def get_metadata_names(session, df): + if isinstance(session._conn, MockServerConnection): + return [col.name for col in session._conn.get_result_and_metadata(df._plan)[1]] + description = session._conn._cursor.describe(df.queries["queries"][-1]) return [quote_name(metadata.name) for metadata in description] +@pytest.mark.localtest def test_like(session): - df1 = session.sql("select 'v' as c") + df1 = session.create_dataframe(["v"], schema=["c"]) df2 = df1.select(df1["c"].like(lit("v%"))) + assert ( df2._output[0].name == df2.columns[0] @@ -42,8 +48,9 @@ def test_like(session): == '"""C"" LIKE \'V%\'"' ) - df1 = session.sql("select 'v' as \"c c\"") + df1 = session.create_dataframe(["v"], schema=['"c c"']) df2 = df1.select(df1["c c"].like(lit("v%"))) + assert ( df2._output[0].name == df2.columns[0] @@ -52,8 +59,9 @@ def test_like(session): ) +@pytest.mark.localtest def test_regexp(session): - df1 = session.sql("select 'v' as c") + df1 = session.create_dataframe(["v"], schema=["c"]) df2 = df1.select(df1["c"].regexp(lit("v%"))) assert ( df2._output[0].name @@ -62,8 +70,8 @@ def test_regexp(session): == '"""C"" REGEXP \'V%\'"' ) - df1 = session.sql("select 'v' as \"c c\"") - df2 = df1.select(df1["c c"].regexp(lit("v%"))) + df1 = session.create_dataframe(["v"], schema=['"c c"']) + df2 = df1.select(df1['"c c"'].regexp(lit("v%"))) assert ( df2._output[0].name == df2.columns[0] @@ -72,6 +80,11 @@ def test_regexp(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_collate(session): df1 = session.sql("select 'v' as c") df2 = df1.select(df1["c"].collate("en")) @@ -92,8 +105,11 @@ def test_collate(session): ) +@pytest.mark.localtest def test_subfield(session): - df1 = session.sql('select [1, 2, 3] as c, parse_json(\'{"a": "b"}\') as "c c"') + df1 = session.create_dataframe( + data=[[[1, 2, 3], {"a": "b"}]], schema=["c", '"c c"'] + ) df2 = df1.select(df1["C"][0], df1["c c"]["a"]) assert ( [x.name for x in df2._output] @@ -103,6 +119,11 @@ def test_subfield(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_case_when(session): df1 = session.sql('select 1 as c, 2 as "c c"') df2 = df1.select(when(df1["c"] == 1, lit(True)).when(df1["c"] == 2, lit("abc"))) @@ -114,6 +135,11 @@ def test_case_when(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_multiple_expression(session): df1 = session.sql("select 1 as c, 'v' as \"c c\"") df2 = df1.select(in_(["c", "c c"], [[lit(1), lit("v")]])) @@ -125,6 +151,11 @@ def test_multiple_expression(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_in_expression(session): df1 = session.sql("select 1 as c, 'v' as \"c c\"") df2 = df1.select(df1["c"].in_(1, 2, 3), df1["c c"].in_("v")) @@ -148,6 +179,11 @@ def test_scalar_subquery(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_specified_window_frame(session): df1 = session.sql("select 'v' as \" a\"") assert df1._output[0].name == '" a"' @@ -161,8 +197,10 @@ def test_specified_window_frame(session): ) +@pytest.mark.localtest def test_cast(session): - df1 = session.sql("select 1 as a, 'v' as \" a\"") + + df1 = session.create_dataframe([[1, "v"]], schema=["a", '" a"']) df2 = df1.select( df1["a"].cast("string(23)"), df1[" a"].try_cast("integer"), @@ -180,6 +218,11 @@ def test_cast(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_unspecified_frame(session): df1 = session.sql("select 'v' as \" a\"") assert ( @@ -197,6 +240,11 @@ def test_unspecified_frame(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_special_frame_boundry(session): df1 = session.sql("select 'v' as \" a\"") assert df1._output[0].name == '" a"' @@ -217,6 +265,11 @@ def test_special_frame_boundry(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_rank_related_function_expression(session): "Lag, Lead, FirstValue, LastValue" df1 = session.sql("select 1 as a, 'v' as \" a\"") @@ -257,6 +310,9 @@ def test_rank_related_function_expression(session): ) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="Relies on DUAL table" +) def test_literal(session): df1 = session.table("dual") df2 = df1.select(lit("a"), lit(1), lit(True), lit([1])) @@ -273,9 +329,11 @@ def test_literal(session): ) +@pytest.mark.localtest def test_attribute(session): - df1 = session.sql('select 1 as " a", 2 as a') + df1 = session.create_dataframe([[1, 2]], schema=[" a", "a"]) df2 = df1.select(df1[" a"], df1["a"]) + assert ( [x.name for x in df2._output] == get_metadata_names(session, df2) @@ -287,9 +345,12 @@ def test_attribute(session): ] # In class ColumnIdentifier, the "" is removed for '"A"'. +@pytest.mark.localtest def test_unresolved_attribute(session): - df1 = session.sql('select 1 as " a", 2 as a') + df1 = session.create_dataframe([[1, 2]], schema=[" a", "a"]) + df2 = df1.select(" a", "a") + assert ( [x.name for x in df2._output] == get_metadata_names(session, df2) @@ -301,8 +362,9 @@ def test_unresolved_attribute(session): ] # In class ColumnIdentifier, the "" is removed for '"A"'. +@pytest.mark.localtest def test_star(session): - df1 = session.sql('select 1 as " a", 2 as a') + df1 = session.create_dataframe([[1, 2]], schema=[" a", "a"]) df2 = df1.select(df1["*"]) assert ( [x.name for x in df2._output] @@ -325,15 +387,18 @@ def test_star(session): ] # In class ColumnIdentifier, the "" is removed for '"A"'. -def test_function_expression(session): - df1 = session.sql("select 'a' as a") - df2 = df1.select(upper(df1["A"])) - assert ( - df2._output[0].name - == df2.columns[0] - == get_metadata_names(session, df2)[0] - == '"UPPER(""A"")"' - ) +@pytest.mark.localtest +def test_function_expression(session, local_testing_mode): + df1 = session.create_dataframe(["a"], schema=["a"]) + if not local_testing_mode: + # local testing does not support upper + df2 = df1.select(upper(df1["A"])) + assert ( + df2._output[0].name + == df2.columns[0] + == get_metadata_names(session, df2)[0] + == '"UPPER(""A"")"' + ) df3 = df1.select(count_distinct("a")) assert ( @@ -346,6 +411,11 @@ def test_function_expression(session): @pytest.mark.udf @pytest.mark.parametrize("use_qualified_name", [True, False]) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_udf(session, use_qualified_name): def add_one(x: int) -> int: return x + 1 @@ -415,6 +485,11 @@ def add_one(x: int) -> int: Utils.drop_stage(session, stage_name) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_unary_expression(session): """Alias, UnresolvedAlias, Cast, UnaryMinus, IsNull, IsNotNull, IsNaN, Not""" df1 = session.sql('select 1 as " a", 2 as a') @@ -475,6 +550,11 @@ def test_unary_expression(session): ] # In class ColumnIdentifier, the "" is removed for '"B"'. +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_list_agg_within_group_sort_order(session): df1 = session.sql( 'select c as "a b" from (select c from values((1), (2), (3)) as t(c))' @@ -492,6 +572,11 @@ def test_list_agg_within_group_sort_order(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_binary_expression(session): """=, !=, >, <, >=, <=, EQUAL_NULL, AND, OR, +, -, *, /, %, POWER, BITAND, BITOR, BITXOR""" df1 = session.sql("select 1 as \" a\", 'x' as \" b\", 1 as a, 'x' as b") @@ -578,8 +663,9 @@ def test_binary_expression(session): ) +@pytest.mark.localtest def test_cast_nan_column_name(session): - df1 = session.sql("select 'a' as a") + df1 = session.create_dataframe([["a"]], schema=["a"]) df2 = df1.select(df1["A"] == math.nan) assert ( df2._output[0].name @@ -589,6 +675,11 @@ def test_cast_nan_column_name(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_inf_column_name(session): df1 = session.sql("select 'inf'") df2 = df1.select(df1["'INF'"] == math.inf) diff --git a/tests/integ/test_context.py b/tests/integ/test_context.py index d73e37234e7..c2e992b78d0 100644 --- a/tests/integ/test_context.py +++ b/tests/integ/test_context.py @@ -3,8 +3,15 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +import pytest + from snowflake.snowpark.context import get_active_session +pytestmark = pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="no active session in local testing", +) + def test_get_active_session(session): assert session == get_active_session() diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index bf746e66258..c523a7237f3 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -100,23 +100,26 @@ @pytest.fixture(scope="module", autouse=True) -def setup(session, resources_path): - test_files = TestFiles(resources_path) - Utils.create_stage(session, tmp_stage_name, is_temporary=True) - Utils.upload_to_stage( - session, f"@{tmp_stage_name}", test_files.test_file_csv, compress=False - ) +def setup(session, resources_path, local_testing_mode): + if not local_testing_mode: + test_files = TestFiles(resources_path) + Utils.create_stage(session, tmp_stage_name, is_temporary=True) + Utils.upload_to_stage( + session, f"@{tmp_stage_name}", test_files.test_file_csv, compress=False + ) @pytest.fixture(scope="function") def table_name_1(session): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) - Utils.create_table(session, table_name, "num int") - session._run_query(f"insert into {table_name} values (1), (2), (3)") + session.create_dataframe( + [[1], [2], [3]], schema=StructType([StructField("num", IntegerType())]) + ).write.save_as_table(table_name) yield table_name Utils.drop_table(session, table_name) +@pytest.mark.localtest def test_dataframe_get_item(session): df = session.create_dataframe([[1, "a"], [2, "b"], [3, "c"], [4, "d"]]).to_df( "id", "value" @@ -131,6 +134,7 @@ def test_dataframe_get_item(session): assert "Unexpected item type: " in str(exc_info) +@pytest.mark.localtest def test_dataframe_get_attr(session): df = session.create_dataframe([[1, "a"], [2, "b"], [3, "c"], [4, "d"]]).to_df( "id", "value" @@ -143,14 +147,16 @@ def test_dataframe_get_attr(session): assert "object has no attribute" in str(exc_info) +@pytest.mark.localtest @pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="need resources") -def test_read_stage_file_show(session, resources_path): +def test_read_stage_file_show(session, resources_path, local_testing_mode): tmp_stage_name = Utils.random_stage_name() test_files = TestFiles(resources_path) test_file_on_stage = f"@{tmp_stage_name}/testCSV.csv" try: - Utils.create_stage(session, tmp_stage_name, is_temporary=True) + if not local_testing_mode: + Utils.create_stage(session, tmp_stage_name, is_temporary=True) Utils.upload_to_stage( session, "@" + tmp_stage_name, test_files.test_file_csv, compress=False ) @@ -179,9 +185,15 @@ def test_read_stage_file_show(session, resources_path): """.lstrip() ) finally: - Utils.drop_stage(session, tmp_stage_name) + if not local_testing_mode: + Utils.drop_stage(session, tmp_stage_name) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_show_using_with_select_statement(session): df = session.sql( "with t1 as (select 1 as a union all select 2 union all select 3 " @@ -210,6 +222,7 @@ def test_show_using_with_select_statement(session): ) +@pytest.mark.localtest def test_distinct(session): """Tests df.distinct().""" @@ -246,6 +259,7 @@ def test_distinct(session): assert res == [Row(None), Row(1), Row(2), Row(3), Row(4), Row(5)] +@pytest.mark.localtest def test_first(session): """Tests df.first().""" @@ -291,6 +305,7 @@ def test_first(session): assert "Invalid type of argument passed to first()" in str(ex_info) +@pytest.mark.localtest def test_new_df_from_range(session): """Tests df.range().""" @@ -334,6 +349,7 @@ def test_new_df_from_range(session): assert res == expected +@pytest.mark.localtest def test_select_single_column(session): """Tests df.select() on dataframes with a single column.""" @@ -360,6 +376,7 @@ def test_select_single_column(session): assert res == expected +@pytest.mark.localtest def test_select_star(session): """Tests df.select('*').""" @@ -376,6 +393,10 @@ def test_select_star(session): assert res == expected +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Relies on get_active_session", +) @pytest.mark.udf def test_select_table_function(session): df = session.create_dataframe( @@ -498,6 +519,11 @@ def process(self, n: int): Utils.check_answer(df.select(df.a, table_func(df.b, lit(" "))), expected_result) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_generator_table_function(session): # works with rowcount expected_result = [Row(-108, 3), Row(-107, 3), Row(0, 3)] @@ -551,6 +577,11 @@ def test_generator_table_function(session): Utils.check_answer(df, expected_result) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_generator_table_function_negative(session): # fails when no operators added with pytest.raises(ValueError) as ex_info: @@ -558,6 +589,10 @@ def test_generator_table_function_negative(session): assert "Columns cannot be empty for generator table function" in str(ex_info) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Relies on get_active_session", +) @pytest.mark.udf def test_select_table_function_negative(session): df = session.create_dataframe([(1, "a", 10), (2, "b", 20), (3, "c", 30)]).to_df( @@ -649,6 +684,11 @@ def process(self, n: int): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_explode(session): df = session.create_dataframe( [[1, [1, 2, 3], {"a": "b"}, "Kimura"]], schema=["idx", "lists", "maps", "strs"] @@ -703,6 +743,11 @@ def test_explode(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_explode_negative(session): df = session.create_dataframe( [[1, [1, 2, 3], {"a": "b"}, "Kimura"]], schema=["idx", "lists", "maps", "strs"] @@ -737,6 +782,10 @@ def test_explode_negative(session): df.select(explode(col("DOES_NOT_EXIST"))) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Relies on get_active_session", +) @pytest.mark.udf def test_with_column(session): df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) @@ -752,6 +801,10 @@ def process(self, a: int, b: int) -> Iterable[Tuple[int]]: Utils.check_answer(df.with_column("total", sum_udtf(df.a, df.b)), expected) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Relies on get_active_session", +) @pytest.mark.udf def test_with_column_negative(session): df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) @@ -770,6 +823,10 @@ def process(self, a: int, b: int) -> Iterable[Tuple[int, int]]: ) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Relies on get_active_session", +) @pytest.mark.udf def test_with_columns(session): df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) @@ -841,6 +898,10 @@ def process(self, a: int, b: int) -> Iterable[Tuple[int, int]]: ) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Relies on get_active_session", +) @pytest.mark.udf def test_with_columns_negative(session): df = session.create_dataframe( @@ -893,6 +954,7 @@ def process(self, a: int, b: int) -> Iterable[Tuple[int, int]]: ) +@pytest.mark.localtest def test_df_subscriptable(session): """Tests select & filter as df[...]""" @@ -942,6 +1004,7 @@ def test_df_subscriptable(session): assert res == expected +@pytest.mark.localtest def test_filter(session): """Tests for df.filter().""" df = session.range(1, 10, 2) @@ -965,6 +1028,14 @@ def test_filter(session): expected = [] assert res == expected + +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) +def test_filter_with_sql_str(session): + df = session.range(1, 10, 2) # sql text assert ( df.filter(col("id") > 4).collect() @@ -983,6 +1054,7 @@ def test_filter(session): ) +@pytest.mark.localtest def test_filter_incorrect_type(session): """Tests for incorrect type passed to DataFrame.filter().""" df = session.range(1, 10, 2) @@ -995,6 +1067,7 @@ def test_filter_incorrect_type(session): ) +@pytest.mark.localtest def test_filter_chained(session): """Tests for chained DataFrame.filter() operations""" @@ -1026,6 +1099,7 @@ def test_filter_chained(session): assert res == expected +@pytest.mark.localtest def test_filter_chained_col_objects_int(session): """Tests for chained DataFrame.filter() operations.""" @@ -1055,6 +1129,7 @@ def test_filter_chained_col_objects_int(session): assert res == expected +@pytest.mark.localtest def test_drop(session): """Test for dropping columns from a dataframe.""" @@ -1088,8 +1163,10 @@ def test_drop(session): assert res == expected +@pytest.mark.localtest def test_alias(session): """Test for dropping columns from a dataframe.""" + # Selecting non-existing column (already renamed) should fail with pytest.raises(SnowparkSQLException): session.range(3, 8).select(col("id").alias("id_prime")).select( @@ -1109,6 +1186,7 @@ def test_alias(session): assert res == expected +@pytest.mark.localtest def test_join_inner(session): """Test for inner join of dataframes.""" @@ -1140,6 +1218,7 @@ def test_join_inner(session): assert res == expected +@pytest.mark.localtest def test_join_left_anti(session): """Test for left-anti join of dataframes.""" @@ -1164,6 +1243,7 @@ def test_join_left_anti(session): assert sorted(res, key=lambda r: r[0]) == expected +@pytest.mark.localtest def test_join_left_outer(session): """Test for left-outer join of dataframes.""" @@ -1200,6 +1280,7 @@ def test_join_left_outer(session): assert sorted(res, key=lambda r: r[0]) == expected +@pytest.mark.localtest def test_join_right_outer(session): """Test for right-outer join of dataframes.""" @@ -1236,6 +1317,7 @@ def test_join_right_outer(session): assert sorted(res, key=lambda r: r[0]) == expected +@pytest.mark.localtest def test_join_left_semi(session): """Test for left semi join of dataframes.""" @@ -1260,6 +1342,7 @@ def test_join_left_semi(session): assert sorted(res, key=lambda r: r[0]) == expected +@pytest.mark.localtest def test_join_cross(session): """Test for cross join of dataframes.""" @@ -1299,6 +1382,7 @@ def test_join_cross(session): assert sorted(res, key=lambda r: (r[0], r[1])) == expected +@pytest.mark.localtest def test_join_outer(session): """Test for outer join of dataframes.""" @@ -1347,6 +1431,7 @@ def test_join_outer(session): assert sorted(res, key=lambda r: r[0]) == expected +@pytest.mark.localtest def test_toDF(session): """Test df.to_df().""" @@ -1374,6 +1459,7 @@ def test_toDF(session): assert sorted(res, key=lambda r: r[0]) == expected +@pytest.mark.localtest def test_df_col(session): """Test df.col()""" @@ -1432,6 +1518,7 @@ def test_create_dataframe_with_basic_data_types(session): assert df.select(expected_names).collect() == expected_rows +@pytest.mark.localtets def test_create_dataframe_with_semi_structured_data_types(session): data = [ [ @@ -1478,6 +1565,10 @@ def test_create_dataframe_with_semi_structured_data_types(session): ) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="TODO: enable for local testing after supporting more snowflake data types", +) def test_create_dataframe_with_dict(session): data = {f"snow_{idx + 1}": idx**3 for idx in range(5)} expected_names = [name.upper() for name in data.keys()] @@ -1509,6 +1600,7 @@ def test_create_dataframe_with_dict(session): ) +@pytest.mark.localtest def test_create_dataframe_with_dict_given_schema(session): schema = StructType( [ @@ -1571,6 +1663,7 @@ def test_create_dataframe_with_dict_given_schema(session): Utils.check_answer(df, [Row(None, None), Row(None, None)]) +@pytest.mark.localtest def test_create_dataframe_with_namedtuple(session): Data = namedtuple("Data", [f"snow_{idx + 1}" for idx in range(5)]) data = Data(*[idx**3 for idx in range(5)]) @@ -1592,6 +1685,7 @@ def test_create_dataframe_with_namedtuple(session): Utils.check_answer(df, [Row(1, 2, None, None), Row(None, None, 3, 4)]) +@pytest.mark.localtest def test_create_dataframe_with_row(session): row1 = Row(a=1, b=2) row2 = Row(a=3, b=4) @@ -1627,6 +1721,7 @@ def test_create_dataframe_with_row(session): assert "4 fields are required by schema but 2 values are provided" in str(ex_info) +@pytest.mark.localtest def test_create_dataframe_with_mixed_dict_namedtuple_row(session): d = {"a": 1, "b": 2} Data = namedtuple("Data", ["a", "b"]) @@ -1644,6 +1739,7 @@ def test_create_dataframe_with_mixed_dict_namedtuple_row(session): ) +@pytest.mark.localtest def test_create_dataframe_with_schema_col_names(session): col_names = ["a", "b", "c", "d"] df = session.create_dataframe([[1, 2, 3, 4]], schema=col_names) @@ -1664,6 +1760,10 @@ def test_create_dataframe_with_schema_col_names(session): assert Utils.equals_ignore_case(field.name, expected_name) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="TODO: enable for local testing after supporting more snowflake data types", +) def test_create_dataframe_with_variant(session): data = [ 1, @@ -1701,6 +1801,7 @@ def test_create_dataframe_with_variant(session): ] +@pytest.mark.localtest @pytest.mark.parametrize("data", [[0, 1, 2, 3], ["", "a"], [False, True], [None]]) def test_create_dataframe_with_single_value(session, data): expected_names = ["_1"] @@ -1765,6 +1866,7 @@ def test_create_dataframe_empty(session): assert df.with_column("c", lit(2)).columns == ["A", "B", "C"] +@pytest.mark.localtest @pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="Large result") def test_create_dataframe_from_none_data(session): assert session.create_dataframe([None, None]).collect() == [ @@ -1784,6 +1886,10 @@ def test_create_dataframe_from_none_data(session): assert session.create_dataframe([None] * 20000).collect() == [Row(None)] * 20000 +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="batch insert does not apply for local testing", +) def test_create_dataframe_large_without_batch_insert(session): from snowflake.snowpark._internal.analyzer import analyzer @@ -1798,6 +1904,7 @@ def test_create_dataframe_large_without_batch_insert(session): analyzer.ARRAY_BIND_THRESHOLD = original_value +@pytest.mark.localtest def test_create_dataframe_with_invalid_data(session): # None input with pytest.raises(ValueError) as ex_info: @@ -1850,6 +1957,7 @@ def test_create_dataframe_with_invalid_data(session): assert "data consists of rows with different lengths" in str(ex_info) +@pytest.mark.localtest def test_attribute_reference_to_sql(session): from snowflake.snowpark.functions import sum as sum_ @@ -1870,6 +1978,10 @@ def test_attribute_reference_to_sql(session): Utils.check_answer([Row(1, 1)], agg_results) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="TODO: selecting duplicate column names are not supported in Local Testing", +) def test_dataframe_duplicated_column_names(session): df = session.sql("select 1 as a, 2 as a") # collect() works and return a row with duplicated keys @@ -1886,6 +1998,11 @@ def test_dataframe_duplicated_column_names(session): assert "duplicate column name 'A'" in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.skipif( IS_IN_STORED_PROC, reason="Async query is not supported in stored procedure yet" ) @@ -2004,10 +2121,13 @@ def test_case_insensitive_local_iterator(session): assert row["P@$$W0RD"] == "test" -def test_dropna(session): - Utils.check_answer(TestData.double3(session).dropna(), [Row(1.0, 1)]) +@pytest.mark.localtest +def test_dropna(session, local_testing_mode): + Utils.check_answer( + TestData.double3(session, local_testing_mode).dropna(), [Row(1.0, 1)] + ) - res = TestData.double3(session).dropna(how="all").collect() + res = TestData.double3(session, local_testing_mode).dropna(how="all").collect() assert res[0] == Row(1.0, 1) assert math.isnan(res[1][0]) assert res[1][1] == 2 @@ -2015,10 +2135,11 @@ def test_dropna(session): assert res[3] == Row(4.0, None) Utils.check_answer( - TestData.double3(session).dropna(subset=["a"]), [Row(1.0, 1), Row(4.0, None)] + TestData.double3(session, local_testing_mode).dropna(subset=["a"]), + [Row(1.0, 1), Row(4.0, None)], ) - res = TestData.double3(session).dropna(thresh=1).collect() + res = TestData.double3(session, local_testing_mode).dropna(thresh=1).collect() assert res[0] == Row(1.0, 1) assert math.isnan(res[1][0]) assert res[1][1] == 2 @@ -2026,26 +2147,28 @@ def test_dropna(session): assert res[3] == Row(4.0, None) with pytest.raises(TypeError) as ex_info: - TestData.double3(session).dropna(subset={1: "a"}) + TestData.double3(session, local_testing_mode).dropna(subset={1: "a"}) assert "subset should be a list or tuple of column names" in str(ex_info) -def test_fillna(session): - Utils.check_answer( - TestData.double3(session).fillna(11), - [ - Row(1.0, 1), - Row(11.0, 2), - Row(11.0, 3), - Row(4.0, 11), - Row(11.0, 11), - Row(11.0, 11), - ], - sort=False, - ) +@pytest.mark.localtest +def test_fillna(session, local_testing_mode): + if not local_testing_mode: # Enable for local testing after coercion support + Utils.check_answer( + TestData.double3(session, local_testing_mode).fillna(11), + [ + Row(1.0, 1), + Row(11.0, 2), + Row(11.0, 3), + Row(4.0, 11), + Row(11.0, 11), + Row(11.0, 11), + ], + sort=False, + ) Utils.check_answer( - TestData.double3(session).fillna(11, subset=["a"]), + TestData.double3(session, local_testing_mode).fillna(11.0, subset=["a"]), [ Row(1.0, 1), Row(11.0, 2), @@ -2058,7 +2181,7 @@ def test_fillna(session): ) Utils.check_answer( - TestData.double3(session).fillna(None), + TestData.double3(session, local_testing_mode).fillna(None), [ Row(1.0, 1), Row(None, 2), @@ -2107,20 +2230,21 @@ def test_fillna(session): Utils.check_answer( session.create_dataframe( [[1, 1.1], [None, None]], schema=["col1", "col2"] - ).fillna({"col1": 1.1, "col2": 1}), - [Row(1, 1.1), Row(None, 1)], + ).fillna({"col1": 1.1, "col2": 1.1}), + [Row(1, 1.1), Row(None, 1.1)], ) - df = session.create_dataframe( - [[[1, 2], (1, 3)], [None, None]], schema=["col1", "col2"] - ) - Utils.check_answer( - df.fillna([1, 3]), - [ - Row("[\n 1,\n 2\n]", "[\n 1,\n 3\n]"), - Row("[\n 1,\n 3\n]", "[\n 1,\n 3\n]"), - ], - ) + if not local_testing_mode: # TODO: enable this after rebasing on support-variant + df = session.create_dataframe( + [[[1, 2], (1, 3)], [None, None]], schema=["col1", "col2"] + ) + Utils.check_answer( + df.fillna([1, 3]), + [ + Row("[\n 1,\n 2\n]", "[\n 1,\n 3\n]"), + Row("[\n 1,\n 3\n]", "[\n 1,\n 3\n]"), + ], + ) # negative case with pytest.raises(TypeError) as ex_info: @@ -2128,7 +2252,8 @@ def test_fillna(session): assert "subset should be a list or tuple of column names" in str(ex_info) -def test_replace(session): +# TODO: enable this for local testing after supporting type coercion +def test_replace_with_coercion(session, local_testing_mode): df = session.create_dataframe( [[1, 1.0, "1.0"], [2, 2.0, "2.0"]], schema=["a", "b", "c"] ) @@ -2194,6 +2319,7 @@ def test_replace(session): assert "to_replace and value lists should be of the same length" in str(ex_info) +@pytest.mark.localtest def test_select_case_expr(session): df = session.create_dataframe([1, 2, 3], schema=["a"]) Utils.check_answer( @@ -2201,6 +2327,11 @@ def test_select_case_expr(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_select_expr(session): df = session.create_dataframe([-1, 2, 3], schema=["a"]) Utils.check_answer( @@ -2213,6 +2344,11 @@ def test_select_expr(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_describe(session): assert TestData.test_data2(session).describe().columns == [ "SUMMARY", @@ -2338,19 +2474,20 @@ def test_describe(session): assert "invalid identifier" in str(ex_info) +@pytest.mark.localtest @pytest.mark.parametrize("table_type", ["", "temp", "temporary", "transient"]) @pytest.mark.parametrize( "save_mode", ["append", "overwrite", "ignore", "errorifexists"] ) -def test_table_types_in_save_as_table(session, save_mode, table_type): +def test_table_types_in_save_as_table( + session, save_mode, table_type, local_testing_mode +): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) df = session.create_dataframe([(1, 2), (3, 4)]).toDF("a", "b") - try: - df.write.save_as_table(table_name, mode=save_mode, table_type=table_type) - Utils.check_answer(session.table(table_name), df, True) + df.write.save_as_table(table_name, mode=save_mode, table_type=table_type) + Utils.check_answer(session.table(table_name), df, True) + if not local_testing_mode: Utils.assert_table_type(session, table_name, table_type) - finally: - Utils.drop_table(session, table_name) @pytest.mark.parametrize( @@ -2504,11 +2641,14 @@ def test_write_table_with_clustering_keys(session, save_mode): Utils.drop_table(session, table_name3) +@pytest.mark.localtest @pytest.mark.parametrize("table_type", ["temp", "temporary", "transient"]) @pytest.mark.parametrize( "save_mode", ["append", "overwrite", "ignore", "errorifexists"] ) -def test_write_temp_table_no_breaking_change(session, save_mode, table_type, caplog): +def test_write_temp_table_no_breaking_change( + session, save_mode, table_type, caplog, local_testing_mode +): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) df = session.create_dataframe([(1, 2), (3, 4)]).toDF("a", "b") try: @@ -2521,13 +2661,15 @@ def test_write_temp_table_no_breaking_change(session, save_mode, table_type, cap ) assert "create_temp_table is deprecated" in caplog.text Utils.check_answer(session.table(table_name), df, True) - Utils.assert_table_type(session, table_name, "temp") + if not local_testing_mode: + Utils.assert_table_type(session, table_name, "temp") finally: Utils.drop_table(session, table_name) # clear the warning dict otherwise it will affect the future tests warning_dict.clear() +@pytest.mark.localtest def test_write_invalid_table_type(session): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) df = session.create_dataframe([(1, 2), (3, 4)]).toDF("a", "b") @@ -2535,6 +2677,9 @@ def test_write_invalid_table_type(session): df.write.save_as_table(table_name, table_type="invalid") +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="Testing query history" +) def test_append_existing_table(session): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) Utils.create_table(session, table_name, "a int, b int", is_temporary=True) @@ -2550,6 +2695,11 @@ def test_append_existing_table(session): Utils.drop_table(session, table_name) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_create_dynamic_table(session, table_name_1): try: df = session.table(table_name_1) @@ -2566,6 +2716,11 @@ def test_create_dynamic_table(session, table_name_1): Utils.drop_dynamic_table(session, dt_name) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_write_copy_into_location_basic(session): temp_stage = Utils.random_name_for_temp_object(TempObjectType.STAGE) Utils.create_stage(session, temp_stage, is_temporary=True) @@ -2582,6 +2737,11 @@ def test_write_copy_into_location_basic(session): Utils.drop_stage(session, temp_stage) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.parametrize( "partition_by", [ @@ -2615,6 +2775,9 @@ def test_write_copy_into_location_csv(session, partition_by): Utils.drop_stage(session, temp_stage) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="Tests query generation" +) def test_queries(session): df = TestData.column_has_special_char(session) queries = df.queries @@ -2633,14 +2796,14 @@ def test_queries(session): assert post_actions[0].startswith("DROP") +@pytest.mark.localtest def test_df_columns(session): assert session.create_dataframe([1], schema=["a"]).columns == ["A"] temp_table = Utils.random_name_for_temp_object(TempObjectType.TABLE) - Utils.create_table( - session, temp_table, '"a b" int, "a""b" int, "a" int, a int', is_temporary=True - ) - session.sql(f"insert into {temp_table} values (1, 2, 3, 4)").collect() + session.create_dataframe( + [[1, 2, 3, 4]], schema=['"a b"', '"a""b"', '"a"', "a"] + ).write.save_as_table(temp_table, table_type="temporary") try: df = session.table(temp_table) assert df.columns == ['"a b"', '"a""b"', '"a"', "A"] @@ -2665,6 +2828,11 @@ def test_df_columns(session): Utils.drop_table(session, temp_table) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.parametrize( "column_list", [["jan", "feb", "mar", "apr"], [col("jan"), col("feb"), col("mar"), col("apr")]], @@ -2692,6 +2860,11 @@ def test_unpivot(session, column_list): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_create_dataframe_string_length(session): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) df = session.create_dataframe(["ab", "abc", "abcd"], schema=["a"]) @@ -2707,6 +2880,10 @@ def test_create_dataframe_string_length(session): ) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Batch insert does not apply", +) @pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="need resources") def test_create_table_twice_no_error(session): from snowflake.snowpark._internal.analyzer import analyzer @@ -2732,6 +2909,9 @@ def test_create_table_twice_no_error(session): ) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="Relies on internal API" +) def check_df_with_query_id_result_scan(session, df): query_id = df._execute_and_get_query_id() df_from_result_scan = session.sql(result_scan_statement(query_id)) @@ -2739,6 +2919,9 @@ def check_df_with_query_id_result_scan(session, df): Utils.check_answer(df, df_from_result_scan) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="Relies on internal API" +) @pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="need resources") def test_query_id_result_scan(session): from snowflake.snowpark._internal.analyzer import analyzer @@ -2762,6 +2945,10 @@ def test_query_id_result_scan(session): @pytest.mark.skipif(not is_pandas_available, reason="pandas is required") +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Statement parameters are not supported in Local Testing", +) def test_call_with_statement_params(session): statement_params_wrong_date_format = { "DATE_INPUT_FORMAT": "YYYY-MM-DD", @@ -2934,12 +3121,14 @@ def test_call_with_statement_params(session): Utils.drop_stage(session, temp_stage) +@pytest.mark.localtest def test_limit_offset(session): df = session.create_dataframe([[1, 2, 3], [4, 5, 6]], schema=["a", "b", "c"]) assert df.limit(1).collect() == [Row(A=1, B=2, C=3)] assert df.limit(1, offset=1).collect() == [Row(A=4, B=5, C=6)] +@pytest.mark.localtest def test_df_join_how_on_overwrite(session): df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df( ["int", "int2", "str"] @@ -2955,6 +3144,7 @@ def test_df_join_how_on_overwrite(session): Utils.check_answer(df, [Row(1, 1, "1"), Row(2, 3, "5")]) +@pytest.mark.localtest def test_create_dataframe_special_char_column_name(session): df1 = session.create_dataframe( [[1, 2, 3], [1, 2, 3]], schema=["a b", '"abc"', "@%!^@&#"] @@ -2975,6 +3165,7 @@ def test_create_dataframe_with_tuple_schema(session): Utils.check_answer(df, [Row(20000101, 1, "x"), Row(20000101, 2, "y")]) +@pytest.mark.localtest def test_df_join_suffix(session): df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df(["a", "b", "c"]) df2 = session.create_dataframe([[1, 1, "1"], [2, 3, "5"]]).to_df(["a", "b", "c"]) @@ -3037,6 +3228,7 @@ def test_df_join_suffix(session): assert df14.columns == ['"a_l"', '"a_r"'] +@pytest.mark.localtest def test_df_cross_join_suffix(session): df1 = session.create_dataframe([[1, 1, "1"]]).to_df(["a", "b", "c"]) df2 = session.create_dataframe([[1, 1, "1"]]).to_df(["a", "b", "c"]) @@ -3080,6 +3272,7 @@ def test_df_cross_join_suffix(session): assert df14.columns == ['"a_l"', '"a_r"'] +@pytest.mark.localtest def test_suffix_negative(session): df1 = session.create_dataframe([[1, 1, "1"]]).to_df(["a", "b", "c"]) df2 = session.create_dataframe([[1, 1, "1"]]).to_df(["a", "b", "c"]) @@ -3095,6 +3288,10 @@ def test_suffix_negative(session): df1.join(df2, lsuffix="suffix", rsuffix="suffix") +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Relies on generating SQL queries", +) def test_create_or_replace_view_with_multiple_queries(session): df = session.read.option("purge", False).schema(user_schema).csv(test_file_on_stage) with pytest.raises( @@ -3104,6 +3301,11 @@ def test_create_or_replace_view_with_multiple_queries(session): df.create_or_replace_view("temp") +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_create_or_replace_dynamic_table_with_multiple_queries(session): df = session.read.option("purge", False).schema(user_schema).csv(test_file_on_stage) with pytest.raises( @@ -3115,6 +3317,7 @@ def test_create_or_replace_dynamic_table_with_multiple_queries(session): ) +@pytest.mark.localtest def test_nested_joins(session): df1 = session.create_dataframe([[1, 2], [4, 5]], schema=["a", "b"]) df2 = session.create_dataframe([[1, 3], [4, 6]], schema=["c", "d"]) diff --git a/tests/integ/test_datatypes.py b/tests/integ/test_datatypes.py new file mode 100644 index 00000000000..86d91f5084b --- /dev/null +++ b/tests/integ/test_datatypes.py @@ -0,0 +1,425 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from decimal import Decimal + +import pytest + +from snowflake.snowpark import DataFrame, Row +from snowflake.snowpark.functions import lit +from snowflake.snowpark.types import ( + BooleanType, + DecimalType, + DoubleType, + FloatType, + LongType, + StringType, + StructField, + StructType, +) +from tests.utils import Utils + + +@pytest.mark.localtest +def test_basic_filter(session): + df: DataFrame = session.create_dataframe( + [ + [1, 2, "abc"], + [3, 4, "def"], + [6, 5, "ghi"], + [8, 7, "jkl"], + [100, 200, "mno"], + [400, 300, "pqr"], + ], + schema=["a", "b", "c"], + ).select("a", "b", "c") + assert repr(df.schema) == repr( + StructType( + [ + StructField("A", LongType(), nullable=False), + StructField("B", LongType(), nullable=False), + StructField("C", StringType(16777216), nullable=False), + ] + ) + ) + + +@pytest.mark.localtest +def test_plus_basic(session): + df = session.create_dataframe( + [[1, 1.1, 2.2, 3.3]], + schema=StructType( + [ + StructField("a", LongType(), nullable=False), + StructField("b", DecimalType(3, 1), nullable=False), + StructField("c", DoubleType(), nullable=False), + StructField("d", DecimalType(4, 2), nullable=False), + ] + ), + ) + + df = df.select( + (df["a"] + 1).as_("new_a"), + (df["b"] + df["d"]).as_("new_b"), + (df["c"] + 3).as_("new_c"), + ) + assert repr(df.schema) == repr( + StructType( + [ + StructField("NEW_A", LongType(), nullable=False), + StructField("NEW_B", DecimalType(5, 2), nullable=False), + StructField("NEW_C", DoubleType(), nullable=False), + ] + ) + ) + + +@pytest.mark.localtest +def test_minus_basic(session): + df = session.create_dataframe( + [[1, 1.1, 2.2, 3.3]], + schema=StructType( + [ + StructField("a", LongType(), nullable=False), + StructField("b", DecimalType(3, 1), nullable=False), + StructField("c", DoubleType(), nullable=False), + StructField("d", DecimalType(4, 2), nullable=False), + ] + ), + ) + + df = df.select( + (df["a"] - 1).as_("new_a"), + (df["b"] - df["d"]).as_("new_b"), + (df["c"] - 3).as_("new_c"), + ) + assert repr(df.schema) == repr( + StructType( + [ + StructField("NEW_A", LongType(), nullable=False), + StructField("NEW_B", DecimalType(5, 2), nullable=False), + StructField("NEW_C", DoubleType(), nullable=False), + ] + ) + ) + + +@pytest.mark.localtest +def test_multiple_basic(session): + df = session.create_dataframe( + [[1, 1.1, 2.2, 3.3]], + schema=StructType( + [ + StructField("a", LongType(), nullable=False), + StructField("b", DecimalType(3, 1), nullable=False), + StructField("c", FloatType(), nullable=False), + StructField("d", DecimalType(4, 2), nullable=False), + ] + ), + ) + + df = df.select( + (df["a"] * 1).as_("new_a"), + (df["b"] * df["d"]).as_("new_b"), + (df["c"] * 3).as_("new_c"), + ) + assert repr(df.schema) == repr( + StructType( + [ + StructField("NEW_A", LongType(), nullable=False), + StructField("NEW_B", DecimalType(7, 3), nullable=False), + StructField("NEW_C", DoubleType(), nullable=False), + ] + ) + ) + + +@pytest.mark.localtest +def test_divide_basic(session): + df = session.create_dataframe( + [[1, 1.1, 2.2, 3.3]], + schema=StructType( + [ + StructField("a", LongType(), nullable=False), + StructField("b", DecimalType(3, 1), nullable=False), + StructField("c", DoubleType(), nullable=False), + StructField("d", DecimalType(4, 2), nullable=False), + ] + ), + ) + + df = df.select( + (df["a"] / 1).as_("new_a"), + (df["b"] / df["d"]).as_("new_b"), + (df["c"] / 3).as_("new_c"), + ) + assert repr(df.schema) == repr( + StructType( + [ + StructField("NEW_A", DecimalType(38, 6), nullable=False), + StructField("NEW_B", DecimalType(11, 7), nullable=False), + StructField("NEW_C", DoubleType(), nullable=False), + ] + ) + ) + Utils.check_answer( + df, [Row(Decimal("1.0"), Decimal("0.3333333"), 0.7333333333333334)] + ) + + +@pytest.mark.localtest +def test_div_decimal_double(session): + df = session.create_dataframe( + [[11.0, 13.0]], + schema=StructType( + [StructField("a", DoubleType()), StructField("b", DoubleType())] + ), + ) + df = df.select([df["a"] / df["b"]]) + Utils.check_answer(df, [Row(0.8461538461538461)]) + + df2 = session.create_dataframe([[11, 13]], schema=["a", "b"]) + df2 = df2.select([df2["a"] / df2["b"]]) + Utils.check_answer(df2, [Row(Decimal("0.846154"))]) + + +@pytest.mark.localtest +def test_modulo_basic(session): + df = session.create_dataframe( + [[1, 1.1, 2.2, 3.3]], + schema=StructType( + [ + StructField("a", LongType(), nullable=False), + StructField("b", DecimalType(3, 1), nullable=False), + StructField("c", DoubleType(), nullable=False), + StructField("d", DecimalType(4, 2), nullable=False), + ] + ), + ) + + df = df.select( + (df["a"] % 1).as_("new_a"), + (df["b"] % df["d"]).as_("new_b"), + (df["c"] % 3).as_("new_c"), + ) + assert repr(df.schema) == repr( + StructType( + [ + StructField("NEW_A", LongType(), nullable=False), + StructField("NEW_B", DecimalType(4, 2), nullable=False), + StructField("NEW_C", DoubleType(), nullable=False), + ] + ) + ) + + +@pytest.mark.localtest +def test_binary_ops_bool(session): + df = session.create_dataframe( + [[1, 1.1]], + schema=StructType( + [ + StructField("a", LongType(), nullable=False), + StructField("b", DecimalType(3, 1), nullable=False), + ] + ), + ) + df1 = df.select( + df["a"] > df["b"], + df["a"] >= df["b"], + df["a"] == df["b"], + df["a"] != df["b"], + df["a"] < df["b"], + df["a"] <= df["b"], + ) + assert repr(df1.schema) == repr( + StructType( + [ + StructField('"(""A"" > ""B"")"', BooleanType(), nullable=True), + StructField('"(""A"" >= ""B"")"', BooleanType(), nullable=True), + StructField('"(""A"" = ""B"")"', BooleanType(), nullable=True), + StructField('"(""A"" != ""B"")"', BooleanType(), nullable=True), + StructField('"(""A"" < ""B"")"', BooleanType(), nullable=True), + StructField('"(""A"" <= ""B"")"', BooleanType(), nullable=True), + ] + ) + ) + + df2 = df.select( + (df["a"] > df["b"]) & (df["a"] >= df["b"]), + (df["a"] > df["b"]) | (df["a"] >= df["b"]), + ) + assert repr(df2.schema) == repr( + StructType( + [ + StructField( + '"((""A"" > ""B"") AND (""A"" >= ""B""))"', + BooleanType(), + nullable=True, + ), + StructField( + '"((""A"" > ""B"") OR (""A"" >= ""B""))"', + BooleanType(), + nullable=True, + ), + ] + ) + ) + + +@pytest.mark.localtest +def test_unary_ops_bool(session): + df = session.create_dataframe( + [[1, 1.1]], + schema=StructType( + [ + StructField("a", LongType(), nullable=False), + StructField("b", DecimalType(3, 1), nullable=False), + ] + ), + ) + + df = df.select( + df["a"].is_null(), + df["a"].is_not_null(), + df["a"].equal_nan(), + ~df["a"].is_null(), + ) + assert repr(df.schema) == repr( + StructType( + [ + StructField('"""A"" IS NULL"', BooleanType(), nullable=True), + StructField('"""A"" IS NOT NULL"', BooleanType(), nullable=True), + StructField('"""A"" = \'NAN\'"', BooleanType(), nullable=True), + StructField('"NOT ""A"" IS NULL"', BooleanType(), nullable=True), + ] + ) + ) + + +@pytest.mark.localtest +def test_literal(session): + df = session.create_dataframe( + [[1]], schema=StructType([StructField("a", LongType(), nullable=False)]) + ) + df = df.select(lit("lit_value")) + assert repr(df.schema) == repr( + StructType([StructField("\"'LIT_VALUE'\"", StringType(9), nullable=False)]) + ) + + +@pytest.mark.localtest +def test_string_op_bool(session): + df = session.create_dataframe([["value"]], schema=["a"]) + df = df.select(df["a"].like("v%"), df["a"].regexp("v")) + assert repr(df.schema) == repr( + StructType( + [ + StructField('"""A"" LIKE \'V%\'"', BooleanType(), nullable=True), + StructField('"""A"" REGEXP \'V\'"', BooleanType(), nullable=True), + ] + ) + ) + + +@pytest.mark.localtest +def test_filter(session): + df = session.create_dataframe( + [[1, 1.1, 2.2, 3.3]], + schema=StructType( + [ + StructField("a", LongType(), nullable=False), + StructField("b", DecimalType(3, 1), nullable=False), + StructField("c", DoubleType(), nullable=False), + StructField("d", DecimalType(4, 2), nullable=False), + ] + ), + ) + + df1 = df.filter(df["a"] > 1).filter(df["b"] > 1) + assert repr(df1.schema) == repr(df.schema) + + +@pytest.mark.localtest +def test_sort(session): + df = session.create_dataframe( + [[1, 1.1, 2.2, 3.3]], + schema=StructType( + [ + StructField("a", LongType(), nullable=False), + StructField("b", DecimalType(3, 1), nullable=False), + StructField("c", DoubleType(), nullable=False), + StructField("d", DecimalType(4, 2), nullable=False), + ] + ), + ) + + df1 = df.sort(df["a"].asc_nulls_last()) + assert repr(df1.schema) == repr(df.schema) + + +@pytest.mark.localtest +def test_limit(session): + df = session.create_dataframe( + [[1, 1.1, 2.2, 3.3]], + schema=StructType( + [ + StructField("a", LongType(), nullable=False), + StructField("b", DecimalType(3, 1), nullable=False), + StructField("c", DoubleType(), nullable=False), + StructField("d", DecimalType(4, 2), nullable=False), + ] + ), + ) + + df1 = df.limit(5) + assert repr(df1.schema) == repr(df.schema) + + +@pytest.mark.localtest +def test_chain_filter_sort_limit(session): + df = session.create_dataframe( + [[1, 1.1, 2.2, 3.3]], + schema=StructType( + [ + StructField("a", LongType(), nullable=False), + StructField("b", DecimalType(3, 1), nullable=False), + StructField("c", DoubleType(), nullable=False), + StructField("d", DecimalType(4, 2), nullable=False), + ] + ), + ) + df1 = ( + df.filter(df["a"] > 1) + .filter(df["b"] > 1) + .sort(df["a"].asc_nulls_last()) + .limit(5) + ) + assert repr(df1.schema) == repr(df.schema) + + +@pytest.mark.localtest +def test_join_basic(session): + df = session.create_dataframe( + [[1, 1.1, 2.2, 3.3]], + schema=["a", "b", "c"], + ) + df2 = session.create_dataframe( + [[1, 1.1, 2.2, 3.3]], + schema=["a", "b", "c"], + ) + df3 = df.join(df2, lsuffix="_l", rsuffix="_r") + assert repr(df3.schema) == repr( + StructType( + [ + StructField("A_L", LongType(), nullable=False), + StructField("B_L", DoubleType(), nullable=False), + StructField("C_L", DoubleType(), nullable=False), + StructField("_4_L", DoubleType(), nullable=False), + StructField("A_R", LongType(), nullable=False), + StructField("B_R", DoubleType(), nullable=False), + StructField("C_R", DoubleType(), nullable=False), + StructField("_4_R", DoubleType(), nullable=False), + ] + ) + ) diff --git a/tests/integ/test_df_aggregate.py b/tests/integ/test_df_aggregate.py index 551aa64fbcf..c1228570479 100644 --- a/tests/integ/test_df_aggregate.py +++ b/tests/integ/test_df_aggregate.py @@ -10,6 +10,72 @@ from tests.utils import Utils +@pytest.mark.localtest +def test_df_agg_tuples_basic_without_std(session): + df = session.create_dataframe([[1, 4], [1, 4], [2, 5], [2, 6]]).to_df( + ["first", "second"] + ) + + # Aggregations on 'first' column + res = df.agg([("first", "min")]).collect() + Utils.assert_rows(res, [Row(1)]) + + res = df.agg([("first", "count")]).collect() + Utils.assert_rows(res, [Row(4)]) + + res = df.agg([("first", "max")]).collect() + Utils.assert_rows(res, [Row(2)]) + + res = df.agg([("first", "avg")]).collect() + Utils.assert_rows(res, [Row(1.5)]) + + # combine those together + res = df.agg( + [ + ("first", "min"), + ("first", "count"), + ("first", "max"), + ("first", "avg"), + ] + ).collect() + Utils.assert_rows(res, [Row(1, 4, 2, 1.5)]) + + # Aggregations on 'second' column + res = df.agg([("second", "min")]).collect() + Utils.assert_rows(res, [Row(4)]) + + res = df.agg([("second", "count")]).collect() + Utils.assert_rows(res, [Row(4)]) + + res = df.agg([("second", "max")]).collect() + Utils.assert_rows(res, [Row(6)]) + + res = df.agg([("second", "avg")]).collect() + Utils.assert_rows(res, [Row(4.75)]) + + # combine those together + res = df.agg( + [ + ("second", "min"), + ("second", "count"), + ("second", "max"), + ("second", "avg"), + ] + ).collect() + Utils.assert_rows(res, [Row(4, 4, 6, 4.75)]) + + # Get aggregations for both columns + res = df.agg( + [ + ("first", "min"), + ("second", "count"), + ("first", "max"), + ("second", "avg"), + ] + ).collect() + Utils.assert_rows(res, [Row(1, 4, 2, 4.75)]) + + def test_df_agg_tuples_basic(session): df = session.create_dataframe([[1, 4], [1, 4], [2, 5], [2, 6]]).to_df( ["first", "second"] @@ -84,6 +150,7 @@ def test_df_agg_tuples_basic(session): Utils.assert_rows(res, [Row(1, 4, 2, 4.75, 0.577349980514419)]) +@pytest.mark.localtest def test_df_agg_tuples_avg_basic(session): """Test for making sure all avg word-variations work as expected""" @@ -115,6 +182,7 @@ def test_df_agg_tuples_std_basic(session): Utils.assert_rows(res, [Row(0.577349980514419)]) +@pytest.mark.localtest def test_df_agg_tuples_count_basic(session): """Test for making sure all count variations work as expected""" @@ -129,6 +197,7 @@ def test_df_agg_tuples_count_basic(session): Utils.assert_rows(res, [Row(4)]) +@pytest.mark.localtest def test_df_group_by_invalid_input(session): """Test for check invalid input for group_by function""" @@ -149,6 +218,7 @@ def test_df_group_by_invalid_input(session): ) +@pytest.mark.localtest def test_df_agg_tuples_sum_basic(session): """Test for making sure sum works as expected""" @@ -175,6 +245,7 @@ def test_df_agg_tuples_sum_basic(session): Utils.assert_rows(res, [Row(1, 8), Row(2, 11)]) +@pytest.mark.localtest def test_df_agg_dict_arg(session): """Test for making sure dict when passed to agg() works as expected""" @@ -216,6 +287,7 @@ def test_df_agg_dict_arg(session): ) +@pytest.mark.localtest def test_df_agg_invalid_args_in_list(session): """Test for making sure when a list passed to agg() produces correct errors.""" @@ -271,6 +343,7 @@ def test_df_agg_invalid_args_in_list(session): ) +@pytest.mark.localtest def test_df_agg_empty_args(session): """Test for making sure dict when passed to agg() works as expected""" @@ -281,6 +354,7 @@ def test_df_agg_empty_args(session): Utils.assert_rows(df.agg({}).collect(), [Row(1, 4)]) +@pytest.mark.localtest def test_df_agg_varargs_tuple_list(session): df = session.create_dataframe([[1, 4], [1, 4], [2, 5], [2, 6]]).to_df( ["first", "second"] @@ -294,6 +368,7 @@ def test_df_agg_varargs_tuple_list(session): Utils.check_answer(df.agg(["first", "count"], ("second", "sum")), [Row(4, 19)]) +@pytest.mark.localtest @pytest.mark.parametrize( "col1,col2,alias1,alias2", [ diff --git a/tests/integ/test_df_sort.py b/tests/integ/test_df_sort.py index 79fd83debc3..6d4fad59aa7 100644 --- a/tests/integ/test_df_sort.py +++ b/tests/integ/test_df_sort.py @@ -8,6 +8,7 @@ from snowflake.snowpark import Column +@pytest.mark.localtest def test_sort_different_inputs(session): df = session.create_dataframe( [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3)] @@ -48,6 +49,7 @@ def test_sort_different_inputs(session): ) +@pytest.mark.localtest def test_sort_invalid_inputs(session): df = session.create_dataframe( [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3)] diff --git a/tests/integ/test_df_to_pandas.py b/tests/integ/test_df_to_pandas.py index b21a20e20d5..efbee64ea11 100644 --- a/tests/integ/test_df_to_pandas.py +++ b/tests/integ/test_df_to_pandas.py @@ -26,7 +26,12 @@ from snowflake.snowpark.types import DecimalType, IntegerType from tests.utils import IS_IN_STORED_PROC, Utils +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", raises=NotImplementedError +) + +@pytest.mark.localtest def test_to_pandas_new_df_from_range(session): # Single column snowpark_df = session.range(3, 8) @@ -52,8 +57,9 @@ def test_to_pandas_new_df_from_range(session): assert all(pandas_df["OTHER"][i] == i + 3 for i in range(5)) +@pytest.mark.localtest @pytest.mark.parametrize("to_pandas_api", ["to_pandas", "to_pandas_batches"]) -def test_to_pandas_cast_integer(session, to_pandas_api): +def test_to_pandas_cast_integer(session, to_pandas_api, local_testing_mode): snowpark_df = session.create_dataframe( [["1", "1" * 20], ["2", "2" * 20]], schema=["a", "b"] ).select( @@ -93,16 +99,21 @@ def test_to_pandas_cast_integer(session, to_pandas_api): if to_pandas_api == "to_pandas" else next(timestamp_snowpark_df.to_pandas_batches()) ) - # Starting from pyarrow 13, pyarrow no longer coerces non-nanosecond to nanosecond for pandas >=2.0 - # https://arrow.apache.org/release/13.0.0.html and https://github.com/apache/arrow/issues/33321 - pyarrow_major_version = int(pa.__version__.split(".")[0]) - pandas_major_version = int(pd.__version__.split(".")[0]) - expected_dtype = ( - "datetime64[s]" - if pyarrow_major_version >= 13 and pandas_major_version >= 2 - else "datetime64[ns]" - ) - assert str(timestamp_pandas_df.dtypes[0]) == expected_dtype + + if not local_testing_mode: + # Starting from pyarrow 13, pyarrow no longer coerces non-nanosecond to nanosecond for pandas >=2.0 + # https://arrow.apache.org/release/13.0.0.html and https://github.com/apache/arrow/issues/33321 + pyarrow_major_version = int(pa.__version__.split(".")[0]) + pandas_major_version = int(pd.__version__.split(".")[0]) + expected_dtype = ( + "datetime64[s]" + if pyarrow_major_version >= 13 and pandas_major_version >= 2 + else "datetime64[ns]" + ) + assert str(timestamp_pandas_df.dtypes[0]) == expected_dtype + else: + # TODO: mock the non-nanosecond unit pyarrow+pandas behavior in local test + assert str(timestamp_pandas_df.dtypes[0]) == "datetime64[ns]" def test_to_pandas_precision_for_number_38_0(session): @@ -166,14 +177,18 @@ def check_fetch_data_exception(query: str) -> None: @pytest.mark.skipif( IS_IN_STORED_PROC, reason="SNOW-507565: Need localaws for large result" ) -def test_to_pandas_batches(session): +@pytest.mark.localtest +def test_to_pandas_batches(session, local_testing_mode): df = session.range(100000).cache_result() iterator = df.to_pandas_batches() assert isinstance(iterator, Iterator) entire_pandas_df = df.to_pandas() pandas_df_list = list(df.to_pandas_batches()) - assert len(pandas_df_list) > 1 + if not local_testing_mode: + # in live session, large data result will be split into multiple chunks by snowflake + # local test does not split the data result chunk/is not intended for large data result chunk + assert len(pandas_df_list) > 1 assert_frame_equal(pd.concat(pandas_df_list, ignore_index=True), entire_pandas_df) for df_batch in df.to_pandas_batches(): diff --git a/tests/integ/test_function.py b/tests/integ/test_function.py index 8648b95f5dc..f693bd677ea 100644 --- a/tests/integ/test_function.py +++ b/tests/integ/test_function.py @@ -154,6 +154,7 @@ from tests.utils import TestData, Utils +@pytest.mark.localtest def test_order(session): null_data1 = TestData.null_data1(session) assert null_data1.sort(asc(null_data1["A"])).collect() == [ @@ -200,6 +201,11 @@ def test_order(session): ] +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_current_date_and_time(session): df1 = session.sql("select current_date(), current_time(), current_timestamp()") df2 = session.create_dataframe([1]).select( @@ -208,6 +214,11 @@ def test_current_date_and_time(session): assert len(df1.union(df2).collect()) == 1 +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.parametrize("col_a", ["a", col("a")]) def test_regexp_replace(session, col_a): df = session.create_dataframe( @@ -229,6 +240,11 @@ def test_regexp_replace(session, col_a): assert res[0][0] == "lastname, firstname middlename" +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_regexp_extract(session): df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"]) res = df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).collect() @@ -237,6 +253,11 @@ def test_regexp_extract(session): assert res[0]["RES"] == "30" and res[1]["RES"] == "50" +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.parametrize( "col_a, col_b, col_c", [("a", "b", "c"), (col("a"), col("b"), col("c"))] ) @@ -246,6 +267,11 @@ def test_concat(session, col_a, col_b, col_c): assert res[0][0] == "123" +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.parametrize( "col_a, col_b, col_c", [("a", "b", "c"), (col("a"), col("b"), col("c"))] ) @@ -255,6 +281,7 @@ def test_concat_ws(session, col_a, col_b, col_c): assert res[0][0] == "1,2,3" +@pytest.mark.localtest @pytest.mark.parametrize("col_a", ["a", col("a")]) def test_to_char(session, col_a): df = session.create_dataframe([[1]], schema=["a"]) @@ -262,12 +289,18 @@ def test_to_char(session, col_a): assert res[0][0] == "1" +@pytest.mark.localtest def test_date_to_char(session): df = session.create_dataframe([[datetime.date(2021, 12, 21)]], schema=["a"]) res = df.select(to_char(col("a"), "mm-dd-yyyy")).collect() assert res[0][0] == "12-21-2021" +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_format_number(session): # Create a dataframe with a column of numbers data = [ @@ -284,6 +317,11 @@ def test_format_number(session): assert res[2].VALUE_FORMATTED == "1.41" +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.parametrize("col_a, col_b", [("a", "b"), (col("a"), col("b"))]) def test_months_between(session, col_a, col_b): df = session.create_dataframe( @@ -293,6 +331,7 @@ def test_months_between(session, col_a, col_b): assert res[0][0] == 1.0 +@pytest.mark.localtest @pytest.mark.parametrize("col_a", ["a", col("a")]) def test_cast(session, col_a): df = session.create_dataframe([["2018-01-01"]], schema=["a"]) @@ -301,6 +340,7 @@ def test_cast(session, col_a): assert cast_res[0][0] == try_cast_res[0][0] == datetime.date(2018, 1, 1) +@pytest.mark.localtest @pytest.mark.parametrize("number_word", ["decimal", "number", "numeric"]) def test_cast_decimal(session, number_word): df = session.create_dataframe([[5.2354]], schema=["a"]) @@ -309,18 +349,21 @@ def test_cast_decimal(session, number_word): ) +@pytest.mark.localtest def test_cast_map_type(session): df = session.create_dataframe([['{"key": "1"}']], schema=["a"]) result = df.select(cast(parse_json(df["a"]), "object")).collect() assert json.loads(result[0][0]) == {"key": "1"} +@pytest.mark.localtest def test_cast_array_type(session): df = session.create_dataframe([["[1,2,3]"]], schema=["a"]) result = df.select(cast(parse_json(df["a"]), "array")).collect() assert json.loads(result[0][0]) == [1, 2, 3] +@pytest.mark.localtest def test_startswith(session): Utils.check_answer( TestData.string4(session).select(col("a").startswith(lit("a"))), @@ -329,6 +372,11 @@ def test_startswith(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_struct(session): df = session.createDataFrame([("Bob", 80), ("Alice", None)], ["name", "age"]) # case sensitive @@ -360,6 +408,11 @@ def test_struct(session): assert re.sub(r"\s", "", res[1].STRUCT) == '{"A":null,"B":"Alice"}' +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_strtok_to_array(session): # Create a dataframe data = [("a.b.c")] @@ -368,6 +421,11 @@ def test_strtok_to_array(session): assert res[0] == "a" and res[1] == "b" and res[2] == "c" +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.parametrize( "col_a, col_b, col_c", [("a", "b", "c"), (col("a"), col("b"), col("c"))] ) @@ -377,6 +435,11 @@ def test_greatest(session, col_a, col_b, col_c): assert res[0][0] == 3 +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.parametrize( "col_a, col_b, col_c", [("a", "b", "c"), (col("a"), col("b"), col("c"))] ) @@ -386,6 +449,11 @@ def test_least(session, col_a, col_b, col_c): assert res[0][0] == 1 +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) @pytest.mark.parametrize("col_a, col_b", [("a", "b"), (col("a"), col("b"))]) def test_hash(session, col_a, col_b): df = session.create_dataframe([[10, "10"]], schema=["a", "b"]) @@ -396,6 +464,11 @@ def test_hash(session, col_a, col_b): assert res[0][1] == 3622494980440108984 +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_basic_numerical_operations_negative(session): # sqrt df = session.sql("select 4").to_df("a") @@ -446,6 +519,11 @@ def test_basic_numerical_operations_negative(session): assert "'CEIL' expected Column or str, got: " in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_basic_string_operations(session): # Substring df = session.sql("select 'a not that long string'").to_df("a") @@ -531,6 +609,11 @@ def test_basic_string_operations(session): assert "'REVERSE' expected Column or str, got: " in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_substring_index(session): """test calling substring_index with delimiter as string""" df = session.create_dataframe([[0, "a.b.c.d"], [1, ""], [2, None]], ["id", "s"]) @@ -551,6 +634,11 @@ def test_substring_index(session): assert respos[2][0] is None +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_substring_index_col(session): """test calling substring_index with delimiter as column""" df = session.create_dataframe([["a,b,c,d", ","]], ["s", "delimiter"]) @@ -562,6 +650,11 @@ def test_substring_index_col(session): assert reslit[0][0] == "b,c,d" +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_bitshiftright(session): # Create a dataframe data = [(65504), (1), (4)] @@ -570,6 +663,11 @@ def test_bitshiftright(session): assert res[0][0] == 32752 and res[1][0] == 0 and res[2][0] == 2 +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_bround(session): # Create a dataframe data = [(decimal.Decimal(1.235)), decimal.Decimal(3.5)] @@ -580,6 +678,7 @@ def test_bround(session): assert str(res[0][0]) == "1" and str(res[1][0]) == "4" +# Enable for local testing after addressing SNOW-850268 def test_count_distinct(session): df = session.create_dataframe( [["a", 1, 1], ["b", 2, 2], ["c", 1, None], ["d", 5, None]] @@ -603,6 +702,9 @@ def test_count_distinct(session): assert df.select(count_distinct(df["*"])).collect() == [Row(2)] +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="Testing builtin" +) def test_builtin_avg_from_range(session): """Tests the builtin functionality, using avg().""" avg = builtin("avg") @@ -643,6 +745,9 @@ def test_builtin_avg_from_range(session): assert res == expected +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", reason="Testing builtin" +) def test_call_builtin_avg_from_range(session): """Tests the builtin functionality, using avg().""" df = session.range(1, 10, 2).select(call_builtin("avg", col("id"))) @@ -685,6 +790,11 @@ def test_call_builtin_avg_from_range(session): assert res == expected +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_is_negative(session): td = TestData.string1(session) @@ -817,6 +927,7 @@ def test_is_negative(session): assert "Invalid argument types for function 'IS_TIMESTAMP_TZ'" in str(ex_info) +@pytest.mark.localtest def test_parse_json(session): assert TestData.null_json1(session).select(parse_json(col("v"))).collect() == [ Row('{\n "a": null\n}'), @@ -832,6 +943,11 @@ def test_parse_json(session): ] +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_as_negative(session): td = TestData.string1(session) @@ -1010,6 +1126,7 @@ def test_as_negative(session): ) +@pytest.mark.localtest def test_to_date_to_array_to_variant_to_object(session): df = ( session.create_dataframe([["2013-05-17", 1, 3.14, '{"a":1}']]) @@ -1035,6 +1152,7 @@ def test_to_date_to_array_to_variant_to_object(session): assert df1.schema.fields[3].datatype == MapType(StringType(), StringType()) +@pytest.mark.localtest def test_to_binary(session): res = ( TestData.test_data1(session) @@ -1151,6 +1269,7 @@ def test_array_sort(session): Utils.check_answer(res, [Row(SORTED_A="[\n null,\n 20,\n 10,\n 0\n]")]) +@pytest.mark.localtest def test_coalesce(session): # Taken from FunctionSuite.scala Utils.check_answer( @@ -1169,6 +1288,11 @@ def test_coalesce(session): assert "'COALESCE' expected Column or str, got: " in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_uniform(session): df = session.sql("select 1").to_df("a") @@ -1203,6 +1327,11 @@ def test_uniform(session): assert decimal_int == decimal_decimal +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_uniform_negative(session): df = session.sql("select 1").to_df("a") with pytest.raises(SnowparkSQLException) as ex_info: @@ -1210,6 +1339,7 @@ def test_uniform_negative(session): assert "Numeric value 'z' is not recognized" in str(ex_info) +@pytest.mark.localtest def test_negate_and_not_negative(session): with pytest.raises(TypeError) as ex_info: TestData.null_data2(session).select(negate(["A", "B", "C"])) @@ -1220,6 +1350,11 @@ def test_negate_and_not_negative(session): assert "'NOT_' expected Column or str, got: " in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_random_negative(session): df = session.sql("select 1") with pytest.raises(SnowparkSQLException) as ex_info: @@ -1227,6 +1362,11 @@ def test_random_negative(session): assert "Numeric value 'abc' is not recognized" in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_check_functions_negative(session): df = session.sql("select 1").to_df("a") @@ -1255,6 +1395,11 @@ def test_parse_functions_negative(session): assert "'PARSE_XML' expected Column or str, got: " in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_json_functions_negative(session): df = session.sql("select 1").to_df("a") @@ -1274,6 +1419,11 @@ def test_json_functions_negative(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_to_filetype_negative(session): df = session.sql("select 1").to_df("a") # to_json @@ -1287,6 +1437,11 @@ def test_to_filetype_negative(session): assert "'TO_XML' expected Column or str, got: " in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_array_distinct(session): df = session.sql("select 1 A") df = df.withColumn( @@ -1299,6 +1454,11 @@ def test_array_distinct(session): assert array[0] == 1 and array[1] == 2 and array[2] == 3 +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_array_negative(session): df = session.sql("select 1").to_df("a") @@ -1390,6 +1550,11 @@ def test_array_negative(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_object_negative(session): df = session.sql("select 1").to_df("a") @@ -1423,6 +1588,11 @@ def test_object_negative(session): assert "'OBJECT_PICK' expected Column or str, got: " in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_date_operations_negative(session): df = session.sql("select 1").to_df("a") @@ -1435,6 +1605,7 @@ def test_date_operations_negative(session): assert "'DATEADD' expected Column or str, got: " in str(ex_info) +# TODO: enable for local testing after addressing SNOW-850263 def test_date_add_date_sub(session): df = session.createDataFrame( [("2019-01-23"), ("2019-06-24"), ("2019-09-20")], ["date"] @@ -1450,12 +1621,22 @@ def test_date_add_date_sub(session): assert res[2].DATE == datetime.date(2019, 9, 16) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_daydiff(session): df = session.createDataFrame([("2015-04-08", "2015-05-10")], ["d1", "d2"]) res = df.select(daydiff(to_date(df.d2), to_date(df.d1)).alias("diff")).collect() assert res[0].DIFF == 32 +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_get_negative(session): df = session.sql("select 1").to_df("a") @@ -1464,6 +1645,11 @@ def test_get_negative(session): assert "'GET' expected Column, int or str, got: " in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_array_generate_range(session): df = session.createDataFrame([(-2, 2)], ["C1", "C2"]) Utils.check_answer( @@ -1494,6 +1680,11 @@ def test_array_generate_range(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_sequence_negative(session): df = session.sql("select 1").to_df("a") @@ -1502,6 +1693,11 @@ def test_sequence_negative(session): assert "'SEQUENCE' expected Column or str, got: " in str(ex_info) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_sequence(session): df = session.createDataFrame([(-2, 2)], ["C1", "C2"]) Utils.check_answer( @@ -1553,6 +1749,11 @@ def test_sequence(session): ) +@pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) def test_array_unique_agg(session): def _result_str2lst(result): col_str = result[0][0] diff --git a/tests/integ/test_pandas_to_df.py b/tests/integ/test_pandas_to_df.py index f1c429eaaf4..fbb80f28fa5 100644 --- a/tests/integ/test_pandas_to_df.py +++ b/tests/integ/test_pandas_to_df.py @@ -26,6 +26,10 @@ from snowflake.snowpark.exceptions import SnowparkPandasException from tests.utils import Utils +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", raises=NotImplementedError +) + @pytest.fixture(scope="module") def tmp_table_basic(session): @@ -260,7 +264,8 @@ def test_write_temp_table_no_breaking_change(session, table_type, caplog): warning_dict.clear() -def test_create_dataframe_from_pandas(session): +@pytest.mark.localtest +def test_create_dataframe_from_pandas(session, local_testing_mode): pd = PandasDF( [ (1, 4.5, "t1", True), @@ -319,7 +324,8 @@ def test_write_pandas_temp_table_and_irregular_column_names(session, table_type) Utils.drop_table(session, table_name) -def test_write_pandas_with_timestamps(session): +@pytest.mark.localtest +def test_write_pandas_with_timestamps(session, local_testing_mode): datetime_with_tz = datetime( 1997, 6, 3, 14, 21, 32, 00, tzinfo=timezone(timedelta(hours=+10)) ) @@ -330,14 +336,23 @@ def test_write_pandas_with_timestamps(session): ], columns=["tm_tz", "tm_ntz"], ) - table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) - try: - session.write_pandas(pd, table_name, auto_create_table=True, table_type="temp") - data = session.sql(f'select * from "{table_name}"').collect() - assert data[0]["tm_tz"] is not None - assert data[0]["tm_ntz"] is not None - finally: - Utils.drop_table(session, table_name) + + if local_testing_mode: + sp_df = session.create_dataframe(pd) + data = sp_df.select("*").collect() + assert data[0]["tm_tz"] == datetime(1997, 6, 3, 4, 21, 32, 00) + assert data[0]["tm_ntz"] == 865347692000000 + else: + table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) + try: + session.write_pandas( + pd, table_name, auto_create_table=True, table_type="temp" + ) + data = session.sql(f'select * from "{table_name}"').collect() + assert data[0]["tm_tz"] is not None + assert data[0]["tm_ntz"] is not None + finally: + Utils.drop_table(session, table_name) def test_auto_create_table_similar_column_names(session): diff --git a/tests/integ/test_query_history.py b/tests/integ/test_query_history.py index eb04ae5a370..6989a173c9a 100644 --- a/tests/integ/test_query_history.py +++ b/tests/integ/test_query_history.py @@ -7,6 +7,12 @@ from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD from tests.utils import IS_IN_STORED_PROC +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=NotImplementedError, + strict=True, +) + def test_query_history(session): with session.query_history() as query_listener: @@ -60,12 +66,14 @@ def test_query_history_multiple_actions(session): assert query_listener.queries[2].sql_text == "select 2" +@pytest.mark.skipif(condition="config.getvalue('local_testing_mode')") def test_query_history_no_actions(session): with session.query_history() as query_listener: pass # no action assert len(query_listener.queries) == 0 +@pytest.mark.skipif(condition="config.getvalue('local_testing_mode')") @pytest.mark.skipif( IS_IN_STORED_PROC, reason="alter session is not supported in owner's right stored proc", diff --git a/tests/integ/test_scoped_temp_objects.py b/tests/integ/test_scoped_temp_objects.py index f53792056c4..8322484ea36 100644 --- a/tests/integ/test_scoped_temp_objects.py +++ b/tests/integ/test_scoped_temp_objects.py @@ -10,6 +10,11 @@ random_name_for_temp_object, ) +pytestmark = pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="local test does not support", +) + def test_create_scoped_temp_objects_syntax(session): snowpark_temp_table_name = random_name_for_temp_object(TempObjectType.TABLE) diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 39604679daf..c0113a151b8 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -25,6 +25,11 @@ ) from tests.utils import IS_IN_STORED_PROC, IS_IN_STORED_PROC_LOCALFS, TestFiles, Utils +pytestmark = pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Tests are creating sessions from connection parameters", +) + @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_runtime_config(db_parameters): @@ -191,7 +196,7 @@ def test_close_session_in_sp(session): @pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="need resources") -def test_list_files_in_stage(session, resources_path): +def test_list_files_in_stage(session, resources_path, local_testing_mode): stage_name = Utils.random_stage_name() special_name = f'"{stage_name}/aa"' single_quoted_name = f"'{stage_name}/b\\' b'" @@ -241,9 +246,17 @@ def test_list_files_in_stage(session, resources_path): assert os.path.basename(test_files.test_file_csv) in files6 Utils.create_stage(session, single_quoted_name, is_temporary=False) - Utils.upload_to_stage( - session, single_quoted_name, test_files.test_file_csv, compress=False - ) + if not local_testing_mode: + # TODO: session.file.put has a bug that it can not add '@' to single quoted name stage when normalizing path + session._conn.upload_file( + stage_location=single_quoted_name, + path=test_files.test_file_csv, + compress_data=False, + ) + else: + Utils.upload_to_stage( + session, single_quoted_name, test_files.test_file_csv, compress=False + ) files7 = session._list_files_in_stage(single_quoted_name) assert len(files7) == 1 assert os.path.basename(test_files.test_file_csv) in files7 diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 55df8432f80..6cc5307498e 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -36,6 +36,12 @@ from collections.abc import Iterable +pytestmark = pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="local test does not support sql generation", +) + + @pytest.fixture(scope="module", autouse=True) def skip(pytestconfig): if pytestconfig.getoption("disable_sql_simplifier"): diff --git a/tests/integ/test_stored_procedure.py b/tests/integ/test_stored_procedure.py index 1e60c9812d9..680a0306509 100644 --- a/tests/integ/test_stored_procedure.py +++ b/tests/integ/test_stored_procedure.py @@ -26,6 +26,7 @@ from snowflake.snowpark.dataframe import DataFrame from snowflake.snowpark.exceptions import ( SnowparkInvalidObjectNameException, + SnowparkSessionException, SnowparkSQLException, ) from snowflake.snowpark.functions import ( @@ -54,16 +55,24 @@ Utils, ) -pytestmark = pytest.mark.udf +pytestmark = [ + pytest.mark.udf, + pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=(NotImplementedError, SnowparkSessionException), + strict=True, + ), +] tmp_stage_name = Utils.random_stage_name() @pytest.fixture(scope="module", autouse=True) -def setup(session, resources_path): +def setup(session, resources_path, local_testing_mode): test_files = TestFiles(resources_path) - Utils.create_stage(session, tmp_stage_name, is_temporary=True) - session.add_packages("snowflake-snowpark-python") + if not local_testing_mode: + Utils.create_stage(session, tmp_stage_name, is_temporary=True) + session.add_packages("snowflake-snowpark-python") Utils.upload_to_stage( session, tmp_stage_name, test_files.test_sp_py_file, compress=False ) @@ -563,6 +572,7 @@ def return_datetime(_: Session) -> datetime.datetime: assert return_datetime_sp() == dt +@pytest.mark.skipif(condition="config.getvalue('local_testing_mode')") @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_permanent_sp(session, db_parameters): stage_name = Utils.random_stage_name() diff --git a/tests/integ/test_table_function.py b/tests/integ/test_table_function.py index 27f6d0387ca..385a1200ee6 100644 --- a/tests/integ/test_table_function.py +++ b/tests/integ/test_table_function.py @@ -14,6 +14,10 @@ ) from tests.utils import Utils +pytestmark = pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", raises=NotImplementedError +) + def test_query_args(session): split_to_table = table_function("split_to_table") diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index dfe7a782e12..e3d2b8b4336 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -47,6 +47,11 @@ else: from collections.abc import Iterable +pytestmark = pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="Telemetry is not public API and currently not supported in local testing", +) + class TelemetryDataTracker: def __init__(self, session: Session) -> None: @@ -777,8 +782,12 @@ def test_dataframe_stat_functions_api_calls(session): ] -def test_dataframe_na_functions_api_calls(session): - df1 = TestData.double3(session) +@pytest.mark.skipif( + condition="config.getvalue('local_testing_mode')", + reason="api calls is not the same in local testing", +) +def test_dataframe_na_functions_api_calls(session, local_testing_mode): + df1 = TestData.double3(session, local_testing_mode) assert df1._plan.api_calls == [{"name": "Session.sql"}] drop = df1.na.drop(thresh=1, subset=["a"]) @@ -792,7 +801,7 @@ def test_dataframe_na_functions_api_calls(session): # check to make sure that the original DF is unchanged assert df1._plan.api_calls == [{"name": "Session.sql"}] - df2 = TestData.null_data3(session) + df2 = TestData.null_data3(session, local_testing_mode) assert df2._plan.api_calls == [{"name": "Session.sql"}] fill = df2.na.fill({"flo": 12.3, "int": 11, "boo": False, "str": "f"}) diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index ee4d0d3b7cc..8ee77fd5c13 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -89,7 +89,12 @@ Utils, ) -pytestmark = pytest.mark.udf +pytestmark = [ + pytest.mark.udf, + pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", raises=NotImplementedError + ), +] tmp_stage_name = Utils.random_stage_name() diff --git a/tests/integ/test_udtf.py b/tests/integ/test_udtf.py index f12fb2e3587..c596cdffb6d 100644 --- a/tests/integ/test_udtf.py +++ b/tests/integ/test_udtf.py @@ -10,7 +10,7 @@ from snowflake.snowpark import Row, Table from snowflake.snowpark._internal.utils import TempObjectType -from snowflake.snowpark.exceptions import SnowparkSQLException +from snowflake.snowpark.exceptions import SnowparkSessionException, SnowparkSQLException from snowflake.snowpark.functions import lit, udtf from snowflake.snowpark.session import Session from snowflake.snowpark.types import ( @@ -43,7 +43,14 @@ except ImportError: is_pandas_available = False -pytestmark = pytest.mark.udf +pytestmark = [ + pytest.mark.udf, + pytest.mark.xfail( + condition="config.getvalue('local_testing_mode')", + raises=(NotImplementedError, SnowparkSessionException), + strict=True, + ), +] @pytest.fixture(scope="module") diff --git a/tests/mock/readme.md b/tests/mock/readme.md new file mode 100644 index 00000000000..c4306a8a91d --- /dev/null +++ b/tests/mock/readme.md @@ -0,0 +1,194 @@ +# Snowpark Local Testing + +Snowpark Local Testing allows the usage of creating DataFrames and +performing operations without a connection to Snowflake. + +## Quickstart + +Instead of using `Session.SessionBuilder` to create a session, in local testing, +instantiate `Session` with a `MockServerConnection` object. + +```python +from snowflake.snowpark.session import Session +from snowflake.snowpark.mock.connection import MockServerConnection +session = Session(MockServerConnection()) +df = session.create_dataframe( + [ + [1, 2, "welcome"], + [3, 4, "to"], + [5, 6, "the"], + [7, 8, "private"], + [9, 0, "preview"] + ], + schema=['a', 'b', 'c'] +) +df.select('c').show() +``` + +## General Usage Documentation + +### Installation + +The Snowpark local testing framework can be installed from the [development branch](https://github.com/snowflakedb/snowpark-python/tree/dev/local-testing) on the public repository. In your requirements.txt, add the following: + +```bash +pip install "snowflake-snowpark-python[pandas]@git+https://github.com/snowflakedb/snowpark-python@dev/local-testing" +``` + +#### Install from specific commit + +The installation instructions above will install Snowpark from the development branch, meaning that any new commits pushed to that branch will be installed to your Python environment the next time you install the package. You can also set the installation to a specific commit using the syntax below: + +```bash +pip install "snowflake-snowpark-python[pandas]@git+https://github.com/snowflakedb/snowpark-python@dev/local-testing@29385014c755fe20122fac536358a8ebdeb761bc" +``` + +If you use this approach, you will need to manually update the commit if you want to use features that are added to the branch in future commits. + + +## Patching Built-In Functions + +Not all the built-in functions under `snowflake.snowpark.functions` have been re-implemented for the local testing framework. +So if you use a function which is not compatible, you will need to use the `@patch` decorator from `snowflake.snowpark.mock.functions` to create a patch. + +To define and implement the patched function, the signature (the parameter list) must align with the built-in function, and the local testing framework +will pass parameters to the patched function by the following rules: +- for parameter of `ColumnOrName` type in the signature of built-in functions, `ColumnEmulator` will be passed as the parameter of the patched functions, `ColumnEmulator` is a pandas.Series-like object containing the column data. +- for parameter of `LiteralType` type in the signature of built-in functions, the literal value will be passed as the parameter of the patched functions. +- for parameter of non-string Python type in the signature of built-in functions, the raw value will be passed as the parameter of the patched functions. + +As for the returning type of the patched functions, returning an instance of `ColumnEmulator` is expected in correspondence with the returning type of `Column` of built-in functions. + +For example, the built-in function `to_timestamp()` could be patched as so: + +```python +import datetime +from snowflake.snowpark.session import Session +from snowflake.snowpark.mock.connection import MockServerConnection +from snowflake.snowpark.mock.functions import patch +from snowflake.snowpark.functions import to_timestamp +from snowflake.snowpark.mock.snowflake_data_type import ColumnEmulator +from snowflake.snowpark.types import TimestampType + +@patch(to_timestamp) +def mock_to_timestamp(column: ColumnEmulator, format = None) -> ColumnEmulator: + ret_column = ColumnEmulator(data=[datetime.datetime.strptime(row, '%Y-%m-%dT%H:%M:%S%z') for row in column]) + ret_column.sf_type = TimestampType() + return ret_column + +session = Session(MockServerConnection()) +df = session.create_dataframe( + [ + ["2023-06-20T12:00:00+00:00"], + ], + schema=["a"], +) +df = df.select(to_timestamp("a")) +df.collect() +``` + +Let's go through the above patch conceptually: the first line of the method iterates down the rows of the given column, using `strptime()` to do the timestamp conversion. The following line sets the type of the column, and then the column is returned. Note that the implementation of the patch does not necessarily need to re-implement the built-in; the patch could return static values *if* that fulfills the test case. + +Let's do another example, this time for `parse_json()`. Similarly, the implementation iterates down the given column and uses a Python method to transform the data--in this case, using `json.loads()`. + +```python +import json +from snowflake.snowpark.mock.functions import patch +from snowflake.snowpark.functions import parse_json +from snowflake.snowpark.types import VariantType +from snowflake.snowpark.mock.snowflake_data_type import ColumnEmulator + +@patch(parse_json) +def mock(col: ColumnEmulator): + ret_column = ColumnEmulator(data=[json.loads(row) for row in col]) + ret_column.sf_type = VariantType() + return ret_column +``` + +## SQL and Table Operations + +`Session.sql` is not supported in Local Testing due to the complexity of parsing SQL texts. Please use DataFrame API where possible, otherwise we suggest mocking this method using Python's builtin mock module: + +#### Example + +```python +from unittest import mock +from functools import partial +def test_something(pytestconfig, session): + + def mock_sql(session, sql_string): # patch for SQL operations + if sql_string == "select 1,2,3": + return session.create_dataframe([[1,2,3]]) + if sql_string == "select * from shared_table": + return session.create_dataframe(session.table("shared_table")) + elif sql_string == "drop table shared_table": + session.table("shared_table").drop_table() + return session.create_dataframe([]) + else: + raise RuntimeError(f"Unexpected query execution: {sql_string}") + + if pytestconfig.getoption('--snowflake-session'): + mock.patch.object(session, 'sql', wraps=partial(mock_sql, session)) + + assert session.sql("select 1,2,3").collect() == [[1,2,3]] + assert session.sql("select * from shared_table").collect() == [[1,2],[3,4]] + + session.table("shared_table").delete() + assert session.sql("select * from shared_table").collect() == [] +``` + +Currently, the only supported operations on `snowflake.snowpark.Table` are + +- `DataFrame.save_as_table()` +- `Session.table()` +- `Table.drop_table()` + +Where all tables created by `DataFrame.save_as_table` are saved as temporary tables in memory and can be retrieved using `Session.table`. You can use the supported `DataFrame` operations on `Table` as usual, but `Table`-specific API's other than `Table.drop_table` are not supported yet. + + +## Supported APIs + +### Session + +- `.create_dataframe()` + +### DataFrame + +- `.select()` +- `.sort()` +- `.filter()` and `.where()` +- `.agg()` +- `.join()` +- `.union()` +- `.take()` +- `.first()` +- `.sort()` +- `.with_column()` + +### Scalar Functions + +- `min()` +- `max()` +- `sum()` +- `count()` +- `contains()` +- `abs()` + +> If a scalar function is not in the list above, you can [patch it](#patching-built-in-functions). + +## Limitations + +Apart from the unsupported APIs which are not listed in the above section, here is the list of known limitations that will be addressed in the future. +Please note that there will be unaware limitations, in this case please feel free to reach out to share feedbacks. + +- SQL Simplifier must be enabled on the Session (this is the default in the latest Snowpark Python version) +- Altering warehouses, schemas, and other session properties is not currently supported +- Stored Procedures and UDFs are not supported +- Window Functions are not supported +- `DataFrame.join` does not support join with another DataFrame that has the same column names +- `DataFrame.with_column` does not support replace column of the same +- Raw SQL String is in general not supported, e.g., DataFrame.filter doesn't support raw SQL String expression. +- There could be gaps between local testing framework and SnowflakeDB in the following areas: + - Results of calculation regarding the values and data types + - Columns names + - Error experience diff --git a/tests/mock/test_agg.py b/tests/mock/test_agg.py new file mode 100644 index 00000000000..98eaf3b6ab0 --- /dev/null +++ b/tests/mock/test_agg.py @@ -0,0 +1,280 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import math + +import pytest + +import snowflake.snowpark.mock.functions as snowpark_mock_functions +from snowflake.snowpark import DataFrame, Row, Session +from snowflake.snowpark.functions import ( + approx_percentile_combine, + array_agg, + avg, + col, + count, + covar_pop, + covar_samp, + function, + grouping, + listagg, + lit, + max, + mean, + median, + min, + stddev, + stddev_pop, + sum, +) +from snowflake.snowpark.mock.connection import MockServerConnection +from snowflake.snowpark.mock.snowflake_data_type import ColumnEmulator, ColumnType +from snowflake.snowpark.types import DoubleType +from tests.utils import Utils + +session = Session(MockServerConnection()) + + +@pytest.mark.localtest +def test_agg_single_column(): + origin_df: DataFrame = session.create_dataframe( + [[1], [8], [6], [3], [100], [400], [None]], schema=["v"] + ) + assert origin_df.select(sum("v")).collect() == [Row(518)] + assert origin_df.select(max("v")).collect() == [Row(400)] + assert origin_df.select(min("v")).collect() == [Row(1)] + assert origin_df.select(median("v")).collect() == [Row(7.0)] + assert origin_df.select(avg("v")).collect() == [ + Row(86.33333) + ] # snowflake keeps scale of 5 + assert origin_df.select(mean("v")).collect() == [Row(86.33333)] + assert origin_df.select(count("v")).collect() == [Row(6)] + assert origin_df.count() == 7 + + +@pytest.mark.localtest +def test_agg_double_column(): + origin_df: DataFrame = session.create_dataframe( + [ + [10.0, 11.0], + [20.0, 22.0], + [25.0, 0.0], + [30.0, 35.0], + [999.0, None], + [None, 1234.0], + [math.nan, None], + [math.nan, 1.0], + ], + schema=["m", "n"], + ) + assert origin_df.select(covar_pop("m", "n")).collect() == [Row(38.75)] + assert origin_df.select(sum(col("m") + col("n"))).collect() == [Row(153.0)] + assert origin_df.select(sum(col("m") - col("n"))).collect() == [Row(17.0)] + + +@pytest.mark.localtest +def test_agg_function_multiple_parameters(): + origin_df: DataFrame = session.create_dataframe( + ["k1", "k1", "k3", "k4", [None]], schema=["v"] + ) + assert origin_df.select(listagg("v", delimiter='~!1,."')).collect() == [ + Row('k1~!1,."k1~!1,."k3~!1,."k4') + ] + + assert origin_df.select( + listagg("v", delimiter='~!1,."', is_distinct=True) + ).collect() == [Row('k1~!1,."k3~!1,."k4')] + + +@pytest.mark.localtest +def test_register_new_methods(): + origin_df: DataFrame = session.create_dataframe( + [ + [10.0, 11.0], + [20.0, 22.0], + [25.0, 0.0], + [30.0, 35.0], + ], + schema=["m", "n"], + ) + + # approx_percentile + with pytest.raises(NotImplementedError): + origin_df.select(function("approx_percentile")(col("m"), lit(0.5))).collect() + # snowflake.snowpark.functions.approx_percentile is being updated to use lit + # so `function` won't be needed here. + + @snowpark_mock_functions.patch("approx_percentile") + def mock_approx_percentile( + column: ColumnEmulator, percentile: float + ) -> ColumnEmulator: + assert column.tolist() == [10.0, 20.0, 25.0, 30.0] + assert percentile == 0.5 + return ColumnEmulator(data=123, sf_type=ColumnType(DoubleType(), False)) + + assert origin_df.select( + function("approx_percentile")(col("m"), lit(0.5)) + ).collect() == [Row(123)] + + # covar_samp + with pytest.raises(NotImplementedError): + origin_df.select(covar_samp(col("m"), "n")).collect() + + @snowpark_mock_functions.patch(covar_samp) + def mock_covar_samp( + column1: ColumnEmulator, + column2: ColumnEmulator, + ): + assert column1.tolist() == [10.0, 20.0, 25.0, 30.0] + assert column2.tolist() == [11.0, 22.0, 0.0, 35.0] + return ColumnEmulator(data=123, sf_type=ColumnType(DoubleType(), False)) + + assert origin_df.select(covar_samp(col("m"), "n")).collect() == [Row(123)] + + # stddev + with pytest.raises(NotImplementedError): + origin_df.select(stddev("n")).collect() + + @snowpark_mock_functions.patch(stddev) + def mock_stddev(column: ColumnEmulator): + assert column.tolist() == [11.0, 22.0, 0.0, 35.0] + return ColumnEmulator(data=123, sf_type=ColumnType(DoubleType(), False)) + + assert origin_df.select(stddev("n")).collect() == [Row(123)] + + # array_agg + with pytest.raises(NotImplementedError): + origin_df.select(array_agg("n", False)).collect() + + # instead of kwargs, positional argument also works + @snowpark_mock_functions.patch(array_agg) + def mock_mock_array_agg(column: ColumnEmulator, is_distinct): + assert is_distinct is True + assert column.tolist() == [11.0, 22.0, 0.0, 35.0] + return ColumnEmulator(data=123, sf_type=ColumnType(DoubleType(), False)) + + assert origin_df.select(array_agg("n", True)).collect() == [Row(123)] + + # grouping + with pytest.raises(NotImplementedError): + origin_df.select(grouping("m", col("n"))).collect() + + @snowpark_mock_functions.patch(grouping) + def mock_mock_grouping(*columns): + assert len(columns) == 2 + assert columns[0].tolist() == [10.0, 20.0, 25.0, 30.0] + assert columns[1].tolist() == [11.0, 22.0, 0.0, 35.0] + return ColumnEmulator(data=123, sf_type=ColumnType(DoubleType(), False)) + + assert origin_df.select(grouping("m", col("n"))).collect() == [Row(123)] + + +@pytest.mark.localtest +def test_group_by(): + origin_df: DataFrame = session.create_dataframe( + [ + ["a", "ddd", 11.0], + ["a", "ddd", 22.0], + ["b", "ccc", 9.0], + ["b", "ccc", 9.0], + ["b", "aaa", 35.0], + ["b", "aaa", 99.0], + ], + schema=["m", "n", "q"], + ) + + Utils.check_answer( + origin_df.group_by("m").agg(sum("q")).collect(), + [ + Row("a", 33.0), + Row("b", 152.0), + ], + ) + + Utils.check_answer( + origin_df.group_by("n").agg(min("q")).collect(), + [ + Row("ddd", 11.0), + Row("ccc", 9.0), + Row("aaa", 35.0), + ], + ) + + with pytest.raises(NotImplementedError): + origin_df.group_by("n", "m").agg(approx_percentile_combine("q")).collect() + + @snowpark_mock_functions.patch(approx_percentile_combine) + def mock_approx_percentile_combine(state: ColumnEmulator): + if state.iat[0] == 11: + return ColumnEmulator(data=-1.0, sf_type=ColumnType(DoubleType(), False)) + if state.iat[0] == 9: + return ColumnEmulator(data=0.0, sf_type=ColumnType(DoubleType(), False)) + if state.iat[0] == 35: + return ColumnEmulator(data=1.0, sf_type=ColumnType(DoubleType(), False)) + raise RuntimeError("This error shall never be raised") + + Utils.check_answer( + origin_df.group_by("n").agg(approx_percentile_combine("q")).collect(), + [ + Row("ddd", -1.0), + Row("ccc", 0.0), + Row("aaa", 1.0), + ], + ) + + Utils.check_answer( + origin_df.group_by("m", "n").agg(mean("q")).collect(), + [ + Row("a", "ddd", 16.5), + Row("b", "ccc", 9.0), + Row("b", "aaa", 67.0), + ], + ) + + +@pytest.mark.localtest +def test_agg(): + origin_df: DataFrame = session.create_dataframe( + [ + [15.0, 11.0], + [2.0, 22.0], + [29.0, 9.0], + [30.0, 9.0], + [4.0, 35.0], + [54.0, 99.0], + ], + schema=["m", "n"], + ) + + Utils.check_answer(origin_df.agg(sum("m")).collect(), Row(134.0)) + + Utils.check_answer(origin_df.agg(min("m"), max("n")).collect(), Row(2.0, 99.0)) + + Utils.check_answer( + origin_df.agg({"m": "count", "n": "sum"}).collect(), Row(6.0, 185.0) + ) + + snowpark_mock_functions._unregister_func_implementation("stddev") + snowpark_mock_functions._unregister_func_implementation("stddev_pop") + + with pytest.raises(NotImplementedError): + origin_df.select(stddev("n"), stddev_pop("m")).collect() + + @snowpark_mock_functions.patch("stddev") + def mock_stddev(column: ColumnEmulator): + assert column.tolist() == [11.0, 22.0, 9.0, 9.0, 35.0, 99.0] + return ColumnEmulator(data=123, sf_type=ColumnType(DoubleType(), False)) + + # stddev_pop is not implemented yet + with pytest.raises(NotImplementedError): + origin_df.select(stddev("n"), stddev_pop("m")).collect() + + @snowpark_mock_functions.patch("stddev_pop") + def mock_stddev_pop(column: ColumnEmulator): + assert column.tolist() == [15.0, 2.0, 29.0, 30.0, 4.0, 54.0] + return ColumnEmulator(data=456, sf_type=ColumnType(DoubleType(), False)) + + Utils.check_answer( + origin_df.select(stddev("n"), stddev_pop("m")).collect(), Row(123.0, 456.0) + ) diff --git a/tests/mock/test_column_names.py b/tests/mock/test_column_names.py new file mode 100644 index 00000000000..97b2ca06d8d --- /dev/null +++ b/tests/mock/test_column_names.py @@ -0,0 +1,72 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest + +from snowflake.snowpark import Session +from snowflake.snowpark.functions import avg, col +from snowflake.snowpark.mock.connection import MockServerConnection + +session = Session(MockServerConnection()) + + +@pytest.mark.localtest +def test_str_column_name_no_quotes(): + df = session.create_dataframe([1, 2], schema=["a"]) + assert str(df.select(col("a")).collect()) == "[Row(A=1), Row(A=2)]" + assert str(df.select(avg(col("a"))).collect()) == '[Row(AVG("A")=1.5)]' + + # column name with quotes + df = session.create_dataframe([1, 2], schema=['"a"']) + assert str(df.select(col('"a"')).collect()) == "[Row(a=1), Row(a=2)]" + assert str(df.select(avg(col('"a"'))).collect()) == '[Row(AVG("A")=1.5)]' + + +@pytest.mark.localtest +def test_show_column_name_with_quotes(): + df = session.create_dataframe([1, 2], schema=["a"]) + assert ( + df.select(col("a"))._show_string() + == """\ +------- +|"A" | +------- +|1 | +|2 | +------- +""" + ) + assert ( + df.select(avg(col("a")))._show_string() + == """\ +---------------- +|"AVG(""A"")" | +---------------- +|1.5 | +---------------- +""" + ) + + # column name with quotes + df = session.create_dataframe([1, 2], schema=['"a"']) + assert ( + df.select(col('"a"'))._show_string() + == """\ +------- +|"a" | +------- +|1 | +|2 | +------- +""" + ) + assert ( + df.select(avg(col('"a"')))._show_string() + == """\ +---------------- +|"AVG(""A"")" | +---------------- +|1.5 | +---------------- +""" + ) diff --git a/tests/mock/test_create_df_from_pandas.py b/tests/mock/test_create_df_from_pandas.py new file mode 100644 index 00000000000..e608326d9b3 --- /dev/null +++ b/tests/mock/test_create_df_from_pandas.py @@ -0,0 +1,347 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import datetime +import decimal +import json +import math + +import pytest +import pytz + +from snowflake.snowpark import Row, Session, Table +from snowflake.snowpark.mock.connection import MockServerConnection +from snowflake.snowpark.types import BooleanType, DoubleType, LongType, StringType + +try: + import pandas as pd +except ImportError: + pytest.skip("pandas is not installed, skipping the tests", allow_module_level=True) + +session = Session(MockServerConnection()) + + +@pytest.mark.localtest +def test_create_from_pandas_basic_pandas_types(): + now_time = datetime.datetime( + year=2023, month=10, day=25, hour=13, minute=46, second=12, microsecond=123 + ) + delta_time = datetime.timedelta(days=1) + pandas_df = pd.DataFrame( + data=[ + ("Name1", 1.2, 1234567890, True, now_time, delta_time), + ("nAme_2", 20, 1, False, now_time - delta_time, delta_time), + ], + columns=[ + "sTr", + "dOublE", + "LoNg", + "booL", + "timestamp", + "TIMEDELTA", # note that in the current snowpark, column name with all upper case is not double quoted + ], + ) + sp_df = session.create_dataframe(data=pandas_df) + assert ( + sp_df.schema[0].name == '"sTr"' + and isinstance(sp_df.schema[0].datatype, StringType) + and sp_df.schema[0].nullable + ) + assert ( + sp_df.schema[1].name == '"dOublE"' + and isinstance(sp_df.schema[1].datatype, DoubleType) + and sp_df.schema[1].nullable + ) + assert ( + sp_df.schema[2].name == '"LoNg"' + and isinstance(sp_df.schema[2].datatype, LongType) + and sp_df.schema[2].nullable + ) + assert ( + sp_df.schema[3].name == '"booL"' + and isinstance(sp_df.schema[3].datatype, BooleanType) + and sp_df.schema[3].nullable + ) + assert ( + sp_df.schema[4].name == '"timestamp"' + and isinstance(sp_df.schema[4].datatype, LongType) + and sp_df.schema[4].nullable + ) + assert ( + sp_df.schema[5].name == "TIMEDELTA" + and isinstance(sp_df.schema[5].datatype, LongType) + and sp_df.schema[5].nullable + ) + assert isinstance(sp_df, Table) + assert ( + str(sp_df.schema) + == """\ +StructType([\ +StructField('"sTr"', StringType(16777216), nullable=True), \ +StructField('"dOublE"', DoubleType(), nullable=True), \ +StructField('"LoNg"', LongType(), nullable=True), \ +StructField('"booL"', BooleanType(), nullable=True), \ +StructField('"timestamp"', LongType(), nullable=True), \ +StructField('TIMEDELTA', LongType(), nullable=True)\ +])\ +""" + ) + assert sp_df.select('"sTr"').collect() == [Row("Name1"), Row("nAme_2")] + assert sp_df.select('"dOublE"').collect() == [Row(1.2), Row(20)] + assert sp_df.select('"LoNg"').collect() == [Row(1234567890), Row(1)] + assert sp_df.select('"booL"').collect() == [Row(True), Row(False)] + assert sp_df.select('"timestamp"').collect() == [ + Row(1698241572000123), + Row(1698155172000123), + ] + assert sp_df.select("TIMEDELTA").collect() == [ + Row(86400000000000), + Row(86400000000000), + ] + + pandas_df = pd.DataFrame( + data=[ + float("inf"), + float("-inf"), + ], + columns=["float"], + ) + sp_df = session.create_dataframe(data=pandas_df) + assert ( + sp_df.schema[0].name == '"float"' + and isinstance(sp_df.schema[0].datatype, DoubleType) + and sp_df.schema[0].nullable + ) + + assert sp_df.select('"float"').collect() == [ + Row(float("inf")), + Row(float("-inf")), + ] + + +@pytest.mark.localtest +def test_create_from_pandas_basic_python_types(): + date_data = datetime.date(year=2023, month=10, day=26) + time_data = datetime.time(hour=12, minute=12, second=12) + byte_data = b"bytedata" + dict_data = {"a": 123} + array_data = [1, 2, 3, 4] + decimal_data = decimal.Decimal("1.23") + pandas_df = pd.DataFrame( + { + "A": pd.Series([date_data]), + "B": pd.Series([time_data]), + "C": pd.Series([byte_data]), + "D": pd.Series([dict_data]), + "E": pd.Series([array_data]), + "F": pd.Series([decimal_data]), + } + ) + sp_df = session.create_dataframe(data=pandas_df) + assert ( + str(sp_df.schema) + == """\ +StructType([StructField('A', DateType(), nullable=True), StructField('B', TimeType(), nullable=True), StructField('C', BinaryType(), nullable=True), StructField('D', VariantType(), nullable=True), StructField('E', VariantType(), nullable=True), StructField('F', DecimalType(3, 2), nullable=True)])\ +""" + ) + assert sp_df.select("*").collect() == [ + Row( + date_data, + time_data, + bytearray(byte_data), + json.dumps(dict_data, indent=2), + json.dumps(array_data, indent=2), + decimal_data, + ) + ] + + +@pytest.mark.localtest +def test_create_from_pandas_datetime_types(): + now_time = datetime.datetime( + year=2023, + month=10, + day=25, + hour=13, + minute=46, + second=12, + microsecond=123, + tzinfo=pytz.UTC, + ) + now_time_without_tz = datetime.datetime( + year=2023, month=10, day=25, hour=13, minute=46, second=12, microsecond=123 + ) + delta_time = datetime.timedelta(days=1) + pandas_df = pd.DataFrame( + { + "A": pd.Series([now_time]).dt.tz_localize(None), + "B": pd.Series([delta_time]), + "C": pd.Series([now_time], dtype=pd.DatetimeTZDtype(tz=pytz.UTC)), + } + ) + sp_df = session.create_dataframe(data=pandas_df) + assert sp_df.select("A").collect() == [Row(1698241572000123)] + assert sp_df.select("B").collect() == [Row(86400000000000)] + assert sp_df.select("C").collect() == [Row(now_time_without_tz)] + + pandas_df = pd.DataFrame( + { + "A": pd.Series( + [ + datetime.datetime( + 1997, + 6, + 3, + 14, + 21, + 32, + 00, + tzinfo=datetime.timezone(datetime.timedelta(hours=+10)), + ) + ] + ) + } + ) + sp_df = session.create_dataframe(data=pandas_df) + assert ( + str(sp_df.schema) + == "StructType([StructField('A', TimestampType(), nullable=True)])" + ) + assert sp_df.select("A").collect() == [ + Row(datetime.datetime(1997, 6, 3, 4, 21, 32, 00)) + ] + + +@pytest.mark.localtest +def test_create_from_pandas_extension_types(): + """ + + notes: + pd.SparseDtype is not supported in the live mode due to pyarrow + """ + pandas_df = pd.DataFrame( + { + "A": pd.Series(["a", "b", "c", "a"], dtype=pd.CategoricalDtype()), + } + ) + sp_df = session.create_dataframe(data=pandas_df) + assert sp_df.select("A").collect() == [Row("a"), Row("b"), Row("c"), Row("a")] + + pandas_df = pd.DataFrame( + { + "A": pd.Series([1, 2, 3], dtype=pd.Int8Dtype()), + "B": pd.Series([1, 2, 3], dtype=pd.Int16Dtype()), + "C": pd.Series([1, 2, 3], dtype=pd.Int32Dtype()), + "D": pd.Series([1, 2, 3], dtype=pd.Int64Dtype()), + "E": pd.Series([1, 2, 3], dtype=pd.UInt8Dtype()), + "F": pd.Series([1, 2, 3], dtype=pd.UInt16Dtype()), + "G": pd.Series([1, 2, 3], dtype=pd.UInt32Dtype()), + "H": pd.Series([1, 2, 3], dtype=pd.UInt64Dtype()), + } + ) + sp_df = session.create_dataframe(data=pandas_df) + assert ( + sp_df.select("A").collect() + == sp_df.select("B").collect() + == sp_df.select("C").collect() + == sp_df.select("D").collect() + == sp_df.select("E").collect() + == sp_df.select("F").collect() + == sp_df.select("G").collect() + == sp_df.select("H").collect() + == [Row(1), Row(2), Row(3)] + ) + + pandas_df = pd.DataFrame( + { + "A": pd.Series([1.1, 2.2, 3.3], dtype=pd.Float32Dtype()), + } + ) + sp_df = session.create_dataframe(data=pandas_df) + assert sp_df.select("A").collect() == [Row(1.1), Row(2.2), Row(3.3)] + + pandas_df = pd.DataFrame( + { + "A": pd.Series([1.1, 2.2, 3.3], dtype=pd.Float64Dtype()), + } + ) + sp_df = session.create_dataframe(data=pandas_df) + assert sp_df.select("A").collect() == [Row(1.1), Row(2.2), Row(3.3)] + + pandas_df = pd.DataFrame( + { + "A": pd.Series(["a", "b", "c"], dtype=pd.StringDtype()), + } + ) + sp_df = session.create_dataframe(data=pandas_df) + assert sp_df.select("A").collect() == [Row("a"), Row("b"), Row("c")] + + pandas_df = pd.DataFrame( + { + "A": pd.Series([True, False, True], dtype=pd.BooleanDtype()), + } + ) + sp_df = session.create_dataframe(data=pandas_df) + assert sp_df.select("A").collect() == [Row(True), Row(False), Row(True)] + + pandas_df = pd.DataFrame( + { + "A": pd.Series([pd.Period("2022-01", freq="M")], dtype=pd.PeriodDtype()), + } + ) + sp_df = session.create_dataframe(data=pandas_df) + assert sp_df.select("A").collect() == [Row(624)] + + pandas_df = pd.DataFrame( + { + "A": pd.Series([pd.Interval(left=0, right=5)], dtype=pd.IntervalDtype()), + "B": pd.Series( + [ + pd.Interval( + pd.Timestamp("2017-01-01 00:00:00"), + pd.Timestamp("2018-01-01 00:00:00"), + ) + ], + dtype=pd.IntervalDtype(), + ), + } + ) + sp_df = session.create_dataframe(data=pandas_df) + ret = sp_df.select("*").collect() + assert ( + str(sp_df.schema) + == """\ +StructType([StructField('A', VariantType(), nullable=True), StructField('B', VariantType(), nullable=True)])\ +""" + ) + assert ( + str(ret) + == """\ +[Row(A='{\\n "left": 0,\\n "right": 5\\n}', B='{\\n "left": 1483228800000000,\\n "right": 1514764800000000\\n}')]\ +""" + ) + assert ret == [ + Row( + '{\n "left": 0,\n "right": 5\n}', + '{\n "left": 1483228800000000,\n "right": 1514764800000000\n}', + ) + ] + + +@pytest.mark.localtest +def test_na_and_null_data(): + pandas_df = pd.DataFrame( + data={ + "A": pd.Series([1, None, 2, math.nan]), + } + ) + sp_df = session.create_dataframe(data=pandas_df) + assert sp_df.select("A").collect() == [Row(1.0), Row(None), Row(2.0), Row(None)] + + pandas_df = pd.DataFrame( + data={ + "A": pd.Series(["abc", None, "a", ""]), + } + ) + sp_df = session.create_dataframe(data=pandas_df) + assert sp_df.select("A").collect() == [Row("abc"), Row(None), Row("a"), Row("")] diff --git a/tests/mock/test_filter.py b/tests/mock/test_filter.py new file mode 100644 index 00000000000..535eb715019 --- /dev/null +++ b/tests/mock/test_filter.py @@ -0,0 +1,222 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import math + +import pytest + +from snowflake.snowpark import DataFrame, Row, Session +from snowflake.snowpark.functions import col +from snowflake.snowpark.mock.connection import MockServerConnection + +session = Session(MockServerConnection()) + + +@pytest.mark.localtest +def test_basic_filter(): + origin_df: DataFrame = session.create_dataframe( + [ + [1, 2, "abc"], + [3, 4, "def"], + [6, 5, "ghi"], + [8, 7, "jkl"], + [100, 200, "mno"], + [400, 300, "pqr"], + ], + schema=["a", "b", "c"], + ).select("a", "b", "c") + + # equal + assert origin_df.filter(col("a") == 1).collect() == [Row(1, 2, "abc")] + + # not equal + assert origin_df.filter(col("a") != 1).collect() == [ + Row(3, 4, "def"), + Row(6, 5, "ghi"), + Row(8, 7, "jkl"), + Row(100, 200, "mno"), + Row(400, 300, "pqr"), + ] + + # greater + assert origin_df.filter(col("a") > 8).collect() == [ + Row(100, 200, "mno"), + Row(400, 300, "pqr"), + ] + + # greater or equal + assert origin_df.filter(col("a") >= 8).collect() == [ + Row(8, 7, "jkl"), + Row(100, 200, "mno"), + Row(400, 300, "pqr"), + ] + + # less + assert origin_df.filter(col("a") < 8).collect() == [ + Row(1, 2, "abc"), + Row(3, 4, "def"), + Row(6, 5, "ghi"), + ] + + # less or equal + assert origin_df.filter(col("a") <= 8).collect() == [ + Row(1, 2, "abc"), + Row(3, 4, "def"), + Row(6, 5, "ghi"), + Row(8, 7, "jkl"), + ] + + # and expression + assert origin_df.filter((col("a") >= 4) & (col("b") == 300)).collect() == [ + Row(400, 300, "pqr") + ] + + # or expression + assert origin_df.filter((col("a") > 300) | (col("c") == "ghi")).collect() == [ + Row(6, 5, "ghi"), + Row(400, 300, "pqr"), + ] + + assert origin_df.filter(origin_df["a"].between(6, 100)).collect() == [ + Row(6, 5, "ghi"), + Row(8, 7, "jkl"), + Row(100, 200, "mno"), + ] + + # in expression + assert origin_df.filter(col("a").in_([1, 6, 100])).collect() == [ + Row(1, 2, "abc"), + Row(6, 5, "ghi"), + Row(100, 200, "mno"), + ] + + # not expression + assert origin_df.filter(~col("a").in_([1, 6, 100])).collect() == [ + Row(3, 4, "def"), + Row(8, 7, "jkl"), + Row(400, 300, "pqr"), + ] + + +@pytest.mark.localtest +def test_null_nan_filter(): + origin_df: DataFrame = session.create_dataframe( + [ + [float("nan"), 2, "abc"], + [3.0, 4, "def"], + [6.0, 5, "ghi"], + [8.0, 7, None], + [float("nan"), 200, None], + ], + schema=["a", "b", "c"], + ).select("a", "b", "c") + + # is null + res = origin_df.filter(origin_df["c"].is_null()).collect() + assert len(res) == 2 + assert res[0] == Row(8.0, 7, None) + assert math.isnan(res[1][0]) + assert res[1][1] == 200 and res[1][2] is None + + # is not null + res = origin_df.filter(origin_df["c"].is_not_null()).collect() + assert len(res) == 3 + assert math.isnan(res[0][0]) and res[0][1] == 2 and res[0][2] == "abc" + assert res[1] == Row(3.0, 4, "def") and res[2] == Row(6.0, 5, "ghi") + + # equal na + res = origin_df.filter(origin_df["a"].equal_nan()).collect() + assert len(res) == 2 + assert math.isnan(res[0][0]) and res[0][1] == 2 and res[0][2] == "abc" + assert math.isnan(res[1][0]) and res[1][1] == 200 and res[1][2] is None + + res = origin_df.filter(origin_df["c"].is_null()).collect() + assert len(res) == 2 + assert res[0] == Row(8.0, 7, None) + assert math.isnan(res[1][0]) + assert res[1][1] == 200 and res[1][2] is None + + # equal_null + origin_df: DataFrame = session.create_dataframe( + [ + [float("nan"), float("nan")], + [float("nan"), 15.0], + [1.0, 1.0], + [1.0, 2.0], + [99.0, 100.0], + [None, None], + ], + schema=["a", "b"], + ).select("a", "b") + res = origin_df.filter(origin_df["a"].equal_null(origin_df["b"])).collect() + assert len(res) == 3 + assert math.isnan(res[0][0]) and math.isnan(res[0][1]) + assert res[1] == Row(1.0, 1.0) + assert res[2] == Row(None, None) + + +@pytest.mark.localtest +def test_chain_filter(): + origin_df: DataFrame = session.create_dataframe( + [ + [1, 2, "abc"], + [3, 4, "def"], + [6, 5, "ghi"], + [8, 7, "jkl"], + [100, 200, "mno"], + [400, 300, "pqr"], + ], + schema=["a", "b", "c"], + ).select("a", "b", "c") + + assert origin_df.filter(col("a") > 8).filter(col("c") == "pqr").collect() == [ + Row(400, 300, "pqr"), + ] + + +@pytest.mark.localtest +def test_like_filter(): + origin_df: DataFrame = session.create_dataframe( + [["test"], ["tttest"], ["tett"], ["ess"], ["es#!s"], ["es#)s"]], schema=["a"] + ).select("a") + + assert origin_df.filter(col("a").like("test")).collect() == [Row("test")] + assert origin_df.filter(col("a").like("%est%")).collect() == [ + Row("test"), + Row("tttest"), + ] + assert origin_df.filter(col("a").like(".e.*")).collect() == [] + assert origin_df.filter(col("a").like("es__s")).collect() == [ + Row("es#!s"), + Row("es#)s"), + ] + assert origin_df.filter(col("a").like("es___s")).collect() == [] + assert origin_df.filter(col("a").like("es%s")).collect() == [ + Row("ess"), + Row("es#!s"), + Row("es#)s"), + ] + assert origin_df.filter(col("a").like("%tt%")).collect() == [ + Row("tttest"), + Row("tett"), + ] + + +@pytest.mark.localtest +def test_regex_filter(): + origin_df: DataFrame = session.create_dataframe( + [["test"], ["tttest"], ["tett"], ["ess"], ["es#%s"]], schema=["a"] + ).select("a") + + assert origin_df.filter(col("a").regexp("test")).collect() == [Row("test")] + assert origin_df.filter(col("a").regexp("^est")).collect() == [] + assert origin_df.filter(col("a").regexp(".e.*")).collect() == [ + Row("test"), + Row("tett"), + ] + assert origin_df.filter(col("a").regexp(".*s")).collect() == [ + Row("ess"), + Row("es#%s"), + ] + assert origin_df.filter(col("a").regexp("...%.")).collect() == [Row("es#%s")] diff --git a/tests/mock/test_functions.py b/tests/mock/test_functions.py new file mode 100644 index 00000000000..283c69943c5 --- /dev/null +++ b/tests/mock/test_functions.py @@ -0,0 +1,286 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import datetime +import math + +import pytest + +from snowflake.snowpark import DataFrame, Row, Session +from snowflake.snowpark.functions import ( # count,; is_null,; + abs, + asc, + col, + contains, + count, + desc, + is_null, + lit, + max, + min, + to_date, +) +from snowflake.snowpark.mock.connection import MockServerConnection + +session = Session(MockServerConnection()) + + +@pytest.mark.localtest +def test_col(): + origin_df: DataFrame = session.create_dataframe( + [ + [1, "a", True], + [6, "c", False], + [None, None, None], + ], + schema=["m", "n", "o"], + ) + assert origin_df.select(col("m")).collect() == [Row(1), Row(6), Row(None)] + assert origin_df.select(col("n")).collect() == [Row("a"), Row("c"), Row(None)] + assert origin_df.select(col("o")).collect() == [Row(True), Row(False), Row(None)] + + +@pytest.mark.localtest +def test_max(): + origin_df: DataFrame = session.create_dataframe( + [ + ["a", "ddd", 11.0, None, None, True, math.nan], + ["a", "ddd", 22.0, None, None, True, math.nan], + ["b", None, 99.0, None, math.nan, False, math.nan], + ["b", "g", None, None, math.nan, None, math.nan], + ], + schema=["m", "n", "o", "p", "q", "r", "s"], + ) + # JIRA for same name alias support: https://snowflakecomputing.atlassian.net/browse/SNOW-845619 + assert origin_df.select(max("m").as_("a")).collect() == [Row("b")] + assert origin_df.select(max("n").as_("b")).collect() == [Row("g")] + assert origin_df.select(max("o").as_("c")).collect() == [Row(99.0)] + assert origin_df.select(max("p").as_("d")).collect() == [Row(None)] + assert math.isnan(origin_df.select(max("q").as_("e")).collect()[0][0]) + assert origin_df.select(max("r").as_("f")).collect() == [Row(True)] + assert math.isnan(origin_df.select(max("s").as_("g")).collect()[0][0]) + + +@pytest.mark.localtest +def test_min(): + origin_df: DataFrame = session.create_dataframe( + [ + ["a", "ddd", 11.0, None, None, True, math.nan], + ["a", "ddd", 22.0, None, None, True, math.nan], + ["b", None, 99.0, None, math.nan, False, math.nan], + ["b", "g", None, None, math.nan, None, math.nan], + ], + schema=["m", "n", "o", "p", "q", "r", "s"], + ) + + # JIRA for same name alias support: https://snowflakecomputing.atlassian.net/browse/SNOW-845619 + assert origin_df.select(min("m").as_("a")).collect() == [Row("a")] + assert origin_df.select(min("n").as_("b")).collect() == [Row("ddd")] + assert origin_df.select(min("o").as_("c")).collect() == [Row(11.0)] + assert origin_df.select(min("p").as_("d")).collect() == [Row(None)] + assert math.isnan(origin_df.select(min("q").as_("e")).collect()[0][0]) + assert origin_df.select(min("r").as_("f")).collect() == [Row(False)] + assert math.isnan(origin_df.select(min("s").as_("g")).collect()[0][0]) + + +@pytest.mark.localtest +def test_to_date(): + origin_df: DataFrame = session.create_dataframe( + ["2013-05-17", "31536000000000"], + schema=["m"], + ) + + assert origin_df.select(to_date("m")).collect() == [ + Row(datetime.date(2013, 5, 17)), + Row(datetime.date(1971, 1, 1)), + ] + + +@pytest.mark.localtest +def test_contains(): + origin_df: DataFrame = session.create_dataframe( + [ + ["1", "2"], + ["3", "4"], + ["5", "5"], + ], + schema=["m", "n"], + ) + + assert origin_df.select(contains(col("m"), col("n"))).collect() == [ + Row(False), + Row(False), + Row(True), + ] + + origin_df: DataFrame = session.create_dataframe( + [ + ["abcd", "bc"], + ["defgg", "gg"], + ["xx", "zz"], + ], + schema=["m", "n"], + ) + + assert origin_df.select(contains(col("m"), col("n"))).collect() == [ + Row(True), + Row(True), + Row(False), + ] + + assert origin_df.select(contains(col("m"), lit("xx"))).collect() == [ + Row(False), + Row(False), + Row(True), + ] + + +@pytest.mark.localtest +def test_abs(): + origin_df: DataFrame = session.create_dataframe( + [ + [1, -4], + [-1, -5], + [2, -6], + ], + schema=["m", "n"], + ) + assert origin_df.select(abs(col("m"))).collect() == [Row(1), Row(1), Row(2)] + + +@pytest.mark.localtest +def test_asc_and_desc(): + origin_df: DataFrame = session.create_dataframe( + [ + [1], + [8], + [6], + [3], + [100], + [400], + ], + schema=["v"], + ) + expected = [Row(1), Row(3), Row(6), Row(8), Row(100), Row(400)] + assert origin_df.sort(asc(col("v"))).collect() == expected + expected.reverse() + assert origin_df.sort(desc(col("v"))).collect() == expected + + +@pytest.mark.localtest +def test_count(): + origin_df: DataFrame = session.create_dataframe( + [ + [1], + [8], + [6], + [3], + [100], + [400], + ], + schema=["v"], + ) + assert origin_df.select(count("v")).collect() == [Row(6)] + + +@pytest.mark.localtest +def test_is_null(): + origin_df: DataFrame = session.create_dataframe( + [ + [float("nan"), 2, "abc"], + [3.0, 4, "def"], + [6.0, 5, "ghi"], + [8.0, 7, None], + [float("nan"), 200, None], + ], + schema=["a", "b", "c"], + ) + assert origin_df.select(is_null("a"), is_null("b"), is_null("c")).collect() == [ + Row(False, False, False), + Row(False, False, False), + Row(False, False, False), + Row(False, False, True), + Row(False, False, True), + ] + + +@pytest.mark.localtest +def test_take_first(): + origin_df: DataFrame = session.create_dataframe( + [ + [float("nan"), 2, "abc"], + [3.0, 4, "def"], + [6.0, 5, "ghi"], + [8.0, 7, None], + [float("nan"), 200, None], + ], + schema=["a", "b", "c"], + ) + assert ( + math.isnan(origin_df.select("a").first()[0]) + and len(origin_df.select("a").first()) == 1 + ) + assert origin_df.select("a", "c").order_by("c", ascending=False).first(2) == [ + Row(6.0, "ghi"), + Row(3.0, "def"), + ] + + res = origin_df.select("a", "b", "c").take(10) + assert len(res) == 5 + assert math.isnan(res[0][0]) and res[0][1] == 2 and res[0][2] == "abc" + assert res[1:4] == [ + Row(3.0, 4, "def"), + Row(6.0, 5, "ghi"), + Row(8.0, 7, None), + ] + assert math.isnan(res[4][0]) and res[4][1] == 200 and res[4][2] is None + + res = origin_df.select("a", "b", "c").take(-1) + assert len(res) == 5 + assert math.isnan(res[0][0]) and res[0][1] == 2 and res[0][2] == "abc" + assert res[1:4] == [ + Row(3.0, 4, "def"), + Row(6.0, 5, "ghi"), + Row(8.0, 7, None), + ] + assert math.isnan(res[4][0]) and res[4][1] == 200 and res[4][2] is None + + +@pytest.mark.localtest +def test_show(): + origin_df: DataFrame = session.create_dataframe( + [ + [float("nan"), 2, "abc"], + [3.0, 4, "def"], + [6.0, 5, "ghi"], + [8.0, 7, None], + [float("nan"), 200, None], + ], + schema=["a", "b", "c"], + ) + + origin_df.show() + assert ( + origin_df._show_string() + == """ +-------------------- +|"A" |"B" |"C" | +-------------------- +|nan |2 |abc | +|3.0 |4 |def | +|6.0 |5 |ghi | +|8.0 |7 |NULL | +|nan |200 |NULL | +--------------------\n""".lstrip() + ) + + assert ( + origin_df._show_string(2, 2) + == """ +---------------- +|"A...|"B...|"C...| +---------------- +|na...|2 |ab...| +|3....|4 |de...| +----------------\n""".lstrip() + ) diff --git a/tests/mock/test_sort.py b/tests/mock/test_sort.py new file mode 100644 index 00000000000..58b6d68a39d --- /dev/null +++ b/tests/mock/test_sort.py @@ -0,0 +1,176 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import pytest + +from snowflake.snowpark import DataFrame, Row, Session +from snowflake.snowpark.functions import col +from snowflake.snowpark.mock.connection import MockServerConnection +from tests.utils import Utils + +session = Session(MockServerConnection()) + + +@pytest.mark.localtest +def test_sort_single_column(): + origin_df: DataFrame = session.create_dataframe( + [ + [1], + [8], + [6], + [3], + [100], + [400], + ], + schema=["v"], + ) + expected = [Row(1), Row(3), Row(6), Row(8), Row(100), Row(400)] + assert origin_df.sort(col("v")).collect() == expected + expected.reverse() + assert origin_df.sort(col("v").desc()).collect() == expected + + origin_df: DataFrame = session.create_dataframe( + [[1.0], [8.0], [6.0], [None], [3.0], [100.0], [400.0], [float("nan")]], + schema=["v"], + ) + expected_null_first = [ + Row(None), + Row(1.0), + Row(3.0), + Row(6.0), + Row(8.0), + Row(100.0), + Row(400.0), + Row(float("nan")), + ] + expected_null_last = [ + Row(1), + Row(3), + Row(6), + Row(8), + Row(100), + Row(400), + Row(float("nan")), + Row(None), + ] + Utils.check_answer( + origin_df.sort(col("v").asc_nulls_first()).collect(), + expected_null_first, + sort=False, + ) + Utils.check_answer( + origin_df.sort(col("v").asc_nulls_last()).collect(), + expected_null_last, + sort=False, + ) + + +@pytest.mark.localtest +def test_sort_multiple_column(): + origin_df: DataFrame = session.create_dataframe( + [ + [1.0, 2.3], + [8.0, 1.9], + [None, 7.8], + [3.0, 5.6], + [3.0, 4.7], + [3.0, None], + [float("nan"), 0.9], + ], + schema=["m", "n"], + ) + Utils.check_answer( + origin_df.sort(col("m")).collect(), + [ + Row(None, 7.8), + Row(1.0, 2.3), + Row(3.0, 5.6), + Row(3.0, 4.7), + Row(3.0, None), + Row(8.0, 1.9), + Row(float("nan"), 0.9), + ], + sort=False, + ) + Utils.check_answer( + origin_df.sort(col("m").asc_nulls_last()).collect(), + [ + Row(1.0, 2.3), + Row(3.0, 5.6), + Row(3.0, 4.7), + Row(3.0, None), + Row(8.0, 1.9), + Row(float("nan"), 0.9), + Row(None, 7.8), + ], + sort=False, + ) + Utils.check_answer( + origin_df.sort(col("m").desc_nulls_first()).collect(), + [ + Row(None, 7.8), + Row(float("nan"), 0.9), + Row(8.0, 1.9), + Row(3.0, 5.6), + Row(3.0, 4.7), + Row(3.0, None), + Row(1.0, 2.3), + ], + sort=False, + ) + Utils.check_answer( + origin_df.sort(col("m").desc_nulls_last()).collect(), + [ + Row(float("nan"), 0.9), + Row(8.0, 1.9), + Row(3.0, 5.6), + Row(3.0, 4.7), + Row(3.0, None), + Row(1.0, 2.3), + Row(None, 7.8), + ], + sort=False, + ) + + Utils.check_answer( + origin_df.sort([col("m"), col("n")], ascending=[1, 1]).collect(), + [ + Row(None, 7.8), + Row(1.0, 2.3), + Row(3.0, None), + Row(3.0, 4.7), + Row(3.0, 5.6), + Row(8.0, 1.9), + Row(float("nan"), 0.9), + ], + sort=False, + ) + + Utils.check_answer( + origin_df.sort([col("m"), col("n")], ascending=[True, False]).collect(), + [ + Row(None, 7.8), + Row(1.0, 2.3), + Row(3.0, 5.6), + Row(3.0, 4.7), + Row(3.0, None), + Row(8.0, 1.9), + Row(float("nan"), 0.9), + ], + sort=False, + ) + + Utils.check_answer( + origin_df.sort([col("m"), col("n").desc_nulls_first()]).collect(), + [ + Row(None, 7.8), + Row(1.0, 2.3), + Row(3.0, None), + Row(3.0, 5.6), + Row(3.0, 4.7), + Row(8.0, 1.9), + Row(float("nan"), 0.9), + ], + sort=False, + ) diff --git a/tests/mock/test_to_pandas.py b/tests/mock/test_to_pandas.py new file mode 100644 index 00000000000..3b3707b8435 --- /dev/null +++ b/tests/mock/test_to_pandas.py @@ -0,0 +1,166 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import datetime +import decimal + +import pytest + +from snowflake.snowpark import Session +from snowflake.snowpark.mock.connection import MockServerConnection +from snowflake.snowpark.types import ( + ArrayType, + BinaryType, + BooleanType, + ByteType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + MapType, + NullType, + ShortType, + StringType, + StructField, + StructType, + TimestampType, + TimeType, + VariantType, +) + +try: + import numpy as np + import pandas as pd + from pandas.testing import assert_frame_equal +except ImportError: + pytest.skip("pandas is not installed, skipping the tests", allow_module_level=True) + + +session = Session(MockServerConnection()) + + +@pytest.mark.localtest +def test_df_to_pandas_df(): + df = session.create_dataframe( + [ + [ + 1, + 1234567890, + True, + 1.23, + "abc", + b"abc", + datetime.datetime( + year=2023, month=10, day=30, hour=12, minute=12, second=12 + ), + ] + ], + schema=[ + "aaa", + "BBB", + "cCc", + "DdD", + "e e", + "ff ", + " gg", + ], + ) + + to_compare_df = pd.DataFrame( + { + "AAA": pd.Series( + [1], dtype=np.int8 + ), # int8 is the snowpark behavior, by default pandas use int64 + "BBB": pd.Series( + [1234567890], dtype=np.int32 + ), # int32 is the snowpark behavior, by default pandas use int64 + "CCC": pd.Series([True]), + "DDD": pd.Series([1.23]), + "e e": pd.Series(["abc"]), + "ff ": pd.Series([b"abc"]), + " gg": pd.Series( + [ + datetime.datetime( + year=2023, month=10, day=30, hour=12, minute=12, second=12 + ) + ] + ), + } + ) + + # assert_frame_equal also checks dtype + assert_frame_equal(df.to_pandas(), to_compare_df) + assert_frame_equal(list(df.to_pandas_batches())[0], to_compare_df) + + # check snowflake types explicitly + df = session.create_dataframe( + data=[ + [ + [1, 2, 3, 4], + b"123", + True, + 1, + datetime.date(year=2023, month=10, day=30), + decimal.Decimal(1), + 1.23, + 1.23, + 100, + 100, + None, + 100, + "abc", + datetime.datetime(2023, 10, 30, 12, 12, 12), + datetime.time(12, 12, 12), + {"a": "b"}, + {"a": "b"}, + ], + ], + schema=StructType( + [ + StructField("a", ArrayType()), + StructField("b", BinaryType()), + StructField("c", BooleanType()), + StructField("d", ByteType()), + StructField("e", DateType()), + StructField("f", DecimalType()), + StructField("g", DoubleType()), + StructField("h", FloatType()), + StructField("i", IntegerType()), + StructField("j", LongType()), + StructField("k", NullType()), + StructField("l", ShortType()), + StructField("m", StringType()), + StructField("n", TimestampType()), + StructField("o", TimeType()), + StructField("p", VariantType()), + StructField("q", MapType(StringType(), StringType())), + ] + ), + ) + + pandas_df = pd.DataFrame( + { + "A": pd.Series(["[\n 1,\n 2,\n 3,\n 4\n]"], dtype=object), + "B": pd.Series([b"123"], dtype=object), + "C": pd.Series([True], dtype=bool), + "D": pd.Series([1], dtype=np.int8), + "E": pd.Series([datetime.date(year=2023, month=10, day=30)], dtype=object), + "F": pd.Series([decimal.Decimal(1)], dtype=np.int8), + "G": pd.Series([1.23], dtype=np.float64), + "H": pd.Series([1.23], dtype=np.float64), + "I": pd.Series([100], dtype=np.int8), + "J": pd.Series([100], dtype=np.int8), + "K": pd.Series([None], dtype=object), + "L": pd.Series([100], dtype=np.int8), + "M": pd.Series(["abc"], dtype=object), + "N": pd.Series( + [datetime.datetime(2023, 10, 30, 12, 12, 12)], dtype="datetime64[ns]" + ), + "O": pd.Series([datetime.time(12, 12, 12)], dtype=object), + "P": pd.Series(['{\n "a": "b"\n}'], dtype=object), + "Q": pd.Series(['{\n "a": "b"\n}'], dtype=object), + } + ) + assert_frame_equal(df.to_pandas(), pandas_df) diff --git a/tests/mock/test_union.py b/tests/mock/test_union.py new file mode 100644 index 00000000000..520b7da1ab3 --- /dev/null +++ b/tests/mock/test_union.py @@ -0,0 +1,122 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import pytest + +from snowflake.snowpark import DataFrame, Row, Session +from snowflake.snowpark.mock.connection import MockServerConnection +from tests.utils import Utils + +session = Session(MockServerConnection()) + + +@pytest.mark.localtest +def test_union_basic(): + df1: DataFrame = session.create_dataframe( + [ + [1, 2], + [3, 4], + ], + schema=["a", "b"], + ) + + df2: DataFrame = session.create_dataframe( + [ + [1, 2], + [5, 6], + ], + schema=["a", "b"], + ) + + df3: DataFrame = session.create_dataframe( + [ + [5, 6], + [9, 10], + [11, 12], + ], + schema=["a", "b"], + ) + + Utils.check_answer( + df1.union(df2).collect(), + [ + Row(1, 2), + Row(3, 4), + Row(5, 6), + ], + ) + + Utils.check_answer( + df1.union_all(df2).collect(), + [ + Row(1, 2), + Row(3, 4), + Row(1, 2), + Row(5, 6), + ], + ) + + Utils.check_answer( + df1.union(df2).union(df3).collect(), + [ + Row(1, 2), + Row(3, 4), + Row(5, 6), + Row(9, 10), + Row(11, 12), + ], + ) + + Utils.check_answer( + df1.union_all(df2).union_all(df3).collect(), + [ + Row(1, 2), + Row(3, 4), + Row(1, 2), + Row(5, 6), + Row(5, 6), + Row(9, 10), + Row(11, 12), + ], + ) + + +@pytest.mark.localtest +def test_union_by_name(): + df1: DataFrame = session.create_dataframe( + [ + [1, 2], + [3, 4], + ], + schema=["a", "b"], + ) + + df2: DataFrame = session.create_dataframe( + [ + [1, 2], + [2, 1], + [5, 6], + ], + schema=["b", "a"], + ) + Utils.check_answer( + df1.union_by_name(df2).collect(), + [ + Row(1, 2), + Row(3, 4), + Row(2, 1), + Row(6, 5), + ], + ) + + Utils.check_answer( + df1.union_all_by_name(df2).collect(), + [ + Row(1, 2), + Row(3, 4), + Row(2, 1), + Row(1, 2), + Row(6, 5), + ], + ) diff --git a/tests/resources/testCSVvariousData.csv b/tests/resources/testCSVvariousData.csv new file mode 100644 index 00000000000..0131941aff7 --- /dev/null +++ b/tests/resources/testCSVvariousData.csv @@ -0,0 +1,2 @@ +1,234,one,1.2,12.3456,12.3456,-12.3456,12.3456,56.78,true,2023-06-06,2023-06-06 12:34:56,12:34:56 +2,567,two,2.2,56.7867,56.7867,-56.7867,56.7867,89.01,false,2023-06-06,2023-06-06 12:34:56,12:34:56 \ No newline at end of file diff --git a/tests/unit/scala/test_utils_suite.py b/tests/unit/scala/test_utils_suite.py index ab2e9d584b8..a720cec5226 100644 --- a/tests/unit/scala/test_utils_suite.py +++ b/tests/unit/scala/test_utils_suite.py @@ -245,6 +245,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files): "resources/test.xml", "resources/test2CSV.csv", "resources/testCSV.csv", + "resources/testCSVvariousData.csv", "resources/testCSVcolon.csv", "resources/testCSVheader.csv", "resources/testCSVquotes.csv", diff --git a/tests/utils.py b/tests/utils.py index 6efb7cfcfa9..037900ceb98 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,6 +24,8 @@ is_in_stored_procedure, quote_name, ) +from snowflake.snowpark.functions import col, parse_json +from snowflake.snowpark.mock.connection import MockServerConnection from snowflake.snowpark.types import ( ArrayType, BinaryType, @@ -37,6 +39,7 @@ LongType, MapType, StringType, + StructField, StructType, TimestampType, TimeType, @@ -116,7 +119,10 @@ def drop_stage(session: "Session", name: str): @staticmethod def drop_table(session: "Session", name: str): - session._run_query(f"drop table if exists {quote_name(name)}") + if isinstance(session._conn, MockServerConnection): + session.table(name).drop_table() + else: + session._run_query(f"drop table if exists {quote_name(name)}") @staticmethod def drop_dynamic_table(session: "Session", name: str): @@ -146,8 +152,8 @@ def unset_query_tag(session: "Session"): def upload_to_stage( session: "Session", stage_name: str, filename: str, compress: bool ): - session._conn.upload_file( - stage_location=stage_name, path=filename, compress_data=compress + session.file.put( + local_file_name=filename, stage_location=stage_name, auto_compress=compress ) @staticmethod @@ -418,42 +424,81 @@ def all_nulls(cls, session: "Session") -> DataFrame: @classmethod def null_data1(cls, session: "Session") -> DataFrame: - return session.sql("select * from values(null),(2),(1),(3),(null) as T(a)") + return session.create_dataframe([[None], [2], [1], [3], [None]], schema=["a"]) @classmethod def null_data2(cls, session: "Session") -> DataFrame: - return session.sql( - "select * from values(1,2,3),(null,2,3),(null,null,3),(null,null,null)," - "(1,null,3),(1,null,null),(1,2,null) as T(a,b,c)" + return session.create_dataframe( + [ + [1, 2, 3], + [None, 2, 3], + [None, None, 3], + [None, None, None], + [1, None, 3], + [1, None, None], + [1, 2, None], + ], + schema=["a", "b", "c"], ) @classmethod - def null_data3(cls, session: "Session") -> DataFrame: - return session.sql( - "select * from values(1.0, 1, true, 'a'),('NaN'::Double, 2, null, 'b')," - "(null, 3, false, null), (4.0, null, null, 'd'), (null, null, null, null)," - "('NaN'::Double, null, null, null) as T(flo, int, boo, str)" + def null_data3(cls, session: "Session", local_testing_mode=False) -> DataFrame: + return ( + session.sql( + "select * from values(1.0, 1, true, 'a'),('NaN'::Double, 2, null, 'b')," + "(null, 3, false, null), (4.0, null, null, 'd'), (null, null, null, null)," + "('NaN'::Double, null, null, null) as T(flo, int, boo, str)" + ) + if not local_testing_mode + else session.create_dataframe( + [ + [1.0, 1, True, "a"], + [math.nan, 2, None, "b"], + [None, 3, False, None], + [4.0, None, None, "d"], + [None, None, None, None], + [math.nan, None, None, None], + ], + schema=["flo", "int", "boo", "str"], + ) ) @classmethod def integer1(cls, session: "Session") -> DataFrame: - return session.sql("select * from values(1),(2),(3) as T(a)") + return session.create_dataframe([[1], [2], [3]]).to_df(["a"]) @classmethod def double1(cls, session: "Session") -> DataFrame: - return session.sql("select * from values(1.111),(2.222),(3.333) as T(a)") + return session.create_dataframe( + [[1.111], [2.222], [3.333]], + schema=StructType([StructField("a", DecimalType(scale=3))]), + ) @classmethod def double2(cls, session: "Session") -> DataFrame: - return session.sql( - "select * from values(0.1, 0.5),(0.2, 0.6),(0.3, 0.7) as T(a,b)" + return session.create_dataframe( + [[0.1, 0.5], [0.2, 0.6], [0.3, 0.7]], schema=["a", "b"] ) @classmethod - def double3(cls, session: "Session") -> DataFrame: - return session.sql( - "select * from values(1.0, 1),('NaN'::Double, 2),(null, 3)," - "(4.0, null), (null, null), ('NaN'::Double, null) as T(a, b)" + def double3(cls, session: "Session", local_testing_mode=False) -> DataFrame: + return ( + session.sql( + "select * from values(1.0, 1),('NaN'::Double, 2),(null, 3)," + "(4.0, null), (null, null), ('NaN'::Double, null) as T(a, b)" + ) + if not local_testing_mode + else session.create_dataframe( + [ + [1.0, 1], + [math.nan, 2], + [None, 3], + [4.0, None], + [None, None], + [math.nan, None], + ], + schema=["a", "b"], + ) ) @classmethod @@ -472,8 +517,8 @@ def duplicated_numbers(cls, session: "Session") -> DataFrame: @classmethod def approx_numbers(cls, session: "Session") -> DataFrame: - return session.sql( - "select * from values(1),(2),(3),(4),(5),(6),(7),(8),(9),(0) as T(a)" + return session.create_dataframe( + [[1], [2], [3], [4], [5], [6], [7], [8], [9], [0]], schema=["a"] ) @classmethod @@ -485,35 +530,40 @@ def approx_numbers2(cls, session: "Session") -> DataFrame: @classmethod def string1(cls, session: "Session") -> DataFrame: - return session.sql( - "select * from values('test1', 'a'),('test2', 'b'),('test3', 'c') as T(a, b)" + return session.create_dataframe( + [["test1", "a"], ["test2", "b"], ["test3", "c"]], + schema=StructType( + [StructField("a", StringType(5)), StructField("b", StringType(1))] + ), ) @classmethod def string2(cls, session: "Session") -> DataFrame: - return session.sql("select * from values('asdFg'),('qqq'),('Qw') as T(a)") + return session.create_dataframe([["asdFg"], ["qqq"], ["Qw"]], schema=["a"]) @classmethod def string3(cls, session: "Session") -> DataFrame: - return session.sql("select * from values(' abcba '), (' a12321a ') as T(a)") + return session.create_dataframe([[" abcba "], [" a12321a "]], schema=["a"]) @classmethod def string4(cls, session: "Session") -> DataFrame: - return session.sql("select * from values('apple'),('banana'),('peach') as T(a)") + return session.create_dataframe( + [["apple"], ["banana"], ["peach"]], schema=["a"] + ) @classmethod def string5(cls, session: "Session") -> DataFrame: - return session.sql("select * from values('1,2,3,4,5') as T(a)") + return session.create_dataframe([["1,2,3,4,5"]], schema=["a"]) @classmethod def string6(cls, session: "Session") -> DataFrame: - return session.sql( - "select * from values('1,2,3,4,5', ','),('1 2 3 4 5', ' ') as T(a, b)" + return session.create_dataframe( + [["1,2,3,4,5", ","], ["1 2 3 4 5", " "]], schema=["a", "b"] ) @classmethod def string7(cls, session: "Session") -> DataFrame: - return session.sql("select * from values('str', 1),(null, 2) as T(a, b)") + return session.create_dataframe([["str", 1], [None, 2]], schema=["a", "b"]) @classmethod def array1(cls, session: "Session") -> DataFrame: @@ -589,25 +639,27 @@ def variant1(cls, session: "Session") -> DataFrame: @classmethod def variant2(cls, session: "Session") -> DataFrame: - return session.sql( - """ - select parse_json(column1) as src - from values - ('{ - "date with '' and ." : "2017-04-28", - "salesperson" : { - "id": "55", - "name": "Frank Beasley" - }, - "customer" : [ - {"name": "Joyce Ridgely", "phone": "16504378889", "address": "San Francisco, CA"} - ], - "vehicle" : [ - {"make": "Honda", "extras":["ext warranty", "paint protection"]} - ] - }') - """ + df = session.create_dataframe( + data=[ + """\ +{ + "date with ' and .": "2017-04-28", + "salesperson": { + "id": "55", + "name": "Frank Beasley" + }, + "customer": [ + {"name": "Joyce Ridgely", "phone": "16504378889", "address": "San Francisco, CA"} + ], + "vehicle": [ + {"make": "Honda", "extras": ["ext warranty", "paint protection"]} + ] +}\ +""" + ], + schema=["values"], ) + return df.select(parse_json("values").as_("src")) @classmethod def geography(cls, session: "Session") -> DataFrame: @@ -675,10 +727,8 @@ def geometry_type(cls, session: "Session") -> DataFrame: @classmethod def null_json1(cls, session: "Session") -> DataFrame: - return session.sql( - 'select parse_json(column1) as v from values (\'{"a": null}\'), (\'{"a": "foo"}\'),' - " (null)" - ) + res = session.create_dataframe([['{"a": null}'], ['{"a": "foo"}'], [None]]) + return res.select(parse_json(col("_1")).as_("v")) @classmethod def valid_json1(cls, session: "Session") -> DataFrame: @@ -871,6 +921,10 @@ def __init__(self, resources_path) -> None: def test_file_csv(self): return os.path.join(self.resources_path, "testCSV.csv") + @property + def test_file_csv_various_data(self): + return os.path.join(self.resources_path, "testCSVvariousData.csv") + @property def test_file2_csv(self): return os.path.join(self.resources_path, "test2CSV.csv") diff --git a/tox.ini b/tox.ini index 8f58baa0fa8..00dae2ed4d2 100644 --- a/tox.ini +++ b/tox.ini @@ -69,6 +69,7 @@ commands = udf: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} or udf" {posargs:} src/snowflake/snowpark tests notdoctest: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} or udf" {posargs:} tests notudfdoctest: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and not udf" {posargs:} tests + local: {env:SNOWFLAKE_PYTEST_CMD} --local_testing_mode -m "(integ and localtest) or unit" {posargs:} tests [testenv:nopandas] allowlist_externals = bash @@ -148,6 +149,7 @@ markers = doctest: doctest tests # Other markers timeout: tests that need a timeout time + localtest: local tests addopts = --doctest-modules --timeout=1200 [flake8]