Skip to content

Commit

Permalink
Optuna hyperparameter optimization tutorial (#178)
Browse files Browse the repository at this point in the history
* first commit for the addition of the TabDDPM plugin

* Add DDPM test script and update DDPM plugin

* add TabDDPM class and refactor

* handle discrete cols and label generation

* add hparam space and update tests of DDPM

* debug and test DDPM

* update TensorDataLoader and training loop

* clear bugs

* debug for regression tasks

* debug for regression tasks; ALL TESTS PASSED

* remove the official repo of TabDDPM

* passed all pre-commit checks

* convert assert to conditional AssertionErrors

* added an auto annotation tool

* update auto-anno and generate annotations

* remove auto-anno and flake8 noqa

* add python<3.9 compatible annotations

* remove star import

* replace builtin type annos to typing annos

* resolve py38 compatibility issue

* tests/plugins/generic/test_ddpm.py

* change TabDDPM method signatures

* remove Iterator subscription

* update AssertionErrors, add EarlyStop callback, removed additional MLP, update logging

* remove TensorDataLoader, update test_ddpm

* update EarlyStopping

* add TabDDPM tutorial, update TabDDPM plugin and encoders

* add TabDDPM tutorial

* major update of FeatureEncoder and TabularEncoder

* add LogDistribution and LogIntDistribution

* update DDPM to use TabularEncoder

* update test_tabular_encoder and debug

* debug and DDPM tutorial OK

* debug LogDistribution and LogIntDistribution

* change discrete encoding of BinEncoder to passthrough;  passed all tests in test_tabular_encoder

* add tabnet to plugins/core/models

* add factory.py, let DDPM use TabNet, refactor

* update docstrings and refactor

* fix type annotation compatibility

* make SkipConnection serializable

* fix TabularEncoder.activation_layout

* remove unnecessary code

* fix minor bug and add more nn models in factory

* update pandas and torch version requirement

* update pandas and torch version requirement

* update ddpm tutorial

* restore setup.cfg

* restore setup.cfg

* replace LabelEncoder with OrdinalEncoder

* update setup.cfg

* update setup.cfg

* debug datetimeDistribution

* clean

* update setup.cfg and goggle test

* move DDPM tutorial to tutorials/plugins

* update tabnet.py reference

* update tab_ddpm

* update distribution, add optuna utils and tutorial

* update

* Fix plugin type of static_model of fflows

* update intlogdistribution and tutorial

* try fixing goggle

* add more activations

* minor fix

* update

* update

* update

* update

* Update tabular_encoder.py

* Update test_goggle.py

* Update tabular_encoder.py

* update

* update tutorial 8

* update

* default cat nonlin of goggle <- gumbel_softmax

* get_nonlin('softmax') <- GumbelSoftmax()

* remove debug logging

* update

* update

* fix merge

* fix merge

* update pip upgrade commands in workflows

* update pip upgrade commands in workflows

* keep pip version to 23.0.1 in workflows

* keep pip version to 23.0.1 in workflows

* update

* update

* update

* update

* update

* update

* fix distribution

* update

* move upgrading of wheel to prereq.txt

* update

---------

Co-authored-by: Bogdan Cebere <[email protected]>
Co-authored-by: Rob <[email protected]>
  • Loading branch information
3 people authored Apr 25, 2023
1 parent a4190e6 commit 46f83ad
Show file tree
Hide file tree
Showing 10 changed files with 440 additions and 83 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
if: ${{ matrix.os == 'macos-latest' }}
- name: Install dependencies
run: |
pip install pip==23.0.1
python -m pip install -U pip
pip install -r prereq.txt
- name: Test Core
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
if: ${{ matrix.os == 'macos-latest' }}
- name: Install dependencies
run: |
pip install pip==23.0.1
python -m pip install -U pip
pip install -r prereq.txt
- name: Test Core
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
if: ${{ matrix.os == 'macos-latest' }}
- name: Install dependencies
run: |
pip install pip==23.0.1
python -m pip install -U pip
pip install -r prereq.txt
pip install .[all]
Expand Down
3 changes: 2 additions & 1 deletion prereq.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy
torch<2.0
torch>=1.10.0,<2.0
tsai
wheel>=0.40
125 changes: 54 additions & 71 deletions src/synthcity/plugins/core/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,25 @@ def as_constraint(self) -> Constraints:

@abstractmethod
def min(self) -> Any:
"Get the min value of the distribution"
"""Get the min value of the distribution."""
...

@abstractmethod
def max(self) -> Any:
"Get the max value of the distribution"
"""Get the max value of the distribution."""
...

@abstractmethod
def __eq__(self, other: Any) -> bool:
...
return type(self) == type(other) and self.get() == other.get()

def __contains__(self, item: Any) -> bool:
"""
Example:
>>> dist = CategoricalDistribution(name="foo", choices=["a", "b", "c"])
>>> "a" in dist
True
"""
return self.has(item)

@abstractmethod
def dtype(self) -> str:
Expand All @@ -146,7 +154,7 @@ def _validate_choices(cls: Any, v: List, values: Dict) -> List:
raise ValueError(
"Invalid choices for CategoricalDistribution. Provide data or choices params"
)
return v
return sorted(set(v))

def get(self) -> List[Any]:
return [self.name, self.choices]
Expand Down Expand Up @@ -176,12 +184,6 @@ def min(self) -> Any:
def max(self) -> Any:
return max(self.choices)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, CategoricalDistribution):
return False

return self.name == other.name and set(self.choices) == set(other.choices)

def dtype(self) -> str:
types = {
"object": 0,
Expand Down Expand Up @@ -259,33 +261,24 @@ def min(self) -> Any:
def max(self) -> Any:
return self.high

def __eq__(self, other: Any) -> bool:
if not isinstance(other, type(self)):
return False

return (
self.name == other.name
and self.low == other.low
and self.high == other.high
)

def dtype(self) -> str:
return "float"


class LogDistribution(FloatDistribution):
low: float = np.finfo(np.float64).tiny
high: float = np.finfo(np.float64).max
base: float = 2.0

def get(self) -> List[Any]:
return [self.name, self.low, self.high]

def sample(self, count: int = 1) -> Any:
np.random.seed(self.random_state)
msamples = self.sample_marginal(count)
if msamples is not None:
return msamples
lo = np.log2(self.low) / np.log2(self.base)
hi = np.log2(self.high) / np.log2(self.base)
return self.base ** np.random.uniform(lo, hi, count)
lo, hi = np.log2(self.low), np.log2(self.high)
return 2.0 ** np.random.uniform(lo, hi, count)


class IntegerDistribution(Distribution):
Expand Down Expand Up @@ -313,6 +306,12 @@ def _validate_high_thresh(cls: Any, v: int, values: Dict) -> int:
return int(values[mkey].index.max())
return v

@validator("step", always=True)
def _validate_step(cls: Any, v: int, values: Dict) -> int:
if v < 1:
raise ValueError("Step must be greater than 0")
return v

def get(self) -> List[Any]:
return [self.name, self.low, self.high, self.step]

Expand All @@ -322,9 +321,9 @@ def sample(self, count: int = 1) -> Any:
if msamples is not None:
return msamples

high = (self.high + 1 - self.low) // self.step
s = np.random.choice(high, count)
return s * self.step + self.low
steps = (self.high - self.low) // self.step
samples = np.random.choice(steps + 1, count)
return samples * self.step + self.low

def has(self, val: Any) -> bool:
return self.low <= val and val <= self.high
Expand All @@ -347,34 +346,31 @@ def min(self) -> Any:
def max(self) -> Any:
return self.high

def __eq__(self, other: Any) -> bool:
if not isinstance(other, IntegerDistribution):
return False

return (
self.name == other.name
and self.low == other.low
and self.high == other.high
)

def dtype(self) -> str:
return "int"


class LogIntDistribution(FloatDistribution):
low: float = 1.0
high: float = float(np.iinfo(np.int64).max)
base: float = 2.0
class IntLogDistribution(IntegerDistribution):
low: int = 1
high: int = np.iinfo(np.int64).max

@validator("step", always=True)
def _validate_step(cls: Any, v: int, values: Dict) -> int:
if v != 1:
raise ValueError("Step must be 1 for IntLogDistribution")
return v

def get(self) -> List[Any]:
return [self.name, self.low, self.high]

def sample(self, count: int = 1) -> Any:
np.random.seed(self.random_state)
msamples = self.sample_marginal(count)
if msamples is not None:
return msamples
lo = np.log2(self.low) / np.log2(self.base)
hi = np.log2(self.high) / np.log2(self.base)
s = self.base ** np.random.uniform(lo, hi, count)
return s.astype(int)
lo, hi = np.log2(self.low), np.log2(self.high)
samples = 2.0 ** np.random.uniform(lo, hi, count)
return samples.astype(int)


class DatetimeDistribution(Distribution):
Expand All @@ -383,49 +379,46 @@ class DatetimeDistribution(Distribution):
:parts: 1
"""

offset: int = 120
low: datetime = datetime.utcfromtimestamp(0)
high: datetime = datetime.now()

@validator("offset", always=True)
def _validate_offset(cls: Any, v: int) -> int:
if v < 0:
raise ValueError("offset must be greater than 0")
return v
step: timedelta = timedelta(microseconds=1)
offset: timedelta = timedelta(seconds=120)

@validator("low", always=True)
def _validate_low_thresh(cls: Any, v: datetime, values: Dict) -> datetime:
mkey = "marginal_distribution"
if mkey in values and values[mkey] is not None:
v = values[mkey].index.min()
return v - timedelta(seconds=values["offset"])
return v

@validator("high", always=True)
def _validate_high_thresh(cls: Any, v: datetime, values: Dict) -> datetime:
mkey = "marginal_distribution"
if mkey in values and values[mkey] is not None:
v = values[mkey].index.max()
return v + timedelta(seconds=values["offset"])
return v

def get(self) -> List[Any]:
return [self.name, self.low, self.high]
return [self.name, self.low, self.high, self.step, self.offset]

def sample(self, count: int = 1) -> Any:
np.random.seed(self.random_state)
msamples = self.sample_marginal(count)
if msamples is not None:
return msamples

delta = self.high - self.low
return self.low + delta * np.random.rand(count)
n = (self.high - self.low) // self.step + 1
samples = np.round(np.random.rand(count) * n - 0.5)
return self.low + samples * self.step

def has(self, val: datetime) -> bool:
return self.low <= val and val <= self.high

def includes(self, other: "Distribution") -> bool:
return self.min() - timedelta(
seconds=self.offset
) <= other.min() and other.max() <= self.max() + timedelta(seconds=self.offset)
return (
self.min() - self.offset <= other.min()
and other.max() <= self.max() + self.offset
)

def as_constraint(self) -> Constraints:
return Constraints(
Expand All @@ -442,16 +435,6 @@ def min(self) -> Any:
def max(self) -> Any:
return self.high

def __eq__(self, other: Any) -> bool:
if not isinstance(other, DatetimeDistribution):
return False

return (
self.name == other.name
and self.low == other.low
and self.high == other.high
)

def dtype(self) -> str:
return "datetime"

Expand Down
2 changes: 2 additions & 0 deletions src/synthcity/plugins/core/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DatetimeEncoder,
FeatureEncoder,
GaussianQuantileTransformer,
LabelEncoder,
MinMaxScaler,
OneHotEncoder,
OrdinalEncoder,
Expand Down Expand Up @@ -75,6 +76,7 @@
datetime=DatetimeEncoder,
onehot=OneHotEncoder,
ordinal=OrdinalEncoder,
label=LabelEncoder,
standard=StandardScaler,
minmax=MinMaxScaler,
robust=RobustScaler,
Expand Down
8 changes: 4 additions & 4 deletions src/synthcity/plugins/generic/plugin_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from synthcity.plugins.core.distribution import (
Distribution,
IntegerDistribution,
IntLogDistribution,
LogDistribution,
LogIntDistribution,
)
from synthcity.plugins.core.models.tabular_ddpm import TabDDPM
from synthcity.plugins.core.models.tabular_encoder import TabularEncoder
Expand Down Expand Up @@ -180,11 +180,11 @@ def hyperparameter_space(**kwargs: Any) -> List[Distribution]:
"""
return [
LogDistribution(name="lr", low=1e-5, high=1e-1),
LogIntDistribution(name="batch_size", low=256, high=4096),
IntLogDistribution(name="batch_size", low=256, high=4096),
IntegerDistribution(name="num_timesteps", low=10, high=1000),
LogIntDistribution(name="n_iter", low=1000, high=10000),
IntLogDistribution(name="n_iter", low=1000, high=10000),
# IntegerDistribution(name="n_layers_hidden", low=2, high=8),
# LogIntDistribution(name="dim_hidden", low=128, high=1024),
# IntLogDistribution(name="dim_hidden", low=128, high=1024),
]

def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin":
Expand Down
6 changes: 2 additions & 4 deletions src/synthcity/plugins/time_series/plugin_fflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fflows import FourierFlow

# synthcity absolute
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import DataLoader
from synthcity.plugins.core.distribution import (
CategoricalDistribution,
Expand All @@ -24,7 +25,6 @@
from synthcity.plugins.core.models.ts_model import TimeSeriesModel
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema
from synthcity.plugins.generic import GenericPlugins
from synthcity.utils.constants import DEVICE


Expand Down Expand Up @@ -134,9 +134,7 @@ def __init__(
normalize=normalize,
).to(device)

self.static_model = GenericPlugins().get(
self.static_model_name, device=self.device
)
self.static_model = Plugins().get(self.static_model_name, device=self.device)

self.temporal_encoder = TimeSeriesTabularEncoder(
max_clusters=encoder_max_clusters
Expand Down
27 changes: 27 additions & 0 deletions src/synthcity/utils/optuna_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# stdlib
from typing import Any, Dict, List

# third party
import optuna

# synthcity absolute
import synthcity.plugins.core.distribution as D


def suggest(trial: optuna.Trial, dist: D.Distribution) -> Any:
if isinstance(dist, D.FloatDistribution):
return trial.suggest_float(dist.name, dist.low, dist.high)
elif isinstance(dist, D.LogDistribution):
return trial.suggest_float(dist.name, dist.low, dist.high, log=True)
elif isinstance(dist, D.IntegerDistribution):
return trial.suggest_int(dist.name, dist.low, dist.high, dist.step)
elif isinstance(dist, D.IntLogDistribution):
return trial.suggest_int(dist.name, dist.low, dist.high, log=True)
elif isinstance(dist, D.CategoricalDistribution):
return trial.suggest_categorical(dist.name, dist.choices)
else:
raise ValueError(f"Unknown dist: {dist}")


def suggest_all(trial: optuna.Trial, distributions: List[D.Distribution]) -> Dict:
return {dist.name: suggest(trial, dist) for dist in distributions}
Loading

0 comments on commit 46f83ad

Please sign in to comment.