-
Notifications
You must be signed in to change notification settings - Fork 421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add analytics logging to MosaicMLLogger
#3106
base: main
Are you sure you want to change the base?
Changes from all commits
f2e5537
87f3d09
55e738f
da1a179
e0b559d
44cb283
681e166
362f9ba
21cce59
be30004
77d2c25
82d771a
b9aa219
250cfff
d841895
67fd0d8
4a9da31
8a5d9df
b8032c5
bca595b
52db068
f44ac72
b499ab4
45b1f8b
06cd615
53ec76c
a08bc51
192fb4b
0c6439e
6abd957
f10e8a9
cee3876
11eb853
51a0dd0
59dea8f
5037ffb
ea2eabc
8c867bc
5753020
a51d5d9
1083ae6
25e3ceb
bf02105
3be10ed
ba907b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,17 +14,21 @@ | |
import time | ||
import warnings | ||
from concurrent.futures import wait | ||
from dataclasses import dataclass | ||
from functools import reduce | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional | ||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union | ||
|
||
import mcli | ||
import torch | ||
import torch.utils.data | ||
|
||
from composer.core.time import TimeUnit | ||
from composer.core.event import Event | ||
from composer.core.time import Time, TimeUnit | ||
from composer.loggers import Logger | ||
from composer.loggers.logger_destination import LoggerDestination | ||
from composer.loggers.wandb_logger import WandBLogger | ||
from composer.utils import dist | ||
from composer.utils.file_helpers import parse_uri | ||
|
||
if TYPE_CHECKING: | ||
from composer.core import State | ||
|
@@ -46,6 +50,18 @@ class MosaicMLLogger(LoggerDestination): | |
Logs metrics to the MosaicML platform. Logging only happens on rank 0 every ``log_interval`` | ||
seconds to avoid performance issues. | ||
|
||
Additionally, The following metrics are logged upon ``INIT``: | ||
- ``composer/autoresume``: Whether or not the run can be stopped / resumed during training. | ||
- ``composer/precision``: The precision to use for training. | ||
- ``composer/eval_loaders``: A list containing the labels of each evaluation dataloader. | ||
- ``composer/optimizers``: A list of dictionaries containing information about each optimizer. | ||
- ``composer/algorithms``: A list containing the names of the algorithms used for training. | ||
- ``composer/loggers``: A list containing the loggers used in the ``Trainer``. | ||
- ``composer/cloud_provided_load_path``: The cloud provider for the load path. | ||
- ``composer/cloud_provided_save_folder``: The cloud provider for the save folder. | ||
- ``composer/save_interval``: The save interval for the run. | ||
- ``composer/fsdp_config``: The FSDP config used for training. | ||
|
||
When running on the MosaicML platform, the logger is automatically enabled by Trainer. To disable, | ||
the environment variable 'MOSAICML_PLATFORM' can be set to False. | ||
|
||
|
@@ -62,17 +78,20 @@ class MosaicMLLogger(LoggerDestination): | |
|
||
(default: ``None``) | ||
ignore_exceptions: Flag to disable logging exceptions. Defaults to False. | ||
analytics_data (Dict[str, Any], optional): A dictionary containing variables used to log analytics. Defaults to ``None``. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
log_interval: int = 60, | ||
ignore_keys: Optional[List[str]] = None, | ||
ignore_exceptions: bool = False, | ||
analytics_data: Optional[MosaicAnalyticsData] = None, | ||
) -> None: | ||
self.log_interval = log_interval | ||
self.ignore_keys = ignore_keys | ||
self.ignore_exceptions = ignore_exceptions | ||
self.analytics_data = analytics_data | ||
self._enabled = dist.get_global_rank() == 0 | ||
if self._enabled: | ||
self.time_last_logged = 0 | ||
|
@@ -96,10 +115,57 @@ def log_hyperparameters(self, hyperparameters: Dict[str, Any]): | |
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: | ||
self._log_metadata(metrics) | ||
|
||
def log_analytics(self, state: State, loggers: Tuple[LoggerDestination, ...]) -> None: | ||
if self.analytics_data is None: | ||
return | ||
|
||
metrics: Dict[str, Any] = { | ||
'composer/autoresume': self.analytics_data.autoresume, | ||
'composer/precision': state.precision, | ||
} | ||
metrics['composer/eval_loaders'] = [evaluator.label for evaluator in state.evaluators] | ||
metrics['composer/optimizers'] = [{ | ||
optimizer.__class__.__name__: optimizer.defaults, | ||
} for optimizer in state.optimizers] | ||
metrics['composer/algorithms'] = [algorithm.__class__.__name__ for algorithm in state.algorithms] | ||
metrics['composer/loggers'] = [logger.__class__.__name__ for logger in loggers] | ||
|
||
# Take the service provider out of the URI and log it to metadata. If no service provider | ||
# is found (i.e. backend = ''), then we assume 'local' for the cloud provider. | ||
if self.analytics_data.load_path is not None: | ||
backend, _, _ = parse_uri(self.analytics_data.load_path) | ||
metrics['composer/cloud_provided_load_path'] = backend if backend else 'local' | ||
if self.analytics_data.save_folder is not None: | ||
backend, _, _ = parse_uri(self.analytics_data.save_folder) | ||
metrics['composer/cloud_provided_save_folder'] = backend if backend else 'local' | ||
|
||
# Save interval can be passed in w/ multiple types. If the type is a function, then | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there some idea for utility of analytics on save interval? Nothing is coming to mind for me There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was requested in a Slack thread a while back. this included some less-helpful metrics (i.e. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah i think just drop this one |
||
# we log 'callable' as the save_interval value for analytics. | ||
if isinstance(self.analytics_data.save_interval, Union[str, int]): | ||
save_interval_str = str(self.analytics_data.save_interval) | ||
elif isinstance(self.analytics_data.save_interval, Time): | ||
save_interval_str = f'{self.analytics_data.save_interval._value}{self.analytics_data.save_interval._unit}' | ||
else: | ||
save_interval_str = 'callable' | ||
metrics['composer/save_interval'] = save_interval_str | ||
|
||
if state.fsdp_config: | ||
# Keys need to be sorted so they can be parsed consistently in SQL queries | ||
metrics['composer/fsdp_config'] = json.dumps(state.fsdp_config, sort_keys=True) | ||
|
||
self.log_metrics(metrics) | ||
self._flush_metadata(force_flush=True) | ||
|
||
def log_exception(self, exception: Exception): | ||
self._log_metadata({'exception': exception_to_json_serializable_dict(exception)}) | ||
self._flush_metadata(force_flush=True) | ||
|
||
def init(self, state: State, logger: Logger) -> None: | ||
try: | ||
self.log_analytics(state, logger.destinations) | ||
except: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we log the exception here? We should also probably only log once. if its failing consistently, we don't want to spam the run with warnings |
||
warnings.warn('Failed to log analytics data to MosaicML. Continuing without logging analytics data.') | ||
|
||
def after_load(self, state: State, logger: Logger) -> None: | ||
# Log model data downloaded and initialized for run events | ||
log.debug(f'Logging model initialized time to metadata') | ||
|
@@ -229,6 +295,14 @@ def _get_training_progress_metrics(self, state: State) -> Dict[str, Any]: | |
return training_progress_metrics | ||
|
||
|
||
@dataclass(frozen=True) | ||
class MosaicAnalyticsData: | ||
autoresume: bool | ||
save_interval: Union[str, int, Time, Callable[[State, Event], bool]] | ||
load_path: Union[str, None] | ||
save_folder: Union[str, None] | ||
|
||
|
||
def format_data_to_json_serializable(data: Any): | ||
"""Recursively formats data to be JSON serializable. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a callback instead? It seems like all loggers and experiment tracking tools would benefit from having these extra information for navigation or reproducibility purposes.