diff --git a/src/smexperiments/tracker.py b/src/smexperiments/tracker.py index 493a129..8e78b39 100644 --- a/src/smexperiments/tracker.py +++ b/src/smexperiments/tracker.py @@ -19,6 +19,8 @@ import logging import botocore import json +from math import isnan, isinf +from numbers import Number from smexperiments._utils import get_module from os.path import join @@ -231,7 +233,8 @@ def log_parameter(self, name, value): name (str): The name of the parameter value (str or numbers.Number): The value of the parameter """ - self.trial_component.parameters[name] = value + if self._is_input_valid("parameter", name, value): + self.trial_component.parameters[name] = value def log_parameters(self, parameters): """Record a collection of parameter values for this trial component. @@ -245,7 +248,10 @@ def log_parameters(self, parameters): Args: parameters (dict[str, str or numbers.Number]): The parameters to record. """ - self.trial_component.parameters.update(parameters) + filtered_parameters = { + key: value for (key, value) in parameters.items() if self._is_input_valid("parameter", key, value) + } + self.trial_component.parameters.update(filtered_parameters) def log_input(self, name, value, media_type=None): """Record a single input artifact for this trial component. @@ -402,7 +408,8 @@ def log_metric(self, metric_name, value, timestamp=None, iteration_number=None): AttributeError: If the metrics writer is not initialized. """ try: - self._metrics_writer.log_metric(metric_name, value, timestamp, iteration_number) + if self._is_input_valid("metric", metric_name, value): + self._metrics_writer.log_metric(metric_name, value, timestamp, iteration_number) except AttributeError: if not self._metrics_writer: if not self._warned_on_metrics: @@ -654,6 +661,12 @@ def _log_graph_artifact(self, name, data, graph_type, output_artifact): else: self._lineage_artifact_tracker.add_input_artifact(artifact_name, s3_uri, etag, graph_type) + def _is_input_valid(self, input_type, field_name, field_value): + if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)): + logging.warning(f"Failed to log {input_type} {field_name}. Received invalid value: {field_value}.") + return False + return True + def __enter__(self): """Updates the start time of the tracked trial component. diff --git a/tests/unit/test_tracker.py b/tests/unit/test_tracker.py index 8fe30ec..94ffbb3 100644 --- a/tests/unit/test_tracker.py +++ b/tests/unit/test_tracker.py @@ -16,6 +16,8 @@ import tempfile import os import datetime +from math import nan, inf +import numpy as np from smexperiments import api_types, tracker, trial_component, _utils, _environment import pandas as pd @@ -171,6 +173,11 @@ def test_log_parameter(under_test): assert under_test.trial_component.parameters["whizz"] == 1 +def test_log_parameter_skip_invalid_value(under_test): + under_test.log_parameter("key", nan) + assert "key" not in under_test.trial_component.parameters + + def test_enter(under_test): under_test.__enter__() assert isinstance(under_test.trial_component.start_time, datetime.datetime) @@ -213,6 +220,11 @@ def test_log_parameters(under_test): assert under_test.trial_component.parameters == {"a": "b", "c": "d", "e": 5} +def test_log_parameters_skip_invalid_values(under_test): + under_test.log_parameters({"a": "b", "c": "d", "e": 5, "f": nan}) + assert under_test.trial_component.parameters == {"a": "b", "c": "d", "e": 5} + + def test_log_input(under_test): under_test.log_input("foo", "baz", "text/text") assert under_test.trial_component.input_artifacts == { @@ -233,6 +245,11 @@ def test_log_metric(under_test): under_test._metrics_writer.log_metric.assert_called_with("foo", 1.0, 1, now) +def test_log_metric_skip_invalid_value(under_test): + under_test.log_metric(None, nan, None, None) + assert not under_test._metrics_writer.log_metric.called + + def test_log_metric_attribute_error(under_test): now = datetime.datetime.now() @@ -630,3 +647,19 @@ def test_log_roc_curve(under_test): ) under_test._lineage_artifact_tracker.add_input_artifact("TestROCCurve", "s3uri_value", "etag_value", "ROCCurve") + + +@pytest.mark.parametrize( + "metric_value", + [1.3, "nan", "inf", "-inf", None], +) +def test_is_input_valid(under_test, metric_value): + assert under_test._is_input_valid("metric", "Name", metric_value) + + +@pytest.mark.parametrize( + "metric_value", + [nan, inf, -inf], +) +def test__is_input_valid_false(under_test, metric_value): + assert not under_test._is_input_valid("parameter", "Name", metric_value)