diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 444fbf38178..10959f7eb06 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -18,7 +18,7 @@ VALID_SNOWPARK_TYPES_FOR_LITERAL_VALUE, infer_type, ) -from snowflake.snowpark.types import DataType +from snowflake.snowpark.types import DataType, StringType COLUMN_DEPENDENCY_DOLLAR = frozenset( "$" @@ -214,6 +214,8 @@ def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None: self.datatype = datatype else: self.datatype = infer_type(value) + if isinstance(self.datatype, StringType): + self.datatype = StringType(len(value)) class Like(Expression): diff --git a/src/snowflake/snowpark/mock/connection.py b/src/snowflake/snowpark/mock/connection.py index 6e5f9b9fe91..9f71bd4dee3 100644 --- a/src/snowflake/snowpark/mock/connection.py +++ b/src/snowflake/snowpark/mock/connection.py @@ -9,6 +9,7 @@ import sys import time from copy import copy +from decimal import Decimal from logging import getLogger from typing import IO, Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from unittest.mock import Mock @@ -403,8 +404,23 @@ def execute( # to align with snowflake behavior, we unquote name here columns = [unquote_if_quoted(col_name) for col_name in res.columns] rows = [] + # TODO: SNOW-976145, move to index based approach to store col type mapping + # for now we only use the index based approach in aggregation functions + if res.sf_types_by_col_index: + keys = sorted(res.sf_types_by_col_index.keys()) + sf_types = [res.sf_types_by_col_index[key] for key in keys] + else: + sf_types = list(res.sf_types.values()) for pdr in res.itertuples(index=False, name=None): - row = Row(*pdr) + row = Row( + *[ + Decimal(str(v)) + if isinstance(sf_types[i].datatype, DecimalType) + and v is not None + else v + for i, v in enumerate(pdr) + ] + ) row._fields = columns rows.append(row) elif isinstance(res, list): diff --git a/src/snowflake/snowpark/mock/functions.py b/src/snowflake/snowpark/mock/functions.py index 18624d7f41c..13dfd595722 100644 --- a/src/snowflake/snowpark/mock/functions.py +++ b/src/snowflake/snowpark/mock/functions.py @@ -745,16 +745,27 @@ def mock_iff(condition: ColumnEmulator, expr1: ColumnEmulator, expr2: ColumnEmul if ( all(condition) or all(~condition) + or ( + isinstance(expr1.sf_type.datatype, StringType) + and isinstance(expr2.sf_type.datatype, StringType) + ) or expr1.sf_type.datatype == expr2.sf_type.datatype or isinstance(expr1.sf_type.datatype, NullType) or isinstance(expr2.sf_type.datatype, NullType) ): res = ColumnEmulator(data=[None] * len(condition), dtype=object) - sf_data_type = ( - expr1.sf_type.datatype - if any(condition) and not isinstance(expr1.sf_type.datatype, NullType) - else expr2.sf_type.datatype - ) + if isinstance(expr1.sf_type.datatype, StringType) and isinstance( + expr2.sf_type.datatype, StringType + ): + l1 = expr1.sf_type.datatype.length or StringType._MAX_LENGTH + l2 = expr2.sf_type.datatype.length or StringType._MAX_LENGTH + sf_data_type = StringType(max(l1, l2)) + else: + sf_data_type = ( + expr1.sf_type.datatype + if any(condition) and not isinstance(expr1.sf_type.datatype, NullType) + else expr2.sf_type.datatype + ) nullability = expr1.sf_type.nullable and expr2.sf_type.nullable res.sf_type = ColumnType(sf_data_type, nullability) res.where(condition, other=expr2, inplace=True) @@ -903,4 +914,3 @@ def mock_to_variant(expr: ColumnEmulator): res = expr.copy() res.sf_type = ColumnType(VariantType(), expr.sf_type.nullable) return res - diff --git a/src/snowflake/snowpark/mock/plan.py b/src/snowflake/snowpark/mock/plan.py index 093b63b9b9b..4a8fe2dba50 100644 --- a/src/snowflake/snowpark/mock/plan.py +++ b/src/snowflake/snowpark/mock/plan.py @@ -343,6 +343,7 @@ def execute_mock_plan( result_df.insert(len(result_df.columns), str(i), from_df.iloc[:, i]) result_df.columns = from_df.columns result_df.sf_types = from_df.sf_types + result_df.sf_types_by_col_index = from_df.sf_types_by_col_index elif ( isinstance(exp, UnresolvedAlias) and exp.child @@ -537,6 +538,8 @@ def execute_mock_plan( ) result_df_sf_Types = {} + result_df_sf_Types_by_col_idx = {} + column_exps = [ ( plan.session._analyzer.analyze(exp), @@ -547,10 +550,11 @@ def execute_mock_plan( ) for exp in source_plan.grouping_expressions ] - for column_name, _, column_type in column_exps: + for idx, (column_name, _, column_type) in enumerate(column_exps): result_df_sf_Types[ column_name ] = column_type # TODO: fix this, this does not work + result_df_sf_Types_by_col_idx[idx] = column_type # Aggregate may not have column_exps, which is allowed in the case of `Dataframe.agg`, in this case we pass # lambda x: True as the `by` parameter # also pandas group by takes None and nan as the same, so we use .astype to differentiate the two @@ -607,10 +611,16 @@ def aggregate_by_groups(cur_group: TableEmulator): values.append(cal_exp_res.iat[0]) result_df_sf_Types[ columns[idx + len(column_exps)] + ] = result_df_sf_Types_by_col_idx[ + idx + len(column_exps) ] = cal_exp_res.sf_type else: values.append(cal_exp_res) - result_df_sf_Types[columns[idx] + len(column_exps)] = ColumnType( + result_df_sf_Types[ + columns[idx + len(column_exps)] + ] = result_df_sf_Types_by_col_idx[ + idx + len(column_exps) + ] = ColumnType( infer_type(cal_exp_res), nullable=True ) data.append(values) @@ -634,6 +644,7 @@ def aggregate_by_groups(cur_group: TableEmulator): result_df[intermediate_mapped_column[col]] = series_data result_df.sf_types = result_df_sf_Types + result_df.sf_types_by_col_index = result_df_sf_Types_by_col_idx result_df.columns = columns return result_df if isinstance(source_plan, Range): @@ -1126,6 +1137,12 @@ def describe(plan: MockExecutionPlan) -> List[Attribute]: data_type = LongType() elif isinstance(data_type, FloatType): data_type = DoubleType() + elif ( + isinstance(data_type, DecimalType) + and data_type.precision == 38 + and data_type.scale == 0 + ): + data_type = LongType() ret.append( Attribute( quote_name(result[c].name.strip()), @@ -1243,7 +1260,7 @@ def calculate_expression( ) return ColumnEmulator( data=[bool(data is None) for data in child_column], - sf_type=ColumnType(BooleanType(), False), + sf_type=ColumnType(BooleanType(), True), ) if isinstance(exp, IsNotNull): child_column = calculate_expression( @@ -1251,7 +1268,7 @@ def calculate_expression( ) return ColumnEmulator( data=[bool(data is not None) for data in child_column], - sf_type=ColumnType(BooleanType(), False), + sf_type=ColumnType(BooleanType(), True), ) if isinstance(exp, IsNaN): child_column = calculate_expression( @@ -1264,7 +1281,7 @@ def calculate_expression( except TypeError: res.append(False) return ColumnEmulator( - data=res, dtype=object, sf_type=ColumnType(BooleanType(), False) + data=res, dtype=object, sf_type=ColumnType(BooleanType(), True) ) if isinstance(exp, Not): child_column = calculate_expression( @@ -1286,6 +1303,8 @@ def calculate_expression( if isinstance(exp, BinaryExpression): left = calculate_expression(exp.left, input_data, analyzer, expr_to_alias) right = calculate_expression(exp.right, input_data, analyzer, expr_to_alias) + # TODO: Address mixed type calculation here. For instance Snowflake allows to add a date to a number, but + # pandas doesn't allow. Type coercion will address it. if isinstance(exp, Multiply): new_column = left * right elif isinstance(exp, Divide): @@ -1354,7 +1373,7 @@ def calculate_expression( except re.error: raise SnowparkSQLException(f"Invalid regular expression {raw_pattern}") result = lhs.str.match(pattern) - result.sf_type = ColumnType(BooleanType(), exp.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result if isinstance(exp, Like): lhs = calculate_expression(exp.expr, input_data, analyzer, expr_to_alias) @@ -1366,7 +1385,7 @@ def calculate_expression( ) ) result = lhs.str.match(pattern) - result.sf_type = ColumnType(BooleanType(), exp.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result if isinstance(exp, InExpression): lhs = calculate_expression(exp.columns, input_data, analyzer, expr_to_alias) @@ -1661,7 +1680,12 @@ def calculate_expression( ).sf_type if not calculated_sf_type: calculated_sf_type = cur_windows_sf_type - elif calculated_sf_type.datatype != cur_windows_sf_type.datatype: + elif calculated_sf_type != cur_windows_sf_type and ( + not ( + isinstance(calculated_sf_type.datatype, StringType) + and isinstance(cur_windows_sf_type.datatype, StringType) + ) + ): if isinstance(calculated_sf_type.datatype, NullType): calculated_sf_type = sub_window_res.sf_type # the result calculated upon a windows can be None, this is still valid and we can keep diff --git a/src/snowflake/snowpark/mock/snowflake_data_type.py b/src/snowflake/snowpark/mock/snowflake_data_type.py index d576a1380ed..99aa13520fb 100644 --- a/src/snowflake/snowpark/mock/snowflake_data_type.py +++ b/src/snowflake/snowpark/mock/snowflake_data_type.py @@ -1,11 +1,7 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # - -# -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. -# -from typing import Dict, NamedTuple, NoReturn, Optional, Union +from typing import Dict, NamedTuple, Optional, Union import pandas as pd @@ -127,6 +123,12 @@ def normalize_decimal(d: DecimalType): d.precision = min(38, d.precision) +def normalize_output_sf_type(t: DataType) -> DataType: + if t == DecimalType(38, 0): + return LongType() + return t + + def calculate_type(c1: ColumnType, c2: Optional[ColumnType], op: Union[str]): """op, left, right decide what's next.""" t1, t2 = c1.datatype, c2.datatype @@ -142,7 +144,9 @@ def calculate_type(c1: ColumnType, c2: Optional[ColumnType], op: Union[str]): res_scale = max(min(s1 + division_min_scale, division_max_scale), s1) res_lead = l1 + s2 res_precision = min(38, res_scale + res_lead) - result_type = DecimalType(res_precision, res_scale) + result_type = normalize_output_sf_type( + DecimalType(res_precision, res_scale) + ) return ColumnType(result_type, nullable) elif op == "*": multiplication_max_scale = 12 @@ -152,6 +156,7 @@ def calculate_type(c1: ColumnType, c2: Optional[ColumnType], op: Union[str]): result_precision = min(38, result_scale + l1 + l2) result_type = DecimalType(result_precision, result_scale) normalize_decimal(result_type) + result_type = normalize_output_sf_type(result_type) return ColumnType(result_type, nullable) elif op in ("+", "-"): # widen the number with smaller scale @@ -167,13 +172,15 @@ def calculate_type(c1: ColumnType, c2: Optional[ColumnType], op: Union[str]): gap = gap + 1 p1 += gap s1 += gap - result_type = DecimalType(min(38, max(p1, p2) + 1), max(s1, s2)) + result_type = normalize_output_sf_type( + DecimalType(min(38, max(p1, p2) + 1), max(s1, s2)) + ) return ColumnType(result_type, nullable) elif op == "%": new_scale = max(s1, s2) new_decimal = max(p1 - s1, p2 - s2) new_decimal = new_decimal + new_scale - result_type = DecimalType(new_decimal, new_scale) + result_type = normalize_output_sf_type(DecimalType(new_decimal, new_scale)) return ColumnType(result_type, nullable) else: return NotImplementedError( @@ -204,7 +211,7 @@ def calculate_type(c1: ColumnType, c2: Optional[ColumnType], op: Union[str]): class TableEmulator(pd.DataFrame): - _metadata = ["sf_types", "_null_rows_idxs_map"] + _metadata = ["sf_types", "sf_types_by_col_index", "_null_rows_idxs_map"] @property def _constructor(self): @@ -215,10 +222,18 @@ def _constructor_sliced(self): return ColumnEmulator def __init__( - self, *args, sf_types: Optional[Dict[str, ColumnType]] = None, **kwargs - ) -> NoReturn: + self, + *args, + sf_types: Optional[Dict[str, ColumnType]] = None, + sf_types_by_col_index: Optional[Dict[int, ColumnType]] = None, + **kwargs, + ) -> None: super().__init__(*args, **kwargs) self.sf_types = {} if not sf_types else sf_types + # TODO: SNOW-976145, move to index based approach to store col type mapping + self.sf_types_by_col_index = ( + {} if not sf_types_by_col_index else sf_types_by_col_index + ) self._null_rows_idxs_map = {} def __getitem__(self, item): @@ -282,7 +297,7 @@ def _constructor(self): def _constructor_expanddim(self): return TableEmulator - def __init__(self, *args, **kwargs) -> NoReturn: + def __init__(self, *args, **kwargs) -> None: sf_type = kwargs.pop("sf_type", None) super().__init__(*args, **kwargs) self.sf_type: ColumnType = sf_type @@ -349,57 +364,59 @@ def __bool__(self): def __and__(self, other): result = super().__and__(other) - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def __or__(self, other): result = super().__or__(other) - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def __ne__(self, other): result = super().__ne__(other) - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def __xor__(self, other): result = super().__xor__(other) - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def __pow__(self, power): result = super().__pow__(power) - result.sf_type = ColumnType(DoubleType(), self.sf_type.nullable) + result.sf_type = ColumnType( + DoubleType(), self.sf_type.nullable or power.sf_type.nullable + ) return result def __ge__(self, other): result = super().__ge__(other) - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def __gt__(self, other): result = super().__gt__(other) - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def __invert__(self): result = super().__invert__() - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def __le__(self, other): result = super().__le__(other) - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def __lt__(self, other): result = super().__lt__(other) - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def __eq__(self, other): result = super().__eq__(other) - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def __neg__(self): @@ -409,7 +426,7 @@ def __neg__(self): def __rand__(self, other): result = super().__rand__(other) - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def __mod__(self, other): @@ -424,7 +441,7 @@ def __rmod__(self, other): def __ror__(self, other): result = super().__ror__(other) - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def __round__(self, n=None): @@ -445,25 +462,29 @@ def __round__(self, n=None): def __rpow__(self, other): result = super().__rpow__(other) - result.sf_type = ColumnType(DoubleType(), self.sf_type.nullable) + result.sf_type = ColumnType(DoubleType(), True) return result def __rtruediv__(self, other): - result = super().__rtruediv__(other) - result.sf_type = calculate_type(other.sf_type, self.sf_type, op="/") - return result + return other.__truediv__(self) def __truediv__(self, other): result = super().__truediv__(other) - result.sf_type = calculate_type(self.sf_type, other.sf_type, op="/") + sf_type = calculate_type(self.sf_type, other.sf_type, op="/") + if isinstance(sf_type.datatype, DecimalType): + result = result.astype("double").round(sf_type.datatype.scale) + elif isinstance(sf_type.datatype, (FloatType, DoubleType)): + result = result.astype("double").round(16) + result.sf_type = sf_type + return result def isna(self): result = super().isna() - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result def isnull(self): result = super().isnull() - result.sf_type = ColumnType(BooleanType(), self.sf_type.nullable) + result.sf_type = ColumnType(BooleanType(), True) return result diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index a23bc0b07f5..7db5698faff 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -80,7 +80,7 @@ def __init__(self, length: Optional[int] = None) -> None: self.length = length def __repr__(self) -> str: - if self.length: + if self.length and self.length < self._MAX_LENGTH: return f"StringType({self.length})" return "StringType()" diff --git a/tests/mock/test_datatypes.py b/tests/integ/test_datatypes.py similarity index 72% rename from tests/mock/test_datatypes.py rename to tests/integ/test_datatypes.py index 1f38e2fdb3f..25a39e9e3f9 100644 --- a/tests/mock/test_datatypes.py +++ b/tests/integ/test_datatypes.py @@ -1,10 +1,12 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +from decimal import Decimal -from snowflake.snowpark import DataFrame, Session +import pytest + +from snowflake.snowpark import DataFrame, Row from snowflake.snowpark.functions import lit -from snowflake.snowpark.mock.connection import MockServerConnection from snowflake.snowpark.types import ( BooleanType, DecimalType, @@ -15,11 +17,11 @@ StructField, StructType, ) - -session = Session(MockServerConnection()) +from tests.utils import Utils -def test_basic_filter(): +@pytest.mark.localtest +def test_basic_filter(session): df: DataFrame = session.create_dataframe( [ [1, 2, "abc"], @@ -42,7 +44,8 @@ def test_basic_filter(): ) -def test_plus_basic(): +@pytest.mark.localtest +def test_plus_basic(session): df = session.create_dataframe( [[1, 1.1, 2.2, 3.3]], schema=StructType( @@ -63,7 +66,7 @@ def test_plus_basic(): assert repr(df.schema) == repr( StructType( [ - StructField("NEW_A", DecimalType(38, 0), nullable=False), + StructField("NEW_A", LongType(), nullable=False), StructField("NEW_B", DecimalType(5, 2), nullable=False), StructField("NEW_C", DoubleType(), nullable=False), ] @@ -71,7 +74,8 @@ def test_plus_basic(): ) -def test_minus_basic(): +@pytest.mark.localtest +def test_minus_basic(session): df = session.create_dataframe( [[1, 1.1, 2.2, 3.3]], schema=StructType( @@ -92,7 +96,7 @@ def test_minus_basic(): assert repr(df.schema) == repr( StructType( [ - StructField("NEW_A", DecimalType(38, 0), nullable=False), + StructField("NEW_A", LongType(), nullable=False), StructField("NEW_B", DecimalType(5, 2), nullable=False), StructField("NEW_C", DoubleType(), nullable=False), ] @@ -100,14 +104,15 @@ def test_minus_basic(): ) -def test_multiple_basic(): +@pytest.mark.localtest +def test_multiple_basic(session): df = session.create_dataframe( [[1, 1.1, 2.2, 3.3]], schema=StructType( [ StructField("a", LongType(), nullable=False), StructField("b", DecimalType(3, 1), nullable=False), - StructField("c", DoubleType(), nullable=False), + StructField("c", FloatType(), nullable=False), StructField("d", DecimalType(4, 2), nullable=False), ] ), @@ -121,7 +126,7 @@ def test_multiple_basic(): assert repr(df.schema) == repr( StructType( [ - StructField("NEW_A", DecimalType(38, 0), nullable=False), + StructField("NEW_A", LongType(), nullable=False), StructField("NEW_B", DecimalType(7, 3), nullable=False), StructField("NEW_C", DoubleType(), nullable=False), ] @@ -129,7 +134,8 @@ def test_multiple_basic(): ) -def test_divide_basic(): +@pytest.mark.localtest +def test_divide_basic(session): df = session.create_dataframe( [[1, 1.1, 2.2, 3.3]], schema=StructType( @@ -156,9 +162,29 @@ def test_divide_basic(): ] ) ) + Utils.check_answer( + df, [Row(Decimal("1.0"), Decimal("0.3333333"), 0.7333333333333334)] + ) -def test_modulo_basic(): +@pytest.mark.localtest +def test_div_decimal_double(session): + df = session.create_dataframe( + [[11.0, 13.0]], + schema=StructType( + [StructField("a", DoubleType()), StructField("b", DoubleType())] + ), + ) + df = df.select([df["a"] / df["b"]]) + Utils.check_answer(df, [Row(0.8461538461538461)]) + + df2 = session.create_dataframe([[11, 13]], schema=["a", "b"]) + df2 = df2.select([df2["a"] / df2["b"]]) + Utils.check_answer(df2, [Row(Decimal("0.846154"))]) + + +@pytest.mark.localtest +def test_modulo_basic(session): df = session.create_dataframe( [[1, 1.1, 2.2, 3.3]], schema=StructType( @@ -179,7 +205,7 @@ def test_modulo_basic(): assert repr(df.schema) == repr( StructType( [ - StructField("NEW_A", DecimalType(38, 0), nullable=False), + StructField("NEW_A", LongType(), nullable=False), StructField("NEW_B", DecimalType(4, 2), nullable=False), StructField("NEW_C", DoubleType(), nullable=False), ] @@ -187,7 +213,8 @@ def test_modulo_basic(): ) -def test_binary_ops_bool(): +@pytest.mark.localtest +def test_binary_ops_bool(session): df = session.create_dataframe( [[1, 1.1]], schema=StructType( @@ -208,14 +235,12 @@ def test_binary_ops_bool(): assert repr(df1.schema) == repr( StructType( [ - StructField('GREATERTHAN("A", "B")', BooleanType(), nullable=False), - StructField( - 'GREATERTHANOREQUAL("A", "B")', BooleanType(), nullable=False - ), - StructField('EQUALTO("A", "B")', BooleanType(), nullable=False), - StructField('NOTEQUALTO("A", "B")', BooleanType(), nullable=False), - StructField('LESSTHAN("A", "B")', BooleanType(), nullable=False), - StructField('LESSTHANOREQUAL("A", "B")', BooleanType(), nullable=False), + StructField('"(""A"" > ""B"")"', BooleanType(), nullable=True), + StructField('"(""A"" >= ""B"")"', BooleanType(), nullable=True), + StructField('"(""A"" = ""B"")"', BooleanType(), nullable=True), + StructField('"(""A"" != ""B"")"', BooleanType(), nullable=True), + StructField('"(""A"" < ""B"")"', BooleanType(), nullable=True), + StructField('"(""A"" <= ""B"")"', BooleanType(), nullable=True), ] ) ) @@ -228,21 +253,22 @@ def test_binary_ops_bool(): StructType( [ StructField( - 'AND(GREATERTHAN("A", "B"), GREATERTHANOREQUAL("A", "B"))', + '"((""A"" > ""B"") AND (""A"" >= ""B""))"', BooleanType(), - nullable=False, + nullable=True, ), StructField( - 'OR(GREATERTHAN("A", "B"), GREATERTHANOREQUAL("A", "B"))', + '"((""A"" > ""B"") OR (""A"" >= ""B""))"', BooleanType(), - nullable=False, + nullable=True, ), ] ) ) -def test_unary_ops_bool(): +@pytest.mark.localtest +def test_unary_ops_bool(session): df = session.create_dataframe( [[1, 1.1]], schema=StructType( @@ -262,39 +288,42 @@ def test_unary_ops_bool(): assert repr(df.schema) == repr( StructType( [ - StructField('ISNULL("A")', BooleanType(), nullable=False), - StructField('ISNOTNULL("A")', BooleanType(), nullable=False), - StructField('ISNAN("A")', BooleanType(), nullable=False), - StructField('NOT(ISNULL("A"))', BooleanType(), nullable=False), + StructField('"""A"" IS NULL"', BooleanType(), nullable=True), + StructField('"""A"" IS NOT NULL"', BooleanType(), nullable=True), + StructField('"""A"" = \'NAN\'"', BooleanType(), nullable=True), + StructField('"NOT ""A"" IS NULL"', BooleanType(), nullable=True), ] ) ) -def test_literal(): +@pytest.mark.localtest +def test_literal(session): df = session.create_dataframe( [[1]], schema=StructType([StructField("a", LongType(), nullable=False)]) ) df = df.select(lit("lit_value")) assert repr(df.schema) == repr( - StructType([StructField("LITERAL()", StringType(), nullable=False)]) + StructType([StructField("\"'LIT_VALUE'\"", StringType(9), nullable=False)]) ) -def test_string_op_bool(): +@pytest.mark.localtest +def test_string_op_bool(session): df = session.create_dataframe([["value"]], schema=["a"]) df = df.select(df["a"].like("v%"), df["a"].regexp("v")) assert repr(df.schema) == repr( StructType( [ - StructField('LIKE("A")', BooleanType(), nullable=False), - StructField('REGEXP("A")', BooleanType(), nullable=False), + StructField('"""A"" LIKE \'V%\'"', BooleanType(), nullable=True), + StructField('"""A"" REGEXP \'V\'"', BooleanType(), nullable=True), ] ) ) -def test_filter(): +@pytest.mark.localtest +def test_filter(session): df = session.create_dataframe( [[1, 1.1, 2.2, 3.3]], schema=StructType( @@ -311,7 +340,8 @@ def test_filter(): assert repr(df1.schema) == repr(df.schema) -def test_sort(): +@pytest.mark.localtest +def test_sort(session): df = session.create_dataframe( [[1, 1.1, 2.2, 3.3]], schema=StructType( @@ -328,7 +358,8 @@ def test_sort(): assert repr(df1.schema) == repr(df.schema) -def test_limit(): +@pytest.mark.localtest +def test_limit(session): df = session.create_dataframe( [[1, 1.1, 2.2, 3.3]], schema=StructType( @@ -345,7 +376,8 @@ def test_limit(): assert repr(df1.schema) == repr(df.schema) -def test_chain_filter_sort_limit(): +@pytest.mark.localtest +def test_chain_filter_sort_limit(session): df = session.create_dataframe( [[1, 1.1, 2.2, 3.3]], schema=StructType( @@ -366,7 +398,8 @@ def test_chain_filter_sort_limit(): assert repr(df1.schema) == repr(df.schema) -def test_join_basic(): +@pytest.mark.localtest +def test_join_basic(session): df = session.create_dataframe( [[1, 1.1, 2.2, 3.3]], schema=["a", "b", "c"], @@ -380,13 +413,13 @@ def test_join_basic(): StructType( [ StructField("A_L", LongType(), nullable=False), - StructField("B_L", FloatType(), nullable=False), - StructField("C_L", FloatType(), nullable=False), - StructField("_4_L", FloatType(), nullable=False), + StructField("B_L", DoubleType(), nullable=False), + StructField("C_L", DoubleType(), nullable=False), + StructField("_4_L", DoubleType(), nullable=False), StructField("A_R", LongType(), nullable=False), - StructField("B_R", FloatType(), nullable=False), - StructField("C_R", FloatType(), nullable=False), - StructField("_4_R", FloatType(), nullable=False), + StructField("B_R", DoubleType(), nullable=False), + StructField("C_R", DoubleType(), nullable=False), + StructField("_4_R", DoubleType(), nullable=False), ] ) ) diff --git a/tests/mock/test_create_df_from_pandas.py b/tests/mock/test_create_df_from_pandas.py index 3091f6b8004..c1d216e27f8 100644 --- a/tests/mock/test_create_df_from_pandas.py +++ b/tests/mock/test_create_df_from_pandas.py @@ -73,7 +73,7 @@ def test_create_from_pandas_basic_pandas_types(): str(sp_df.schema) == """\ StructType([\ -StructField('"sTr"', StringType(16777216), nullable=True), \ +StructField('"sTr"', StringType(), nullable=True), \ StructField('"dOublE"', DoubleType(), nullable=True), \ StructField('"LoNg"', LongType(), nullable=True), \ StructField('"booL"', BooleanType(), nullable=True), \