Skip to content

Commit

Permalink
Merge pull request #11 from danabens/list-trial-components
Browse files Browse the repository at this point in the history
add list trial components method on trial object and trial/experiment name filter params
  • Loading branch information
danabens authored Jan 9, 2020
2 parents 1533617 + fed6774 commit d548937
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 7 deletions.
33 changes: 33 additions & 0 deletions src/smexperiments/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,36 @@ def remove_trial_component(self, tc):
self.sagemaker_boto_client.disassociate_trial_component(
TrialName=self.trial_name, TrialComponentName=trial_component_name
)

def list_trial_components(
self,
created_before=None,
created_after=None,
sort_by=None,
sort_order=None,
max_results=None,
next_token=None):
"""List trial components in this trial matching the specified criteria.
Args:
created_before (datetime.datetime, optional): Return trials created before this instant.
created_after (datetime.datetime, optional): Return trials created after this instant.
sort_by (str, optional): Which property to sort results by. One of 'Name',
'CreationTime'.
sort_order (str, optional): One of 'Ascending', or 'Descending'.
max_results (int, optional): maximum number of trial components to retrieve
next_token (str, optional): token for next page of results
Returns:
collections.Iterator[smexperiments.api_types.TrialComponentSummary] : An iterator over
trials matching the criteria.
"""
return trial_component.TrialComponent.list(
trial_name=self.trial_name,
created_before=created_before,
created_after=created_after,
sort_by=sort_by,
sort_order=sort_order,
max_results=max_results,
next_token=next_token,
sagemaker_boto_client=self.sagemaker_boto_client,
)
25 changes: 22 additions & 3 deletions src/smexperiments/trial_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,18 @@ def create(cls, trial_component_name, display_name=None, sagemaker_boto_client=N
sagemaker_boto_client=sagemaker_boto_client)

@classmethod
def list(cls, source_arn=None, created_before=None, created_after=None,
sort_by=None, sort_order=None, sagemaker_boto_client=None):
def list(
cls,
source_arn=None,
created_before=None,
created_after=None,
sort_by=None,
sort_order=None,
sagemaker_boto_client=None,
trial_name=None,
experiment_name=None,
max_results=None,
next_token=None):
"""
Return a list of trial component summaries.
Expand All @@ -124,6 +134,11 @@ def list(cls, source_arn=None, created_before=None, created_after=None,
sort_order (str, optional): One of 'Ascending', or 'Descending'.
sagemaker_boto_client (SageMaker.Client, optional) : Boto3 client for SageMaker.
If not supplied, a default boto3 client will be created and used.
trial_name (str, optional): Name of a Trial
experiment_name (str, optional): Name of an Experiment
max_results (int, optional): maximum number of trial components to retrieve
next_token (str, optional): token for next page of results
Returns:
collections.Iterator[smexperiments.api_types.TrialComponentSummary]: An iterator
over ``TrialComponentSummary`` objects.
Expand All @@ -137,4 +152,8 @@ def list(cls, source_arn=None, created_before=None, created_after=None,
created_after=created_after,
sort_by=sort_by,
sort_order=sort_order,
sagemaker_boto_client=sagemaker_boto_client)
sagemaker_boto_client=sagemaker_boto_client,
trial_name=trial_name,
experiment_name=experiment_name,
max_results=max_results,
next_token=next_token)
1 change: 0 additions & 1 deletion tests/integ-jobs/test_track_from_training_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# language governing permissions and limitations under the License.

import sys

import boto3

from tests.helpers import *
Expand Down
181 changes: 181 additions & 0 deletions tests/unit/test_list_trial_components_from_trial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import pytest
import unittest.mock
import datetime

from smexperiments import trial, api_types


@pytest.fixture
def sagemaker_boto_client():
return unittest.mock.Mock()


@pytest.fixture
def datetime_obj():
return datetime.datetime(2017, 6, 16, 15, 55, 0)

def test_list_trial_components(sagemaker_boto_client, datetime_obj):
sagemaker_boto_client.list_trial_components.return_value = {
"TrialComponentSummaries": [
{
"TrialComponentName": "trial-component-1",
"CreationTime": datetime_obj,
"LastModifiedTime": datetime_obj,
},
{
"TrialComponentName": "trial-component-2",
"CreationTime": datetime_obj,
"LastModifiedTime": datetime_obj,
},
]
}
expected = [
api_types.TrialComponentSummary(
trial_component_name="trial-component-1",
creation_time=datetime_obj,
last_modified_time=datetime_obj,
),
api_types.TrialComponentSummary(
trial_component_name="trial-component-2",
creation_time=datetime_obj,
last_modified_time=datetime_obj,
),
]

trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)

assert expected == list(trial_obj.list_trial_components())


def test_list_trial_components_empty(sagemaker_boto_client):
sagemaker_boto_client.list_trial_components.return_value = {"TrialComponentSummaries": []}
trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
assert list(trial_obj.list_trial_components()) == []


def test_list_trial_components_single(sagemaker_boto_client, datetime_obj):
trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
sagemaker_boto_client.list_trial_components.return_value = {
"TrialComponentSummaries": [
{
"TrialComponentName": "trial-component-foo",
"CreationTime": datetime_obj,
"LastModifiedTime": datetime_obj
}
]
}

assert list(trial_obj.list_trial_components()) == [
api_types.TrialComponentSummary(
trial_component_name="trial-component-foo",
creation_time=datetime_obj,
last_modified_time=datetime_obj
)
]


def test_list_trial_components_two_values(sagemaker_boto_client, datetime_obj):
trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
sagemaker_boto_client.list_trial_components.return_value = {
"TrialComponentSummaries": [
{"TrialComponentName": "trial-component-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
{"TrialComponentName": "trial-component-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
]
}

assert list(trial_obj.list_trial_components()) == [
api_types.TrialComponentSummary(
trial_component_name="trial-component-foo-1",
creation_time=datetime_obj,
last_modified_time=datetime_obj
),
api_types.TrialComponentSummary(
trial_component_name="trial-component-foo-2",
creation_time=datetime_obj,
last_modified_time=datetime_obj
),
]


def test_next_token(sagemaker_boto_client, datetime_obj):
trial_obj = trial.Trial(sagemaker_boto_client)
sagemaker_boto_client.list_trial_components.side_effect = [
{
"TrialComponentSummaries": [
{
"TrialComponentName": "trial-component-foo-1",
"CreationTime": datetime_obj,
"LastModifiedTime": datetime_obj,
},
{
"TrialComponentName": "trial-component-foo-2",
"CreationTime": datetime_obj,
"LastModifiedTime": datetime_obj,
},
],
"NextToken": "foo",
},
{
"TrialComponentSummaries": [
{
"TrialComponentName": "trial-component-foo-3",
"CreationTime": datetime_obj,
"LastModifiedTime": datetime_obj,
}
]
},
]

assert list(trial_obj.list_trial_components()) == [
api_types.TrialComponentSummary(
trial_component_name="trial-component-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj
),
api_types.TrialComponentSummary(
trial_component_name="trial-component-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj
),
api_types.TrialComponentSummary(
trial_component_name="trial-component-foo-3", creation_time=datetime_obj, last_modified_time=datetime_obj
),
]

sagemaker_boto_client.list_trial_components.assert_any_call(**{})
sagemaker_boto_client.list_trial_components.assert_any_call(NextToken="foo")


def test_list_trial_components_call_args(sagemaker_boto_client):
created_before = datetime.datetime(1999, 10, 12, 0, 0, 0)
created_after = datetime.datetime(1990, 10, 12, 0, 0, 0)
trial_name = 'foo-trial'
next_token = 'thetoken'
max_results = 99

trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
trial_obj.trial_name=trial_name

sagemaker_boto_client.list_trial_components.return_value = {}
assert [] == list(
trial_obj.list_trial_components(
created_after=created_after,
created_before=created_before,
next_token=next_token,
max_results=max_results)
)
sagemaker_boto_client.list_trial_components.assert_called_with(
CreatedBefore=created_before,
CreatedAfter=created_after,
TrialName=trial_name,
NextToken=next_token,
MaxResults=max_results,
)
44 changes: 41 additions & 3 deletions tests/unit/test_trial_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,12 @@ def test_list(sagemaker_boto_client):
last_modified_by={}
) for i in range(20)
]
result = list(trial_component.TrialComponent.list(sagemaker_boto_client=sagemaker_boto_client,
source_arn='foo', sort_by='CreationTime',
sort_order='Ascending'))
result = list(trial_component.TrialComponent.list(
sagemaker_boto_client=sagemaker_boto_client,
source_arn='foo',
sort_by='CreationTime',
sort_order='Ascending'))

assert expected == result
expected_calls= [unittest.mock.call(SortBy='CreationTime', SortOrder='Ascending', SourceArn='foo'),
unittest.mock.call(NextToken='100', SortBy='CreationTime', SortOrder='Ascending', SourceArn='foo')]
Expand All @@ -181,6 +184,41 @@ def test_list_empty(sagemaker_boto_client):
assert [] == list(trial_component.TrialComponent.list(sagemaker_boto_client=sagemaker_boto_client))


def test_list_trial_components_call_args(sagemaker_boto_client):
created_before = datetime.datetime(1999, 10, 12, 0, 0, 0)
created_after = datetime.datetime(1990, 10, 12, 0, 0, 0)
trial_name = 'foo-trial'
experiment_name = 'foo-experiment'
next_token = 'thetoken'
max_results = 99

sagemaker_boto_client.list_trial_components.return_value = {}
assert [] == list(
trial_component.TrialComponent.list(
sagemaker_boto_client=sagemaker_boto_client,
trial_name=trial_name,
experiment_name=experiment_name,
created_before=created_before,
created_after=created_after,
next_token=next_token,
max_results=max_results,
sort_by='CreationTime',
sort_order='Ascending')
)

expected_calls = [unittest.mock.call(
TrialName='foo-trial',
ExperimentName='foo-experiment',
CreatedBefore=created_before,
CreatedAfter=created_after,
SortBy='CreationTime',
SortOrder='Ascending',
NextToken='thetoken',
MaxResults=99,
)]
assert expected_calls == sagemaker_boto_client.list_trial_components.mock_calls


def test_save(sagemaker_boto_client):
obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name='foo', display_name='bar')
sagemaker_boto_client.update_trial_component.return_value = {}
Expand Down

0 comments on commit d548937

Please sign in to comment.