diff --git a/snowpark_streaming_demo.py b/snowpark_streaming_demo.py index 92de4d204d5..ff5da216221 100644 --- a/snowpark_streaming_demo.py +++ b/snowpark_streaming_demo.py @@ -1,8 +1,9 @@ from snowflake.snowpark.session import Session from snowflake.snowpark.functions import parse_json, col -from snowflake.snowpark.types import StructType, MapType, StructField, StringType +from snowflake.snowpark.types import StructType, MapType, StructField, StringType, IntegerType, FloatType, TimestampType import logging; logging.getLogger("snowflake.snowpark").setLevel(logging.DEBUG) import pandas as pd +from snowflake.snowpark.async_job import AsyncJob # Function to generate random JSON data @@ -34,6 +35,16 @@ def generate_json_data(): static_df = session.table("static_df") +kafka_event_schema = StructType( + [ + StructField(column_identifier="ID", datatype=IntegerType()), + StructField(column_identifier="NAME", datatype=StringType()), + StructField(column_identifier="PRICE", datatype=FloatType()), + StructField(column_identifier="TIMESTAMP", datatype=TimestampType()), + ] + ) + + # Subscribe to 1 topic kafka_ingest_df = ( session @@ -42,29 +53,37 @@ def generate_json_data(): .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("topic", "topic1") .option("partition_id", 1) - .schema( - StructType( - [ - StructField(column_identifier="KEY", datatype=StringType()), - StructField(column_identifier="STREAM_VALUE", datatype=StringType()) - ] - ) - ) + .schema(kafka_event_schema) .load() ) -# Join kafka ingest to static table, and write result to dynamic table. -joined = kafka_ingest_df.join(static_df, on='KEY') -joined.create_or_replace_dynamic_table( - 'dynamic_join_result', - warehouse=session.connection.warehouse, - lag='1 hour', - - ) +RESULT_TABLE_NAME = "dynamic_join_result"; + +transformed_df = kafka_ingest_df \ + .select(col("id"), col("timestamp"), col("name")) \ + .filter(col("price") > 100.0) + + +""" +This query looks like + +SELECT write_stream_udf('dynamic_join_result', "id", "timestamp", "name") +FROM (SELECT id, + name, + price, + timestamp + FROM ( TABLE (my_streaming_udtf('host1:port1,host2:port2', 'topic1', 1 + :: INT + ) ))) +WHERE ( "price" > 100.0 ) +""" + +streaming_query: AsyncJob = transformed_df \ + .writeStream \ + .toTable(RESULT_TABLE_NAME) + +streaming_query.cancel() -# Clean up dynamic table. -drop_result = session.connection.cursor().execute('DROP DYNAMIC TABLE dynamic_join_result;') -assert drop_result is not None # # Write streaming dataframe to output data sink # sink_query = ( diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 2e7f08e527f..e7970265fe0 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -1000,12 +1000,13 @@ def load(self) -> DataFrame: bootstrap_servers = self._cur_options["kafka.bootstrap.servers".upper()] topic = self._cur_options["topic".upper()] partition_id = self._cur_options["partition_id".upper()] + self._session.custom_package_usage_config['force_push'] = True self._session.custom_package_usage_config['enabled'] = True self._session.add_import(snowflake.snowpark.kafka_ingest_udtf.__file__, import_path="snowflake.snowpark.kafka_ingest_udtf") self._session.add_packages(["python-confluent-kafka"]) - self._session.sql("create or replace stage mystage").collect() + kafka_udtf = udtf( KafkaFetch, output_schema=self._user_schema, diff --git a/src/snowflake/snowpark/dataframe_writer.py b/src/snowflake/snowpark/dataframe_writer.py index 635871884bb..8729c8c0caf 100644 --- a/src/snowflake/snowpark/dataframe_writer.py +++ b/src/snowflake/snowpark/dataframe_writer.py @@ -8,12 +8,14 @@ import snowflake.snowpark # for forward references of type hints import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto +from snowflake.snowpark.write_stream_to_table import write_stream_to_table from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( CopyIntoLocationNode, SaveMode, SnowflakeCreateTable, TableCreationSource, ) +from snowflake.snowpark.async_job import AsyncJob from snowflake.snowpark._internal.ast.utils import ( build_expr_from_snowpark_column_or_col_name, debug_check_missing_ast, @@ -40,8 +42,9 @@ warning, ) from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType +from snowflake.snowpark.types import StringType from snowflake.snowpark.column import Column, _to_col_if_str -from snowflake.snowpark.functions import sql_expr +from snowflake.snowpark.functions import sql_expr, udf, lit, col from snowflake.snowpark.mock._connection import MockServerConnection from snowflake.snowpark.row import Row @@ -917,6 +920,32 @@ def parquet( saveAsTable = save_as_table + class DataStreamWriter(DataFrameWriter): - def start(self): - raise NotImplementedError("cannot write a data stream yet.") \ No newline at end of file + def toTable(self, table_name: str) -> AsyncJob: + self._dataframe.session.custom_package_usage_config['force_push'] = True + self._dataframe.session.custom_package_usage_config['enabled'] = True + self._dataframe.session.add_import(snowflake.snowpark.write_stream_to_table.__file__, import_path="snowflake.snowpark.write_stream_to_table") + self._dataframe.session.sql("create or replace stage mystage").collect() + + + write_stream_udf = udf( + write_stream_to_table, + input_types= + [ + StringType(), + *(f.datatype for f in self._dataframe.schema.fields) + ], + is_permanent=True, + replace=True, + name='write_stream_udf', + stage_location="@mystage" + ) + + return self._dataframe.select(write_stream_udf( + lit(table_name), + *( + col(f.name) + for f in self._dataframe.schema.fields + ) + )).collect_nowait() \ No newline at end of file diff --git a/src/snowflake/snowpark/kafka_ingest_udtf.py b/src/snowflake/snowpark/kafka_ingest_udtf.py index f3ed856870f..0a16b0579ee 100644 --- a/src/snowflake/snowpark/kafka_ingest_udtf.py +++ b/src/snowflake/snowpark/kafka_ingest_udtf.py @@ -68,7 +68,7 @@ def process(self, bootstrap_servers: str, topic: str, partition_id: int): # logging.info(f"Received message: {msg.value().decode('utf-8')}") # yield (msg.value().decode('utf-8'),) - yield (str(i), str(generate_json_data())) + yield tuple(generate_json_data().values()) except: logging.error("Consumer Error")