diff --git a/README.md b/README.md index 02a8350..62bb8b1 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,8 @@ result = task.result() Read more about AutoQASM decorators like `@aq.main` [here](doc/decorators.md). +Read more about using AutoQASM with Amazon Braket Hybrid Jobs [here](doc/hybrid_jobs.md). + For more example usage of AutoQASM, visit the [example notebooks](examples). ## Architecture diff --git a/doc/hybrid_jobs.md b/doc/hybrid_jobs.md new file mode 100644 index 0000000..bc82bd2 --- /dev/null +++ b/doc/hybrid_jobs.md @@ -0,0 +1,80 @@ +# AutoQASM with Amazon Braket Hybrid Jobs + +Amazon Braket Hybrid Jobs provides a solution for executing hybrid quantum-classical algorithms that utilize both classical computing resources and Quantum Processing Units (QPUs). This service efficiently manages allocating classical compute resources, executing your algorithm, and then freeing up those resources upon completion, ensuring cost-effectiveness by charging only for the resources used. It's perfectly suited for iterative algorithms that span lengthy durations and require the integration of classical and quantum computing. + +## Using `AwsQuantumJob.create` + +This [documentation page](https://docs.aws.amazon.com/braket/latest/developerguide/braket-jobs-first.html#braket-jobs-first-create) shows you how to create a hybrid job with `AwsQuantumJob.create`. To use a hybrid job with AutoQASM, simply use AutoQASM in your algorithm script. Because AutoQASM is currently not installed in the default job container, be sure to include the AutoQASM feature branch in the requirements.txt of your source module, or add AutoQASM as a dependency when you build your own container. Below is an example algorithm script to get you started. +``` +import os + +from braket.devices import LocalSimulator +from braket.circuits import Circuit + +import autoqasm as aq +from autoqasm.instructions import measure, h, cnot + +def start_here(): + print("Test job started!") + + # Use the device declared in the job script + device = LocalSimulator("autoqasm") + + @aq.main + def bell(): + h(0) + cnot(0, 1) + c = measure([0, 1]) + + for count in range(5): + task = device.run(bell, shots=100) + print(task.result().measurements) + + print("Test job completed!") +``` + +Save this algorithm script as "algorithm_script.py" and run this code snippet below to create your first hybrid job with AutoQASM! +``` +from braket.aws import AwsQuantumJob + +job = AwsQuantumJob.create( + device="local:braket_simulator", + dependencies=["autoqasm"], + source_module="algorithm_script.py", + entry_point="algorithm_script:start_here", + wait_until_complete=True +) +``` + +## Using the `@aq.hybrid_job` decorator + +Alternatively, you can use the `@aq.hybrid_job` decorator to create a hybrid job with AutoQASM. Because AutoQASM is currently not installed in the default job container, be sure to include AutoQASM in the `dependencies` keyword of the `@aq.hybrid_job` decorator, or add AutoQASM as a dependency when you build your own container. + +One of the core mechanisms of AutoQASM is source code analysis. When calling an AutoQASM decorated function, the source code of the function is analyzed and converted into a transformed Python function by AutoGraph. The source code of a function defined inside the `@aq.hybrid_job` decorated function is then separately saved as input data to the job. When [AutoQASM decorators](decorators.md) wrap these functions, the source code is retrieved from the input data. Because of this, if you use an AutoQASM decorator to convert a function that is defined outside of the `@aq.hybrid_job` decorated function, it may not work properly. If your application requires AutoQASM decorated functions to be defined outside of the `@aq.hybrid_job` decorated function, we recommend that you use the option described in "Using `AwsQuantumJob.create`" to create the hybrid job. + +Below is a working example to create an AutoQASM job with the `@aq.hybrid_job` decorator. +``` +from braket.jobs import hybrid_job +from braket.devices import LocalSimulator + +import autoqasm as aq +from autoqasm.instructions import measure, h, cnot + +@aq.hybrid_job( + device="local:braket_simulator", + dependencies=["autoqasm"], +) +def bell_circuit_job(): + @aq.main + def bell(): + h(0) + cnot(0, 1) + c = measure([0, 1]) + + device = LocalSimulator("autoqasm") + for count in range(5): + task = device.run(bell, shots=100) + print(task.result().measurements) + +bell_circuit_job() +``` diff --git a/doc/index.rst b/doc/index.rst index c9982d1..d8832a3 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -26,8 +26,12 @@ readme at https://github.com/amazon-braket/autoqasm/blob/main/README.md. AutoQASM Decorators ******************* -For details on usage of AutoQASM decorators such as `@aq.main`, see -the decorators documentation at https://github.com/amazon-braket/autoqasm/blob/main/doc/decorators.md +For details on usage of AutoQASM decorators such as `@aq.main`, `@aq.subroutine`, `@aq.gate`, +and `@aq.gate_calibration`, see the decorators documentation at +https://github.com/amazon-braket/autoqasm/blob/main/doc/decorators.md. + +For details on using AutoQASM with Amazon Braket Hybrid Jobs with the `@aq.hybrid_job` decorator, +see the documentation at https://github.com/amazon-braket/autoqasm/blob/main/doc/hybrid_jobs.md ******** diff --git a/src/autoqasm/__init__.py b/src/autoqasm/__init__.py index 20ed5c0..113f36c 100644 --- a/src/autoqasm/__init__.py +++ b/src/autoqasm/__init__.py @@ -45,6 +45,7 @@ def my_program(): """ from . import errors, instructions, operators # noqa: F401 from .api import gate, gate_calibration, main, subroutine # noqa: F401 +from .hybrid_job import hybrid_job # noqa: F401 from .instructions import QubitIdentifierType as Qubit # noqa: F401 from .program import Program, build_program, verbatim # noqa: F401 from .transpiler import transpiler # noqa: F401 diff --git a/src/autoqasm/hybrid_job.py b/src/autoqasm/hybrid_job.py new file mode 100644 index 0000000..10f7397 --- /dev/null +++ b/src/autoqasm/hybrid_job.py @@ -0,0 +1,340 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import annotations + +import functools +import inspect +import os +import tempfile +from collections.abc import Callable, Iterable +from contextlib import contextmanager +from logging import Logger, getLogger +from pathlib import Path +from types import CodeType, ModuleType + +from braket.aws.aws_session import AwsSession +from braket.jobs.config import ( + CheckpointConfig, + InstanceConfig, + OutputDataConfig, + S3DataSourceConfig, + StoppingCondition, +) +from braket.jobs.hybrid_job import ( + _create_job, + _IncludeModules, + _log_hyperparameters, + _process_dependencies, + _process_input_data, + _serialize_entry_point, + _validate_python_version, +) +from braket.jobs.local.local_job_container_setup import _get_env_input_data +from braket.jobs.quantum_job_creation import _generate_default_job_name + +DEFAULT_INPUT_CHANNEL = "input" +INNER_SOURCE_INPUT_CHANNEL = "_braket_job_decorator_inner_function_source" +INNER_SOURCE_INPUT_FOLDER = "_inner_function_source_folder" + + +def hybrid_job( + *, + device: str | None, + include_modules: str | ModuleType | Iterable[str | ModuleType] | None = None, + dependencies: str | Path | list[str] | None = None, + local: bool = False, + job_name: str | None = None, + image_uri: str | None = None, + input_data: str | dict | S3DataSourceConfig | None = None, + wait_until_complete: bool = False, + instance_config: InstanceConfig | None = None, + distribution: str | None = None, + copy_checkpoints_from_job: str | None = None, + checkpoint_config: CheckpointConfig | None = None, + role_arn: str | None = None, + stopping_condition: StoppingCondition | None = None, + output_data_config: OutputDataConfig | None = None, + aws_session: AwsSession | None = None, + tags: dict[str, str] | None = None, + logger: Logger = getLogger(__name__), + quiet: bool | None = None, + reservation_arn: str | None = None, +) -> Callable: + """Defines a hybrid job by decorating the entry point function. The job will be created + when the decorated function is called. + + The job created will be a `LocalQuantumJob` when `local` is set to `True`, otherwise an + `AwsQuantumJob`. The following parameters will be ignored when running a job with + `local` set to `True`: `wait_until_complete`, `instance_config`, `distribution`, + `copy_checkpoints_from_job`, `stopping_condition`, `tags`, `logger`, and `quiet`. + + Remarks: + Hybrid jobs created using this decorator have limited access to the source code of + functions defined outside of the decorated function. Functionality that depends on + source code analysis may not work properly when referencing functions defined outside + of the decorated function. + + Args: + device (str | None): Device ARN of the QPU device that receives priority quantum + task queueing once the hybrid job begins running. Each QPU has a separate hybrid jobs + queue so that only one hybrid job is running at a time. The device string is accessible + in the hybrid job instance as the environment variable "AMZN_BRAKET_DEVICE_ARN". + When using embedded simulators, you may provide the device argument as string of the + form: "local:/" or `None`. + + include_modules (str | ModuleType | Iterable[str | ModuleType] | None): Either a + single module or module name or a list of module or module names referring to local + modules to be included. Any references to members of these modules in the hybrid job + algorithm code will be serialized as part of the algorithm code. Default: `[]` + + dependencies (str | Path | list[str] | None): Path (absolute or relative) to a + requirements.txt file, or alternatively a list of strings, with each string being a + `requirement specifier `_, to be used for the hybrid job. + + local (bool): Whether to use local mode for the hybrid job. Default: `False` + + job_name (str | None): A string that specifies the name with which the job is created. + Allowed pattern for job name: `^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,50}$`. Defaults to + f'{decorated-function-name}-{timestamp}'. + + image_uri (str | None): A str that specifies the ECR image to use for executing the job. + `retrieve_image()` function may be used for retrieving the ECR image URIs + for the containers supported by Braket. Default: ``. + + input_data (str | dict | S3DataSourceConfig | None): Information about the training + data. Dictionary maps channel names to local paths or S3 URIs. Contents found + at any local paths will be uploaded to S3 at + f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}'. If a local + path, S3 URI, or S3DataSourceConfig is provided, it will be given a default + channel name "input". + Default: {}. + + wait_until_complete (bool): `True` if we should wait until the job completes. + This would tail the job logs as it waits. Otherwise `False`. Ignored if using + local mode. Default: `False`. + + instance_config (InstanceConfig | None): Configuration of the instance(s) for running the + classical code for the hybrid job. Default: + `InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30)`. + + distribution (str | None): A str that specifies how the job should be distributed. + If set to "data_parallel", the hyperparameters for the job will be set to use data + parallelism features for PyTorch or TensorFlow. Default: `None`. + + copy_checkpoints_from_job (str | None): A str that specifies the job ARN whose + checkpoint you want to use in the current job. Specifying this value will copy + over the checkpoint data from `use_checkpoints_from_job`'s checkpoint_config + s3Uri to the current job's checkpoint_config s3Uri, making it available at + checkpoint_config.localPath during the job execution. Default: `None` + + checkpoint_config (CheckpointConfig | None): Configuration that specifies the + location where checkpoint data is stored. + Default: `CheckpointConfig(localPath='/opt/jobs/checkpoints', + s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints')`. + + role_arn (str | None): A str providing the IAM role ARN used to execute the + script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`. + + stopping_condition (StoppingCondition | None): The maximum length of time, in seconds, + and the maximum number of tasks that a job can run before being forcefully stopped. + Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60). + + output_data_config (OutputDataConfig | None): Specifies the location for the output of + the job. + Default: `OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data', + kmsKeyId=None)`. + + aws_session (AwsSession | None): AwsSession for connecting to AWS Services. + Default: AwsSession() + + tags (dict[str, str] | None): Dict specifying the key-value pairs for tagging this job. + Default: {}. + + logger (Logger): Logger object with which to write logs, such as task statuses + while waiting for task to be in a terminal state. Default: `getLogger(__name__)` + + quiet (bool | None): Sets the verbosity of the logger to low and does not report queue + position. Default is `False`. + + reservation_arn (str | None): the reservation window arn provided by Braket + Direct to reserve exclusive usage for the device to run the hybrid job on. + Default: None. + + Returns: + Callable: the callable for creating a Hybrid Job. + """ + _validate_python_version(image_uri, aws_session) + + def _hybrid_job(entry_point: Callable) -> Callable: + @functools.wraps(entry_point) + def job_wrapper(*args, **kwargs) -> Callable: + """ + The job wrapper. + Returns: + Callable: the callable for creating a Hybrid Job. + """ + with ( + _IncludeModules(include_modules), + tempfile.TemporaryDirectory(dir="", prefix="decorator_job_") as temp_dir, + persist_inner_function_source(entry_point) as inner_source_input, + ): + if input_data is None: + job_input_data = inner_source_input + elif isinstance(input_data, dict): + if INNER_SOURCE_INPUT_CHANNEL in input_data: + raise ValueError(f"input channel cannot be {INNER_SOURCE_INPUT_CHANNEL}") + job_input_data = {**input_data, **inner_source_input} + else: + job_input_data = {DEFAULT_INPUT_CHANNEL: input_data, **inner_source_input} + + temp_dir_path = Path(temp_dir) + entry_point_file_path = Path("entry_point.py") + with open(temp_dir_path / entry_point_file_path, "w") as entry_point_file: + template = "\n".join( + [ + _process_input_data(input_data), + _serialize_entry_point(entry_point, args, kwargs), + ] + ) + entry_point_file.write(template) + + if dependencies: + _process_dependencies(dependencies, temp_dir_path) + + job_args = { + "device": device or "local:none/none", + "source_module": temp_dir, + "entry_point": ( + f"{temp_dir}.{entry_point_file_path.stem}:{entry_point.__name__}" + ), + "wait_until_complete": wait_until_complete, + "job_name": job_name or _generate_default_job_name(func=entry_point), + "hyperparameters": _log_hyperparameters(entry_point, args, kwargs), + "logger": logger, + } + optional_args = { + "image_uri": image_uri, + "input_data": job_input_data, + "instance_config": instance_config, + "distribution": distribution, + "checkpoint_config": checkpoint_config, + "copy_checkpoints_from_job": copy_checkpoints_from_job, + "role_arn": role_arn, + "stopping_condition": stopping_condition, + "output_data_config": output_data_config, + "aws_session": aws_session, + "tags": tags, + "quiet": quiet, + "reservation_arn": reservation_arn, + } + for key, value in optional_args.items(): + if value is not None: + job_args[key] = value + + job = _create_job(job_args, local) + return job + + return job_wrapper + + return _hybrid_job + + +@contextmanager +def persist_inner_function_source(entry_point: callable) -> None: + """Persist the source code of the cloudpickled function by saving its source code as input data + and replace the source file path with the saved one. + + Args: + entry_point (callable): The job decorated function. + """ + inner_source_mapping = _get_inner_function_source(entry_point.__code__) + + with tempfile.TemporaryDirectory() as temp_dir: + copy_dir = f"{temp_dir}/{INNER_SOURCE_INPUT_FOLDER}" + os.mkdir(copy_dir) + path_mapping = _save_inner_source_to_file(inner_source_mapping, copy_dir) + entry_point.__code__ = _replace_inner_function_source_path( + entry_point.__code__, path_mapping + ) + yield {INNER_SOURCE_INPUT_CHANNEL: copy_dir} + + +def _replace_inner_function_source_path( + code_object: CodeType, path_mapping: dict[str, str] +) -> CodeType: + """Recursively replace source code file path of the code object and of its child node's code + objects. + + Args: + code_object (CodeType): Code object which source code file path to be replaced. + path_mapping (dict[str, str]): Mapping between local file path to path in a job + environment. + + Returns: + CodeType: Code object with the source code file path replaced + """ + new_co_consts = [] + for const in code_object.co_consts: + if inspect.iscode(const): + new_path = path_mapping[const.co_filename] + const = const.replace(co_filename=new_path) + const = _replace_inner_function_source_path(const, path_mapping) + new_co_consts.append(const) + + code_object = code_object.replace(co_consts=tuple(new_co_consts)) + return code_object + + +def _save_inner_source_to_file(inner_source: dict[str, str], input_data_dir: str) -> dict[str, str]: + """Saves the source code as input data for a job and returns a dictionary that maps the local + source file path of a function to the one to be used in the job environment. + + Args: + inner_source (dict[str, str]): Mapping between source file name and source code. + input_data_dir (str): The path of the folder to be uploaded to job as input data. + + Returns: + dict[str, str]: Mapping between local file path to path in a job environment. + """ + path_mapping = {} + for i, (local_path, source_code) in enumerate(inner_source.items()): + copy_file_name = f"source_{i}.py" + with open(f"{input_data_dir}/{copy_file_name}", "w") as f: + f.write(source_code) + + path_mapping[local_path] = os.path.join( + _get_env_input_data()["AMZN_BRAKET_INPUT_DIR"], + INNER_SOURCE_INPUT_CHANNEL, + copy_file_name, + ) + return path_mapping + + +def _get_inner_function_source(code_object: CodeType) -> dict[str, str]: + """Returns a dictionary that maps the source file name to source code for all source files + used by the inner functions inside the job decorated function. + Args: + code_object (CodeType): Code object of a inner function. + Returns: + dict[str, str]: Mapping between source file name and source code. + """ + inner_source = {} + for const in code_object.co_consts: + if inspect.iscode(const): + source_file_path = inspect.getfile(code_object) + lines, _ = inspect.findsource(code_object) + inner_source.update({source_file_path: "".join(lines)}) + inner_source.update(_get_inner_function_source(const)) + return inner_source diff --git a/test/unit_tests/autoqasm/test_hybrid_job.py b/test/unit_tests/autoqasm/test_hybrid_job.py new file mode 100644 index 0000000..71d0eeb --- /dev/null +++ b/test/unit_tests/autoqasm/test_hybrid_job.py @@ -0,0 +1,284 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""Tests for the hybrid_job module.""" + +import inspect +import sys +import tempfile +from logging import getLogger +from pathlib import Path +from unittest.mock import MagicMock, mock_open, patch + +import pytest +from braket.aws import AwsQuantumJob +from braket.devices import Devices +from braket.jobs.local import LocalQuantumJob + +import autoqasm as aq + + +@pytest.fixture +def aws_session(): + aws_session = MagicMock() + python_version_str = f"py{sys.version_info.major}{sys.version_info.minor}" + aws_session.get_full_image_tag.return_value = f"1.0-cpu-{python_version_str}-ubuntu22.04" + aws_session.region = "us-west-2" + return aws_session + + +@patch("builtins.open", new_callable=mock_open) +@patch.object(sys.modules["os"], "mkdir") +@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image") +@patch("time.time", return_value=123.0) +@patch("tempfile.TemporaryDirectory") +@patch.object(LocalQuantumJob, "create") +def test_decorator_persist_inner_function_source( + mock_create, mock_tempdir, mock_time, mock_retrieve, mock_mkdir, mock_file, aws_session +): + from autoqasm.hybrid_job import INNER_SOURCE_INPUT_CHANNEL, INNER_SOURCE_INPUT_FOLDER + + mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest" + + def my_entry(): + def inner_function_1(): + def inner_function_2(): + return "my inner function 2" + + return "my inner function 1" + + return inner_function_1 + + inner1 = my_entry() + + mock_tempdir_name = "job_temp_dir_00000" + mock_tempdir.return_value.__enter__.return_value = mock_tempdir_name + + device = Devices.Amazon.SV1 + source_module = mock_tempdir_name + entry_point = f"{mock_tempdir_name}.entry_point:my_entry" + + my_entry = aq.hybrid_job(device=Devices.Amazon.SV1, local=True, aws_session=aws_session)( + my_entry + ) + my_entry() + + expected_source = "".join(inspect.findsource(inner1)[0]) + assert mock_file().write.call_args_list[0][0][0] == expected_source + + expect_source_path = f"{mock_tempdir_name}/{INNER_SOURCE_INPUT_FOLDER}/source_0.py" + assert mock_file.call_args_list[0][0][0] == expect_source_path + + mock_create.assert_called_with( + device=device, + source_module=source_module, + entry_point=entry_point, + job_name="my-entry-123000", + hyperparameters={}, + aws_session=aws_session, + input_data={INNER_SOURCE_INPUT_CHANNEL: f"{mock_tempdir_name}/{INNER_SOURCE_INPUT_FOLDER}"}, + ) + assert mock_tempdir.return_value.__exit__.called + + +@patch.object(sys.modules["autoqasm.hybrid_job"], "persist_inner_function_source") +@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image") +@patch("time.time", return_value=123.0) +@patch("builtins.open") +@patch("tempfile.TemporaryDirectory") +@patch.object(AwsQuantumJob, "create") +def test_decorator_conflict_channel_name( + mock_create, + mock_tempdir, + _mock_open, + mock_time, + mock_retrieve, + mock_persist_source, + aws_session, +): + from autoqasm.hybrid_job import INNER_SOURCE_INPUT_CHANNEL + + mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest" + + @aq.hybrid_job( + device=None, aws_session=aws_session, input_data={INNER_SOURCE_INPUT_CHANNEL: "foo-bar"} + ) + def my_entry(c=0, d: float = 1.0, **extras): + return "my entry return value" + + mock_tempdir_name = "job_temp_dir_00000" + mock_tempdir.return_value.__enter__.return_value = mock_tempdir_name + mock_persist_source.return_value.__enter__.return_value = {} + + expect_error_message = f"input channel cannot be {INNER_SOURCE_INPUT_CHANNEL}" + with pytest.raises(ValueError, match=expect_error_message): + my_entry() + + +@patch.object(sys.modules["autoqasm.hybrid_job"], "persist_inner_function_source") +@patch("braket.jobs.image_uris.retrieve_image") +@patch("sys.stdout") +@patch("time.time", return_value=123.0) +@patch("cloudpickle.register_pickle_by_value") +@patch("cloudpickle.unregister_pickle_by_value") +@patch("shutil.copy") +@patch("builtins.open") +@patch.object(AwsQuantumJob, "create") +def test_decorator_non_defaults( + mock_create, + _mock_open, + mock_copy, + mock_register, + mock_unregister, + mock_time, + mock_stdout, + mock_retrieve, + mock_persist_source, +): + mock_retrieve.return_value = "should-not-be-used" + dependencies = "my_requirements.txt" + image_uri = "my_image.uri" + distribution = "data_parallel" + copy_checkpoints_from_job = "arn/other-job" + role_arn = "role_arn" + aws_session = MagicMock() + tags = {"my_tag": "my_value"} + reservation_arn = ( + "arn:aws:braket:us-west-2:123456789123:reservation/a1b123cd-45e6-789f-gh01-i234567jk8l9" + ) + logger = getLogger(__name__) + + with tempfile.TemporaryDirectory() as tempdir: + Path(tempdir, "temp_dir").mkdir() + Path(tempdir, "temp_file").touch() + + input_data = { + "my_prefix": "my_input_data", + "my_dir": Path(tempdir, "temp_dir"), + "my_file": Path(tempdir, "temp_file"), + "my_s3_prefix": "s3://bucket/path/to/prefix", + } + + @aq.hybrid_job( + device=Devices.Amazon.SV1, + dependencies=dependencies, + image_uri=image_uri, + input_data=input_data, + wait_until_complete=True, + distribution=distribution, + copy_checkpoints_from_job=copy_checkpoints_from_job, + role_arn=role_arn, + aws_session=aws_session, + tags=tags, + reservation_arn=reservation_arn, + logger=logger, + ) + def my_entry(a, b: int, c=0, d: float = 1.0, **extras) -> str: + return "my entry return value" + + mock_tempdir = MagicMock(spec=tempfile.TemporaryDirectory) + mock_tempdir_name = "job_temp_dir_00000" + mock_tempdir.__enter__.return_value = mock_tempdir_name + mock_persist_source.return_value.__enter__.return_value = {} + + device = Devices.Amazon.SV1 + source_module = mock_tempdir_name + entry_point = f"{mock_tempdir_name}.entry_point:my_entry" + wait_until_complete = True + + s3_not_linked = ( + "Input data channels mapped to an S3 source will not be available in the working " + 'directory. Use `get_input_data_dir(channel="my_s3_prefix")` to read input data ' + "from S3 source inside the job container." + ) + + with patch("tempfile.TemporaryDirectory", return_value=mock_tempdir): + my_entry("a", 2, 3, 4, extra_param="value", another=6) + + mock_create.assert_called_with( + device=device, + source_module=source_module, + entry_point=entry_point, + image_uri=image_uri, + input_data=input_data, + wait_until_complete=wait_until_complete, + job_name="my-entry-123000", + distribution=distribution, + hyperparameters={ + "a": "a", + "b": "2", + "c": "3", + "d": "4", + "extra_param": "value", + "another": "6", + }, + copy_checkpoints_from_job=copy_checkpoints_from_job, + role_arn=role_arn, + aws_session=aws_session, + tags=tags, + logger=logger, + reservation_arn=reservation_arn, + ) + mock_copy.assert_called_with( + Path("my_requirements.txt").resolve(), Path(mock_tempdir_name, "requirements.txt") + ) + assert mock_tempdir.__exit__.called + mock_stdout.write.assert_any_call(s3_not_linked) + + +@patch.object(sys.modules["autoqasm.hybrid_job"], "persist_inner_function_source") +@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image") +@patch("time.time", return_value=123.0) +@patch("builtins.open") +@patch("tempfile.TemporaryDirectory") +@patch.object(AwsQuantumJob, "create") +def test_decorator_non_dict_input( + mock_create, + mock_tempdir, + _mock_open, + mock_time, + mock_retrieve, + mock_persist_source, + aws_session, +): + mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest" + input_prefix = "my_input" + + @aq.hybrid_job(device=None, input_data=input_prefix, aws_session=aws_session) + def my_entry(): + return "my entry return value" + + mock_tempdir_name = "job_temp_dir_00000" + mock_tempdir.return_value.__enter__.return_value = mock_tempdir_name + mock_persist_source.return_value.__enter__.return_value = {} + + source_module = mock_tempdir_name + entry_point = f"{mock_tempdir_name}.entry_point:my_entry" + wait_until_complete = False + + device = "local:none/none" + + my_entry() + + mock_create.assert_called_with( + device=device, + source_module=source_module, + entry_point=entry_point, + wait_until_complete=wait_until_complete, + job_name="my-entry-123000", + hyperparameters={}, + logger=getLogger("autoqasm.hybrid_job"), + input_data={"input": input_prefix}, + aws_session=aws_session, + ) + assert mock_tempdir.return_value.__exit__.called