diff --git a/adatest/_topic_model.py b/adatest/_topic_model.py index 752280d..e46f242 100644 --- a/adatest/_topic_model.py +++ b/adatest/_topic_model.py @@ -19,7 +19,8 @@ def predict_prob(self, embeddings): class CVModel(): def __init__(self, embeddings, labels): - self.inner_model = RidgeClassifierCV(class_weight={"pass": 1, "fail": 1}) + class_weight = {label: 1 for label in labels} + self.inner_model = RidgeClassifierCV(class_weight=class_weight) self.inner_model.fit(embeddings, labels) def predict_prob(self, embeddings): @@ -157,8 +158,7 @@ def __init__(self, topic, test_tree): else: # we are in a highly overparametrized situation, so we use a linear SVC to get "max-margin" based generalization - self.model = CVModel() - self.model.fit(embeddings, labels) + self.model = CVModel(embeddings, labels) def __call__(self, input): embeddings = adatest.embed([input])[0]