From 0e1b354e4c041d4a2f44fbd10c2b70b5abcc563d Mon Sep 17 00:00:00 2001 From: Alexey Taymanov Date: Thu, 16 Nov 2023 10:50:13 -0500 Subject: [PATCH 1/4] unified role --- src/learn_to_pick/base.py | 53 ++++++++++++---------------------- src/learn_to_pick/pick_best.py | 4 +-- 2 files changed, 20 insertions(+), 37 deletions(-) diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index 5430c39..2abfaf6 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -20,6 +20,7 @@ from learn_to_pick.model_repository import ModelRepository from learn_to_pick.vw_logger import VwLogger from learn_to_pick.features import Featurized, DenseFeatures, SparseFeatures +from enum import Enum if TYPE_CHECKING: import vowpal_wabbit_next as vw @@ -27,34 +28,30 @@ logger = logging.getLogger(__name__) -class _BasedOn: - def __init__(self, value: Any): - self.value = value - - def __str__(self) -> str: - return str(self.value) - - __repr__ = __str__ - +class Role(Enum): + CONTEXT = 1 + ACTIONS = 2 -def BasedOn(anything: Any) -> _BasedOn: - return _BasedOn(anything) - -class _ToSelectFrom: - def __init__(self, value: Any): +class _Roled: + def __init__(self, value: Any, role: Role): self.value = value + self.role = role def __str__(self) -> str: return str(self.value) - + __repr__ = __str__ -def ToSelectFrom(anything: Any) -> _ToSelectFrom: +def BasedOn(anything: Any) -> _Roled: + return _Roled(anything, Role.CONTEXT) + + +def ToSelectFrom(anything: Any) -> _Roled: if not isinstance(anything, list): raise ValueError("ToSelectFrom must be a list to select from") - return _ToSelectFrom(anything) + return _Roled(anything, Role.ACTIONS) class _Embed: @@ -69,10 +66,8 @@ def __str__(self) -> str: def Embed(anything: Any, keep: bool = False) -> Any: - if isinstance(anything, _ToSelectFrom): - return ToSelectFrom(Embed(anything.value, keep=keep)) - elif isinstance(anything, _BasedOn): - return BasedOn(Embed(anything.value, keep=keep)) + if isinstance(anything, _Roled): + return _Roled(Embed(anything.value, keep=keep), anything.role) if isinstance(anything, list): return [Embed(v, keep=keep) for v in anything] elif isinstance(anything, dict): @@ -93,20 +88,8 @@ def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Exam return [parser.parse_line(line) for line in input_str.split("\n")] -def get_based_on(inputs: Dict[str, Any]) -> Dict: - return { - k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value - for k in inputs.keys() - if isinstance(inputs[k], _BasedOn) - } - - -def get_to_select_from(inputs: Dict[str, Any]) -> Dict: - return { - k: inputs[k].value - for k in inputs.keys() - if isinstance(inputs[k], _ToSelectFrom) - } +def filter_inputs(inputs: Dict[str, Any], role: Role) -> Dict[str, Any]: + return {k: v.value for k, v in inputs.items() if isinstance(v, _Roled) and v.role == role} # end helper functions diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index c85eefa..f51aedc 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -38,8 +38,8 @@ def __init__( selected: Optional[PickBestSelected] = None, ): super().__init__(inputs=inputs, selected=selected or PickBestSelected()) - self.to_select_from = base.get_to_select_from(inputs) - self.based_on = base.get_based_on(inputs) + self.to_select_from = base.filter_inputs(inputs, base.Role.ACTIONS) + self.based_on = base.filter_inputs(inputs, base.Role.CONTEXT) if not self.to_select_from: raise ValueError( "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." From 60c40b4c1ea8efe8e6b82a7733285e17287d7c7c Mon Sep 17 00:00:00 2001 From: Alexey Taymanov Date: Thu, 16 Nov 2023 11:11:20 -0500 Subject: [PATCH 2/4] _Embed -> _Featurize --- src/learn_to_pick/base.py | 40 +++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index 2abfaf6..46c0e8f 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -53,29 +53,32 @@ def ToSelectFrom(anything: Any) -> _Roled: raise ValueError("ToSelectFrom must be a list to select from") return _Roled(anything, Role.ACTIONS) - -class _Embed: - def __init__(self, value: Any, keep: bool = False): +class _Featurize: + def __init__(self, value: Any, keep: bool = True, embed: bool = False): self.value = value self.keep = keep + self.embed = embed def __str__(self) -> str: return str(self.value) + + @staticmethod + def create(value: Any, *args, **kwargs): + if isinstance(value, _Roled): + return _Roled(_Featurize.create(value.value, *args, **kwargs), value.role) + if isinstance(value, list): + return [_Featurize.create(v, *args, **kwargs) for v in value] + if isinstance(value, dict): + return {k: _Featurize.create(v, *args, **kwargs) for k, v in value.items()} + if isinstance(value, _Featurize): # should we swap? it will allow overwriting + return value + return _Featurize(value, *args, **kwargs) __repr__ = __str__ def Embed(anything: Any, keep: bool = False) -> Any: - if isinstance(anything, _Roled): - return _Roled(Embed(anything.value, keep=keep), anything.role) - if isinstance(anything, list): - return [Embed(v, keep=keep) for v in anything] - elif isinstance(anything, dict): - return {k: Embed(v, keep=keep) for k, v in anything.items()} - elif isinstance(anything, _Embed): - return anything - return _Embed(anything, keep=keep) - + return _Featurize.create(anything, keep=keep, embed=True) def EmbedAndKeep(anything: Any) -> Any: return Embed(anything, keep=True) @@ -463,14 +466,15 @@ def run(self, *args, **kwargs) -> Dict[str, Any]: def _embed_string_type( - item: Union[str, _Embed], model: Any, namespace: str + item: Union[str, _Featurize], model: Any, namespace: str ) -> Featurized: """Helper function to embed a string or an _Embed object.""" import re result = Featurized() - if isinstance(item, _Embed): - result[namespace] = DenseFeatures(model.encode(item.value)) + if isinstance(item, _Featurize): + if item.embed: + result[namespace] = DenseFeatures(model.encode(item.value)) if item.keep: keep_str = item.value.replace(" ", "_") result[namespace] = {"default_ft": re.sub(r"[\t\n\r\f\v]+", " ", keep_str)} @@ -512,7 +516,7 @@ def _embed_list_type( def embed( - to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]], + to_embed: Union[Union[str, _Featurize], Dict, List[Union[str, _Featurize]], List[Dict]], model: Any, namespace: Optional[str] = None, ) -> Union[Featurized, List[Featurized]]: @@ -526,7 +530,7 @@ def embed( Returns: List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value """ - if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance( + if (isinstance(to_embed, _Featurize) and isinstance(to_embed.value, str)) or isinstance( to_embed, str ): return _embed_string_type(to_embed, model, namespace) From 589fa656996678069e1332e8ea6d3e0104f6c6a2 Mon Sep 17 00:00:00 2001 From: Alexey Taymanov Date: Mon, 20 Nov 2023 10:11:40 -0500 Subject: [PATCH 3/4] _Featurize -> _Input --- src/learn_to_pick/base.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index 46c0e8f..74418bb 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -53,7 +53,7 @@ def ToSelectFrom(anything: Any) -> _Roled: raise ValueError("ToSelectFrom must be a list to select from") return _Roled(anything, Role.ACTIONS) -class _Featurize: +class _Input: def __init__(self, value: Any, keep: bool = True, embed: bool = False): self.value = value self.keep = keep @@ -65,20 +65,20 @@ def __str__(self) -> str: @staticmethod def create(value: Any, *args, **kwargs): if isinstance(value, _Roled): - return _Roled(_Featurize.create(value.value, *args, **kwargs), value.role) + return _Roled(_Input.create(value.value, *args, **kwargs), value.role) if isinstance(value, list): - return [_Featurize.create(v, *args, **kwargs) for v in value] + return [_Input.create(v, *args, **kwargs) for v in value] if isinstance(value, dict): - return {k: _Featurize.create(v, *args, **kwargs) for k, v in value.items()} - if isinstance(value, _Featurize): # should we swap? it will allow overwriting + return {k: _Input.create(v, *args, **kwargs) for k, v in value.items()} + if isinstance(value, _Input): # should we swap? it will allow overwriting return value - return _Featurize(value, *args, **kwargs) + return _Input(value, *args, **kwargs) __repr__ = __str__ def Embed(anything: Any, keep: bool = False) -> Any: - return _Featurize.create(anything, keep=keep, embed=True) + return _Input.create(anything, keep=keep, embed=True) def EmbedAndKeep(anything: Any) -> Any: return Embed(anything, keep=True) @@ -466,13 +466,13 @@ def run(self, *args, **kwargs) -> Dict[str, Any]: def _embed_string_type( - item: Union[str, _Featurize], model: Any, namespace: str + item: Union[str, _Input], model: Any, namespace: str ) -> Featurized: """Helper function to embed a string or an _Embed object.""" import re result = Featurized() - if isinstance(item, _Featurize): + if isinstance(item, _Input): if item.embed: result[namespace] = DenseFeatures(model.encode(item.value)) if item.keep: @@ -516,7 +516,7 @@ def _embed_list_type( def embed( - to_embed: Union[Union[str, _Featurize], Dict, List[Union[str, _Featurize]], List[Dict]], + to_embed: Union[Union[str, _Input], Dict, List[Union[str, _Input]], List[Dict]], model: Any, namespace: Optional[str] = None, ) -> Union[Featurized, List[Featurized]]: @@ -530,7 +530,7 @@ def embed( Returns: List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value """ - if (isinstance(to_embed, _Featurize) and isinstance(to_embed.value, str)) or isinstance( + if (isinstance(to_embed, _Input) and isinstance(to_embed.value, str)) or isinstance( to_embed, str ): return _embed_string_type(to_embed, model, namespace) From a3e5e1b050ef205e301adfc2c7ae62d770b5b29c Mon Sep 17 00:00:00 2001 From: Alexey Taymanov Date: Mon, 20 Nov 2023 10:14:19 -0500 Subject: [PATCH 4/4] black --- src/learn_to_pick/base.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index 74418bb..73cd9f8 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -40,7 +40,7 @@ def __init__(self, value: Any, role: Role): def __str__(self) -> str: return str(self.value) - + __repr__ = __str__ @@ -53,6 +53,7 @@ def ToSelectFrom(anything: Any) -> _Roled: raise ValueError("ToSelectFrom must be a list to select from") return _Roled(anything, Role.ACTIONS) + class _Input: def __init__(self, value: Any, keep: bool = True, embed: bool = False): self.value = value @@ -61,7 +62,7 @@ def __init__(self, value: Any, keep: bool = True, embed: bool = False): def __str__(self) -> str: return str(self.value) - + @staticmethod def create(value: Any, *args, **kwargs): if isinstance(value, _Roled): @@ -69,9 +70,9 @@ def create(value: Any, *args, **kwargs): if isinstance(value, list): return [_Input.create(v, *args, **kwargs) for v in value] if isinstance(value, dict): - return {k: _Input.create(v, *args, **kwargs) for k, v in value.items()} - if isinstance(value, _Input): # should we swap? it will allow overwriting - return value + return {k: _Input.create(v, *args, **kwargs) for k, v in value.items()} + if isinstance(value, _Input): # should we swap? it will allow overwriting + return value return _Input(value, *args, **kwargs) __repr__ = __str__ @@ -80,6 +81,7 @@ def create(value: Any, *args, **kwargs): def Embed(anything: Any, keep: bool = False) -> Any: return _Input.create(anything, keep=keep, embed=True) + def EmbedAndKeep(anything: Any) -> Any: return Embed(anything, keep=True) @@ -92,7 +94,11 @@ def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Exam def filter_inputs(inputs: Dict[str, Any], role: Role) -> Dict[str, Any]: - return {k: v.value for k, v in inputs.items() if isinstance(v, _Roled) and v.role == role} + return { + k: v.value + for k, v in inputs.items() + if isinstance(v, _Roled) and v.role == role + } # end helper functions