Skip to content

Commit

Permalink
test: update job test to use decorator (amazon-braket#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajberdy authored Dec 6, 2023
1 parent 168096c commit f3e53b2
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 77 deletions.
25 changes: 11 additions & 14 deletions test/braket_tests/base/test_jobs_qaoa.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,21 @@
# language governing permissions and limitations under the License.

import time

from ..common.braket_jobs_util import job_test


def test_qaoa_circuit(account, role, s3_bucket, image_list):
assert len(image_list) > 0, "Unable to find images for testing"
create_job_args = {
"source_module": "./test/resources/",
"entry_point": "resources.qaoa_entry_point",
"hyperparameters": {
"p": "2",
"seed": "1967",
"max_parallel": "10",
"num_iterations": "5",
"stepsize": "0.1",
"shots": "100",
"interface": "autograd",
"start_time": time.time(),
}
job_args = {
"p": 2,
"seed": 1967,
"max_parallel": 10,
"num_iterations": 5,
"stepsize": 0.1,
"shots": 100,
"pl_interface": "autograd",
"start_time": time.time(),
}
for image_path in image_list:
job_test(account, role, s3_bucket, image_path, "base-qaoa", **create_job_args)
job_test(account, role, s3_bucket, image_path, "base-qaoa", job_args)
25 changes: 16 additions & 9 deletions test/braket_tests/common/braket_jobs_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,43 @@
# language governing permissions and limitations under the License.

import io
import os
import sys
import time
from contextlib import redirect_stdout

import boto3
from braket.aws import AwsSession
from braket.jobs import hybrid_job
from braket.devices import Devices

from braket.aws import AwsQuantumJob, AwsSession
from ...resources.qaoa_entry_point import entry_point


def job_test(account, role, s3_bucket, image_path, job_type, **kwargs):
def job_test(account, role, s3_bucket, image_path, job_type, job_args):
job_output = io.StringIO()
with redirect_stdout(job_output):
try:
create_job(account, role, s3_bucket, image_path, job_type, **kwargs)
create_job(account, role, s3_bucket, image_path, job_type, job_args)
except Exception as e:
print(e)
output = job_output.getvalue()
print(output)
assert output.find("Braket Container Run Success") > 0, "Container did not run successfully"


def create_job(account, role, s3_bucket, image_path, job_type, **kwargs):
def create_job(account, role, s3_bucket, image_path, job_type, job_args):
aws_session = AwsSession(default_bucket=s3_bucket)
job_name = f"ContainerTest-{job_type}-{int(time.time())}"
AwsQuantumJob.create(

@hybrid_job(
aws_session=aws_session,
job_name=job_name,
device="arn:aws:braket:::device/quantum-simulator/amazon/sv1",
device=Devices.Amazon.SV1,
role_arn=f"arn:aws:iam::{account}:role/{role}",
image_uri=image_path,
wait_until_complete=True,
**kwargs
include_modules="test.resources",
)
def decorator_job(*args, **kwargs):
return entry_point(*args, **kwargs)

decorator_job(**job_args)
24 changes: 10 additions & 14 deletions test/braket_tests/pytorch/test_jobs_qaoa.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,15 @@

def test_qaoa_circuit(account, role, s3_bucket, image_list):
assert len(image_list) > 0, "Unable to find images for testing"
create_job_args = {
"source_module": "./test/resources/",
"entry_point": "resources.qaoa_entry_point",
"hyperparameters": {
"p": "2",
"seed": "1967",
"max_parallel": "10",
"num_iterations": "5",
"stepsize": "0.1",
"shots": "100",
"interface": "torch",
"start_time": time.time(),
}
job_args = {
"p": 2,
"seed": 1967,
"max_parallel": 10,
"num_iterations": 5,
"stepsize": 0.1,
"shots": 100,
"pl_interface": "torch",
"start_time": time.time(),
}
for image_path in image_list:
job_test(account, role, s3_bucket, image_path, "pytorch-qaoa", **create_job_args)
job_test(account, role, s3_bucket, image_path, "pytorch-qaoa", job_args)
24 changes: 10 additions & 14 deletions test/braket_tests/tensorflow/test_jobs_qaoa.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,15 @@

def test_qaoa_circuit(account, role, s3_bucket, image_list):
assert len(image_list) > 0, "Unable to find images for testing"
create_job_args = {
"source_module": "./test/resources/",
"entry_point": "resources.qaoa_entry_point",
"hyperparameters": {
"p": "2",
"seed": "1967",
"max_parallel": "10",
"num_iterations": "5",
"stepsize": "0.1",
"shots": "100",
"interface": "tf",
"start_time": time.time(),
}
job_args = {
"p": 2,
"seed": 1967,
"max_parallel": 10,
"num_iterations": 5,
"stepsize": 0.1,
"shots": 100,
"pl_interface": "torch",
"start_time": time.time(),
}
for image_path in image_list:
job_test(account, role, s3_bucket, image_path, "tf-qaoa", **create_job_args)
job_test(account, role, s3_bucket, image_path, "tf-qaoa", job_args)
1 change: 1 addition & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ boto3==1.28.53
certifi>=2023.7.22
invoke==2.2.0
mock
pennylane
pytest==7.2.2
pytest-xdist
sagemaker
40 changes: 14 additions & 26 deletions test/resources/qaoa_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import json
import os
import time
import boto3

import networkx as nx
from pennylane import numpy as np
import pennylane as qml

from braket.jobs import save_job_checkpoint, save_job_result
from braket.jobs import get_job_device_arn, save_job_checkpoint, save_job_result
from braket.jobs.metrics import log_metric

from . import qaoa_utils
Expand All @@ -34,7 +32,7 @@ def record_test_metrics(metric, start_time, interface):
'Dimensions': [
{
'Name': 'TYPE',
'Value': 'braket_tests'
'Value': 'braket_container_tests'
},
{
'Name': 'INTERFACE',
Expand All @@ -44,7 +42,7 @@ def record_test_metrics(metric, start_time, interface):
'Unit': 'Seconds',
'Value': time.time() - start_time
}],
Namespace='braket-container-metrics'
Namespace='/aws/braket'
)


Expand All @@ -60,22 +58,16 @@ def init_pl_device(device_arn, num_nodes, shots, max_parallel):
)


def start_function():
# Read the hyperparameters
hp_file = os.environ["AMZN_BRAKET_HP_FILE"]
with open(hp_file, "r") as f:
hyperparams = json.load(f)
print(hyperparams)

p = int(hyperparams["p"])
seed = int(hyperparams["seed"])
max_parallel = int(hyperparams["max_parallel"])
num_iterations = int(hyperparams["num_iterations"])
stepsize = float(hyperparams["stepsize"])
shots = int(hyperparams["shots"])
pl_interface = hyperparams["interface"]
start_time = float(hyperparams["start_time"])

def entry_point(
p: int,
seed: int,
max_parallel: int,
num_iterations: int,
stepsize: float,
shots: int,
pl_interface: str,
start_time: float,
):
record_test_metrics('Startup', start_time, pl_interface)

interface = qaoa_utils.QAOAInterface.get_interface(pl_interface)
Expand All @@ -95,7 +87,7 @@ def circuit(params, **kwargs):
qml.Hadamard(wires=i)
qml.layer(qaoa_layer, p, params[0], params[1])

device_arn = os.environ["AMZN_BRAKET_DEVICE_ARN"]
device_arn = get_job_device_arn()
dev = init_pl_device(device_arn, num_nodes, shots, max_parallel)

np.random.seed(seed)
Expand Down Expand Up @@ -158,7 +150,3 @@ def cost_function(params):

record_test_metrics('Total', start_time, pl_interface)
print("Braket Container Run Success")


if __name__ == "__main__":
start_function()

0 comments on commit f3e53b2

Please sign in to comment.