Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
"tune function and CLI command"
Browse files Browse the repository at this point in the history
  • Loading branch information
seraphimstreets committed Jun 22, 2022
1 parent a9bdd58 commit 68c923e
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 16 deletions.
1 change: 1 addition & 0 deletions dffml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class DuplicateName(Exception):
"train": "high_level.ml",
"predict": "high_level.ml",
"score": "high_level.ml",
"tune": "high_level.ml",
"load": "high_level.source",
"save": "high_level.source",
"run": "high_level.dataflow",
Expand Down
3 changes: 2 additions & 1 deletion dffml/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from .dataflow import Dataflow
from .config import Config
from .ml import Train, Accuracy, Predict
from .ml import Train, Accuracy, Predict, Tune
from .list import List

version = VERSION
Expand Down Expand Up @@ -366,6 +366,7 @@ class CLI(CMD):
train = Train
accuracy = Accuracy
predict = Predict
tune = Tune
service = services()
dataflow = Dataflow
config = Config
38 changes: 37 additions & 1 deletion dffml/cli/ml.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import inspect

from ..model.model import Model
from ..tuner.tuner import Tuner
from ..source.source import Sources, SubsetSources
from ..util.cli.cmd import CMD, CMDOutputOverride
from ..high_level.ml import train, predict, score
from ..high_level.ml import train, predict, score, tune
from ..util.config.fields import FIELD_SOURCES
from ..util.cli.cmds import (
SourcesCMD,
Expand All @@ -15,6 +16,7 @@
)
from ..base import config, field
from ..accuracy import AccuracyScorer

from ..feature import Features


Expand Down Expand Up @@ -118,3 +120,37 @@ class Predict(CMD):

record = PredictRecord
_all = PredictAll


@config
class TuneCMDConfig:
model: Model = field("Model used for ML", required=True)
tuner: Tuner = field("Tuner to optimize hyperparameters", required=True)
scorer: AccuracyScorer = field(
"Method to use to score accuracy", required=True
)
features: Features = field("Predict Feature(s)", default=Features())
sources: Sources = FIELD_SOURCES


class Tune(MLCMD):
"""Optimize hyperparameters of model with given sources"""

CONFIG = TuneCMDConfig

async def run(self):
# Instantiate the accuracy scorer class if for some reason it is a class
# at this point rather than an instance.
if inspect.isclass(self.scorer):
self.scorer = self.scorer.withconfig(self.extra_config)
if inspect.isclass(self.tuner):
self.tuner = self.tuner.withconfig(self.extra_config)

return await tune(
self.model,
self.tuner,
self.scorer,
self.features,
[self.sources[0]],
[self.sources[1]],
)
149 changes: 149 additions & 0 deletions dffml/high_level/ml.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import contextlib
from typing import Union, Dict, Any, List


from ..record import Record
from ..source.source import BaseSource
from ..feature import Feature, Features
from ..model import Model, ModelContext
from ..util.internal import records_to_sources, list_records_to_dict
from ..accuracy.accuracy import AccuracyScorer, AccuracyContext
from ..tuner import Tuner, TunerContext


async def train(model, *args: Union[BaseSource, Record, Dict[str, Any], List]):
Expand Down Expand Up @@ -293,3 +295,150 @@ async def predict(
)
if update:
await sctx.update(record)

async def tune(
model,
tuner: Union[Tuner, TunerContext],
accuracy_scorer: Union[AccuracyScorer, AccuracyContext],
features: Union[Feature, Features],
train_ds: Union[BaseSource, Record, Dict[str, Any], List],
valid_ds: Union[BaseSource, Record, Dict[str, Any], List],
) -> float:

"""
Tune the hyperparameters of a model with a given tuner.
Parameters
----------
model : Model
Machine Learning model to use. See :doc:`/plugins/dffml_model` for
models options.
tuner: Tuner
Hyperparameter tuning method to use. See :doc:`/plugins/dffml_tuner` for
tuner options.
train_ds : list
Input data for training. Could be a ``dict``, :py:class:`Record`,
filename, one of the data :doc:`/plugins/dffml_source`, or a filename
with the extension being one of the data sources.
valid_ds : list
Validation data for testing. Could be a ``dict``, :py:class:`Record`,
filename, one of the data :doc:`/plugins/dffml_source`, or a filename
with the extension being one of the data sources.
Returns
-------
float
A decimal value representing the result of the accuracy scorer on the given
test set. For instance, ClassificationAccuracy represents the percentage of correct
classifications made by the model.
Examples
--------
>>> import asyncio
>>> from dffml import *
>>> from dffml_model_xgboost.xgbclassifier import XGBClassifierModel
>>>
>>> model = XGBClassifierModel(
... features=Features(
... Feature("SepalLength", float, 1),
... Feature("SepalWidth", float, 1),
... Feature("PetalLength", float, 1),
... ),
... predict=Feature("classification", int, 1),
... location="tempdir",
... )
>>>
>>> async def main():
... await tune(
... model,
... ParameterGrid(
... parameters={
... "learning_rate": [0.01, 0.05, 0.1],
... "n_estimators": [20, 100, 200],
... "max_depth": [3,5,8]
... }
... ),
... MeanSquaredErrorAccuracy(),
... Features(
... Feature("SepalLength", float, 1),
... Feature("SepalWidth", float, 1),
... Feature("PetalLength", float, 1),
... ),
... [CSVSource(filename="iris_training.csv")],
... [CSVSource(filename="iris_test.csv")],
... )
>>>
>>> asyncio.run(main())
Accuracy: 0.0
"""

if not isinstance(features, (Feature, Features)):
raise TypeError(
f"features was {type(features)}: {features!r}. Should have been Feature or Features"
)
if isinstance(features, Feature):
features = Features(features)
if hasattr(model.config, "predict"):
if isinstance(model.config.predict, Features):
predict_feature = [
feature.name for feature in model.config.predict
]
else:
predict_feature = [model.config.predict.name]

if hasattr(model.config, "features") and any(
isinstance(td, list) for td in train_ds
):
train_ds = list_records_to_dict(
[feature.name for feature in model.config.features]
+ predict_feature,
*train_ds,
model=model,
)
if hasattr(model.config, "features") and any(
isinstance(td, list) for td in valid_ds
):
valid_ds = list_records_to_dict(
[feature.name for feature in model.config.features]
+ predict_feature,
*valid_ds,
model=model,
)

async with contextlib.AsyncExitStack() as astack:
# Open sources
train = await astack.enter_async_context(records_to_sources(*train_ds))
test = await astack.enter_async_context(records_to_sources(*valid_ds))
# Allow for keep models open
if isinstance(model, Model):
model = await astack.enter_async_context(model)
mctx = await astack.enter_async_context(model())
elif isinstance(model, ModelContext):
mctx = model

# Allow for keep models open
if isinstance(accuracy_scorer, AccuracyScorer):
accuracy_scorer = await astack.enter_async_context(accuracy_scorer)
actx = await astack.enter_async_context(accuracy_scorer())
elif isinstance(accuracy_scorer, AccuracyContext):
actx = accuracy_scorer
else:
# TODO Replace this with static type checking and maybe dynamic
# through something like pydantic. See issue #36
raise TypeError(f"{accuracy_scorer} is not an AccuracyScorer")

if isinstance(tuner, Tuner):
tuner = await astack.enter_async_context(tuner)
tctx = await astack.enter_async_context(tuner())
elif isinstance(tuner, TunerContext):
tctx = tuner
else:
raise TypeError(f"{tuner} is not an Tuner")

return float(
await tctx.optimize(mctx, model.config.predict, actx, train, test)
)

16 changes: 16 additions & 0 deletions dffml/noasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
train as high_level_train,
score as high_level_score,
predict as high_level_predict,
tune as high_level_tune,
)


Expand All @@ -24,6 +25,21 @@ def train(*args, **kwargs):
)
)

def tune(*args, **kwargs):
return asyncio.run(high_level_tune(*args, **kwargs))


tune.__doc__ = (
high_level_tune.__doc__.replace("await ", "")
.replace("async ", "")
.replace("asyncio.run(main())", "main()")
.replace(" >>> import asyncio\n", "")
.replace(
" >>> from dffml import *\n",
" >>> from dffml import *\n >>> from dffml.noasync import tune\n",
)
)


def score(*args, **kwargs):
return asyncio.run(high_level_score(*args, **kwargs))
Expand Down
1 change: 0 additions & 1 deletion dffml/tuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@
TunerContext,
Tuner,
)
from .parameter_grid import ParameterGrid
49 changes: 36 additions & 13 deletions dffml/tuner/parameter_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

@config
class ParameterGridConfig:
parameters: dict = field("Parameters to be optimized")
parameters: dict = field("Parameters to be optimized", default_factory= lambda:dict())
objective: str = field("How to optimize for the scorer", default="max")


class ParameterGridContext(TunerContext):
Expand All @@ -38,6 +39,8 @@ async def optimize(
Uses a grid of hyperparameters in the form of a dictionary present in config,
Trains each permutation of the grid of parameters and compares accuracy.
Sets model to the best parameters and returns highest accuracy.
If no hyperparameters are provided, the model is simply trained using
default parameters.
Parameters
----------
Expand All @@ -59,33 +62,53 @@ async def optimize(
Returns
-------
float
The highest score value
The best score value
"""
highest_acc = -1
# Score should be optimized based on objective
if self.parent.config.objective == "min":
highest_acc = float("inf")
elif self.parent.config.objective == "max":
highest_acc = -1

best_config = dict()
logging.info(
f"Optimizing model with parameter grid: {self.parent.config.parameters}"
)

names = list(self.parent.config.parameters.keys())
logging.info(names)
with model.config.no_enforce_immutable():

with model.parent.config.no_enforce_immutable():
for combination in itertools.product(
*list(self.parent.config.parameters.values())
):
logging.info(combination)

for i in range(len(combination)):
param = names[i]
setattr(model.config, names[i], combination[i])
await train(model, *train_data)
acc = await score(model, accuracy_scorer, feature, *test_data)
setattr(model.parent.config, names[i], combination[i])

await train(model.parent, *train_data)

acc = await score(
model.parent, accuracy_scorer, feature, *test_data
)

logging.info(f"Accuracy of the tuned model: {acc}")
if acc > highest_acc:
highest_acc = acc
for param in names:
best_config[param] = getattr(model.config, param)
if self.parent.config.objective == "min":
if acc < highest_acc:
highest_acc = acc

elif self.parent.config.objective == "max":
if acc > highest_acc:
highest_acc = acc
for param in names:
best_config[param] = getattr(
model.parent.config, param
)
for param in names:
setattr(model.config, param, best_config[param])
await train(model, *train_data)
setattr(model.parent.config, param, best_config[param])
await train(model.parent, *train_data)
logging.info(f"\nOptimal Hyper-parameters: {best_config}")
logging.info(f"Accuracy of Optimized model: {highest_acc}")
return highest_acc
Expand Down

0 comments on commit 68c923e

Please sign in to comment.