Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dagster-snowflake-pandas] pandas timestamp conversion fix #12190

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from typing import Mapping, Union, cast

import pandas as pd
import pandas.core.dtypes.common as pd_core_dtypes_common
from dagster import (
InputContext,
MetadataValue,
OutputContext,
TableColumn,
TableSchema,
)
from dagster import InputContext, MetadataValue, OutputContext, TableColumn, TableSchema
from dagster._core.definitions.metadata import RawMetadataValue
from dagster._core.errors import DagsterInvalidInvocationError
from dagster._core.storage.db_io_manager import DbTypeHandler, TableSlice
from dagster_snowflake import build_snowflake_io_manager
from dagster_snowflake.resources import SnowflakeConnection
from dagster_snowflake.snowflake_io_manager import SnowflakeDbClient
from snowflake.connector.pandas_tools import pd_writer
from sqlalchemy.exc import InterfaceError


def _connect_snowflake(context: Union[InputContext, OutputContext], table_slice: TableSlice):
Expand All @@ -33,34 +28,6 @@ def _connect_snowflake(context: Union[InputContext, OutputContext], table_slice:
).get_connection(raw_conn=False)


def _convert_timestamp_to_string(s: pd.Series) -> pd.Series:
"""
Converts columns of data of type pd.Timestamp to string so that it can be stored in
snowflake.
"""
if pd_core_dtypes_common.is_datetime_or_timedelta_dtype(s): # type: ignore # (bad stubs)
return s.dt.strftime("%Y-%m-%d %H:%M:%S.%f %z")
else:
return s


def _convert_string_to_timestamp(s: pd.Series) -> pd.Series:
"""
Converts columns of strings in Timestamp format to pd.Timestamp to undo the conversion in
_convert_timestamp_to_string.

This will not convert non-timestamp strings into timestamps (pd.to_datetime will raise an
exception if the string cannot be converted)
"""
if isinstance(s[0], str):
try:
return pd.to_datetime(s.values) # type: ignore # (bad stubs)
except ValueError:
return s
else:
return s


class SnowflakePandasTypeHandler(DbTypeHandler[pd.DataFrame]):
"""
Plugin for the Snowflake I/O Manager that can store and load Pandas DataFrames as Snowflake tables.
Expand All @@ -86,16 +53,23 @@ def handle_output(
connector.paramstyle = "pyformat"
with _connect_snowflake(context, table_slice) as con:
with_uppercase_cols = obj.rename(str.upper, copy=False, axis="columns")
with_uppercase_cols = with_uppercase_cols.apply(
_convert_timestamp_to_string, axis="index"
)
with_uppercase_cols.to_sql(
table_slice.table,
con=con.engine,
if_exists="append",
index=False,
method=pd_writer,
)

try:
with_uppercase_cols.to_sql(
table_slice.table,
con=con.engine,
if_exists="append",
index=False,
method=pd_writer,
)
except InterfaceError as e:
Copy link
Contributor Author

@jamiedemaria jamiedemaria Feb 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another option instead of reacting to a failed write would be to pre-emptively check if the dataframe as Timestamp data. if so we check if there is timezone information, and raise an error if it doesn't have that info

if "out of range" in e.orig.msg:
raise DagsterInvalidInvocationError(
f"Could not store output {context.name} of step {context.step_key}. If the"
" DataFrame includes pandas Timestamp values, ensure that they have"
" timezones."
) from e
raise e

return {
"row_count": obj.shape[0],
Expand All @@ -111,8 +85,18 @@ def handle_output(

def load_input(self, context: InputContext, table_slice: TableSlice) -> pd.DataFrame:
with _connect_snowflake(context, table_slice) as con:
result = pd.read_sql(sql=SnowflakeDbClient.get_select_statement(table_slice), con=con)
result = result.apply(_convert_string_to_timestamp, axis="index")
try:
result = pd.read_sql(
sql=SnowflakeDbClient.get_select_statement(table_slice), con=con
)
except InterfaceError as e:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we go down the route of https://github.com/dagster-io/dagster/pull/12190/files#r1100731596 this error check would get removed

if "out of range" in e.orig.msg:
raise DagsterInvalidInvocationError(
f"Could not load input {context.name} of {context.op_def.name}. If the"
" DataFrame includes pandas Timestamp values, ensure that they have"
" timezones."
) from e
raise e
result.columns = map(str.lower, result.columns) # type: ignore # (bad stubs)
return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas
import pytest
from dagster import (
DagsterInvalidInvocationError,
DailyPartitionsDefinition,
IOManagerDefinition,
MetadataValue,
Expand All @@ -26,10 +27,6 @@
from dagster_snowflake import build_snowflake_io_manager
from dagster_snowflake.resources import SnowflakeConnection
from dagster_snowflake_pandas import SnowflakePandasTypeHandler, snowflake_pandas_io_manager
from dagster_snowflake_pandas.snowflake_pandas_type_handler import (
_convert_string_to_timestamp,
_convert_timestamp_to_string,
)
from pandas import DataFrame, Timestamp

resource_config = {
Expand Down Expand Up @@ -112,31 +109,6 @@ def test_load_input():
assert df.equals(DataFrame([{"col1": "a", "col2": 1}]))


def test_type_conversions():
# no timestamp data
no_time = pandas.Series([1, 2, 3, 4, 5])
converted = _convert_string_to_timestamp(_convert_timestamp_to_string(no_time))

assert (converted == no_time).all()

# timestamp data
with_time = pandas.Series(
[
pandas.Timestamp("2017-01-01T12:30:45.35"),
pandas.Timestamp("2017-02-01T12:30:45.35"),
pandas.Timestamp("2017-03-01T12:30:45.35"),
]
)
time_converted = _convert_string_to_timestamp(_convert_timestamp_to_string(with_time))

assert (with_time == time_converted).all()

# string that isn't a time
string_data = pandas.Series(["not", "a", "timestamp"])

assert (_convert_string_to_timestamp(string_data) == string_data).all()


def test_build_snowflake_pandas_io_manager():
assert isinstance(
build_snowflake_io_manager([SnowflakePandasTypeHandler()]), IOManagerDefinition
Expand Down Expand Up @@ -201,8 +173,8 @@ def test_io_manager_with_snowflake_pandas_timestamp_data():
{
"foo": ["bar", "baz"],
"date": [
pandas.Timestamp("2017-01-01T12:30:45.350"),
pandas.Timestamp("2017-02-01T12:30:45.350"),
pandas.Timestamp("2017-01-01T12:30:15+00:00"),
pandas.Timestamp("2017-02-01T01:30:15+00:00"),
],
}
)
Expand All @@ -219,6 +191,7 @@ def emit_time_df(_):

@op
def read_time_df(df: pandas.DataFrame):
df["date"] = df["date"].dt.tz_localize("UTC")
assert set(df.columns) == {"foo", "date"}
assert (df["date"] == time_df["date"]).all()

Expand All @@ -242,6 +215,59 @@ def io_manager_timestamp_test_job():
assert res.success


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
def test_io_manager_with_snowflake_pandas_timestamp_data_error():
with temporary_snowflake_table(
schema_name="SNOWFLAKE_IO_MANAGER_SCHEMA",
db_name="TEST_SNOWFLAKE_IO_MANAGER",
column_str="foo string, date TIMESTAMP_NTZ(9)",
) as table_name:
time_df = pandas.DataFrame(
{
"foo": ["bar", "baz"],
"date": [
pandas.Timestamp("2017-01-01T12:30:15"),
pandas.Timestamp("2017-02-01T01:30:15"),
],
}
)

@op(
out={
table_name: Out(
io_manager_key="snowflake", metadata={"schema": "SNOWFLAKE_IO_MANAGER_SCHEMA"}
)
}
)
def emit_time_df(_):
return time_df

@op
def read_time_df(df: pandas.DataFrame):
df["date"] = df["date"].dt.tz_localize("UTC")
assert set(df.columns) == {"foo", "date"}
assert (df["date"] == time_df["date"]).all()

@job(
resource_defs={"snowflake": snowflake_pandas_io_manager},
config={
"resources": {
"snowflake": {
"config": {
**SHARED_BUILDKITE_SNOWFLAKE_CONF,
"database": "TEST_SNOWFLAKE_IO_MANAGER",
}
}
}
},
)
def io_manager_timestamp_test_job():
read_time_df(emit_time_df())

with pytest.raises(DagsterInvalidInvocationError):
io_manager_timestamp_test_job.execute_in_process()


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
def test_time_window_partitioned_asset(tmp_path):
with temporary_snowflake_table(
Expand All @@ -258,7 +284,7 @@ def test_time_window_partitioned_asset(tmp_path):
name=table_name,
)
def daily_partitioned(context):
partition = Timestamp(context.asset_partition_key_for_output())
partition = Timestamp(context.asset_partition_key_for_output()).tz_localize("UTC")
value = context.op_config["value"]

return DataFrame(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def delete_table_slice(context: OutputContext, table_slice: TableSlice) -> None:
dict(schema=table_slice.schema, **no_schema_config), context.log
).get_connection() as con:
try:
print("DELETING DATA")
con.execute_string(_get_cleanup_statement(table_slice))
except ProgrammingError:
# table doesn't exist yet, so ignore the error
Expand Down