diff --git a/python/hsfs/core/constants.py b/python/hsfs/core/constants.py index c2cbd33a8b..d6af380185 100644 --- a/python/hsfs/core/constants.py +++ b/python/hsfs/core/constants.py @@ -1,6 +1,19 @@ import importlib.util +# Avro +HAS_FAST_AVRO: bool = importlib.util.find_spec("fastavro") is not None +HAS_AVRO: bool = importlib.util.find_spec("avro") is not None + +# Confluent Kafka +HAS_CONFLUENT_KAFKA: bool = importlib.util.find_spec("confluent_kafka") is not None +confluent_kafka_not_installed_message = ( + "Confluent Kafka package not found. " + "If you want to use Kafka with Hopsworks you can install the corresponding extras " + """`pip install hopsworks[python]` or `pip install "hopsworks[python]"` if using zsh. """ + "You can also install confluent-kafka directly in your environment e.g `pip install confluent-kafka`. " + "You will need to restart your kernel if applicable." +) # Data Validation / Great Expectations HAS_GREAT_EXPECTATIONS: bool = ( importlib.util.find_spec("great_expectations") is not None diff --git a/python/hsfs/core/kafka_engine.py b/python/hsfs/core/kafka_engine.py new file mode 100644 index 0000000000..ca1a50f0a5 --- /dev/null +++ b/python/hsfs/core/kafka_engine.py @@ -0,0 +1,248 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +import json +from io import BytesIO +from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional, Tuple, Union + +from hsfs import client +from hsfs.client import hopsworks +from hsfs.core import storage_connector_api +from hsfs.core.constants import HAS_AVRO, HAS_CONFLUENT_KAFKA, HAS_FAST_AVRO +from tqdm import tqdm + + +if HAS_CONFLUENT_KAFKA: + from confluent_kafka import Consumer, KafkaError, Producer, TopicPartition + +if HAS_FAST_AVRO: + from fastavro import schemaless_writer + from fastavro.schema import parse_schema +elif HAS_AVRO: + import avro.io + import avro.schema + + +if TYPE_CHECKING: + from hsfs.feature_group import ExternalFeatureGroup, FeatureGroup + + +def init_kafka_consumer( + feature_store_id: int, + offline_write_options: Dict[str, Any], +) -> Consumer: + # setup kafka consumer + consumer_config = get_kafka_config(feature_store_id, offline_write_options) + if "group.id" not in consumer_config: + consumer_config["group.id"] = "hsfs_consumer_group" + + return Consumer(consumer_config) + + +def init_kafka_resources( + feature_group: Union[FeatureGroup, ExternalFeatureGroup], + offline_write_options: Dict[str, Any], + project_id: int, +) -> Tuple[ + Producer, Dict[str, bytes], Dict[str, Callable[..., bytes]], Callable[..., bytes] : +]: + # this function is a caching wrapper around _init_kafka_resources + if feature_group._multi_part_insert and feature_group._kafka_producer: + return ( + feature_group._kafka_producer, + feature_group._kafka_headers, + feature_group._feature_writers, + feature_group._writer, + ) + producer, headers, feature_writers, writer = _init_kafka_resources( + feature_group, offline_write_options, project_id + ) + if feature_group._multi_part_insert: + feature_group._kafka_producer = producer + feature_group._kafka_headers = headers + feature_group._feature_writers = feature_writers + feature_group._writer = writer + return producer, headers, feature_writers, writer + + +def _init_kafka_resources( + feature_group: Union[FeatureGroup, ExternalFeatureGroup], + offline_write_options: Dict[str, Any], + project_id: int, +) -> Tuple[ + Producer, Dict[str, bytes], Dict[str, Callable[..., bytes]], Callable[..., bytes] : +]: + # setup kafka producer + producer = init_kafka_producer( + feature_group.feature_store_id, offline_write_options + ) + # setup complex feature writers + feature_writers = { + feature: get_encoder_func(feature_group._get_feature_avro_schema(feature)) + for feature in feature_group.get_complex_features() + } + # setup row writer function + writer = get_encoder_func(feature_group._get_encoded_avro_schema()) + + # custom headers for hopsworks onlineFS + headers = { + "projectId": str(project_id).encode("utf8"), + "featureGroupId": str(feature_group._id).encode("utf8"), + "subjectId": str(feature_group.subject["id"]).encode("utf8"), + } + return producer, headers, feature_writers, writer + + +def init_kafka_producer( + feature_store_id: int, + offline_write_options: Dict[str, Any], +) -> Producer: + # setup kafka producer + return Producer(get_kafka_config(feature_store_id, offline_write_options)) + + +def kafka_get_offsets( + topic_name: str, + feature_store_id: int, + offline_write_options: Dict[str, Any], + high: bool, +) -> str: + consumer = init_kafka_consumer(feature_store_id, offline_write_options) + topics = consumer.list_topics( + timeout=offline_write_options.get("kafka_timeout", 6) + ).topics + if topic_name in topics.keys(): + # topic exists + offsets = "" + tuple_value = int(high) + for partition_metadata in topics.get(topic_name).partitions.values(): + partition = TopicPartition( + topic=topic_name, partition=partition_metadata.id + ) + offsets += f",{partition_metadata.id}:{consumer.get_watermark_offsets(partition)[tuple_value]}" + consumer.close() + + return f" -initialCheckPointString {topic_name + offsets}" + return "" + + +def kafka_produce( + producer: Producer, + key: str, + encoded_row: bytes, + topic_name: str, + headers: Dict[str, bytes], + acked: callable, + debug_kafka: bool = False, +) -> None: + while True: + # if BufferError is thrown, we can be sure, message hasn't been send so we retry + try: + # produce + producer.produce( + topic=topic_name, + key=key, + value=encoded_row, + callback=acked, + headers=headers, + ) + + # Trigger internal callbacks to empty op queue + producer.poll(0) + break + except BufferError as e: + if debug_kafka: + print("Caught: {}".format(e)) + # backoff for 1 second + producer.poll(1) + + +def encode_complex_features( + feature_writers: Dict[str, callable], row: Dict[str, Any] +) -> Dict[str, Any]: + for feature_name, writer in feature_writers.items(): + with BytesIO() as outf: + writer(row[feature_name], outf) + row[feature_name] = outf.getvalue() + return row + + +def get_encoder_func(writer_schema: str) -> callable: + if HAS_FAST_AVRO: + schema = json.loads(writer_schema) + parsed_schema = parse_schema(schema) + return lambda record, outf: schemaless_writer(outf, parsed_schema, record) + + parsed_schema = avro.schema.parse(writer_schema) + writer = avro.io.DatumWriter(parsed_schema) + return lambda record, outf: writer.write(record, avro.io.BinaryEncoder(outf)) + + +def get_kafka_config( + feature_store_id: int, + write_options: Optional[Dict[str, Any]] = None, + engine: Literal["spark", "confluent"] = "confluent", +) -> Dict[str, Any]: + if write_options is None: + write_options = {} + external = not ( + isinstance(client.get_instance(), hopsworks.Client) + or write_options.get("internal_kafka", False) + ) + + storage_connector = storage_connector_api.StorageConnectorApi().get_kafka_connector( + feature_store_id, external + ) + + if engine == "spark": + config = storage_connector.spark_options() + config.update(write_options) + elif engine == "confluent": + config = storage_connector.confluent_options() + config.update(write_options.get("kafka_producer_config", {})) + return config + + +def build_ack_callback_and_optional_progress_bar( + n_rows: int, is_multi_part_insert: bool, offline_write_options: Dict[str, Any] +) -> Tuple[Callable, Optional[tqdm]]: + if not is_multi_part_insert: + progress_bar = tqdm( + total=n_rows, + bar_format="{desc}: {percentage:.2f}% |{bar}| Rows {n_fmt}/{total_fmt} | " + "Elapsed Time: {elapsed} | Remaining Time: {remaining}", + desc="Uploading Dataframe", + mininterval=1, + ) + else: + progress_bar = None + + def acked(err: Exception, msg: Any) -> None: + if err is not None: + if offline_write_options.get("debug_kafka", False): + print("Failed to deliver message: %s: %s" % (str(msg), str(err))) + if err.code() in [ + KafkaError.TOPIC_AUTHORIZATION_FAILED, + KafkaError._MSG_TIMED_OUT, + ]: + progress_bar.colour = "RED" + raise err # Stop producing and show error + # update progress bar for each msg + if not is_multi_part_insert: + progress_bar.update() + + return acked, progress_bar diff --git a/python/hsfs/engine/python.py b/python/hsfs/engine/python.py index 4e094a377e..5e2b9699e8 100644 --- a/python/hsfs/engine/python.py +++ b/python/hsfs/engine/python.py @@ -32,7 +32,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Dict, List, Literal, @@ -45,7 +44,6 @@ if TYPE_CHECKING: import great_expectations -import avro import boto3 import hsfs import numpy as np @@ -54,7 +52,6 @@ import pyarrow as pa import pytz from botocore.response import StreamingBody -from confluent_kafka import Consumer, KafkaError, Producer, TopicPartition from hsfs import ( client, feature, @@ -64,7 +61,6 @@ util, ) from hsfs import storage_connector as sc -from hsfs.client import hopsworks from hsfs.client.exceptions import FeatureStoreException from hsfs.constructor import query from hsfs.core import ( @@ -75,6 +71,7 @@ ingestion_job_conf, job, job_api, + kafka_engine, statistics_api, storage_connector_api, training_dataset_api, @@ -89,19 +86,8 @@ from hsfs.training_dataset import TrainingDataset from hsfs.training_dataset_feature import TrainingDatasetFeature from hsfs.training_dataset_split import TrainingDatasetSplit -from sqlalchemy import sql -from tqdm.auto import tqdm -HAS_FAST = False -try: - from fastavro import schemaless_writer - from fastavro.schema import parse_schema - - HAS_FAST = True -except ImportError: - pass - if HAS_GREAT_EXPECTATIONS: import great_expectations @@ -1239,52 +1225,6 @@ def get_unique_values( ) -> np.ndarray: return feature_dataframe[feature_name].unique() - def _init_kafka_producer( - self, - feature_group: Union[FeatureGroup, ExternalFeatureGroup], - offline_write_options: Dict[str, Any], - ) -> Producer: - # setup kafka producer - return Producer( - self._get_kafka_config( - feature_group.feature_store_id, offline_write_options - ) - ) - - def _init_kafka_consumer( - self, - feature_group: Union[FeatureGroup, ExternalFeatureGroup], - offline_write_options: Dict[str, Any], - ) -> Consumer: - # setup kafka consumer - consumer_config = self._get_kafka_config( - feature_group.feature_store_id, offline_write_options - ) - if "group.id" not in consumer_config: - consumer_config["group.id"] = "hsfs_consumer_group" - - return Consumer(consumer_config) - - def _init_kafka_resources( - self, - feature_group: Union[FeatureGroup, ExternalFeatureGroup], - offline_write_options: Dict[str, Any], - ) -> Tuple[Producer, Dict[str, Callable], Callable]: - # setup kafka producer - producer = self._init_kafka_producer(feature_group, offline_write_options) - - # setup complex feature writers - feature_writers = { - feature: self._get_encoder_func( - feature_group._get_feature_avro_schema(feature) - ) - for feature in feature_group.get_complex_features() - } - - # setup row writer function - writer = self._get_encoder_func(feature_group._get_encoded_avro_schema()) - return producer, feature_writers, writer - def _write_dataframe_kafka( self, feature_group: Union[FeatureGroup, ExternalFeatureGroup], @@ -1292,50 +1232,25 @@ def _write_dataframe_kafka( offline_write_options: Dict[str, Any], ) -> Optional[job.Job]: initial_check_point = "" - if feature_group._multi_part_insert: - if feature_group._kafka_producer is None: - producer, feature_writers, writer = self._init_kafka_resources( - feature_group, offline_write_options - ) - feature_group._kafka_producer = producer - feature_group._feature_writers = feature_writers - feature_group._writer = writer - else: - producer = feature_group._kafka_producer - feature_writers = feature_group._feature_writers - writer = feature_group._writer - else: - producer, feature_writers, writer = self._init_kafka_resources( - feature_group, offline_write_options - ) - - # initialize progress bar - progress_bar = tqdm( - total=dataframe.shape[0], - bar_format="{desc}: {percentage:.2f}% |{bar}| Rows {n_fmt}/{total_fmt} | " - "Elapsed Time: {elapsed} | Remaining Time: {remaining}", - desc="Uploading Dataframe", - mininterval=1, - ) - + producer, headers, feature_writers, writer = kafka_engine.init_kafka_resources( + feature_group, + offline_write_options, + project_id=client.get_instance().project_id, + ) + if not feature_group._multi_part_insert: # set initial_check_point to the current offset - initial_check_point = self._kafka_get_offsets( - feature_group, offline_write_options, True + initial_check_point = kafka_engine.kafka_get_offsets( + topic_name=feature_group._online_topic_name, + feature_store_id=feature_group.feature_store_id, + offline_write_options=offline_write_options, + high=True, ) - def acked(err: Exception, msg: Any) -> None: - if err is not None: - if offline_write_options.get("debug_kafka", False): - print("Failed to deliver message: %s: %s" % (str(msg), str(err))) - if err.code() in [ - KafkaError.TOPIC_AUTHORIZATION_FAILED, - KafkaError._MSG_TIMED_OUT, - ]: - progress_bar.colour = "RED" - raise err # Stop producing and show error - # update progress bar for each msg - if not feature_group._multi_part_insert: - progress_bar.update() + acked, progress_bar = kafka_engine.build_ack_callback_and_optional_progress_bar( + n_rows=dataframe.shape[0], + is_multi_part_insert=feature_group._multi_part_insert, + offline_write_options=offline_write_options, + ) if isinstance(dataframe, pd.DataFrame): row_iterator = dataframe.itertuples(index=False) @@ -1365,7 +1280,7 @@ def acked(err: Exception, msg: Any) -> None: row[k] = None # encode complex features - row = self._encode_complex_features(feature_writers, row) + row = kafka_engine.encode_complex_features(feature_writers, row) # encode feature row with BytesIO() as outf: @@ -1375,8 +1290,14 @@ def acked(err: Exception, msg: Any) -> None: # assemble key key = "".join([str(row[pk]) for pk in sorted(feature_group.primary_key)]) - self._kafka_produce( - producer, feature_group, key, encoded_row, acked, offline_write_options + kafka_engine.kafka_produce( + producer=producer, + key=key, + encoded_row=encoded_row, + topic_name=feature_group._online_topic_name, + headers=headers, + acked=acked, + debug_kafka=offline_write_options.get("debug_kafka", False), ) # make sure producer blocks and everything is delivered @@ -1384,13 +1305,11 @@ def acked(err: Exception, msg: Any) -> None: producer.flush() progress_bar.close() - # start materialization job + # start materialization job if not an external feature group, otherwise return None + if isinstance(feature_group, ExternalFeatureGroup): + return None # if topic didn't exist, always run the materialization job to reset the offsets except if it's a multi insert - if ( - not isinstance(feature_group, ExternalFeatureGroup) - and not initial_check_point - and not feature_group._multi_part_insert - ): + if not initial_check_point and not feature_group._multi_part_insert: if self._start_offline_materialization(offline_write_options): warnings.warn( "This is the first ingestion after an upgrade or backup/restore, running materialization job even though `start_offline_materialization` was set to `False`.", @@ -1398,17 +1317,18 @@ def acked(err: Exception, msg: Any) -> None: stacklevel=1, ) # set the initial_check_point to the lowest offset (it was not set previously due to topic not existing) - initial_check_point = self._kafka_get_offsets( - feature_group, offline_write_options, False + initial_check_point = kafka_engine.kafka_get_offsets( + topic_name=feature_group._online_topic_name, + feature_store_id=feature_group.feature_store_id, + offline_write_options=offline_write_options, + high=True, ) feature_group.materialization_job.run( args=feature_group.materialization_job.config.get("defaultArgs", "") + initial_check_point, await_termination=offline_write_options.get("wait_for_job", False), ) - elif not isinstance( - feature_group, ExternalFeatureGroup - ) and self._start_offline_materialization(offline_write_options): + elif self._start_offline_materialization(offline_write_options): if not offline_write_options.get( "skip_offsets", False ) and self._job_api.last_execution( @@ -1422,110 +1342,8 @@ def acked(err: Exception, msg: Any) -> None: + initial_check_point, await_termination=offline_write_options.get("wait_for_job", False), ) - if isinstance(feature_group, ExternalFeatureGroup): - return None return feature_group.materialization_job - def _kafka_get_offsets( - self, - feature_group: Union[FeatureGroup, ExternalFeatureGroup], - offline_write_options: Dict[str, Any], - high: bool, - ) -> str: - topic_name = feature_group._online_topic_name - consumer = self._init_kafka_consumer(feature_group, offline_write_options) - topics = consumer.list_topics( - timeout=offline_write_options.get("kafka_timeout", 6) - ).topics - if topic_name in topics.keys(): - # topic exists - offsets = "" - tuple_value = int(high) - for partition_metadata in topics.get(topic_name).partitions.values(): - partition = TopicPartition( - topic=topic_name, partition=partition_metadata.id - ) - offsets += f",{partition_metadata.id}:{consumer.get_watermark_offsets(partition)[tuple_value]}" - consumer.close() - - return f" -initialCheckPointString {topic_name + offsets}" - return "" - - def _kafka_produce( - self, - producer: Producer, - feature_group: Union[FeatureGroup, ExternalFeatureGroup], - key: str, - encoded_row: bytes, - acked: callable, - offline_write_options: Dict[str, Any], - ) -> None: - while True: - # if BufferError is thrown, we can be sure, message hasn't been send so we retry - try: - # produce - header = { - "projectId": str(feature_group.feature_store.project_id).encode( - "utf8" - ), - "featureGroupId": str(feature_group._id).encode("utf8"), - "subjectId": str(feature_group.subject["id"]).encode("utf8"), - } - - producer.produce( - topic=feature_group._online_topic_name, - key=key, - value=encoded_row, - callback=acked, - headers=header, - ) - - # Trigger internal callbacks to empty op queue - producer.poll(0) - break - except BufferError as e: - if offline_write_options.get("debug_kafka", False): - print("Caught: {}".format(e)) - # backoff for 1 second - producer.poll(1) - - def _encode_complex_features( - self, feature_writers: Dict[str, callable], row: Dict[str, Any] - ) -> Dict[str, Any]: - for feature_name, writer in feature_writers.items(): - with BytesIO() as outf: - writer(row[feature_name], outf) - row[feature_name] = outf.getvalue() - return row - - def _get_encoder_func(self, writer_schema: str) -> callable: - if HAS_FAST: - schema = json.loads(writer_schema) - parsed_schema = parse_schema(schema) - return lambda record, outf: schemaless_writer(outf, parsed_schema, record) - - parsed_schema = avro.schema.parse(writer_schema) - writer = avro.io.DatumWriter(parsed_schema) - return lambda record, outf: writer.write(record, avro.io.BinaryEncoder(outf)) - - def _get_kafka_config( - self, feature_store_id: int, write_options: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: - if write_options is None: - write_options = {} - external = not ( - isinstance(client.get_instance(), hopsworks.Client) - or write_options.get("internal_kafka", False) - ) - - storage_connector = self._storage_connector_api.get_kafka_connector( - feature_store_id, external - ) - - config = storage_connector.confluent_options() - config.update(write_options.get("kafka_producer_config", {})) - return config - @staticmethod def _convert_pandas_dtype_to_offline_type(arrow_type: str) -> str: # This is a simple type conversion between pandas dtypes and pyspark (hive) types, diff --git a/python/hsfs/engine/spark.py b/python/hsfs/engine/spark.py index 4a48b80e08..1a9fcd3872 100644 --- a/python/hsfs/engine/spark.py +++ b/python/hsfs/engine/spark.py @@ -23,13 +23,14 @@ import shutil import warnings from datetime import date, datetime, timezone -from typing import TYPE_CHECKING, Any, List, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar, Union if TYPE_CHECKING: import great_expectations + from pyspark.rdd import RDD + from pyspark.sql import DataFrame -import avro import numpy as np import pandas as pd import tzlocal @@ -82,17 +83,16 @@ def iteritems(self): from hsfs import client, feature, training_dataset_feature, util from hsfs import feature_group as fg_mod -from hsfs.client import hopsworks from hsfs.client.exceptions import FeatureStoreException from hsfs.constructor import query from hsfs.core import ( dataset_api, delta_engine, hudi_engine, - storage_connector_api, + kafka_engine, transformation_function_engine, ) -from hsfs.core.constants import HAS_GREAT_EXPECTATIONS +from hsfs.core.constants import HAS_AVRO, HAS_GREAT_EXPECTATIONS from hsfs.decorators import uses_great_expectations from hsfs.storage_connector import StorageConnector from hsfs.training_dataset_split import TrainingDatasetSplit @@ -101,6 +101,9 @@ def iteritems(self): if HAS_GREAT_EXPECTATIONS: import great_expectations +if HAS_AVRO: + import avro + class Engine: HIVE_FORMAT = "hive" @@ -123,7 +126,6 @@ def __init__(self): if importlib.util.find_spec("pydoop"): # If we are on Databricks don't setup Pydoop as it's not available and cannot be easily installed. util.setup_pydoop() - self._storage_connector_api = storage_connector_api.StorageConnectorApi() self._dataset_api = dataset_api.DatasetApi() def sql( @@ -382,17 +384,17 @@ def save_dataframe( def save_stream_dataframe( self, - feature_group, + feature_group: Union[fg_mod.FeatureGroup, fg_mod.ExternalFeatureGroup], dataframe, query_name, output_mode, - await_termination, + await_termination: bool, timeout, - checkpoint_dir, - write_options, + checkpoint_dir: Optional[str], + write_options: Optional[Dict[str, Any]], ): - write_options = self._get_kafka_config( - feature_group.feature_store_id, write_options + write_options = kafka_engine.get_kafka_config( + feature_group.feature_store_id, write_options, engine="spark" ) serialized_df = self._online_fg_to_avro( feature_group, self._encode_complex_features(feature_group, dataframe) @@ -485,8 +487,8 @@ def _save_offline_dataframe( ).saveAsTable(feature_group._get_table_name()) def _save_online_dataframe(self, feature_group, dataframe, write_options): - write_options = self._get_kafka_config( - feature_group.feature_store_id, write_options + write_options = kafka_engine.get_kafka_config( + feature_group.feature_store_id, write_options, engine="spark" ) serialized_df = self._online_fg_to_avro( @@ -511,7 +513,11 @@ def _save_online_dataframe(self, feature_group, dataframe, write_options): "topic", feature_group._online_topic_name ).save() - def _encode_complex_features(self, feature_group, dataframe): + def _encode_complex_features( + self, + feature_group: Union[fg_mod.FeatureGroup, fg_mod.ExternalFeatureGroup], + dataframe: Union[RDD, DataFrame], + ): """Encodes all complex type features to binary using their avro type as schema.""" return dataframe.select( [ @@ -524,7 +530,11 @@ def _encode_complex_features(self, feature_group, dataframe): ] ) - def _online_fg_to_avro(self, feature_group, dataframe): + def _online_fg_to_avro( + self, + feature_group: Union[fg_mod.FeatureGroup, fg_mod.ExternalFeatureGroup], + dataframe: Union[DataFrame, RDD], + ): """Packs all features into named struct to be serialized to single avro/binary column. And packs primary key into arry to be serialized for partitioning. """ @@ -976,7 +986,7 @@ def profile( @uses_great_expectations def validate_with_great_expectations( self, - dataframe: TypeVar("pyspark.sql.DataFrame"), # noqa: F821 + dataframe: DataFrame, # noqa: F821 expectation_suite: great_expectations.core.ExpectationSuite, # noqa: F821 ge_validate_kwargs: Optional[dict], ): @@ -1388,24 +1398,6 @@ def cast_columns(df, schema, online=False): df = df.withColumn(_feat, col(_feat).cast(pyspark_schema[_feat])) return df - def _get_kafka_config( - self, feature_store_id: int, write_options: dict = None - ) -> dict: - if write_options is None: - write_options = {} - external = not ( - isinstance(client.get_instance(), hopsworks.Client) - or write_options.get("internal_kafka", False) - ) - - storage_connector = self._storage_connector_api.get_kafka_connector( - feature_store_id, external - ) - - config = storage_connector.spark_options() - config.update(write_options) - return config - @staticmethod def is_connector_type_supported(type): return True diff --git a/python/hsfs/feature_group.py b/python/hsfs/feature_group.py index 605c2950e7..bbd92c2f18 100644 --- a/python/hsfs/feature_group.py +++ b/python/hsfs/feature_group.py @@ -2188,6 +2188,7 @@ def __init__( self._kafka_producer: Optional["confluent_kafka.Producer"] = None self._feature_writers: Optional[Dict[str, callable]] = None self._writer: Optional[callable] = None + self._kafka_headers: Optional[Dict[str, bytes]] = None def read( self, @@ -2907,6 +2908,7 @@ def finalize_multi_part_insert(self) -> None: self._kafka_producer = None self._feature_writers = None self._writer = None + self._kafka_headers = None self._multi_part_insert = False def insert_stream( diff --git a/python/tests/core/test_kafka_engine.py b/python/tests/core/test_kafka_engine.py new file mode 100644 index 0000000000..61c85da7cb --- /dev/null +++ b/python/tests/core/test_kafka_engine.py @@ -0,0 +1,527 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib + +from hsfs import storage_connector +from hsfs.core import constants, kafka_engine + + +if constants.HAS_CONFLUENT_KAFKA: + from confluent_kafka.admin import PartitionMetadata, TopicMetadata + + +class TestKafkaEngine: + def test_kafka_produce(self, mocker): + # Arrange + producer = mocker.Mock() + + # Act + kafka_engine.kafka_produce( + producer=producer, + topic_name="test_topic", + headers={}, + key=None, + encoded_row=None, + acked=None, + debug_kafka=False, + ) + + # Assert + assert producer.produce.call_count == 1 + assert producer.poll.call_count == 1 + + def test_kafka_produce_buffer_error(self, mocker): + # Arrange + mocker.patch("hsfs.client.get_instance") + mock_print = mocker.patch("builtins.print") + + producer = mocker.Mock() + producer.produce.side_effect = [BufferError("test_error"), None] + # Act + kafka_engine.kafka_produce( + producer=producer, + topic_name="test_topic", + headers={}, + key=None, + encoded_row=None, + acked=None, + debug_kafka=True, + ) + + # Assert + assert producer.produce.call_count == 2 + assert producer.poll.call_count == 2 + assert mock_print.call_count == 1 + assert mock_print.call_args[0][0] == "Caught: test_error" + + def test_encode_complex_features(self): + # Arrange + def test_utf(value, bytes_io): + bytes_io.write(bytes(value, "utf-8")) + + # Act + result = kafka_engine.encode_complex_features( + feature_writers={"one": test_utf, "two": test_utf}, + row={"one": "1", "two": "2"}, + ) + + # Assert + assert len(result) == 2 + assert result == {"one": b"1", "two": b"2"} + + def test_get_encoder_func(self, mocker): + # Arrange + mock_json_loads = mocker.patch("json.loads") + mock_avro_schema_parse = mocker.patch("avro.schema.parse") + constants.HAS_AVRO = True + constants.HAS_FAST_AVRO = False + importlib.reload(kafka_engine) + + # Act + result = kafka_engine.get_encoder_func( + writer_schema='{"type" : "record",' + '"namespace" : "Tutorialspoint",' + '"name" : "Employee",' + '"fields" : [{ "name" : "Name" , "type" : "string" },' + '{ "name" : "Age" , "type" : "int" }]}' + ) + + # Assert + assert result is not None + assert mock_json_loads.call_count == 0 + assert mock_avro_schema_parse.call_count == 1 + + def test_get_encoder_func_fast(self, mocker): + # Arrange + mock_avro_schema_parse = mocker.patch("avro.schema.parse") + mock_json_loads = mocker.patch( + "json.loads", + return_value={ + "type": "record", + "namespace": "Tutorialspoint", + "name": "Employee", + "fields": [ + {"name": "Name", "type": "string"}, + {"name": "Age", "type": "int"}, + ], + }, + ) + constants.HAS_AVRO = False + constants.HAS_FAST_AVRO = True + importlib.reload(kafka_engine) + + # Act + result = kafka_engine.get_encoder_func( + writer_schema='{"type" : "record",' + '"namespace" : "Tutorialspoint",' + '"name" : "Employee",' + '"fields" : [{ "name" : "Name" , "type" : "string" },' + '{ "name" : "Age" , "type" : "int" }]}' + ) + + # Assert + assert result is not None + assert mock_json_loads.call_count == 1 + assert mock_avro_schema_parse.call_count == 0 + + def test_get_kafka_config(self, mocker, backend_fixtures): + # Arrange + mocker.patch("hsfs.engine.get_instance") + mock_storage_connector_api = mocker.patch( + "hsfs.core.storage_connector_api.StorageConnectorApi" + ) + + json = backend_fixtures["storage_connector"]["get_kafka"]["response"] + sc = storage_connector.StorageConnector.from_response_json(json) + mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc + + mocker.patch("hsfs.core.kafka_engine.isinstance", return_value=True) + + mock_client = mocker.patch("hsfs.client.get_instance") + mock_client.return_value._write_pem.return_value = ( + "test_ssl_ca_location", + "test_ssl_certificate_location", + "test_ssl_key_location", + ) + + # Act + result = kafka_engine.get_kafka_config( + 1, + write_options={ + "kafka_producer_config": {"test_name_1": "test_value_1"}, + }, + ) + + # Assert + assert ( + mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] + is False + ) + assert result == { + "bootstrap.servers": "test_bootstrap_servers", + "security.protocol": "test_security_protocol", + "ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", + "ssl.ca.location": "test_ssl_ca_location", + "ssl.certificate.location": "test_ssl_certificate_location", + "ssl.key.location": "test_ssl_key_location", + "test_name_1": "test_value_1", + } + + def test_get_kafka_config_external_client(self, mocker, backend_fixtures): + # Arrange + mocker.patch("hsfs.engine.get_instance") + mock_storage_connector_api = mocker.patch( + "hsfs.core.storage_connector_api.StorageConnectorApi" + ) + + json = backend_fixtures["storage_connector"]["get_kafka"]["response"] + sc = storage_connector.StorageConnector.from_response_json(json) + mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc + + mocker.patch("hsfs.engine.python.isinstance", return_value=False) + + mock_client = mocker.patch("hsfs.client.get_instance") + mock_client.return_value._write_pem.return_value = ( + "test_ssl_ca_location", + "test_ssl_certificate_location", + "test_ssl_key_location", + ) + + # Act + result = kafka_engine.get_kafka_config( + 1, + write_options={ + "kafka_producer_config": {"test_name_1": "test_value_1"}, + }, + ) + + # Assert + assert ( + mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] + is True + ) + assert result == { + "bootstrap.servers": "test_bootstrap_servers", + "security.protocol": "test_security_protocol", + "ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", + "ssl.ca.location": "test_ssl_ca_location", + "ssl.certificate.location": "test_ssl_certificate_location", + "ssl.key.location": "test_ssl_key_location", + "test_name_1": "test_value_1", + } + + def test_get_kafka_config_internal_kafka(self, mocker, backend_fixtures): + # Arrange + mocker.patch("hsfs.engine.get_instance") + mock_storage_connector_api = mocker.patch( + "hsfs.core.storage_connector_api.StorageConnectorApi" + ) + + json = backend_fixtures["storage_connector"]["get_kafka"]["response"] + sc = storage_connector.StorageConnector.from_response_json(json) + mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc + + mocker.patch("hsfs.engine.python.isinstance", return_value=True) + + mock_client = mocker.patch("hsfs.client.get_instance") + mock_client.return_value._write_pem.return_value = ( + "test_ssl_ca_location", + "test_ssl_certificate_location", + "test_ssl_key_location", + ) + + # Act + result = kafka_engine.get_kafka_config( + 1, + write_options={ + "kafka_producer_config": {"test_name_1": "test_value_1"}, + "internal_kafka": True, + }, + ) + + # Assert + assert ( + mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] + is False + ) + assert result == { + "bootstrap.servers": "test_bootstrap_servers", + "security.protocol": "test_security_protocol", + "ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", + "ssl.ca.location": "test_ssl_ca_location", + "ssl.certificate.location": "test_ssl_certificate_location", + "ssl.key.location": "test_ssl_key_location", + "test_name_1": "test_value_1", + } + + def test_get_kafka_config_external_client_internal_kafka( + self, mocker, backend_fixtures + ): + # Arrange + mocker.patch("hsfs.engine.get_instance") + mock_storage_connector_api = mocker.patch( + "hsfs.core.storage_connector_api.StorageConnectorApi" + ) + + json = backend_fixtures["storage_connector"]["get_kafka"]["response"] + sc = storage_connector.StorageConnector.from_response_json(json) + mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc + + mocker.patch("hsfs.engine.python.isinstance", return_value=False) + + mock_client = mocker.patch("hsfs.client.get_instance") + mock_client.return_value._write_pem.return_value = ( + "test_ssl_ca_location", + "test_ssl_certificate_location", + "test_ssl_key_location", + ) + + # Act + result = kafka_engine.get_kafka_config( + 1, + write_options={ + "kafka_producer_config": {"test_name_1": "test_value_1"}, + "internal_kafka": True, + }, + ) + + # Assert + assert ( + mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] + is False + ) + assert result == { + "bootstrap.servers": "test_bootstrap_servers", + "security.protocol": "test_security_protocol", + "ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", + "ssl.ca.location": "test_ssl_ca_location", + "ssl.certificate.location": "test_ssl_certificate_location", + "ssl.key.location": "test_ssl_key_location", + "test_name_1": "test_value_1", + } + + def test_kafka_get_offsets_high(self, mocker): + # Arrange + feature_store_id = 99 + topic_name = "test_topic" + partition_metadata = PartitionMetadata() + partition_metadata.id = 0 + topic_metadata = TopicMetadata() + topic_metadata.partitions = {partition_metadata.id: partition_metadata} + topic_mock = mocker.MagicMock() + + # return no topics and one commit, so it should start the job with the extra arg + topic_mock.topics = {topic_name: topic_metadata} + + consumer = mocker.MagicMock() + consumer.list_topics = mocker.MagicMock(return_value=topic_mock) + consumer.get_watermark_offsets = mocker.MagicMock(return_value=(0, 11)) + mocker.patch( + "hsfs.core.kafka_engine.init_kafka_consumer", + return_value=consumer, + ) + + # Act + result = kafka_engine.kafka_get_offsets( + topic_name=topic_name, + feature_store_id=feature_store_id, + offline_write_options={}, + high=True, + ) + + # Assert + assert result == f" -initialCheckPointString {topic_name},0:11" + + def test_kafka_get_offsets_low(self, mocker): + # Arrange + feature_store_id = 99 + topic_name = "test_topic" + partition_metadata = PartitionMetadata() + partition_metadata.id = 0 + topic_metadata = TopicMetadata() + topic_metadata.partitions = {partition_metadata.id: partition_metadata} + topic_mock = mocker.MagicMock() + + # return no topics and one commit, so it should start the job with the extra arg + topic_mock.topics = {topic_name: topic_metadata} + + consumer = mocker.MagicMock() + consumer.list_topics = mocker.MagicMock(return_value=topic_mock) + consumer.get_watermark_offsets = mocker.MagicMock(return_value=(0, 11)) + mocker.patch( + "hsfs.core.kafka_engine.init_kafka_consumer", + return_value=consumer, + ) + + # Act + result = kafka_engine.kafka_get_offsets( + feature_store_id=feature_store_id, + topic_name=topic_name, + offline_write_options={}, + high=False, + ) + + # Assert + assert result == f" -initialCheckPointString {topic_name},0:0" + + def test_kafka_get_offsets_no_topic(self, mocker): + # Arrange + topic_name = "test_topic" + topic_mock = mocker.MagicMock() + + # return no topics and one commit, so it should start the job with the extra arg + topic_mock.topics = {} + + consumer = mocker.MagicMock() + consumer.list_topics = mocker.MagicMock(return_value=topic_mock) + consumer.get_watermark_offsets = mocker.MagicMock(return_value=(0, 11)) + mocker.patch( + "hsfs.core.kafka_engine.init_kafka_consumer", + return_value=consumer, + ) + # Act + result = kafka_engine.kafka_get_offsets( + topic_name=topic_name, + feature_store_id=99, + offline_write_options={}, + high=True, + ) + + # Assert + assert result == "" + + def test_spark_get_kafka_config(self, mocker, backend_fixtures): + # Arrange + mocker.patch("hsfs.client.get_instance") + mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") + mock_engine_get_instance.return_value.add_file.return_value = ( + "result_from_add_file" + ) + mock_storage_connector_api = mocker.patch( + "hsfs.core.storage_connector_api.StorageConnectorApi" + ) + json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"] + sc = storage_connector.StorageConnector.from_response_json(json) + mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc + + mocker.patch("hsfs.core.kafka_engine.isinstance", return_value=True) + + # Act + results = kafka_engine.get_kafka_config( + 1, write_options={"user_opt": "ABC"}, engine="spark" + ) + + # Assert + assert results == { + "kafka.bootstrap.servers": "test_bootstrap_servers", + "kafka.security.protocol": "test_security_protocol", + "kafka.ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", + "kafka.ssl.key.password": "test_ssl_key_password", + "kafka.ssl.keystore.location": "result_from_add_file", + "kafka.ssl.keystore.password": "test_ssl_keystore_password", + "kafka.ssl.truststore.location": "result_from_add_file", + "kafka.ssl.truststore.password": "test_ssl_truststore_password", + "kafka.test_option_name": "test_option_value", + "user_opt": "ABC", + } + assert ( + mock_storage_connector_api.return_value.get_kafka_connector.call_count == 1 + ) + assert ( + mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] + is False + ) + + def test_spark_get_kafka_config_external_client(self, mocker, backend_fixtures): + # Arrange + mocker.patch("hsfs.client.get_instance") + mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") + mock_engine_get_instance.return_value.add_file.return_value = ( + "result_from_add_file" + ) + mock_storage_connector_api = mocker.patch( + "hsfs.core.storage_connector_api.StorageConnectorApi" + ) + json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"] + sc = storage_connector.StorageConnector.from_response_json(json) + mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc + + # Act + results = kafka_engine.get_kafka_config( + 1, write_options={"user_opt": "ABC"}, engine="spark" + ) + + # Assert + assert results == { + "kafka.bootstrap.servers": "test_bootstrap_servers", + "kafka.security.protocol": "test_security_protocol", + "kafka.ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", + "kafka.ssl.key.password": "test_ssl_key_password", + "kafka.ssl.keystore.location": "result_from_add_file", + "kafka.ssl.keystore.password": "test_ssl_keystore_password", + "kafka.ssl.truststore.location": "result_from_add_file", + "kafka.ssl.truststore.password": "test_ssl_truststore_password", + "kafka.test_option_name": "test_option_value", + "user_opt": "ABC", + } + assert ( + mock_storage_connector_api.return_value.get_kafka_connector.call_count == 1 + ) + assert ( + mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] + is True + ) + + def test_spark_get_kafka_config_internal_kafka(self, mocker, backend_fixtures): + # Arrange + mocker.patch("hsfs.client.get_instance") + mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") + mock_engine_get_instance.return_value.add_file.return_value = ( + "result_from_add_file" + ) + mock_storage_connector_api = mocker.patch( + "hsfs.core.storage_connector_api.StorageConnectorApi" + ) + json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"] + sc = storage_connector.StorageConnector.from_response_json(json) + mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc + + # Act + results = kafka_engine.get_kafka_config( + 1, write_options={"user_opt": "ABC", "internal_kafka": True}, engine="spark" + ) + + # Assert + assert results == { + "kafka.bootstrap.servers": "test_bootstrap_servers", + "kafka.security.protocol": "test_security_protocol", + "kafka.ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", + "kafka.ssl.key.password": "test_ssl_key_password", + "kafka.ssl.keystore.location": "result_from_add_file", + "kafka.ssl.keystore.password": "test_ssl_keystore_password", + "kafka.ssl.truststore.location": "result_from_add_file", + "kafka.ssl.truststore.password": "test_ssl_truststore_password", + "kafka.test_option_name": "test_option_value", + "user_opt": "ABC", + "internal_kafka": True, + } + assert ( + mock_storage_connector_api.return_value.get_kafka_connector.call_count == 1 + ) + assert ( + mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] + is False + ) diff --git a/python/tests/engine/test_python.py b/python/tests/engine/test_python.py index a6740b71e6..6ec54beab8 100644 --- a/python/tests/engine/test_python.py +++ b/python/tests/engine/test_python.py @@ -21,7 +21,6 @@ import polars as pl import pyarrow as pa import pytest -from confluent_kafka.admin import PartitionMetadata, TopicMetadata from hsfs import ( feature, feature_group, @@ -2956,6 +2955,8 @@ def test_write_training_dataset_query_td(self, mocker, backend_fixtures): mock_td_api.return_value.compute.return_value = mock_job mocker.patch("hsfs.util.get_job_url") + mocker.patch("hsfs.client.get_instance") + python_engine = python.Engine() fg = feature_group.FeatureGroup.from_response_json( @@ -3304,344 +3305,18 @@ def test_get_unique_values(self): assert 2 in result assert 3 in result - def test_kafka_produce(self, mocker): - # Arrange - mocker.patch("hsfs.client.get_instance") - - python_engine = python.Engine() - - producer = mocker.Mock() - - fg = feature_group.FeatureGroup( - name="test", - version=1, - featurestore_id=99, - primary_key=[], - partition_key=[], - id=10, - stream=False, - ) - fg.feature_store = mocker.Mock() - - # Act - python_engine._kafka_produce( - producer=producer, - feature_group=fg, - key=None, - encoded_row=None, - acked=None, - offline_write_options={}, - ) - - # Assert - assert producer.produce.call_count == 1 - assert producer.poll.call_count == 1 - - def test_kafka_produce_buffer_error(self, mocker): - # Arrange - mocker.patch("hsfs.client.get_instance") - mock_print = mocker.patch("builtins.print") - - python_engine = python.Engine() - - producer = mocker.Mock() - producer.produce.side_effect = [BufferError("test_error"), None] - - fg = feature_group.FeatureGroup( - name="test", - version=1, - featurestore_id=99, - primary_key=[], - partition_key=[], - id=10, - stream=False, - ) - fg.feature_store = mocker.Mock() - - # Act - python_engine._kafka_produce( - producer=producer, - feature_group=fg, - key=None, - encoded_row=None, - acked=None, - offline_write_options={"debug_kafka": True}, - ) - - # Assert - assert producer.produce.call_count == 2 - assert producer.poll.call_count == 2 - assert mock_print.call_count == 1 - assert mock_print.call_args[0][0] == "Caught: test_error" - - def test_encode_complex_features(self): - # Arrange - python_engine = python.Engine() - - def test_utf(value, bytes_io): - bytes_io.write(bytes(value, "utf-8")) - - # Act - result = python_engine._encode_complex_features( - feature_writers={"one": test_utf, "two": test_utf}, - row={"one": "1", "two": "2"}, - ) - - # Assert - assert len(result) == 2 - assert result == {"one": b"1", "two": b"2"} - - def test_get_encoder_func(self, mocker): - # Arrange - mock_json_loads = mocker.patch("json.loads") - mock_avro_schema_parse = mocker.patch("avro.schema.parse") - - python_engine = python.Engine() - python.HAS_FAST = False - - # Act - result = python_engine._get_encoder_func( - writer_schema='{"type" : "record",' - '"namespace" : "Tutorialspoint",' - '"name" : "Employee",' - '"fields" : [{ "name" : "Name" , "type" : "string" },' - '{ "name" : "Age" , "type" : "int" }]}' - ) - - # Assert - assert result is not None - assert mock_json_loads.call_count == 0 - assert mock_avro_schema_parse.call_count == 1 - - def test_get_encoder_func_fast(self, mocker): - # Arrange - mock_json_loads = mocker.patch( - "json.loads", - return_value={ - "type": "record", - "namespace": "Tutorialspoint", - "name": "Employee", - "fields": [ - {"name": "Name", "type": "string"}, - {"name": "Age", "type": "int"}, - ], - }, - ) - mock_avro_schema_parse = mocker.patch("avro.schema.parse") - - python_engine = python.Engine() - python.HAS_FAST = True - - # Act - result = python_engine._get_encoder_func( - writer_schema='{"type" : "record",' - '"namespace" : "Tutorialspoint",' - '"name" : "Employee",' - '"fields" : [{ "name" : "Name" , "type" : "string" },' - '{ "name" : "Age" , "type" : "int" }]}' - ) - - # Assert - assert result is not None - assert mock_json_loads.call_count == 1 - assert mock_avro_schema_parse.call_count == 0 - - def test_get_kafka_config(self, mocker, backend_fixtures): - # Arrange - mocker.patch("hsfs.engine.get_instance") - mock_storage_connector_api = mocker.patch( - "hsfs.core.storage_connector_api.StorageConnectorApi" - ) - - json = backend_fixtures["storage_connector"]["get_kafka"]["response"] - sc = storage_connector.StorageConnector.from_response_json(json) - mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc - - mocker.patch("hsfs.engine.python.isinstance", return_value=True) - - mock_client = mocker.patch("hsfs.client.get_instance") - mock_client.return_value._write_pem.return_value = ( - "test_ssl_ca_location", - "test_ssl_certificate_location", - "test_ssl_key_location", - ) - - python_engine = python.Engine() - - # Act - result = python_engine._get_kafka_config( - 1, - write_options={ - "kafka_producer_config": {"test_name_1": "test_value_1"}, - }, - ) - - # Assert - assert ( - mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] - is False - ) - assert result == { - "bootstrap.servers": "test_bootstrap_servers", - "security.protocol": "test_security_protocol", - "ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", - "ssl.ca.location": "test_ssl_ca_location", - "ssl.certificate.location": "test_ssl_certificate_location", - "ssl.key.location": "test_ssl_key_location", - "test_name_1": "test_value_1", - } - - def test_get_kafka_config_external_client(self, mocker, backend_fixtures): - # Arrange - mocker.patch("hsfs.engine.get_instance") - mock_storage_connector_api = mocker.patch( - "hsfs.core.storage_connector_api.StorageConnectorApi" - ) - - json = backend_fixtures["storage_connector"]["get_kafka"]["response"] - sc = storage_connector.StorageConnector.from_response_json(json) - mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc - - mocker.patch("hsfs.engine.python.isinstance", return_value=False) - - mock_client = mocker.patch("hsfs.client.get_instance") - mock_client.return_value._write_pem.return_value = ( - "test_ssl_ca_location", - "test_ssl_certificate_location", - "test_ssl_key_location", - ) - - python_engine = python.Engine() - - # Act - result = python_engine._get_kafka_config( - 1, - write_options={ - "kafka_producer_config": {"test_name_1": "test_value_1"}, - }, - ) - - # Assert - assert ( - mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] - is True - ) - assert result == { - "bootstrap.servers": "test_bootstrap_servers", - "security.protocol": "test_security_protocol", - "ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", - "ssl.ca.location": "test_ssl_ca_location", - "ssl.certificate.location": "test_ssl_certificate_location", - "ssl.key.location": "test_ssl_key_location", - "test_name_1": "test_value_1", - } - - def test_get_kafka_config_internal_kafka(self, mocker, backend_fixtures): - # Arrange - mocker.patch("hsfs.engine.get_instance") - mock_storage_connector_api = mocker.patch( - "hsfs.core.storage_connector_api.StorageConnectorApi" - ) - - json = backend_fixtures["storage_connector"]["get_kafka"]["response"] - sc = storage_connector.StorageConnector.from_response_json(json) - mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc - - mocker.patch("hsfs.engine.python.isinstance", return_value=True) - - mock_client = mocker.patch("hsfs.client.get_instance") - mock_client.return_value._write_pem.return_value = ( - "test_ssl_ca_location", - "test_ssl_certificate_location", - "test_ssl_key_location", - ) - - python_engine = python.Engine() - - # Act - result = python_engine._get_kafka_config( - 1, - write_options={ - "kafka_producer_config": {"test_name_1": "test_value_1"}, - "internal_kafka": True, - }, - ) - - # Assert - assert ( - mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] - is False - ) - assert result == { - "bootstrap.servers": "test_bootstrap_servers", - "security.protocol": "test_security_protocol", - "ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", - "ssl.ca.location": "test_ssl_ca_location", - "ssl.certificate.location": "test_ssl_certificate_location", - "ssl.key.location": "test_ssl_key_location", - "test_name_1": "test_value_1", - } - - def test_get_kafka_config_external_client_internal_kafka( - self, mocker, backend_fixtures - ): - # Arrange - mocker.patch("hsfs.engine.get_instance") - mock_storage_connector_api = mocker.patch( - "hsfs.core.storage_connector_api.StorageConnectorApi" - ) - - json = backend_fixtures["storage_connector"]["get_kafka"]["response"] - sc = storage_connector.StorageConnector.from_response_json(json) - mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc - - mocker.patch("hsfs.engine.python.isinstance", return_value=False) - - mock_client = mocker.patch("hsfs.client.get_instance") - mock_client.return_value._write_pem.return_value = ( - "test_ssl_ca_location", - "test_ssl_certificate_location", - "test_ssl_key_location", - ) - - python_engine = python.Engine() - - # Act - result = python_engine._get_kafka_config( - 1, - write_options={ - "kafka_producer_config": {"test_name_1": "test_value_1"}, - "internal_kafka": True, - }, - ) - - # Assert - assert ( - mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] - is False - ) - assert result == { - "bootstrap.servers": "test_bootstrap_servers", - "security.protocol": "test_security_protocol", - "ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", - "ssl.ca.location": "test_ssl_ca_location", - "ssl.certificate.location": "test_ssl_certificate_location", - "ssl.key.location": "test_ssl_key_location", - "test_name_1": "test_value_1", - } - def test_materialization_kafka(self, mocker): # Arrange - mocker.patch("hsfs.engine.python.Engine._get_kafka_config", return_value={}) + mocker.patch("hsfs.core.kafka_engine.get_kafka_config", return_value={}) mocker.patch("hsfs.feature_group.FeatureGroup._get_encoded_avro_schema") - mocker.patch("hsfs.engine.python.Engine._get_encoder_func") - mocker.patch("hsfs.engine.python.Engine._encode_complex_features") + mocker.patch("hsfs.core.kafka_engine.get_encoder_func") + mocker.patch("hsfs.core.kafka_engine.encode_complex_features") mock_python_engine_kafka_produce = mocker.patch( - "hsfs.engine.python.Engine._kafka_produce" + "hsfs.core.kafka_engine.kafka_produce" ) mocker.patch("hsfs.util.get_job_url") mocker.patch( - "hsfs.engine.python.Engine._kafka_get_offsets", + "hsfs.core.kafka_engine.kafka_get_offsets", return_value=" tests_offsets", ) mocker.patch( @@ -3649,6 +3324,8 @@ def test_materialization_kafka(self, mocker): return_value=["", ""], ) + mocker.patch("hsfs.client.get_instance") + python_engine = python.Engine() fg = feature_group.FeatureGroup( @@ -3687,16 +3364,16 @@ def test_materialization_kafka(self, mocker): def test_materialization_kafka_first_job_execution(self, mocker): # Arrange - mocker.patch("hsfs.engine.python.Engine._get_kafka_config", return_value={}) + mocker.patch("hsfs.core.kafka_engine.get_kafka_config", return_value={}) mocker.patch("hsfs.feature_group.FeatureGroup._get_encoded_avro_schema") - mocker.patch("hsfs.engine.python.Engine._get_encoder_func") - mocker.patch("hsfs.engine.python.Engine._encode_complex_features") + mocker.patch("hsfs.core.kafka_engine.get_encoder_func") + mocker.patch("hsfs.core.kafka_engine.encode_complex_features") mock_python_engine_kafka_produce = mocker.patch( - "hsfs.engine.python.Engine._kafka_produce" + "hsfs.core.kafka_engine.kafka_produce" ) mocker.patch("hsfs.util.get_job_url") mocker.patch( - "hsfs.engine.python.Engine._kafka_get_offsets", + "hsfs.core.kafka_engine.kafka_get_offsets", return_value=" tests_offsets", ) mocker.patch( @@ -3704,6 +3381,8 @@ def test_materialization_kafka_first_job_execution(self, mocker): return_value=[], ) + mocker.patch("hsfs.client.get_instance") + python_engine = python.Engine() fg = feature_group.FeatureGroup( @@ -3742,19 +3421,21 @@ def test_materialization_kafka_first_job_execution(self, mocker): def test_materialization_kafka_skip_offsets(self, mocker): # Arrange - mocker.patch("hsfs.engine.python.Engine._get_kafka_config", return_value={}) + mocker.patch("hsfs.core.kafka_engine.get_kafka_config", return_value={}) mocker.patch("hsfs.feature_group.FeatureGroup._get_encoded_avro_schema") - mocker.patch("hsfs.engine.python.Engine._get_encoder_func") - mocker.patch("hsfs.engine.python.Engine._encode_complex_features") + mocker.patch("hsfs.core.kafka_engine.get_encoder_func") + mocker.patch("hsfs.core.kafka_engine.encode_complex_features") mock_python_engine_kafka_produce = mocker.patch( - "hsfs.engine.python.Engine._kafka_produce" + "hsfs.core.kafka_engine.kafka_produce" ) mocker.patch("hsfs.util.get_job_url") mocker.patch( - "hsfs.engine.python.Engine._kafka_get_offsets", + "hsfs.core.kafka_engine.kafka_get_offsets", return_value=" tests_offsets", ) + mocker.patch("hsfs.client.get_instance") + python_engine = python.Engine() fg = feature_group.FeatureGroup( @@ -3796,19 +3477,21 @@ def test_materialization_kafka_skip_offsets(self, mocker): def test_materialization_kafka_topic_doesnt_exist(self, mocker): # Arrange - mocker.patch("hsfs.engine.python.Engine._get_kafka_config", return_value={}) + mocker.patch("hsfs.core.kafka_engine.get_kafka_config", return_value={}) mocker.patch("hsfs.feature_group.FeatureGroup._get_encoded_avro_schema") - mocker.patch("hsfs.engine.python.Engine._get_encoder_func") - mocker.patch("hsfs.engine.python.Engine._encode_complex_features") + mocker.patch("hsfs.core.kafka_engine.get_encoder_func") + mocker.patch("hsfs.core.kafka_engine.encode_complex_features") mock_python_engine_kafka_produce = mocker.patch( - "hsfs.engine.python.Engine._kafka_produce" + "hsfs.core.kafka_engine.kafka_produce" ) mocker.patch("hsfs.util.get_job_url") mocker.patch( - "hsfs.engine.python.Engine._kafka_get_offsets", + "hsfs.core.kafka_engine.kafka_get_offsets", side_effect=["", " tests_offsets"], ) + mocker.patch("hsfs.client.get_instance") + python_engine = python.Engine() fg = feature_group.FeatureGroup( @@ -3845,134 +3528,6 @@ def test_materialization_kafka_topic_doesnt_exist(self, mocker): await_termination=False, ) - def test_kafka_get_offsets_high(self, mocker): - # Arrange - topic_name = "test_topic" - partition_metadata = PartitionMetadata() - partition_metadata.id = 0 - topic_metadata = TopicMetadata() - topic_metadata.partitions = {partition_metadata.id: partition_metadata} - topic_mock = mocker.MagicMock() - - # return no topics and one commit, so it should start the job with the extra arg - topic_mock.topics = {topic_name: topic_metadata} - - consumer = mocker.MagicMock() - consumer.list_topics = mocker.MagicMock(return_value=topic_mock) - consumer.get_watermark_offsets = mocker.MagicMock(return_value=(0, 11)) - mocker.patch( - "hsfs.engine.python.Engine._init_kafka_consumer", - return_value=consumer, - ) - - python_engine = python.Engine() - - fg = feature_group.FeatureGroup( - name="test", - version=1, - featurestore_id=99, - primary_key=[], - partition_key=[], - id=10, - stream=False, - time_travel_format="HUDI", - ) - fg._online_topic_name = topic_name - - # Act - result = python_engine._kafka_get_offsets( - feature_group=fg, - offline_write_options={}, - high=True, - ) - - # Assert - assert result == f" -initialCheckPointString {topic_name},0:11" - - def test_kafka_get_offsets_low(self, mocker): - # Arrange - topic_name = "test_topic" - partition_metadata = PartitionMetadata() - partition_metadata.id = 0 - topic_metadata = TopicMetadata() - topic_metadata.partitions = {partition_metadata.id: partition_metadata} - topic_mock = mocker.MagicMock() - - # return no topics and one commit, so it should start the job with the extra arg - topic_mock.topics = {topic_name: topic_metadata} - - consumer = mocker.MagicMock() - consumer.list_topics = mocker.MagicMock(return_value=topic_mock) - consumer.get_watermark_offsets = mocker.MagicMock(return_value=(0, 11)) - mocker.patch( - "hsfs.engine.python.Engine._init_kafka_consumer", - return_value=consumer, - ) - - python_engine = python.Engine() - - fg = feature_group.FeatureGroup( - name="test", - version=1, - featurestore_id=99, - primary_key=[], - partition_key=[], - id=10, - stream=False, - time_travel_format="HUDI", - ) - fg._online_topic_name = topic_name - - # Act - result = python_engine._kafka_get_offsets( - feature_group=fg, - offline_write_options={}, - high=False, - ) - - # Assert - assert result == f" -initialCheckPointString {topic_name},0:0" - - def test_kafka_get_offsets_no_topic(self, mocker): - # Arrange - topic_name = "test_topic" - topic_mock = mocker.MagicMock() - - # return no topics and one commit, so it should start the job with the extra arg - topic_mock.topics = {} - - consumer = mocker.MagicMock() - consumer.list_topics = mocker.MagicMock(return_value=topic_mock) - consumer.get_watermark_offsets = mocker.MagicMock(return_value=(0, 11)) - mocker.patch( - "hsfs.engine.python.Engine._init_kafka_consumer", - return_value=consumer, - ) - - python_engine = python.Engine() - - fg = feature_group.FeatureGroup( - name="test", - version=1, - featurestore_id=99, - primary_key=[], - partition_key=[], - id=10, - stream=False, - time_travel_format="HUDI", - ) - fg._online_topic_name = topic_name - - # Act - result = python_engine._kafka_get_offsets( - feature_group=fg, - offline_write_options={}, - high=True, - ) - - # Assert - assert result == "" - def test_test(self, mocker): fg = feature_group.FeatureGroup( name="test", diff --git a/python/tests/engine/test_python_writer.py b/python/tests/engine/test_python_writer.py index 0dd0533156..2c6a8fd3d1 100644 --- a/python/tests/engine/test_python_writer.py +++ b/python/tests/engine/test_python_writer.py @@ -27,7 +27,8 @@ class TestPythonWriter: def test_write_dataframe_kafka(self, mocker, dataframe_fixture_times): # Arrange - mocker.patch("hsfs.engine.python.Engine._get_kafka_config", return_value={}) + mocker.patch("hsfs.client.get_instance") + mocker.patch("hsfs.core.kafka_engine.get_kafka_config", return_value={}) avro_schema_mock = mocker.patch( "hsfs.feature_group.FeatureGroup._get_encoded_avro_schema" ) @@ -45,7 +46,7 @@ def test_write_dataframe_kafka(self, mocker, dataframe_fixture_times): ) avro_schema_mock.side_effect = [avro_schema] mock_python_engine_kafka_produce = mocker.patch( - "hsfs.engine.python.Engine._kafka_produce" + "hsfs.core.kafka_engine.kafka_produce" ) mocker.patch("hsfs.core.job_api.JobApi") # get, launch mocker.patch("hsfs.util.get_job_url") @@ -55,11 +56,7 @@ def test_write_dataframe_kafka(self, mocker, dataframe_fixture_times): topic_mock.topics = {topic_name: topic_metadata} consumer = mocker.MagicMock() consumer.list_topics = mocker.MagicMock(return_value=topic_mock) - mocker.patch( - "hsfs.engine.python.Consumer", - return_value=consumer, - ) - mocker.patch("hsfs.engine.python.Producer") + mocker.patch("hsfs.core.kafka_engine.Consumer", return_value=consumer) python_engine = python.Engine() fg = feature_group.FeatureGroup( @@ -83,7 +80,8 @@ def test_write_dataframe_kafka(self, mocker, dataframe_fixture_times): ) # Assert - encoded_row = mock_python_engine_kafka_produce.call_args[0][3] + print(mock_python_engine_kafka_produce.call_args) + encoded_row = mock_python_engine_kafka_produce.call_args[1]["encoded_row"] print("Value" + str(encoded_row)) parsed_schema = fastavro.parse_schema(json.loads(avro_schema)) with BytesIO() as outf: diff --git a/python/tests/engine/test_spark.py b/python/tests/engine/test_spark.py index 8b98f521c6..5e7699b534 100644 --- a/python/tests/engine/test_spark.py +++ b/python/tests/engine/test_spark.py @@ -4487,187 +4487,3 @@ def test_create_empty_df(self): # Assert assert result.schema == spark_df.schema assert result.collect() == [] - - def test_get_kafka_config(self, mocker, backend_fixtures): - # Arrange - mocker.patch("hsfs.engine.get_type") - mocker.patch("hsfs.client.get_instance") - mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") - mock_engine_get_instance.return_value.add_file.return_value = ( - "result_from_add_file" - ) - - mocker.patch("hsfs.engine.spark.isinstance", return_value=True) - - mock_storage_connector_api = mocker.patch( - "hsfs.core.storage_connector_api.StorageConnectorApi" - ) - json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"] - sc = storage_connector.StorageConnector.from_response_json(json) - mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc - - spark_engine = spark.Engine() - - # Act - results = spark_engine._get_kafka_config(1, write_options={"user_opt": "ABC"}) - - # Assert - assert results == { - "kafka.bootstrap.servers": "test_bootstrap_servers", - "kafka.security.protocol": "test_security_protocol", - "kafka.ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", - "kafka.ssl.key.password": "test_ssl_key_password", - "kafka.ssl.keystore.location": "result_from_add_file", - "kafka.ssl.keystore.password": "test_ssl_keystore_password", - "kafka.ssl.truststore.location": "result_from_add_file", - "kafka.ssl.truststore.password": "test_ssl_truststore_password", - "kafka.test_option_name": "test_option_value", - "user_opt": "ABC", - } - assert ( - mock_storage_connector_api.return_value.get_kafka_connector.call_count == 1 - ) - assert ( - mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] - is False - ) - - def test_get_kafka_config_external_client(self, mocker, backend_fixtures): - # Arrange - mocker.patch("hsfs.engine.get_type") - mocker.patch("hsfs.client.get_instance") - mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") - mock_engine_get_instance.return_value.add_file.return_value = ( - "result_from_add_file" - ) - - mocker.patch("hsfs.engine.spark.isinstance", return_value=False) - - mock_storage_connector_api = mocker.patch( - "hsfs.core.storage_connector_api.StorageConnectorApi" - ) - json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"] - sc = storage_connector.StorageConnector.from_response_json(json) - mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc - - spark_engine = spark.Engine() - - # Act - results = spark_engine._get_kafka_config(1, write_options={"user_opt": "ABC"}) - - # Assert - assert results == { - "kafka.bootstrap.servers": "test_bootstrap_servers", - "kafka.security.protocol": "test_security_protocol", - "kafka.ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", - "kafka.ssl.key.password": "test_ssl_key_password", - "kafka.ssl.keystore.location": "result_from_add_file", - "kafka.ssl.keystore.password": "test_ssl_keystore_password", - "kafka.ssl.truststore.location": "result_from_add_file", - "kafka.ssl.truststore.password": "test_ssl_truststore_password", - "kafka.test_option_name": "test_option_value", - "user_opt": "ABC", - } - assert ( - mock_storage_connector_api.return_value.get_kafka_connector.call_count == 1 - ) - assert ( - mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] - is True - ) - - def test_get_kafka_config_internal_kafka(self, mocker, backend_fixtures): - # Arrange - mocker.patch("hsfs.engine.get_type") - mocker.patch("hsfs.client.get_instance") - mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") - mock_engine_get_instance.return_value.add_file.return_value = ( - "result_from_add_file" - ) - - mocker.patch("hsfs.engine.spark.isinstance", return_value=True) - - mock_storage_connector_api = mocker.patch( - "hsfs.core.storage_connector_api.StorageConnectorApi" - ) - json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"] - sc = storage_connector.StorageConnector.from_response_json(json) - mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc - - spark_engine = spark.Engine() - - # Act - results = spark_engine._get_kafka_config( - 1, write_options={"user_opt": "ABC", "internal_kafka": True} - ) - - # Assert - assert results == { - "kafka.bootstrap.servers": "test_bootstrap_servers", - "kafka.security.protocol": "test_security_protocol", - "kafka.ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", - "kafka.ssl.key.password": "test_ssl_key_password", - "kafka.ssl.keystore.location": "result_from_add_file", - "kafka.ssl.keystore.password": "test_ssl_keystore_password", - "kafka.ssl.truststore.location": "result_from_add_file", - "kafka.ssl.truststore.password": "test_ssl_truststore_password", - "kafka.test_option_name": "test_option_value", - "user_opt": "ABC", - "internal_kafka": True, - } - assert ( - mock_storage_connector_api.return_value.get_kafka_connector.call_count == 1 - ) - assert ( - mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] - is False - ) - - def test_get_kafka_config_external_client_internal_kafka( - self, mocker, backend_fixtures - ): - # Arrange - mocker.patch("hsfs.engine.get_type") - mocker.patch("hsfs.client.get_instance") - mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") - mock_engine_get_instance.return_value.add_file.return_value = ( - "result_from_add_file" - ) - - mocker.patch("hsfs.engine.spark.isinstance", return_value=False) - - mock_storage_connector_api = mocker.patch( - "hsfs.core.storage_connector_api.StorageConnectorApi" - ) - json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"] - sc = storage_connector.StorageConnector.from_response_json(json) - mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc - - spark_engine = spark.Engine() - - # Act - results = spark_engine._get_kafka_config( - 1, write_options={"user_opt": "ABC", "internal_kafka": True} - ) - - # Assert - assert results == { - "kafka.bootstrap.servers": "test_bootstrap_servers", - "kafka.security.protocol": "test_security_protocol", - "kafka.ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", - "kafka.ssl.key.password": "test_ssl_key_password", - "kafka.ssl.keystore.location": "result_from_add_file", - "kafka.ssl.keystore.password": "test_ssl_keystore_password", - "kafka.ssl.truststore.location": "result_from_add_file", - "kafka.ssl.truststore.password": "test_ssl_truststore_password", - "kafka.test_option_name": "test_option_value", - "user_opt": "ABC", - "internal_kafka": True, - } - assert ( - mock_storage_connector_api.return_value.get_kafka_connector.call_count == 1 - ) - assert ( - mock_storage_connector_api.return_value.get_kafka_connector.call_args[0][1] - is False - ) diff --git a/python/tests/test_feature_group_writer.py b/python/tests/test_feature_group_writer.py index f4d980a99d..861022e5c9 100644 --- a/python/tests/test_feature_group_writer.py +++ b/python/tests/test_feature_group_writer.py @@ -56,11 +56,16 @@ def test_fg_writer_cache_management(self, mocker, dataframe_fixture_basic): mocker.MagicMock(), mocker.MagicMock(), ) + headers = { + "projectId": str(99).encode("utf8"), + "featureGroupId": str(32).encode("utf8"), + "subjectId": str(12).encode("utf8"), + } mock_init_kafka_resources = mocker.patch( - "hsfs.engine.python.Engine._init_kafka_resources", - return_value=(producer, feature_writers, writer_m), + "hsfs.core.kafka_engine._init_kafka_resources", + return_value=(producer, headers, feature_writers, writer_m), ) - mocker.patch("hsfs.engine.python.Engine._encode_complex_features") + mocker.patch("hsfs.core.kafka_engine.encode_complex_features") mocker.patch("hsfs.core.job.Job") mocker.patch("hsfs.engine.get_type", return_value="python") @@ -86,6 +91,7 @@ def test_fg_writer_cache_management(self, mocker, dataframe_fixture_basic): assert writer._feature_group._kafka_producer == producer assert writer._feature_group._feature_writers == feature_writers assert writer._feature_group._writer == writer_m + assert writer._feature_group._kafka_headers == headers writer.insert(dataframe_fixture_basic) # after second insert should have been called only once @@ -96,6 +102,7 @@ def test_fg_writer_cache_management(self, mocker, dataframe_fixture_basic): assert fg._multi_part_insert is False assert fg._kafka_producer is None assert fg._feature_writers is None + assert fg._kafka_headers is None assert fg._writer is None def test_fg_writer_without_context_manager(self, mocker, dataframe_fixture_basic): @@ -107,11 +114,16 @@ def test_fg_writer_without_context_manager(self, mocker, dataframe_fixture_basic mocker.MagicMock(), mocker.MagicMock(), ) + headers = { + "projectId": str(99).encode("utf8"), + "featureGroupId": str(32).encode("utf8"), + "subjectId": str(12).encode("utf8"), + } mock_init_kafka_resources = mocker.patch( - "hsfs.engine.python.Engine._init_kafka_resources", - return_value=(producer, feature_writers, writer_m), + "hsfs.core.kafka_engine._init_kafka_resources", + return_value=(producer, headers, feature_writers, writer_m), ) - mocker.patch("hsfs.engine.python.Engine._encode_complex_features") + mocker.patch("hsfs.core.kafka_engine.encode_complex_features") mocker.patch("hsfs.core.job.Job") mocker.patch("hsfs.engine.get_type", return_value="python") @@ -133,6 +145,7 @@ def test_fg_writer_without_context_manager(self, mocker, dataframe_fixture_basic assert fg._multi_part_insert is True assert fg._kafka_producer == producer assert fg._feature_writers == feature_writers + assert fg._kafka_headers == headers assert fg._writer == writer_m fg.multi_part_insert(dataframe_fixture_basic) @@ -145,4 +158,5 @@ def test_fg_writer_without_context_manager(self, mocker, dataframe_fixture_basic assert fg._multi_part_insert is False assert fg._kafka_producer is None assert fg._feature_writers is None + assert fg._kafka_headers is None assert fg._writer is None