Skip to content

Commit

Permalink
fetch best trial model (#21)
Browse files Browse the repository at this point in the history
* add best trial

* update

* add license

* remove logging.py

* update

* update changelog action
  • Loading branch information
aniketmaurya authored Aug 26, 2021
1 parent 9d72f72 commit f4e86de
Show file tree
Hide file tree
Showing 22 changed files with 319 additions and 20 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/latest-changes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,5 @@ jobs:
with:
token: ${{ secrets.GITHUB_TOKEN }}
latest_changes_file: docs/CHANGELOG.md
template_file: ./.github/workflows/release-notes.jinja2
latest_changes_header: '## 0.0.1\n'
latest_changes_header: '## 0.0.2\n\n'
debug_logs: true
1 change: 0 additions & 1 deletion .github/workflows/release-notes.jinja2

This file was deleted.

4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
build-docs:
cp README.md docs/index.md

docs-serve:
mkdocs serve
docsserve:
mkdocs serve --dirtyreload

test:
python tests/__init__.py
Expand Down
14 changes: 8 additions & 6 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Release Notes

## 0.0.1
* This changed: 📝 update example and documentation. Done by [ aniketmaurya](https://github.com/aniketmaurya). Check the [Pull Request 20 with the changes and stuff](https://github.com/gradsflow/gradsflow/pull/20). now back to code. 🤓\n
* This changed: :tada::sparkles: First Release - v0.0.1 - Refactor API & tested Python 3.7+. Done by [ aniketmaurya](https://github.com/aniketmaurya). Check the [Pull Request 18 with the changes and stuff](https://github.com/gradsflow/gradsflow/pull/18). now back to code. 🤓
* This changed: Adding example notebook for AutoSummarization. Done by [the GitHub user gagan3012](https://github.com/gagan3012). Check the [Pull Request 19 with the changes and stuff](https://github.com/gradsflow/gradsflow/pull/19). now back to code. 🤓
* This changed: Adding text summarisation. Done by [the GitHub user gagan3012](https://github.com/gagan3012). Check the [Pull Request 14 with the changes and stuff](https://github.com/gradsflow/gradsflow/pull/14). now back to code. 🤓
* This changed: add codecov CI. Done by [the GitHub user aniketmaurya](https://github.com/aniketmaurya). Check the [Pull Request 15 with the changes and stuff](https://github.com/gradsflow/gradsflow/pull/15). now back to code. 🤓
* This changed: 📚 update documentation - added citation, acknowledgments, docstrings automation. Done by [the GitHub user aniketmaurya](https://github.com/aniketmaurya). Check the [Pull Request 13 with the changes and stuff](https://github.com/gradsflow/gradsflow/pull/13). now back to code. 🤓
* 📝 update example and documentation. Done by [ aniketmaurya](https://github.com/aniketmaurya). Check the [Pull Request 20 with the changes and stuff](https://github.com/gradsflow/gradsflow/pull/20).
* :tada::sparkles: First Release - v0.0.1 - Refactor API & tested Python 3.7+. Done by [ aniketmaurya](https://github.com/aniketmaurya). Check the [Pull Request 18 with the changes and stuff](https://github.com/gradsflow/gradsflow/pull/18).
* Adding example notebook for AutoSummarization. Done by [the GitHub user gagan3012](https://github.com/gagan3012). Check the [Pull Request 19 with the changes and stuff](https://github.com/gradsflow/gradsflow/pull/19).
* Adding text summarisation. Done by [the GitHub user gagan3012](https://github.com/gagan3012). Check the [Pull Request 14 with the changes and stuff](https://github.com/gradsflow/gradsflow/pull/14).
* add codecov CI. Done by [the GitHub user aniketmaurya](https://github.com/aniketmaurya). Check the [Pull Request 15 with the changes and stuff](https://github.com/gradsflow/gradsflow/pull/15).
* 📚 update documentation - added citation, acknowledgments, docstrings automation. Done by [the GitHub user aniketmaurya](https://github.com/aniketmaurya). Check the [Pull Request 13 with the changes and stuff](https://github.com/gradsflow/gradsflow/pull/13).
* Refactor API Design, CI & Docs PR [#10](https://github.com/gradsflow/gradsflow/pull/10) by [@aniketmaurya](https://github.com/aniketmaurya).
* auto docstring. PR [#7](https://github.com/gradsflow/gradsflow/pull/7) by [@aniketmaurya](https://github.com/aniketmaurya).
* Add AutoImageClassifier. PR [#1](https://github.com/gradsflow/gradsflow/pull/1) by [@aniketmaurya](https://github.com/aniketmaurya).

## 0.0.2
16 changes: 15 additions & 1 deletion gradsflow/core/autoclassifier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import Dict, List, Optional, Union

Expand All @@ -12,7 +26,7 @@


class AutoClassifier(AutoModel):
"""Base Class for Auto Classification Hyperparameter search"""
"""Implements `AutoModel` for classification tasks."""

DEFAULT_BACKBONES = []

Expand Down
55 changes: 52 additions & 3 deletions gradsflow/core/automodel.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,54 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import Dict, Optional, Union

import optuna
import pytorch_lightning as pl
import torch
from flash import DataModule
from loguru import logger
from optuna.integration import PyTorchLightningPruningCallback

from gradsflow.logging import logger
from gradsflow.utility.common import module_to_cls_index
from gradsflow.utility.optuna import is_best_trial


class AutoModel:
"""
Creates Optuna instance, defines methods required for hparam search
Args:
datamodule [flash.DataModule]: DataModule from Flash or PyTorch Lightning
max_epochs [int]: Maximum number of epochs for which model will train
optimization_metric [str]: Value on which hyperparameter search will run.
By default, it is `val_accuracy`.
n_trials [int]: Number of trials for HPO
suggested_conf [Dict]: Any extra suggested configuration
timeout [int]: HPO will stop after timeout
prune [bool]: Whether to stop unpromising training.
optuna_confs [Dict]: Optuna configs
best_trial [bool]: If true model will be loaded with best weights from HPO otherwise
a best trial model without trained weights will be created.
"""

OPTIMIZER_INDEX = module_to_cls_index(torch.optim, True)
DEFAULT_OPTIMIZERS = ["adam", "sgd"]
DEFAULT_LR = (1e-5, 1e-1)
_BEST_MODEL = "best_model"
_CURRENT_MODEL = "current_model"

def __init__(
self,
Expand All @@ -30,13 +60,15 @@ def __init__(
timeout: int = 600,
prune: bool = True,
optuna_confs: Optional[Dict] = None,
best_trial: bool = True,
):

self._pruner: optuna.pruners.BasePruner = (
optuna.pruners.MedianPruner() if prune else optuna.pruners.NopPruner()
)
self.datamodule = datamodule
self.n_trials = n_trials
self.best_trial = best_trial
self.model: Union[torch.nn.Module, pl.LightningModule, None] = None
self.max_epochs = max_epochs
self.timeout = timeout
Expand Down Expand Up @@ -96,18 +128,35 @@ def _objective(
)
trial_confs = self._get_trial_hparams(trial)
model = self.build_model(**trial_confs)
trial.set_user_attr(key="current_model", value=model)
hparams = dict(model=model.hparams)
trainer.logger.log_hyperparams(hparams)
trainer.fit(model, datamodule=self.datamodule)

logger.debug(trainer.callback_metrics)
return trainer.callback_metrics[self.optimization_metric].item()

def callback_best_trial(self, study: optuna.Study, trial: optuna.Trial) -> None:
if is_best_trial(study, trial):
study.set_user_attr(
key=self._BEST_MODEL, value=trial.user_attrs[self._CURRENT_MODEL]
)

def hp_tune(self):
"""
Search Hyperparameter and builds model with the best params
"""
callbacks = []
if self.best_trial:
callbacks.append(self.callback_best_trial)
self._study.optimize(
self._objective, n_trials=self.n_trials, timeout=self.timeout
self._objective,
n_trials=self.n_trials,
timeout=self.timeout,
callbacks=callbacks,
)
self.model = self.build_model(**self._study.best_params)

if self.best_trial:
self.model = self._study.user_attrs[self._BEST_MODEL]
else:
self.model = self.build_model(**self._study.best_params)
1 change: 0 additions & 1 deletion gradsflow/logging.py

This file was deleted.

33 changes: 33 additions & 0 deletions gradsflow/utility/optuna.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import optuna


def is_best_trial(study: optuna.Study, trial: optuna.Trial) -> bool:
if study.best_trial.number == trial.number:
return True
return False
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ classifiers = [
]
requires-python = ">=3.7"
requires = [
"optuna==2.9.1",
"smart_open==5.1.0",
"optuna==2.9",
"smart_open==5.1",
"lightning-flash[all]==0.4.0",
"pytorch-lightning==1.4.0",
"loguru==0.5"
"loguru~=0.5"
]

[tool.flit.metadata.urls]
Expand Down
13 changes: 13 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 changes: 14 additions & 0 deletions tests/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from flash.core.data.utils import download_data

download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
13 changes: 13 additions & 0 deletions tests/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/core/test_autoclassifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 changes: 14 additions & 0 deletions tests/core/test_automodel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from flash.image import ImageClassificationData

Expand Down
13 changes: 13 additions & 0 deletions tests/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 changes: 14 additions & 0 deletions tests/tasks/test_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from flash.image import ImageClassificationData, ImageClassifier
Expand Down
16 changes: 15 additions & 1 deletion tests/tasks/test_summarization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
from gradsflow.autotasks import AutoSummarization
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import MagicMock

from gradsflow.autotasks import AutoSummarization


def test_build_model():
datamodule = MagicMock()
Expand Down
14 changes: 14 additions & 0 deletions tests/tasks/test_text.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import MagicMock

from gradsflow.autotasks import AutoTextClassifier
Expand Down
13 changes: 13 additions & 0 deletions tests/utility/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit f4e86de

Please sign in to comment.