Skip to content

Commit

Permalink
☄️ comet integration (#129)
Browse files Browse the repository at this point in the history
* add requirement decorator

* ☄️ comet integration

* fixes

* fixes

* reformat

* test comet logger

* api key from env variable

* fixes

* add comet to extras

* comet offline logging

* update

* test val_step

* callback_runner ordered_dict

* fixes

* merge refactor callback_runner

* remove docs requirement

* add step

* merge trainval steps

* fixes

* fixes

* update

* update

* update version

* clean Tracker

* load previous experiments

* format
  • Loading branch information
aniketmaurya authored Nov 12, 2021
1 parent d8e3019 commit 3874739
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 37 deletions.
2 changes: 1 addition & 1 deletion examples/src/models/hello_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
),
ModelCheckpoint(),
EmissionTrackerCallback(),
CometCallback(offline=True),
CometCallback(offline=False),
]

if __name__ == "__main__":
Expand Down
12 changes: 7 additions & 5 deletions gradsflow/callbacks/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from loguru import logger

from gradsflow.callbacks import Callback
from gradsflow.utility.imports import requires

Expand All @@ -35,13 +37,13 @@ def __init__(self, offline: bool = False, **kwargs):
from codecarbon import EmissionsTracker, OfflineEmissionsTracker

if offline:
self.tracker = OfflineEmissionsTracker(**kwargs)
self._emission_tracker = OfflineEmissionsTracker(**kwargs)
else:
self.tracker = EmissionsTracker(**kwargs)
self.tracker.start()
self._emission_tracker = EmissionsTracker(**kwargs)
self._emission_tracker.start()

super().__init__(model=None)

def on_fit_end(self):
emissions: float = self.tracker.stop()
print(f"Emissions: {emissions} kg")
emissions: float = self._emission_tracker.stop()
logger.info(f"Emissions: {emissions} kg")
75 changes: 56 additions & 19 deletions gradsflow/callbacks/logger/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,69 @@ class CometCallback(Callback):
def __init__(
self,
project_name: str = "awesome-project",
workspace: Optional[str] = None,
experiment_id: Optional[str] = None,
api_key: Optional[str] = os.environ.get("COMET_API_KEY"),
code_file: str = CURRENT_FILE,
offline: bool = False,
**kwargs
**kwargs,
):
super().__init__(
model=None,
)
self._code_file = code_file
self.experiment = self._create_experiment(project_name=project_name, api_key=api_key, offline=offline, **kwargs)
self._experiment_id = experiment_id
self.experiment = self._create_experiment(
project_name=project_name,
workspace=workspace,
api_key=api_key,
offline=offline,
experiment_id=experiment_id,
**kwargs,
)
self._train_prefix = "train"
self._val_prefix = "val"

@requires("comet_ml", "CometCallback requires comet_ml to be installed!")
def _create_experiment(
self, project_name: str, offline: bool = False, api_key: Optional[str] = None, **kwargs
self,
project_name: str,
workspace: str,
offline: bool = False,
api_key: Optional[str] = None,
experiment_id: Optional[str] = None,
**kwargs,
) -> BaseExperiment:
from comet_ml import Experiment, OfflineExperiment
from comet_ml import (
ExistingExperiment,
ExistingOfflineExperiment,
Experiment,
OfflineExperiment,
)

if offline:
experiment = OfflineExperiment(project_name=project_name, **kwargs)
if experiment_id:
experiment = ExistingOfflineExperiment(
project_name=project_name, workspace=workspace, previous_experiment=experiment_id, **kwargs
)
else:
experiment = OfflineExperiment(project_name=project_name, workspace=workspace, **kwargs)
else:
experiment = Experiment(project_name=project_name, api_key=api_key, **kwargs)
if experiment_id:
experiment = ExistingExperiment(
project_name=project_name,
workspace=workspace,
api_key=api_key,
previous_experiment=experiment_id,
**kwargs,
)
else:
experiment = Experiment(project_name=project_name, workspace=workspace, api_key=api_key, **kwargs)
return experiment

def on_fit_start(self):
self.experiment.set_model_graph(self.model.learner)
self.experiment.set_code(self._code_file)
self.experiment.log_code(self._code_file)

def on_train_epoch_start(
self,
Expand All @@ -71,31 +108,31 @@ def on_val_epoch_start(
):
self.experiment.validate()

def on_train_step_end(self, *args, **kwargs):
def _step(self, prefix: str, *args, **kwargs):
step = self.model.tracker.mode(prefix).steps
outputs = kwargs["outputs"]
loss = outputs["loss"].item()
self.experiment.log_metrics(outputs.get("metrics", {}))
self.experiment.log_metric("train_step_loss", loss)
self.experiment.log_metrics(outputs.get("metrics", {}), step=step, prefix=prefix)
self.experiment.log_metric(f"{prefix}_step_loss", loss, step=step)

def on_train_step_end(self, *args, **kwargs):
self._step(*args, **kwargs, prefix=self._train_prefix)

def on_val_step_end(self, *args, **kwargs):
outputs = kwargs["outputs"]
loss = outputs["loss"].item()
self.experiment.log_metrics(outputs.get("metrics", {}))
self.experiment.log_metric("val_step_loss", loss)
self._step(*args, **kwargs, prefix=self._val_prefix)

def on_epoch_end(self):
step = self.model.tracker.current_step
epoch = self.model.tracker.current_epoch
train_loss = self.model.tracker.train_loss
train_metrics = self.model.tracker.train_metrics
val_loss = self.model.tracker.val_loss
val_metrics = self.model.tracker.val_metrics

self.experiment.train()
self.experiment.log_metric("epoch_loss", train_loss, step=step, epoch=epoch)
self.experiment.log_metrics(train_metrics, step=step, epoch=epoch)
self.experiment.log_metric("train_epoch_loss", train_loss, epoch=epoch)
self.experiment.log_metrics(train_metrics, epoch=epoch, prefix=self._train_prefix)

self.experiment.validate()
self.experiment.log_metric("epoch_loss", val_loss, step=step, epoch=epoch)
self.experiment.log_metrics(val_metrics, step=step, epoch=epoch)
self.experiment.log_metric("val_epoch_loss", val_loss, epoch=epoch)
self.experiment.log_metrics(val_metrics, epoch=epoch, prefix=self._val_prefix)
self.experiment.log_epoch_end(epoch)
1 change: 0 additions & 1 deletion gradsflow/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def reset_metrics(self):
class BaseTracker:
max_epochs: int = 0
current_epoch: int = 0 # current train current_epoch
current_step: int = 0 # current current_step
steps_per_epoch: Optional[int] = None
train: TrackingValues = TrackingValues()
val: TrackingValues = TrackingValues()
13 changes: 4 additions & 9 deletions gradsflow/models/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Dict, List

from rich import box
Expand Down Expand Up @@ -39,13 +38,10 @@ def mode(self, mode) -> TrackingValues:

raise NotImplementedError(f"mode {mode} is not implemented!")

def track(self, key, value, render=False):
"""Tracks values for each step"""
if render:
warnings.warn("render is deprecated!")
def track(self, key, value):
"""Tracks value"""
epoch = self.current_epoch
step = self.current_step
data = {"current_epoch": epoch, "current_step": step, key: to_item(value)}
data = {"current_epoch": epoch, key: to_item(value)}
self.logs.append(data)

def track_loss(self, loss: float, mode: str):
Expand All @@ -56,7 +52,7 @@ def track_loss(self, loss: float, mode: str):
self.track(key, loss)

def track_metrics(self, metric: Dict[str, float], mode: str):
"""Update `TrackingValues` metrics. mode can be train or val and will update logs if render is True"""
"""Update `TrackingValues` metrics. mode can be train or val"""
value_tracker = self.mode(mode)
# Track values that averages with epoch
for key, value in metric.items():
Expand Down Expand Up @@ -103,7 +99,6 @@ def create_table(self) -> Table:
def reset(self):
self.max_epochs = 0
self.current_epoch = 0
self.current_step = 0
self.steps_per_epoch = None
self.train = TrackingValues()
self.val = TrackingValues()
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def test_mode():


def test_track():
tracker.track("val", 0.9, render=True)
tracker.track("score", 0.5, render=False)
tracker.track("val", 0.9)
tracker.track("score", 0.5)


def test_create_table():
Expand Down

0 comments on commit 3874739

Please sign in to comment.