Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

featurization cleanup intermediate changes #36

Merged
merged 4 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 39 additions & 46 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,66 +20,66 @@
from learn_to_pick.model_repository import ModelRepository
from learn_to_pick.vw_logger import VwLogger
from learn_to_pick.features import Featurized, DenseFeatures, SparseFeatures
from enum import Enum

if TYPE_CHECKING:
import vowpal_wabbit_next as vw

logger = logging.getLogger(__name__)


class _BasedOn:
def __init__(self, value: Any):
self.value = value

def __str__(self) -> str:
return str(self.value)

__repr__ = __str__


def BasedOn(anything: Any) -> _BasedOn:
return _BasedOn(anything)
class Role(Enum):
CONTEXT = 1
ACTIONS = 2


class _ToSelectFrom:
def __init__(self, value: Any):
class _Roled:
def __init__(self, value: Any, role: Role):
self.value = value
self.role = role

def __str__(self) -> str:
return str(self.value)

__repr__ = __str__


def ToSelectFrom(anything: Any) -> _ToSelectFrom:
def BasedOn(anything: Any) -> _Roled:
return _Roled(anything, Role.CONTEXT)


def ToSelectFrom(anything: Any) -> _Roled:
if not isinstance(anything, list):
raise ValueError("ToSelectFrom must be a list to select from")
return _ToSelectFrom(anything)
return _Roled(anything, Role.ACTIONS)


class _Embed:
def __init__(self, value: Any, keep: bool = False):
class _Input:
def __init__(self, value: Any, keep: bool = True, embed: bool = False):
self.value = value
self.keep = keep
self.embed = embed

def __str__(self) -> str:
return str(self.value)

@staticmethod
def create(value: Any, *args, **kwargs):
if isinstance(value, _Roled):
return _Roled(_Input.create(value.value, *args, **kwargs), value.role)
if isinstance(value, list):
return [_Input.create(v, *args, **kwargs) for v in value]
if isinstance(value, dict):
return {k: _Input.create(v, *args, **kwargs) for k, v in value.items()}
if isinstance(value, _Input): # should we swap? it will allow overwriting
return value
return _Input(value, *args, **kwargs)

__repr__ = __str__


def Embed(anything: Any, keep: bool = False) -> Any:
if isinstance(anything, _ToSelectFrom):
return ToSelectFrom(Embed(anything.value, keep=keep))
elif isinstance(anything, _BasedOn):
return BasedOn(Embed(anything.value, keep=keep))
if isinstance(anything, list):
return [Embed(v, keep=keep) for v in anything]
elif isinstance(anything, dict):
return {k: Embed(v, keep=keep) for k, v in anything.items()}
elif isinstance(anything, _Embed):
return anything
return _Embed(anything, keep=keep)
return _Input.create(anything, keep=keep, embed=True)


def EmbedAndKeep(anything: Any) -> Any:
Expand All @@ -93,19 +93,11 @@ def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Exam
return [parser.parse_line(line) for line in input_str.split("\n")]


def get_based_on(inputs: Dict[str, Any]) -> Dict:
return {
k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _BasedOn)
}


def get_to_select_from(inputs: Dict[str, Any]) -> Dict:
def filter_inputs(inputs: Dict[str, Any], role: Role) -> Dict[str, Any]:
return {
k: inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _ToSelectFrom)
k: v.value
for k, v in inputs.items()
if isinstance(v, _Roled) and v.role == role
}


Expand Down Expand Up @@ -480,14 +472,15 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:


def _embed_string_type(
item: Union[str, _Embed], model: Any, namespace: str
item: Union[str, _Input], model: Any, namespace: str
) -> Featurized:
"""Helper function to embed a string or an _Embed object."""
import re

result = Featurized()
if isinstance(item, _Embed):
result[namespace] = DenseFeatures(model.encode(item.value))
if isinstance(item, _Input):
if item.embed:
result[namespace] = DenseFeatures(model.encode(item.value))
if item.keep:
keep_str = item.value.replace(" ", "_")
result[namespace] = {"default_ft": re.sub(r"[\t\n\r\f\v]+", " ", keep_str)}
Expand Down Expand Up @@ -529,7 +522,7 @@ def _embed_list_type(


def embed(
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
to_embed: Union[Union[str, _Input], Dict, List[Union[str, _Input]], List[Dict]],
model: Any,
namespace: Optional[str] = None,
) -> Union[Featurized, List[Featurized]]:
Expand All @@ -543,7 +536,7 @@ def embed(
Returns:
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
"""
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
if (isinstance(to_embed, _Input) and isinstance(to_embed.value, str)) or isinstance(
to_embed, str
):
return _embed_string_type(to_embed, model, namespace)
Expand Down
4 changes: 2 additions & 2 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(
selected: Optional[PickBestSelected] = None,
):
super().__init__(inputs=inputs, selected=selected or PickBestSelected())
self.to_select_from = base.get_to_select_from(inputs)
self.based_on = base.get_based_on(inputs)
self.to_select_from = base.filter_inputs(inputs, base.Role.ACTIONS)
self.based_on = base.filter_inputs(inputs, base.Role.CONTEXT)
if not self.to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
Expand Down
Loading