From 06e145c8da4d5cc5e5f5d9b128c1f2d597d316ec Mon Sep 17 00:00:00 2001 From: itlubber <1830611168@qq.com> Date: Wed, 18 Sep 2024 15:35:52 +0800 Subject: [PATCH] fix lift selector --- requirements.txt | 1 + scorecardpipeline/feature_selection.py | 6 ++++++ scorecardpipeline/utils.py | 29 ++++++++++++++++++++++---- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index 06e6f25..b3bf6fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ sweetviz numexpr cvxpy>=1.4.1 protobuf>=5.27.0 +dill diff --git a/scorecardpipeline/feature_selection.py b/scorecardpipeline/feature_selection.py index c76a95f..3c6ecaa 100644 --- a/scorecardpipeline/feature_selection.py +++ b/scorecardpipeline/feature_selection.py @@ -357,6 +357,7 @@ def __init__(self, target="target", threshold=3.0, n_jobs=None, methods=None, co :param target: target :param threshold: float or str (default=3.0). Feature which has a lift score greater than `threshold` will be kept. :param n_jobs: int or None, (default=None). Number of parallel. + :param combiner: Combiner :param methods: Combiner's methods """ super().__init__() @@ -387,6 +388,11 @@ def fit(self, x: pd.DataFrame, y=None, **fit_params): else: xt = x.copy() + # _lift = {} + # for c in tqdm(xt.columns): + # _lift[c] = LIFT(xt[c], y) + # self.scores_ = pd.Series(_lift) + self.scores_ = pd.Series(Parallel(n_jobs=self.n_jobs)(delayed(LIFT)(xt[c], y) for c in xt.columns), index=xt.columns) self.threshold = _calculate_threshold(self, self.scores_, self.threshold) self.select_columns = list(set((self.scores_[self.scores_ >= self.threshold]).index.tolist() + [self.target])) diff --git a/scorecardpipeline/utils.py b/scorecardpipeline/utils.py index 40b2112..649d165 100644 --- a/scorecardpipeline/utils.py +++ b/scorecardpipeline/utils.py @@ -11,6 +11,7 @@ import os import re import six +import pickle import random import joblib import warnings @@ -93,7 +94,7 @@ def init_setting(font_path=None, seed=None, freeze_torch=False, logger=False, ** return init_logger(**kwargs) -def load_pickle(file): +def load_pickle(file, engine="joblib"): """ 导入 pickle 文件 @@ -101,17 +102,37 @@ def load_pickle(file): :return: pickle 文件的内容 """ - return joblib.load(file) + if engine == "joblib": + return joblib.load(file) + elif engine == "dill": + import dill + with open(file, "rb") as f: + return dill.load(f) + elif engine == "pickle": + with open(file, "rb") as f: + return pickle.load(f) + else: + raise ValueError(f"engine 目前只支持 [joblib, dill, pickle], 不支持 {engine}") -def save_pickle(obj, file): +def save_pickle(obj, file, engine="joblib"): """ 保持数据至 pickle 文件 :param obj: 需要保存的数据 :param file: 文件路径 """ - joblib.dump(obj, file) + if engine == "joblib": + return joblib.dump(obj, file) + elif engine == "dill": + import dill + with open(file, "wb") as f: + return dill.dump(obj, f) + elif engine == "pickle": + with open(file, "wb") as f: + return pickle.dump(obj, f) + else: + raise ValueError(f"engine 目前只支持 [joblib, dill, pickle], 不支持 {engine}") def germancredit():