From fed677468a3c9bc73753a34ed6020ee7b8230232 Mon Sep 17 00:00:00 2001 From: Dana Benson Date: Wed, 8 Jan 2020 10:52:53 -0800 Subject: [PATCH] add list trial components method on trial object add trial and experiment name filtering parameters to list trial components --- src/smexperiments/trial.py | 33 ++++ src/smexperiments/trial_component.py | 25 ++- .../test_track_from_training_job.py | 1 - tests/integ/__init__.py | 0 tests/unit/__init__.py | 0 .../test_list_trial_components_from_trial.py | 181 ++++++++++++++++++ tests/unit/test_trial_component.py | 44 ++++- 7 files changed, 277 insertions(+), 7 deletions(-) create mode 100644 tests/integ/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_list_trial_components_from_trial.py diff --git a/src/smexperiments/trial.py b/src/smexperiments/trial.py index 85b7395..4594b0c 100644 --- a/src/smexperiments/trial.py +++ b/src/smexperiments/trial.py @@ -173,3 +173,36 @@ def remove_trial_component(self, tc): self.sagemaker_boto_client.disassociate_trial_component( TrialName=self.trial_name, TrialComponentName=trial_component_name ) + + def list_trial_components( + self, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + max_results=None, + next_token=None): + """List trial components in this trial matching the specified criteria. + + Args: + created_before (datetime.datetime, optional): Return trials created before this instant. + created_after (datetime.datetime, optional): Return trials created after this instant. + sort_by (str, optional): Which property to sort results by. One of 'Name', + 'CreationTime'. + sort_order (str, optional): One of 'Ascending', or 'Descending'. + max_results (int, optional): maximum number of trial components to retrieve + next_token (str, optional): token for next page of results + Returns: + collections.Iterator[smexperiments.api_types.TrialComponentSummary] : An iterator over + trials matching the criteria. + """ + return trial_component.TrialComponent.list( + trial_name=self.trial_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + max_results=max_results, + next_token=next_token, + sagemaker_boto_client=self.sagemaker_boto_client, + ) diff --git a/src/smexperiments/trial_component.py b/src/smexperiments/trial_component.py index ce6c39c..8230f62 100644 --- a/src/smexperiments/trial_component.py +++ b/src/smexperiments/trial_component.py @@ -110,8 +110,18 @@ def create(cls, trial_component_name, display_name=None, sagemaker_boto_client=N sagemaker_boto_client=sagemaker_boto_client) @classmethod - def list(cls, source_arn=None, created_before=None, created_after=None, - sort_by=None, sort_order=None, sagemaker_boto_client=None): + def list( + cls, + source_arn=None, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + sagemaker_boto_client=None, + trial_name=None, + experiment_name=None, + max_results=None, + next_token=None): """ Return a list of trial component summaries. @@ -124,6 +134,11 @@ def list(cls, source_arn=None, created_before=None, created_after=None, sort_order (str, optional): One of 'Ascending', or 'Descending'. sagemaker_boto_client (SageMaker.Client, optional) : Boto3 client for SageMaker. If not supplied, a default boto3 client will be created and used. + trial_name (str, optional): Name of a Trial + experiment_name (str, optional): Name of an Experiment + max_results (int, optional): maximum number of trial components to retrieve + next_token (str, optional): token for next page of results + Returns: collections.Iterator[smexperiments.api_types.TrialComponentSummary]: An iterator over ``TrialComponentSummary`` objects. @@ -137,4 +152,8 @@ def list(cls, source_arn=None, created_before=None, created_after=None, created_after=created_after, sort_by=sort_by, sort_order=sort_order, - sagemaker_boto_client=sagemaker_boto_client) + sagemaker_boto_client=sagemaker_boto_client, + trial_name=trial_name, + experiment_name=experiment_name, + max_results=max_results, + next_token=next_token) diff --git a/tests/integ-jobs/test_track_from_training_job.py b/tests/integ-jobs/test_track_from_training_job.py index 1b079c6..1f35d17 100644 --- a/tests/integ-jobs/test_track_from_training_job.py +++ b/tests/integ-jobs/test_track_from_training_job.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. import sys - import boto3 from tests.helpers import * diff --git a/tests/integ/__init__.py b/tests/integ/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_list_trial_components_from_trial.py b/tests/unit/test_list_trial_components_from_trial.py new file mode 100644 index 0000000..1acc51e --- /dev/null +++ b/tests/unit/test_list_trial_components_from_trial.py @@ -0,0 +1,181 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import pytest +import unittest.mock +import datetime + +from smexperiments import trial, api_types + + +@pytest.fixture +def sagemaker_boto_client(): + return unittest.mock.Mock() + + +@pytest.fixture +def datetime_obj(): + return datetime.datetime(2017, 6, 16, 15, 55, 0) + +def test_list_trial_components(sagemaker_boto_client, datetime_obj): + sagemaker_boto_client.list_trial_components.return_value = { + "TrialComponentSummaries": [ + { + "TrialComponentName": "trial-component-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialComponentName": "trial-component-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + api_types.TrialComponentSummary( + trial_component_name="trial-component-1", + creation_time=datetime_obj, + last_modified_time=datetime_obj, + ), + api_types.TrialComponentSummary( + trial_component_name="trial-component-2", + creation_time=datetime_obj, + last_modified_time=datetime_obj, + ), + ] + + trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client) + + assert expected == list(trial_obj.list_trial_components()) + + +def test_list_trial_components_empty(sagemaker_boto_client): + sagemaker_boto_client.list_trial_components.return_value = {"TrialComponentSummaries": []} + trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client) + assert list(trial_obj.list_trial_components()) == [] + + +def test_list_trial_components_single(sagemaker_boto_client, datetime_obj): + trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client) + sagemaker_boto_client.list_trial_components.return_value = { + "TrialComponentSummaries": [ + { + "TrialComponentName": "trial-component-foo", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj + } + ] + } + + assert list(trial_obj.list_trial_components()) == [ + api_types.TrialComponentSummary( + trial_component_name="trial-component-foo", + creation_time=datetime_obj, + last_modified_time=datetime_obj + ) + ] + + +def test_list_trial_components_two_values(sagemaker_boto_client, datetime_obj): + trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client) + sagemaker_boto_client.list_trial_components.return_value = { + "TrialComponentSummaries": [ + {"TrialComponentName": "trial-component-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}, + {"TrialComponentName": "trial-component-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}, + ] + } + + assert list(trial_obj.list_trial_components()) == [ + api_types.TrialComponentSummary( + trial_component_name="trial-component-foo-1", + creation_time=datetime_obj, + last_modified_time=datetime_obj + ), + api_types.TrialComponentSummary( + trial_component_name="trial-component-foo-2", + creation_time=datetime_obj, + last_modified_time=datetime_obj + ), + ] + + +def test_next_token(sagemaker_boto_client, datetime_obj): + trial_obj = trial.Trial(sagemaker_boto_client) + sagemaker_boto_client.list_trial_components.side_effect = [ + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "trial-component-foo-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialComponentName": "trial-component-foo-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ], + "NextToken": "foo", + }, + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "trial-component-foo-3", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + } + ] + }, + ] + + assert list(trial_obj.list_trial_components()) == [ + api_types.TrialComponentSummary( + trial_component_name="trial-component-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + api_types.TrialComponentSummary( + trial_component_name="trial-component-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + api_types.TrialComponentSummary( + trial_component_name="trial-component-foo-3", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + + sagemaker_boto_client.list_trial_components.assert_any_call(**{}) + sagemaker_boto_client.list_trial_components.assert_any_call(NextToken="foo") + + +def test_list_trial_components_call_args(sagemaker_boto_client): + created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) + created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) + trial_name = 'foo-trial' + next_token = 'thetoken' + max_results = 99 + + trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client) + trial_obj.trial_name=trial_name + + sagemaker_boto_client.list_trial_components.return_value = {} + assert [] == list( + trial_obj.list_trial_components( + created_after=created_after, + created_before=created_before, + next_token=next_token, + max_results=max_results) + ) + sagemaker_boto_client.list_trial_components.assert_called_with( + CreatedBefore=created_before, + CreatedAfter=created_after, + TrialName=trial_name, + NextToken=next_token, + MaxResults=max_results, + ) diff --git a/tests/unit/test_trial_component.py b/tests/unit/test_trial_component.py index d210047..7664ae5 100644 --- a/tests/unit/test_trial_component.py +++ b/tests/unit/test_trial_component.py @@ -165,9 +165,12 @@ def test_list(sagemaker_boto_client): last_modified_by={} ) for i in range(20) ] - result = list(trial_component.TrialComponent.list(sagemaker_boto_client=sagemaker_boto_client, - source_arn='foo', sort_by='CreationTime', - sort_order='Ascending')) + result = list(trial_component.TrialComponent.list( + sagemaker_boto_client=sagemaker_boto_client, + source_arn='foo', + sort_by='CreationTime', + sort_order='Ascending')) + assert expected == result expected_calls= [unittest.mock.call(SortBy='CreationTime', SortOrder='Ascending', SourceArn='foo'), unittest.mock.call(NextToken='100', SortBy='CreationTime', SortOrder='Ascending', SourceArn='foo')] @@ -181,6 +184,41 @@ def test_list_empty(sagemaker_boto_client): assert [] == list(trial_component.TrialComponent.list(sagemaker_boto_client=sagemaker_boto_client)) +def test_list_trial_components_call_args(sagemaker_boto_client): + created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) + created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) + trial_name = 'foo-trial' + experiment_name = 'foo-experiment' + next_token = 'thetoken' + max_results = 99 + + sagemaker_boto_client.list_trial_components.return_value = {} + assert [] == list( + trial_component.TrialComponent.list( + sagemaker_boto_client=sagemaker_boto_client, + trial_name=trial_name, + experiment_name=experiment_name, + created_before=created_before, + created_after=created_after, + next_token=next_token, + max_results=max_results, + sort_by='CreationTime', + sort_order='Ascending') + ) + + expected_calls = [unittest.mock.call( + TrialName='foo-trial', + ExperimentName='foo-experiment', + CreatedBefore=created_before, + CreatedAfter=created_after, + SortBy='CreationTime', + SortOrder='Ascending', + NextToken='thetoken', + MaxResults=99, + )] + assert expected_calls == sagemaker_boto_client.list_trial_components.mock_calls + + def test_save(sagemaker_boto_client): obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name='foo', display_name='bar') sagemaker_boto_client.update_trial_component.return_value = {}