Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
"tune function and CLI command"
Browse files Browse the repository at this point in the history
  • Loading branch information
seraphimstreets committed Jun 24, 2022
1 parent a9bdd58 commit 4a7de3a
Show file tree
Hide file tree
Showing 12 changed files with 243 additions and 16 deletions.
1 change: 1 addition & 0 deletions dffml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class DuplicateName(Exception):
"train": "high_level.ml",
"predict": "high_level.ml",
"score": "high_level.ml",
"tune": "high_level.ml",
"load": "high_level.source",
"save": "high_level.source",
"run": "high_level.dataflow",
Expand Down
3 changes: 2 additions & 1 deletion dffml/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from .dataflow import Dataflow
from .config import Config
from .ml import Train, Accuracy, Predict
from .ml import Train, Accuracy, Predict, Tune
from .list import List

version = VERSION
Expand Down Expand Up @@ -366,6 +366,7 @@ class CLI(CMD):
train = Train
accuracy = Accuracy
predict = Predict
tune = Tune
service = services()
dataflow = Dataflow
config = Config
38 changes: 37 additions & 1 deletion dffml/cli/ml.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import inspect

from ..model.model import Model
from ..tuner.tuner import Tuner
from ..source.source import Sources, SubsetSources
from ..util.cli.cmd import CMD, CMDOutputOverride
from ..high_level.ml import train, predict, score
from ..high_level.ml import train, predict, score, tune
from ..util.config.fields import FIELD_SOURCES
from ..util.cli.cmds import (
SourcesCMD,
Expand All @@ -15,6 +16,7 @@
)
from ..base import config, field
from ..accuracy import AccuracyScorer

from ..feature import Features


Expand Down Expand Up @@ -118,3 +120,37 @@ class Predict(CMD):

record = PredictRecord
_all = PredictAll


@config
class TuneCMDConfig:
model: Model = field("Model used for ML", required=True)
tuner: Tuner = field("Tuner to optimize hyperparameters", required=True)
scorer: AccuracyScorer = field(
"Method to use to score accuracy", required=True
)
features: Features = field("Predict Feature(s)", default=Features())
sources: Sources = FIELD_SOURCES


class Tune(MLCMD):
"""Optimize hyperparameters of model with given sources"""

CONFIG = TuneCMDConfig

async def run(self):
# Instantiate the accuracy scorer class if for some reason it is a class
# at this point rather than an instance.
if inspect.isclass(self.scorer):
self.scorer = self.scorer.withconfig(self.extra_config)
if inspect.isclass(self.tuner):
self.tuner = self.tuner.withconfig(self.extra_config)

return await tune(
self.model,
self.tuner,
self.scorer,
self.features,
[self.sources[0]],
[self.sources[1]],
)
148 changes: 148 additions & 0 deletions dffml/high_level/ml.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import contextlib
from typing import Union, Dict, Any, List


from ..record import Record
from ..source.source import BaseSource
from ..feature import Feature, Features
from ..model import Model, ModelContext
from ..util.internal import records_to_sources, list_records_to_dict
from ..accuracy.accuracy import AccuracyScorer, AccuracyContext
from ..tuner import Tuner, TunerContext


async def train(model, *args: Union[BaseSource, Record, Dict[str, Any], List]):
Expand Down Expand Up @@ -293,3 +295,149 @@ async def predict(
)
if update:
await sctx.update(record)

async def tune(
model,
tuner: Union[Tuner, TunerContext],
accuracy_scorer: Union[AccuracyScorer, AccuracyContext],
features: Union[Feature, Features],
train_ds: Union[BaseSource, Record, Dict[str, Any], List],
valid_ds: Union[BaseSource, Record, Dict[str, Any], List],
) -> float:

"""
Tune the hyperparameters of a model with a given tuner.
Parameters
----------
model : Model
Machine Learning model to use. See :doc:`/plugins/dffml_model` for
models options.
tuner: Tuner
Hyperparameter tuning method to use. See :doc:`/plugins/dffml_tuner` for
tuner options.
train_ds : list
Input data for training. Could be a ``dict``, :py:class:`Record`,
filename, one of the data :doc:`/plugins/dffml_source`, or a filename
with the extension being one of the data sources.
valid_ds : list
Validation data for testing. Could be a ``dict``, :py:class:`Record`,
filename, one of the data :doc:`/plugins/dffml_source`, or a filename
with the extension being one of the data sources.
Returns
-------
float
A decimal value representing the result of the accuracy scorer on the given
test set. For instance, ClassificationAccuracy represents the percentage of correct
classifications made by the model.
Examples
--------
>>> import asyncio
>>> from dffml import *
>>>
>>> model = SLRModel(
... features=Features(
... Feature("Years", int, 1),
... ),
... predict=Feature("Salary", int, 1),
... location="tempdir",
... )
>>>
>>> async def main():
... score = await tune(
... model,
... ParameterGrid(objective="min"),
... MeanSquaredErrorAccuracy(),
... Features(
... Feature("Years", float, 1),
... ),
... [
... {"Years": 0, "Salary": 10},
... {"Years": 1, "Salary": 20},
... {"Years": 2, "Salary": 30},
... {"Years": 3, "Salary": 40}
... ],
... [
... {"Years": 6, "Salary": 70},
... {"Years": 7, "Salary": 80}
... ]
...
... )
... print(f"Tuner score: {score}")
...
>>> asyncio.run(main())
Tuner score: 0.0
"""

if not isinstance(features, (Feature, Features)):
raise TypeError(
f"features was {type(features)}: {features!r}. Should have been Feature or Features"
)
if isinstance(features, Feature):
features = Features(features)
if hasattr(model.config, "predict"):
if isinstance(model.config.predict, Features):
predict_feature = [
feature.name for feature in model.config.predict
]
else:
predict_feature = [model.config.predict.name]

if hasattr(model.config, "features") and any(
isinstance(td, list) for td in train_ds
):
train_ds = list_records_to_dict(
[feature.name for feature in model.config.features]
+ predict_feature,
*train_ds,
model=model,
)
if hasattr(model.config, "features") and any(
isinstance(td, list) for td in valid_ds
):
valid_ds = list_records_to_dict(
[feature.name for feature in model.config.features]
+ predict_feature,
*valid_ds,
model=model,
)

async with contextlib.AsyncExitStack() as astack:
# Open sources
train = await astack.enter_async_context(records_to_sources(*train_ds))
test = await astack.enter_async_context(records_to_sources(*valid_ds))
# Allow for keep models open
if isinstance(model, Model):
model = await astack.enter_async_context(model)
mctx = await astack.enter_async_context(model())
elif isinstance(model, ModelContext):
mctx = model

# Allow for keep models open
if isinstance(accuracy_scorer, AccuracyScorer):
accuracy_scorer = await astack.enter_async_context(accuracy_scorer)
actx = await astack.enter_async_context(accuracy_scorer())
elif isinstance(accuracy_scorer, AccuracyContext):
actx = accuracy_scorer
else:
# TODO Replace this with static type checking and maybe dynamic
# through something like pydantic. See issue #36
raise TypeError(f"{accuracy_scorer} is not an AccuracyScorer")

if isinstance(tuner, Tuner):
tuner = await astack.enter_async_context(tuner)
tctx = await astack.enter_async_context(tuner())
elif isinstance(tuner, TunerContext):
tctx = tuner
else:
raise TypeError(f"{tuner} is not an Tuner")

return float(
await tctx.optimize(mctx, model.config.predict, actx, train, test)
)

16 changes: 16 additions & 0 deletions dffml/noasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
train as high_level_train,
score as high_level_score,
predict as high_level_predict,
tune as high_level_tune,
)


Expand All @@ -24,6 +25,21 @@ def train(*args, **kwargs):
)
)

def tune(*args, **kwargs):
return asyncio.run(high_level_tune(*args, **kwargs))


tune.__doc__ = (
high_level_tune.__doc__.replace("await ", "")
.replace("async ", "")
.replace("asyncio.run(main())", "main()")
.replace(" >>> import asyncio\n", "")
.replace(
" >>> from dffml import *\n",
" >>> from dffml import *\n >>> from dffml.noasync import tune\n",
)
)


def score(*args, **kwargs):
return asyncio.run(high_level_score(*args, **kwargs))
Expand Down
1 change: 0 additions & 1 deletion dffml/skel/config/README.rst

This file was deleted.

1 change: 1 addition & 0 deletions dffml/skel/config/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../common/README.rst
1 change: 0 additions & 1 deletion dffml/skel/model/README.rst

This file was deleted.

1 change: 1 addition & 0 deletions dffml/skel/model/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../common/README.rst
1 change: 0 additions & 1 deletion dffml/skel/operations/README.rst

This file was deleted.

1 change: 1 addition & 0 deletions dffml/skel/operations/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../common/README.rst
1 change: 0 additions & 1 deletion dffml/skel/service/README.rst

This file was deleted.

1 change: 1 addition & 0 deletions dffml/skel/service/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../common/README.rst
1 change: 0 additions & 1 deletion dffml/skel/source/README.rst

This file was deleted.

1 change: 1 addition & 0 deletions dffml/skel/source/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../common/README.rst
1 change: 0 additions & 1 deletion dffml/tuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@
TunerContext,
Tuner,
)
from .parameter_grid import ParameterGrid
Loading

0 comments on commit 4a7de3a

Please sign in to comment.