Skip to content

Commit

Permalink
lowercase source arn before listing trial components
Browse files Browse the repository at this point in the history
  • Loading branch information
danabens committed Nov 23, 2022
1 parent 8553d23 commit 52800e4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion scripts/release.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def recent_changes_to_src(last_version):
stdout = check_output(["git", "log", "{}..HEAD".format(last_version), "--name-only", "--pretty=format: main"])
stdout = stdout.decode("utf-8")
lines = stdout.splitlines()
src_lines = list(filter(lambda l: l.startswith("src"), lines))
src_lines = list(filter(lambda line: line.startswith("src"), lines))
print(f"{len(src_lines)} src files changed since {last_version}")
return src_lines

Expand Down
2 changes: 1 addition & 1 deletion src/smexperiments/_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_trial_component(self, sagemaker_boto_client):
while time.time() - start < 300:
summaries = list(
trial_component.TrialComponent.list(
source_arn=self.source_arn, sagemaker_boto_client=sagemaker_boto_client
source_arn=self.source_arn.lower(), sagemaker_boto_client=sagemaker_boto_client
)
)
if summaries:
Expand Down
13 changes: 9 additions & 4 deletions tests/unit/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import tempfile
import unittest

# https://github.com/coala/coala-bears/issues/2862
from unittest.mock import patch

import pytest

from smexperiments import _environment
Expand All @@ -36,7 +39,7 @@ def sagemaker_boto_client():
@pytest.fixture
def training_job_env():
old_value = os.environ.get("TRAINING_JOB_ARN")
os.environ["TRAINING_JOB_ARN"] = "arn:1234"
os.environ["TRAINING_JOB_ARN"] = "arn:1234aBcDe"
yield os.environ
del os.environ["TRAINING_JOB_ARN"]
if old_value:
Expand All @@ -46,17 +49,17 @@ def training_job_env():
def test_processing_job_environment(tempdir):
config_path = os.path.join(tempdir, "config.json")
with open(config_path, "w") as f:
f.write(json.dumps({"ProcessingJobArn": "arn:1234"}))
f.write(json.dumps({"ProcessingJobArn": "arn:1234aBcDe"}))
environment = _environment.TrialComponentEnvironment.load(processing_job_config_path=config_path)

assert _environment.EnvironmentType.SageMakerProcessingJob == environment.environment_type
assert "arn:1234" == environment.source_arn
assert "arn:1234aBcDe" == environment.source_arn


def test_training_job_environment(training_job_env):
environment = _environment.TrialComponentEnvironment.load()
assert _environment.EnvironmentType.SageMakerTrainingJob == environment.environment_type
assert "arn:1234" == environment.source_arn
assert "arn:1234aBcDe" == environment.source_arn


def test_no_environment():
Expand All @@ -70,9 +73,11 @@ def test_resolve_trial_component(training_job_env, sagemaker_boto_client):
}
sagemaker_boto_client.describe_trial_component.return_value = {"TrialComponentName": trial_component_name}
environment = _environment.TrialComponentEnvironment.load()

tc = environment.get_trial_component(sagemaker_boto_client)

assert trial_component_name == tc.trial_component_name
sagemaker_boto_client.list_trial_components.assert_called_with(SourceArn="arn:1234abcde")
sagemaker_boto_client.describe_trial_component.assert_called_with(TrialComponentName=trial_component_name)


Expand Down

0 comments on commit 52800e4

Please sign in to comment.