diff --git a/quantulum3/classifier.py b/quantulum3/classifier.py index 80a41ab..bb5be1a 100644 --- a/quantulum3/classifier.py +++ b/quantulum3/classifier.py @@ -6,9 +6,10 @@ # Standard library import json import logging -import pkg_resources -import os import multiprocessing +import os + +import pkg_resources # Semi-dependencies try: @@ -27,9 +28,9 @@ wikipedia = None # Quantulum -from . import load +from . import language, load from .load import cached -from . import language + _LOGGER = logging.getLogger(__name__) @@ -109,7 +110,6 @@ def _clean_text_lang(lang): def train_classifier( parameters=None, ngram_range=(1, 1), store=True, lang="en_US", n_jobs=None ): - """ Train the intent classifier TODO auto invoke if sklearn version is new or first install or sth @@ -245,7 +245,35 @@ def disambiguate_unit(unit, text, lang="en_US"): """ Resolve ambiguity between units with same names, symbols or abbreviations. """ + new_unit = disambiguate_unit_by_score(unit, text, lang) + if len(new_unit) == 1: + return next(iter(new_unit)) + + try: + # Instead of picking a random one now, we first change the + # capitalization of the unit and see if we can improve. + unit_changed = unit[:-1] + unit[-1].swapcase() + text_changed = text.replace(unit, unit_changed) + + new_unit_changed = disambiguate_unit_by_score(unit_changed, text_changed, lang) + if len(new_unit_changed) == 1: + return next(iter(new_unit_changed)) + + if 0 < len(new_unit_changed) < len(new_unit): + # See if we have improved, otherwise we stick with the old new_unit. + new_unit = new_unit_changed + + except KeyError: + pass # Attempt failed, we just pick a random from new_unit now. + + _LOGGER.warning( + "Could not resolve ambiguous units: '{}'. For unit '{}' in text '{}'. " + "Taking a random.".format(", ".join(str(u) for u in new_unit), unit, text) + ) + return next(iter(new_unit)) + +def disambiguate_unit_by_score(unit, text, lang): new_unit = ( load.units(lang).symbols.get(unit) or load.units(lang).surfaces.get(unit) @@ -254,25 +282,25 @@ def disambiguate_unit(unit, text, lang="en_US"): ) if not new_unit: raise KeyError('Could not find unit "%s" from "%s"' % (unit, text)) + if len(new_unit) == 1: + return new_unit - if len(new_unit) > 1: - transformed = classifier(lang).tfidf_model.transform([clean_text(text, lang)]) - scores = classifier(lang).classifier.predict_proba(transformed).tolist()[0] - scores = zip(scores, classifier(lang).target_names) - - # Filter for possible names - names = [i.name for i in new_unit] - scores = [i for i in scores if i[1] in names] - - # Sort by rank - scores = sorted(scores, key=lambda x: x[0], reverse=True) - try: - final = load.units(lang).names[scores[0][1]] - _LOGGER.debug('\tAmbiguity resolved for "%s" (%s)' % (unit, scores)) - except IndexError: - _LOGGER.debug('\tAmbiguity not resolved for "%s"' % unit) - final = next(iter(new_unit)) - else: - final = next(iter(new_unit)) - - return final + # Start scoring + transformed = classifier(lang).tfidf_model.transform( + [clean_text(text, lang)] + ) + scores = classifier(lang).classifier.predict_proba(transformed).tolist()[0] + scores = zip(scores, classifier(lang).target_names) + + # Filter for possible names + names = [i.name for i in new_unit] + scores = [i for i in scores if i[1] in names] + + # Sort by rank + scores = sorted(scores, key=lambda x: x[0], reverse=True) + try: + return [load.units(lang).names[scores[0][1]]] + _LOGGER.debug('\tAmbiguity resolved for "%s" (%s)' % (unit, scores)) + except IndexError: + _LOGGER.debug('\tAmbiguity not resolved for "%s"' % unit) + return new_unit