-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optuna hyperparameter optimization tutorial (#178)
* first commit for the addition of the TabDDPM plugin * Add DDPM test script and update DDPM plugin * add TabDDPM class and refactor * handle discrete cols and label generation * add hparam space and update tests of DDPM * debug and test DDPM * update TensorDataLoader and training loop * clear bugs * debug for regression tasks * debug for regression tasks; ALL TESTS PASSED * remove the official repo of TabDDPM * passed all pre-commit checks * convert assert to conditional AssertionErrors * added an auto annotation tool * update auto-anno and generate annotations * remove auto-anno and flake8 noqa * add python<3.9 compatible annotations * remove star import * replace builtin type annos to typing annos * resolve py38 compatibility issue * tests/plugins/generic/test_ddpm.py * change TabDDPM method signatures * remove Iterator subscription * update AssertionErrors, add EarlyStop callback, removed additional MLP, update logging * remove TensorDataLoader, update test_ddpm * update EarlyStopping * add TabDDPM tutorial, update TabDDPM plugin and encoders * add TabDDPM tutorial * major update of FeatureEncoder and TabularEncoder * add LogDistribution and LogIntDistribution * update DDPM to use TabularEncoder * update test_tabular_encoder and debug * debug and DDPM tutorial OK * debug LogDistribution and LogIntDistribution * change discrete encoding of BinEncoder to passthrough; passed all tests in test_tabular_encoder * add tabnet to plugins/core/models * add factory.py, let DDPM use TabNet, refactor * update docstrings and refactor * fix type annotation compatibility * make SkipConnection serializable * fix TabularEncoder.activation_layout * remove unnecessary code * fix minor bug and add more nn models in factory * update pandas and torch version requirement * update pandas and torch version requirement * update ddpm tutorial * restore setup.cfg * restore setup.cfg * replace LabelEncoder with OrdinalEncoder * update setup.cfg * update setup.cfg * debug datetimeDistribution * clean * update setup.cfg and goggle test * move DDPM tutorial to tutorials/plugins * update tabnet.py reference * update tab_ddpm * update distribution, add optuna utils and tutorial * update * Fix plugin type of static_model of fflows * update intlogdistribution and tutorial * try fixing goggle * add more activations * minor fix * update * update * update * update * Update tabular_encoder.py * Update test_goggle.py * Update tabular_encoder.py * update * update tutorial 8 * update * default cat nonlin of goggle <- gumbel_softmax * get_nonlin('softmax') <- GumbelSoftmax() * remove debug logging * update * update * fix merge * fix merge * update pip upgrade commands in workflows * update pip upgrade commands in workflows * keep pip version to 23.0.1 in workflows * keep pip version to 23.0.1 in workflows * update * update * update * update * update * update * fix distribution * update * move upgrading of wheel to prereq.txt * update --------- Co-authored-by: Bogdan Cebere <[email protected]> Co-authored-by: Rob <[email protected]>
- Loading branch information
1 parent
a4190e6
commit 46f83ad
Showing
10 changed files
with
440 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
numpy | ||
torch<2.0 | ||
torch>=1.10.0,<2.0 | ||
tsai | ||
wheel>=0.40 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# stdlib | ||
from typing import Any, Dict, List | ||
|
||
# third party | ||
import optuna | ||
|
||
# synthcity absolute | ||
import synthcity.plugins.core.distribution as D | ||
|
||
|
||
def suggest(trial: optuna.Trial, dist: D.Distribution) -> Any: | ||
if isinstance(dist, D.FloatDistribution): | ||
return trial.suggest_float(dist.name, dist.low, dist.high) | ||
elif isinstance(dist, D.LogDistribution): | ||
return trial.suggest_float(dist.name, dist.low, dist.high, log=True) | ||
elif isinstance(dist, D.IntegerDistribution): | ||
return trial.suggest_int(dist.name, dist.low, dist.high, dist.step) | ||
elif isinstance(dist, D.IntLogDistribution): | ||
return trial.suggest_int(dist.name, dist.low, dist.high, log=True) | ||
elif isinstance(dist, D.CategoricalDistribution): | ||
return trial.suggest_categorical(dist.name, dist.choices) | ||
else: | ||
raise ValueError(f"Unknown dist: {dist}") | ||
|
||
|
||
def suggest_all(trial: optuna.Trial, distributions: List[D.Distribution]) -> Dict: | ||
return {dist.name: suggest(trial, dist) for dist in distributions} |
Oops, something went wrong.