-
Notifications
You must be signed in to change notification settings - Fork 138
Tune function and CLI command #1397
base: main
Are you sure you want to change the base?
Changes from 6 commits
68c923e
4a7de3a
5623a7d
cef4d3e
41e4284
742be25
d4ca3b2
54d54d5
5a05c86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,14 @@ | ||
import contextlib | ||
from typing import Union, Dict, Any, List | ||
|
||
|
||
from ..record import Record | ||
from ..source.source import BaseSource | ||
from ..feature import Feature, Features | ||
from ..model import Model, ModelContext | ||
from ..util.internal import records_to_sources, list_records_to_dict | ||
from ..accuracy.accuracy import AccuracyScorer, AccuracyContext | ||
from ..tuner import Tuner, TunerContext | ||
|
||
|
||
async def train(model, *args: Union[BaseSource, Record, Dict[str, Any], List]): | ||
|
@@ -293,3 +295,149 @@ async def predict( | |
) | ||
if update: | ||
await sctx.update(record) | ||
|
||
async def tune( | ||
model, | ||
tuner: Union[Tuner, TunerContext], | ||
accuracy_scorer: Union[AccuracyScorer, AccuracyContext], | ||
features: Union[Feature, Features], | ||
train_ds: Union[BaseSource, Record, Dict[str, Any], List], | ||
valid_ds: Union[BaseSource, Record, Dict[str, Any], List], | ||
) -> float: | ||
|
||
""" | ||
Tune the hyperparameters of a model with a given tuner. | ||
|
||
|
||
Parameters | ||
---------- | ||
model : Model | ||
Machine Learning model to use. See :doc:`/plugins/dffml_model` for | ||
models options. | ||
tuner: Tuner | ||
Hyperparameter tuning method to use. See :doc:`/plugins/dffml_tuner` for | ||
tuner options. | ||
train_ds : list | ||
Input data for training. Could be a ``dict``, :py:class:`Record`, | ||
filename, one of the data :doc:`/plugins/dffml_source`, or a filename | ||
with the extension being one of the data sources. | ||
valid_ds : list | ||
Validation data for testing. Could be a ``dict``, :py:class:`Record`, | ||
filename, one of the data :doc:`/plugins/dffml_source`, or a filename | ||
with the extension being one of the data sources. | ||
|
||
|
||
Returns | ||
------- | ||
float | ||
A decimal value representing the result of the accuracy scorer on the given | ||
test set. For instance, ClassificationAccuracy represents the percentage of correct | ||
classifications made by the model. | ||
|
||
Examples | ||
-------- | ||
|
||
>>> import asyncio | ||
>>> from dffml import * | ||
>>> | ||
>>> model = SLRModel( | ||
... features=Features( | ||
... Feature("Years", int, 1), | ||
... ), | ||
... predict=Feature("Salary", int, 1), | ||
... location="tempdir", | ||
... ) | ||
>>> | ||
>>> async def main(): | ||
... score = await tune( | ||
... model, | ||
... ParameterGrid(objective="min"), | ||
... MeanSquaredErrorAccuracy(), | ||
... Features( | ||
... Feature("Years", float, 1), | ||
... ), | ||
... [ | ||
... {"Years": 0, "Salary": 10}, | ||
... {"Years": 1, "Salary": 20}, | ||
... {"Years": 2, "Salary": 30}, | ||
... {"Years": 3, "Salary": 40} | ||
... ], | ||
... [ | ||
... {"Years": 6, "Salary": 70}, | ||
... {"Years": 7, "Salary": 80} | ||
... ] | ||
... | ||
... ) | ||
... print(f"Tuner score: {score}") | ||
... | ||
>>> asyncio.run(main()) | ||
Tuner score: 0.0 | ||
""" | ||
|
||
if not isinstance(features, (Feature, Features)): | ||
raise TypeError( | ||
f"features was {type(features)}: {features!r}. Should have been Feature or Features" | ||
) | ||
if isinstance(features, Feature): | ||
features = Features(features) | ||
if hasattr(model.config, "predict"): | ||
if isinstance(model.config.predict, Features): | ||
predict_feature = [ | ||
feature.name for feature in model.config.predict | ||
] | ||
else: | ||
predict_feature = [model.config.predict.name] | ||
|
||
if hasattr(model.config, "features") and any( | ||
isinstance(td, list) for td in train_ds | ||
): | ||
train_ds = list_records_to_dict( | ||
[feature.name for feature in model.config.features] | ||
+ predict_feature, | ||
*train_ds, | ||
model=model, | ||
) | ||
if hasattr(model.config, "features") and any( | ||
isinstance(td, list) for td in valid_ds | ||
): | ||
valid_ds = list_records_to_dict( | ||
[feature.name for feature in model.config.features] | ||
+ predict_feature, | ||
*valid_ds, | ||
model=model, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. avoid repetition of code. |
||
|
||
async with contextlib.AsyncExitStack() as astack: | ||
# Open sources | ||
train = await astack.enter_async_context(records_to_sources(*train_ds)) | ||
test = await astack.enter_async_context(records_to_sources(*valid_ds)) | ||
# Allow for keep models open | ||
if isinstance(model, Model): | ||
model = await astack.enter_async_context(model) | ||
mctx = await astack.enter_async_context(model()) | ||
elif isinstance(model, ModelContext): | ||
mctx = model | ||
|
||
# Allow for keep models open | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #Allow scorers to be kept open |
||
if isinstance(accuracy_scorer, AccuracyScorer): | ||
accuracy_scorer = await astack.enter_async_context(accuracy_scorer) | ||
actx = await astack.enter_async_context(accuracy_scorer()) | ||
elif isinstance(accuracy_scorer, AccuracyContext): | ||
actx = accuracy_scorer | ||
else: | ||
# TODO Replace this with static type checking and maybe dynamic | ||
# through something like pydantic. See issue #36 | ||
raise TypeError(f"{accuracy_scorer} is not an AccuracyScorer") | ||
|
||
if isinstance(tuner, Tuner): | ||
tuner = await astack.enter_async_context(tuner) | ||
tctx = await astack.enter_async_context(tuner()) | ||
elif isinstance(tuner, TunerContext): | ||
tctx = tuner | ||
else: | ||
raise TypeError(f"{tuner} is not an Tuner") | ||
|
||
return float( | ||
await tctx.optimize(mctx, features, actx, train, test) | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,6 +51,7 @@ def inpath(binary): | |
("operations", "nlp"), | ||
("service", "http"), | ||
("source", "mysql"), | ||
("tuner", "bayes_opt_gp"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets have a simpler, more understandable entrypoint |
||
] | ||
|
||
|
||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../common/README.rst |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../common/README.rst |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../common/README.rst |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../common/README.rst |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../common/README.rst |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,3 @@ | |
TunerContext, | ||
Tuner, | ||
) | ||
from .parameter_grid import ParameterGrid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we want the train and test sets to be passed in as keyword arguments like this: