Skip to content

Commit

Permalink
Configurators: Move parameter_values to _Sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
moi90 committed Dec 26, 2023
1 parent b92935f commit a8e7603
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 237 deletions.
8 changes: 4 additions & 4 deletions experitur/configurators/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ def __init__(
"min_count": min_count,
}

@property
def parameter_values(self):
return {}

class _Sampler(ConfigurationSampler):
configurator: "Prune"

Expand All @@ -58,3 +54,7 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]:
)

yield merge_dicts(parent_configuration, pruning_config=pruning_config)

@property
def parameter_values(self):
return self.parent.parameter_values
24 changes: 15 additions & 9 deletions experitur/configurators/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from scipy.stats import distributions

from experitur import get_current_context
from experitur.core.configurators import ConfigurationSampler, Configurator
from experitur.core.configurators import (
ConfigurationSampler,
Configurator,
merge_parameter_values,
)
from experitur.helpers.merge_dicts import merge_dicts


Expand Down Expand Up @@ -73,18 +77,10 @@ def __init__(self, distributions: Dict[str, Union[List, Any]], n_iter: int):
self.distributions = distributions
self.n_iter = n_iter

@property
def parameter_values(self) -> Mapping[str, Container]:
return {
k: tuple(v) if isinstance(v, Iterable) else _DistWrapper(v)
for k, v in self.distributions.items()
}

class _Sampler(ConfigurationSampler):
configurator: "Random"

def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]:

distributions, exclude = self.prepare_values_exclude(
self.configurator.distributions, exclude
)
Expand Down Expand Up @@ -137,3 +133,13 @@ def sample(self, exclude: Optional[Set] = None) -> Iterator[Mapping]:
break

yield merge_dicts(parent_configuration, parameters=params)

@property
def parameter_values(self) -> Mapping[str, Container]:
return merge_parameter_values(
self.parent.parameter_values,
{
k: tuple(v) if isinstance(v, Iterable) else _DistWrapper(v)
for k, v in self.configurator.distributions.items()
},
)
Loading

0 comments on commit a8e7603

Please sign in to comment.