Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mlflow fix #406

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions numalogic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

import numpy.typing as npt
import pytorch_lightning as pl
from sklearn.base import TransformerMixin, OutlierMixin
from sklearn.base import TransformerMixin, OutlierMixin, BaseEstimator


class BaseTransformer(TransformerMixin):
class BaseTransformer(BaseEstimator, TransformerMixin):
"""Base class for all transformer classes."""

pass
Expand Down Expand Up @@ -47,7 +47,7 @@ class TorchModel(pl.LightningModule, metaclass=ABCMeta):
pass


class BaseThresholdModel(OutlierMixin):
class BaseThresholdModel(BaseEstimator, OutlierMixin):
"""Base class for all threshold models."""

pass
1 change: 1 addition & 0 deletions numalogic/udfs/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class TrainerPayload(_BasePayload):

metrics: list[str]
header: Header = Header.TRAIN_REQUEST
force_train_req: bool = False

def to_json(self):
return orjson.dumps(self)
Expand Down
2 changes: 1 addition & 1 deletion numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
metrics=payload.metrics,
)
# Send training request if inference fails
msgs = Messages(get_trainer_message(keys, _stream_conf, payload))
msgs = Messages(get_trainer_message(keys, _stream_conf, payload, _force_train=True))
if _conf.numalogic_conf.score.adjust:
msgs.append(get_static_thresh_message(keys, payload))
return msgs
Expand Down
2 changes: 1 addition & 1 deletion numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload_metrics=payload.metrics,
)
# Send training request if postprocess fails
msgs = Messages(get_trainer_message(keys, _stream_conf, payload))
msgs = Messages(get_trainer_message(keys, _stream_conf, payload, _force_train=True))
if _conf.numalogic_conf.score.adjust:
msgs.append(get_static_thresh_message(keys, payload))
return msgs
Expand Down
8 changes: 7 additions & 1 deletion numalogic/udfs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
status=Status.RUNTIME_ERROR,
)
msgs = Messages(
get_trainer_message(keys, _stream_conf, payload, **_metric_label_values),
get_trainer_message(
keys=keys,
stream_conf=_stream_conf,
payload=payload,
_force_train=True,
**_metric_label_values
),
)
if _conf.numalogic_conf.score.adjust:
msgs.append(get_static_thresh_message(keys, payload))
Expand Down
9 changes: 8 additions & 1 deletion numalogic/udfs/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pandas import DataFrame
from pynumaflow.mapper import Message


from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.tools.exceptions import RedisRegistryError
from numalogic.tools.types import KEYS, redis_client_t
Expand Down Expand Up @@ -285,6 +284,7 @@ def ack_read(
retry: int = 600,
min_train_records: int = 180,
data_freq: int = 60,
_force_train: bool = False,
) -> bool:
"""
Acknowledge the read message. Return True when the msg has to be trained.
Expand All @@ -295,6 +295,7 @@ def ack_read(
retry: Time difference(in secs) between triggering retraining and msg read_ack.
min_train_records: minimum number of records required for training.
data_freq: data granularity/frequency in secs.
_force_train: force training for the key.

Returns
-------
Expand Down Expand Up @@ -332,6 +333,10 @@ def ack_read(
logger.debug("Model with key is being trained by another process")
return False

if _force_train:
logger.debug("Forcing training for the key")
return True

# This check is needed if there is backpressure in the pipeline
if _msg_train_ts and time.time() - float(_msg_train_ts) < retrain_freq * 60 * 60:
logger.debug(
Expand Down Expand Up @@ -374,6 +379,7 @@ def get_trainer_message(
keys: list[str],
stream_conf: StreamConf,
payload: StreamPayload,
_force_train: bool = False,
**metric_values: dict,
) -> Message:
"""
Expand All @@ -397,6 +403,7 @@ def get_trainer_message(
metrics=payload.metrics,
config_id=payload.config_id,
pipeline_id=payload.pipeline_id,
force_train_req=_force_train,
)
if metric_values:
_increment_counter(
Expand Down
1 change: 1 addition & 0 deletions numalogic/udfs/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
retry=retry_ts,
min_train_records=_conf.numalogic_conf.trainer.min_train_size,
data_freq=_conf.numalogic_conf.trainer.data_freq_sec,
_force_train=payload.force_train_req,
):
_increment_counter(
counter="MSG_DROPPED_COUNTER",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.12.4"
version = "0.13.0"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
Loading