This repository has been archived by the owner on Aug 25, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 138
Tune function and CLI command #1397
Open
seraphimstreets
wants to merge
9
commits into
intel:main
Choose a base branch
from
seraphimstreets:tunecli
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
68c923e
"tune function and CLI command"
seraphimstreets 4a7de3a
"tune function and CLI command"
seraphimstreets 5623a7d
Merge branch 'tunecli' of https://github.com/seraphimstreets/dffml in…
seraphimstreets cef4d3e
"unit tests for xgboost, pytorch, spacy"
seraphimstreets 41e4284
"unit test cleaning"
seraphimstreets 742be25
"random_search and bayes_opt_gp"
seraphimstreets d4ca3b2
Minor fixes and documentation
seraphimstreets 54d54d5
Added requested changes
seraphimstreets 5a05c86
"minor doctest edits"
seraphimstreets File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,14 @@ | ||
import contextlib | ||
from typing import Union, Dict, Any, List | ||
|
||
|
||
from ..record import Record | ||
from ..source.source import BaseSource | ||
from ..feature import Feature, Features | ||
from ..model import Model, ModelContext | ||
from ..util.internal import records_to_sources, list_records_to_dict | ||
from ..accuracy.accuracy import AccuracyScorer, AccuracyContext | ||
from ..tuner import Tuner, TunerContext | ||
|
||
|
||
async def train(model, *args: Union[BaseSource, Record, Dict[str, Any], List]): | ||
|
@@ -293,3 +295,146 @@ async def predict( | |
) | ||
if update: | ||
await sctx.update(record) | ||
|
||
async def tune( | ||
model, | ||
tuner: Union[Tuner, TunerContext], | ||
accuracy_scorer: Union[AccuracyScorer, AccuracyContext], | ||
features: Union[Feature, Features], | ||
train_ds: Union[BaseSource, Record, Dict[str, Any], List], | ||
valid_ds: Union[BaseSource, Record, Dict[str, Any], List], | ||
) -> float: | ||
|
||
""" | ||
Tune the hyperparameters of a model with a given tuner. | ||
|
||
|
||
Parameters | ||
---------- | ||
model : Model | ||
Machine Learning model to use. See :doc:`/plugins/dffml_model` for | ||
models options. | ||
tuner: Tuner | ||
Hyperparameter tuning method to use. See :doc:`/plugins/dffml_tuner` for | ||
tuner options. | ||
train_ds : list | ||
Input data for training. Could be a ``dict``, :py:class:`Record`, | ||
filename, one of the data :doc:`/plugins/dffml_source`, or a filename | ||
with the extension being one of the data sources. | ||
valid_ds : list | ||
Validation data for testing. Could be a ``dict``, :py:class:`Record`, | ||
filename, one of the data :doc:`/plugins/dffml_source`, or a filename | ||
with the extension being one of the data sources. | ||
|
||
|
||
Returns | ||
------- | ||
float | ||
A decimal value representing the result of the accuracy scorer on the given | ||
test set. For instance, ClassificationAccuracy represents the percentage of correct | ||
classifications made by the model. | ||
|
||
Examples | ||
-------- | ||
|
||
>>> import asyncio | ||
>>> from dffml import * | ||
>>> | ||
>>> model = SLRModel( | ||
... features=Features( | ||
... Feature("Years", int, 1), | ||
... ), | ||
... predict=Feature("Salary", int, 1), | ||
... location="tempdir", | ||
... ) | ||
>>> | ||
>>> async def main(): | ||
... score = await tune( | ||
... model, | ||
... ParameterGrid(objective="min"), | ||
... MeanSquaredErrorAccuracy(), | ||
... Features( | ||
... Feature("Years", float, 1), | ||
... ), | ||
... [ | ||
... {"Years": 0, "Salary": 10}, | ||
... {"Years": 1, "Salary": 20}, | ||
... {"Years": 2, "Salary": 30}, | ||
... {"Years": 3, "Salary": 40} | ||
... ], | ||
... [ | ||
... {"Years": 6, "Salary": 70}, | ||
... {"Years": 7, "Salary": 80} | ||
... ] | ||
... | ||
... ) | ||
... print(f"Tuner score: {score}") | ||
... | ||
>>> asyncio.run(main()) | ||
Tuner score: 0.0 | ||
""" | ||
|
||
if not isinstance(features, (Feature, Features)): | ||
raise TypeError( | ||
f"features was {type(features)}: {features!r}. Should have been Feature or Features" | ||
) | ||
if isinstance(features, Feature): | ||
features = Features(features) | ||
if hasattr(model.config, "predict"): | ||
if isinstance(model.config.predict, Features): | ||
predict_feature = [ | ||
feature.name for feature in model.config.predict | ||
] | ||
else: | ||
predict_feature = [model.config.predict.name] | ||
|
||
def records_to_dict_check(ds): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's pull this out into the global scope or |
||
if hasattr(model.config, "features") and any( | ||
isinstance(td, list) for td in ds | ||
): | ||
return list_records_to_dict( | ||
[feature.name for feature in model.config.features] | ||
+ predict_feature, | ||
*ds, | ||
model=model, | ||
) | ||
return ds | ||
|
||
train_ds = records_to_dict_check(train_ds) | ||
valid_ds = records_to_dict_check(valid_ds) | ||
|
||
|
||
async with contextlib.AsyncExitStack() as astack: | ||
# Open sources | ||
train = await astack.enter_async_context(records_to_sources(*train_ds)) | ||
test = await astack.enter_async_context(records_to_sources(*valid_ds)) | ||
# Allow for keep models open | ||
if isinstance(model, Model): | ||
model = await astack.enter_async_context(model) | ||
mctx = await astack.enter_async_context(model()) | ||
elif isinstance(model, ModelContext): | ||
mctx = model | ||
|
||
# Allow for scorers to be kept open | ||
if isinstance(accuracy_scorer, AccuracyScorer): | ||
accuracy_scorer = await astack.enter_async_context(accuracy_scorer) | ||
actx = await astack.enter_async_context(accuracy_scorer()) | ||
elif isinstance(accuracy_scorer, AccuracyContext): | ||
actx = accuracy_scorer | ||
else: | ||
# TODO Replace this with static type checking and maybe dynamic | ||
# through something like pydantic. See issue #36 | ||
raise TypeError(f"{accuracy_scorer} is not an AccuracyScorer") | ||
|
||
if isinstance(tuner, Tuner): | ||
tuner = await astack.enter_async_context(tuner) | ||
tctx = await astack.enter_async_context(tuner()) | ||
elif isinstance(tuner, TunerContext): | ||
tctx = tuner | ||
else: | ||
raise TypeError(f"{tuner} is not an Tuner") | ||
|
||
return float( | ||
await tctx.optimize(mctx, features, actx, train, test) | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,6 +51,7 @@ def inpath(binary): | |
("operations", "nlp"), | ||
("service", "http"), | ||
("source", "mysql"), | ||
("tuner", "bayes_opt_gp"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets have a simpler, more understandable entrypoint |
||
] | ||
|
||
|
||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../common/README.rst |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../common/README.rst |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../common/README.rst |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../common/README.rst |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../common/README.rst |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,3 @@ | |
TunerContext, | ||
Tuner, | ||
) | ||
from .parameter_grid import ParameterGrid |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we want the train and test sets to be passed in as keyword arguments like this: