diff --git a/lib/tagnews/senteval/eval.py b/lib/tagnews/senteval/eval.py index 2d36790..f558c8c 100644 --- a/lib/tagnews/senteval/eval.py +++ b/lib/tagnews/senteval/eval.py @@ -63,9 +63,9 @@ def is_police_entity(self, entity): return entity return False - def extract_google_priority_bin(self, article:str, cpd_model_val=1): + def extract_google_priority_bin(self, article:str, cpd_model_val=1, cpd_val=1): cop_word_counts = sum([article.count(substr) for substr in self.police_words]) - score = 0.5 * cpd_model_val + 0.5 * min(cop_word_counts / (2 * len(self.police_words)), 1.) + score = 0.5 * cpd_val + 0.25 * cpd_model_val + 0.25 * min(cop_word_counts / (2 * len(self.police_words)), 1.) bin = [bin for bin, bin_max_val in enumerate(bins) if bin_max_val > score][0] return bin