Skip to content

Commit

Permalink
fix: fixing search device when don't have access to a region. (#708)
Browse files Browse the repository at this point in the history
  • Loading branch information
krneta authored Sep 26, 2023
1 parent f015dd0 commit 1acf476
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 16 deletions.
42 changes: 26 additions & 16 deletions src/braket/aws/aws_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import os
import urllib.request
import warnings
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -602,23 +603,32 @@ def get_devices(
types_for_region = sorted(
types if region == session_region else types - {AwsDeviceType.SIMULATOR}
)
region_device_arns = [
result["deviceArn"]
for result in session_for_region.search_devices(
arns=arns,
names=names,
types=types_for_region,
statuses=statuses,
provider_names=provider_names,
try:
region_device_arns = [
result["deviceArn"]
for result in session_for_region.search_devices(
arns=arns,
names=names,
types=types_for_region,
statuses=statuses,
provider_names=provider_names,
)
]
device_map.update(
{
arn: AwsDevice(arn, session_for_region)
for arn in region_device_arns
if arn not in device_map
}
)
]
device_map.update(
{
arn: AwsDevice(arn, session_for_region)
for arn in region_device_arns
if arn not in device_map
}
)
except ClientError as e:
error_code = e.response["Error"]["Code"]
warnings.warn(
f"{error_code}: Unable to search region '{region}' for devices."
" Please check your settings or try again later."
f" Continuing without devices in '{region}'."
)

devices = list(device_map.values())
devices.sort(key=lambda x: getattr(x, order_by))
return devices
Expand Down
68 changes: 68 additions & 0 deletions test/unit_tests/braket/aws/test_aws_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,6 +1673,74 @@ def test_get_devices_simulators_only(mock_copy_session, aws_session):
assert [result.name for result in results] == ["SV1"]


@patch("braket.aws.aws_device.AwsSession.copy_session")
def test_get_devices_with_error_in_region(mock_copy_session, aws_session):
aws_session.search_devices.side_effect = [
# us-west-1
[
{
"deviceArn": SV1_ARN,
"deviceName": "SV1",
"deviceType": "SIMULATOR",
"deviceStatus": "ONLINE",
"providerName": "Amazon Braket",
}
],
ValueError("should not be reachable"),
]
aws_session.get_device.side_effect = [
MOCK_GATE_MODEL_SIMULATOR,
ValueError("should not be reachable"),
]
session_for_region = Mock()
session_for_region.search_devices.side_effect = [
# us-east-1
[
{
"deviceArn": IONQ_ARN,
"deviceName": "IonQ Device",
"deviceType": "QPU",
"deviceStatus": "ONLINE",
"providerName": "IonQ",
},
],
# us-west-2
ClientError(
{
"Error": {
"Code": "Test Code",
"Message": "Test Message",
}
},
"Test Operation",
),
# eu-west-2
[
{
"deviceArn": OQC_ARN,
"deviceName": "Lucy",
"deviceType": "QPU",
"deviceStatus": "ONLINE",
"providerName": "OQC",
}
],
# Only two regions to search outside of current
ValueError("should not be reachable"),
]
session_for_region.get_device.side_effect = [
MOCK_GATE_MODEL_QPU_2,
MOCK_GATE_MODEL_QPU_3,
ValueError("should not be reachable"),
]
mock_copy_session.return_value = session_for_region
# Search order: us-east-1, us-west-1, us-west-2, eu-west-2
results = AwsDevice.get_devices(
statuses=["ONLINE"],
aws_session=aws_session,
)
assert [result.name for result in results] == ["Blah", "Lucy", "SV1"]


@pytest.mark.xfail(raises=ValueError)
def test_get_devices_invalid_order_by():
AwsDevice.get_devices(order_by="foo")
Expand Down

0 comments on commit 1acf476

Please sign in to comment.