From a38ed1706ea84bee030fbc6c904bf3d606a4332c Mon Sep 17 00:00:00 2001 From: Charlie Meyers Date: Tue, 13 Aug 2024 18:34:13 +0000 Subject: [PATCH] fixed hashing bug --- deckard/base/model/sklearn_pipeline.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/deckard/base/model/sklearn_pipeline.py b/deckard/base/model/sklearn_pipeline.py index 6d436a9d..f303bae9 100644 --- a/deckard/base/model/sklearn_pipeline.py +++ b/deckard/base/model/sklearn_pipeline.py @@ -23,7 +23,7 @@ ) -from ..utils import Hashable +from ..utils import my_hash __all__ = ["SklearnModelPipelineStage", "SklearnModelPipeline"] logger = logging.getLogger(__name__) @@ -140,7 +140,9 @@ def __len__(self): def __iter__(self): return iter(self.pipeline) - + def __hash__(self): + return int(my_hash(self), 16) + def __call__(self, model): params = deepcopy(asdict(self)) pipeline = params.pop("pipeline") @@ -206,6 +208,9 @@ class SklearnModelInitializer(Hashable): pipeline: SklearnModelPipeline = field(default_factory=None) kwargs: Union[dict, None] = field(default_factory=dict) + def __hash__(self): + return int(my_hash(self), 16) + def __init__(self, data, model=None, library="sklearn", pipeline={}, **kwargs): self.data = data self.model = model