Skip to content

Commit

Permalink
feature: Add support for job decorator macro (#67)
Browse files Browse the repository at this point in the history
* feature: Initial implementation of macro for hybrid job launch
  • Loading branch information
kshyatt-aws authored Nov 16, 2023
1 parent de79455 commit db0e405
Show file tree
Hide file tree
Showing 24 changed files with 739 additions and 173 deletions.
19 changes: 18 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ AWSS3 = "1c724243-ef5b-51ab-93f4-b0a88ac62a95"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CodeTracking = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand All @@ -18,6 +19,7 @@ Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand All @@ -32,22 +34,37 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[compat]
AWS = "=1.90.3"
AWSS3 = "=0.11.2"
Aqua = "=0.6"
Aqua = "=0.8"
Base64 = "1.6"
AxisArrays = "=0.4.7"
CSV = "=0.10.11"
CodeTracking = "=1.3.5"
Compat = "=4.10.0"
DataStructures = "=0.18.15"
Dates = "1.6"
DecFP = "=1.3.2"
Distributions = "=0.25.76"
Distributed = "1.6"
Downloads = "1"
Graphs = "=1.9.0"
HTTP = "=1.10.0"
InteractiveUtils = "1.6"
JLD2 = "=0.4.38"
JSON3 = "=1.13.2"
LinearAlgebra = "1.6"
Logging = "1.6"
Markdown = "=0.7.5"
Mocking = "=0.7.6"
NamedTupleTools = "=0.14.3"
OrderedCollections = "=1.6.2"
Pkg = "1.6"
Random = "1.6"
Statistics = "1.6"
SparseArrays = "1.6"
StructTypes = "=1.10.0"
Tar = "1.9.3"
Test = "1.6"
UUIDs = "1.6"
julia = "1.6"

[extras]
Expand Down
4 changes: 3 additions & 1 deletion PyBraket/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"

[compat]
Aqua = "=0.6"
Aqua = "=0.8"
Braket = "=0.7.6"
CondaPkg = "=0.2.21"
DataStructures = "=0.18.15"
LinearAlgebra = "1.6"
PythonCall = "=0.9.14"
Statistics = "1"
StructTypes = "=1.10.0"
Test = "1.6"
julia = "1.6"

[extras]
Expand Down
6 changes: 2 additions & 4 deletions PyBraket/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
using Test, Aqua, Braket, Braket.AWS, PyBraket

withenv("JULIA_CONDAPKG_VERBOSITY"=>"-1") do
Aqua.test_all(PyBraket, ambiguities=false, unbound_args=false, piracy=false)
Aqua.test_ambiguities(PyBraket)
end
Aqua.test_all(PyBraket, ambiguities=false, unbound_args=false, piracies=false, persistent_tasks=false)
Aqua.test_ambiguities(PyBraket)

function set_aws_creds(test_type)
if test_type == "unit"
Expand Down
1 change: 1 addition & 0 deletions docs/src/device.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
```@docs
Device
AwsDevice
Braket.BraketDevice
isavailable
search_devices
get_devices
Expand Down
1 change: 1 addition & 0 deletions docs/src/jobs.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ log_metric
metrics
logs
download_result
@hybrid_job
```
3 changes: 2 additions & 1 deletion src/Braket.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export Tracker, simulator_tasks_cost, qpu_tasks_cost
export arn, cancel, state, result, results, name, download_result, id, ir, isavailable, search_devices, get_devices
export provider_name, properties, type
export apply_gate_noise!, apply
export logs, log_metric, metrics
export logs, log_metric, metrics, @hybrid_job
export depth, qubit_count, qubits, ir, IRType, OpenQASMSerializationProperties
export OpenQasmProgram
export QueueDepthInfo, QueueType, Normal, Priority, queue_depth, queue_position
Expand Down Expand Up @@ -133,6 +133,7 @@ include("device.jl")
include("gate_applicators.jl")
include("noise_applicators.jl")
include("jobs.jl")
include("job_macro.jl")
include("aws_jobs.jl")
include("local_jobs.jl")
include("task.jl")
Expand Down
136 changes: 52 additions & 84 deletions src/aws_jobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,15 @@ function _parse_query_results(results::Vector, metric_type::String=TIMESTAMP, st
end
mergewith!(vcat, metrics_dict, parsed_metrics)
end
@debug "results: $results, metrics_dict: $metrics_dict"
p = sortperm(metrics_dict[sortby])
expanded_metrics_dict = Dict{String, Vector}(sortby=>metrics_dict[sortby][p])
for k in filter(k->k!=sortby, keys(metrics_dict))
expanded_metrics_dict[k] = []
for ind in expanded_metrics_dict[sortby]
ix = findfirst(local_ind -> local_ind == ind, first.(metrics_dict[k]))
push!(expanded_metrics_dict[k], isnothing(ix) ? missing : last(metrics_dict[k][ix]))
val = popat!(metrics_dict[k], ix, (missing,))
push!(expanded_metrics_dict[k], last(val))
end
end
return expanded_metrics_dict
Expand Down Expand Up @@ -309,12 +311,6 @@ function download_result(j::AwsQuantumJob; extract_to=pwd(), poll_timeout_second
return ""
end

function deserialize_values(data_dictionary::Dict{String, Any}, data_format::PersistedJobDataFormat)
data_format == plaintext && return data_dictionary
throw(ArgumentError("unpickling results not yet supported!"))
end
deserialize_values(data_dictionary::Dict{String, Any}, data_format::String) = deserialize_values(data_dictionary, PersistedJobDataFormatDict[data_format])

function _read_and_deserialize_results(j::AwsQuantumJob, loc::String)
try
open(joinpath(loc, RESULTS_FILENAME), "r") do f
Expand Down Expand Up @@ -409,19 +405,23 @@ function _tar_and_upload(source_module_path::String, code_location::String)
end

function _process_input_data(input_data::Dict, job_name::String)
@debug "_process_input_data Input data: $input_data"
processed_input_data = Dict{String, Any}()
for (k, v) in filter(x->!(x.second isa S3DataSourceConfig), input_data)
input_data[k] = _process_channel(v, job_name, k)
processed_input_data[k] = _process_channel(v, job_name, k)
end
return [merge(Dict("channelName"=>k), v.config) for (k,v) in input_data]
return [merge(Dict("channelName"=>k), v.config) for (k,v) in processed_input_data]
end
_process_input_data(input_data, job_name::String) = _process_input_data(Dict{String, Any}("input_data"=>input_data), job_name)

function _process_channel(loc::String, job_name::String, channel_name::String)
is_s3_uri(loc) && return S3DataSourceConfig(loc)
loc_name = splitdir(loc)[2]
s3_prefix = construct_s3_uri(default_bucket(), "jobs", job_name, "data", channel_name, loc_name)
@debug "Uploading input data for channel $channel_name from $loc to s3 $s3_prefix with loc_name $loc_name"
upload_local_data(loc, s3_prefix)
return S3DataSourceConfig(s3_prefix)
suffixed_prefix = isdir(loc) ? s3_prefix * "/" : s3_prefix
return S3DataSourceConfig(suffixed_prefix)
end

function _get_default_jobs_role()
Expand All @@ -430,70 +430,72 @@ function _get_default_jobs_role()
response = IAM.list_roles(params, aws_config=AWS.AWSConfig(creds=global_conf.credentials, region="us-east-1", output=global_conf.output))
roles = response["ListRolesResult"]["Roles"]["member"]
for role in roles
startswith("AmazonBraketJobsExecutionRole", role["RoleName"]) && return role["Arn"]
startswith(role["RoleName"], "AmazonBraketJobsExecutionRole") && return role["Arn"]
end
throw(ErrorException("No default jobs roles found. Please create a role using the Amazon Braket console or supply a custom role."))
end

function prepare_quantum_job(
device::String,
source_module::String,
entry_point::String,
image_uri::String,
job_name::String,
code_location::String,
role_arn::String,
hyperparameters::Dict{String, <:Any},
input_data::Union{String, Dict},
instance_config::InstanceConfig,
distribution::String,
stopping_condition::StoppingCondition,
output_data_config::OutputDataConfig,
copy_checkpoints_from_job::String,
checkpoint_config::CheckpointConfig,
tags::Dict{String, String},
)
hyperparams = Dict(zip(keys(hyperparameters), map(string, values(hyperparameters))))
input_data_list = _process_input_data(input_data, job_name)
function prepare_quantum_job(device::String, source_module::String, j_opts::JobsOptions)
hyperparams = Dict(zip(keys(j_opts.hyperparameters), map(string, values(j_opts.hyperparameters))))
@debug "Job input data: $(j_opts.input_data)"
@debug "\n\n"
input_data_list = _process_input_data(j_opts.input_data, j_opts.job_name)
entry_point = j_opts.entry_point
if is_s3_uri(source_module)
_process_s3_source_module(source_module, entry_point, code_location)
_process_s3_source_module(source_module, j_opts.entry_point, j_opts.code_location)
else
entry_point = _process_local_source_module(source_module, entry_point, code_location)
entry_point = _process_local_source_module(source_module, j_opts.entry_point, j_opts.code_location)
end
algo_spec = Dict("scriptModeConfig"=>OrderedDict("entryPoint"=>entry_point,
"s3Uri"=>code_location*"/source.tar.gz",
"s3Uri"=>j_opts.code_location*"/source.tar.gz",
"compressionType"=>"GZIP"))

!isempty(image_uri) && setindex!(algo_spec, Dict("uri"=>image_uri), "containerImage")
if !isempty(copy_checkpoints_from_job)
checkpoints_to_copy = get_job(copy_checkpoints_from_job)["checkpointConfig"]["s3Uri"]
copy_s3_directory(checkpoints_to_copy, checkpoint_config.s3Uri)
!isempty(j_opts.image_uri) && setindex!(algo_spec, Dict("uri"=>j_opts.image_uri), "containerImage")
if !isempty(j_opts.copy_checkpoints_from_job)
checkpoints_to_copy = get_job(j_opts.copy_checkpoints_from_job)["checkpointConfig"]["s3Uri"]
copy_s3_directory(checkpoints_to_copy, j_opts.checkpoint_config.s3Uri)
end
if distribution == "data_parallel"
if j_opts.distribution == "data_parallel"
merge!(hyperparams, Dict("sagemaker_distributed_dataparallel_enabled"=>"true",
"sagemaker_instance_type"=>instance_config.instanceType))
"sagemaker_instance_type"=>j_opts.instance_config.instanceType))
end

params = OrderedDict(
"checkpointConfig"=>Dict(checkpoint_config),
"checkpointConfig"=>Dict(j_opts.checkpoint_config),
"hyperParameters"=>hyperparams,
"inputDataConfig"=>input_data_list,
"stoppingCondition"=>Dict(stopping_condition),
"tags"=>tags,
"stoppingCondition"=>Dict(j_opts.stopping_condition),
"tags"=>j_opts.tags,
)
token = string(uuid1())
dev_conf = Dict(DeviceConfig(device))
inst_conf = Dict(instance_config)
out_conf = Dict(output_data_config)
return (algo_spec=algo_spec, token=token, dev_conf=dev_conf, inst_conf=inst_conf, job_name=job_name, out_conf=out_conf, role_arn=role_arn, params=params)
inst_conf = Dict(j_opts.instance_config)
out_conf = Dict(j_opts.output_data_config)
return (algo_spec=algo_spec, token=token, dev_conf=dev_conf, inst_conf=inst_conf, job_name=j_opts.job_name, out_conf=out_conf, role_arn=j_opts.role_arn, params=params)
end

function AwsQuantumJob(device::String, source_module::String, job_opts::JobsOptions)
args = prepare_quantum_job(device, source_module, job_opts)
algo_spec = args[:algo_spec]
token = args[:token]
dev_conf = args[:dev_conf]
inst_conf = args[:inst_conf]
job_name = args[:job_name]
out_conf = args[:out_conf]
role_arn = args[:role_arn]
params = args[:params]
response = BRAKET.create_job(algo_spec, token, dev_conf, inst_conf, job_name, out_conf, role_arn, params)
job = AwsQuantumJob(response["jobArn"])
job_opts.wait_until_complete && logs(job, wait=true)
return job
end

"""
AwsQuantumJob(device::String, source_module::String; kwargs...)
AwsQuantumJob(device::Union{String, BraketDevice}, source_module::String; kwargs...)
Create and launch an `AwsQuantumJob` which will use device `device` (a managed simulator, a QPU, or an [embedded simulator](https://docs.aws.amazon.com/braket/latest/developerguide/pennylane-embedded-simulators.html))
and will run the code (either a single file, or a Julia package, or a Python module) located at `source_module`. The keyword arguments
`kwargs` control the launch configuration of the job.
`kwargs` control the launch configuration of the job. `device` can be either the device's ARN as a `String`, or a [`BraketDevice`](@ref).
# Keyword Arguments
- `entry_point::String` - the function to run in `source_module` if `source_module` is a Python module/Julia package. Defaults to an empty string, in which case
Expand Down Expand Up @@ -531,39 +533,5 @@ and will run the code (either a single file, or a Julia package, or a Python mod
The default is `CheckpointConfig("/opt/jobs/checkpoints", "s3://{default_bucket_name}/jobs/{job_name}/checkpoints")`.
- `tags::Dict{String, String}` - specifies the key-value pairs for tagging this job.
"""
function AwsQuantumJob(
device::String,
source_module::String;
entry_point::String="",
image_uri::String="",
job_name::String=_generate_default_job_name(image_uri),
code_location::String=construct_s3_uri(default_bucket(), "jobs", job_name, "script"),
role_arn::String=get(ENV, "BRAKET_JOBS_ROLE_ARN", _get_default_jobs_role()),
wait_until_complete::Bool=false,
hyperparameters::Dict{String, <:Any}=Dict{String, Any}(),
input_data::Union{String, Dict} = Dict(),
instance_config::InstanceConfig = InstanceConfig(),
distribution::String="",
stopping_condition::StoppingCondition = StoppingCondition(),
output_data_config::OutputDataConfig = OutputDataConfig(job_name=job_name),
copy_checkpoints_from_job::String="",
checkpoint_config::CheckpointConfig = CheckpointConfig(job_name),
tags::Dict{String, String}=Dict{String, String}(),
)
args = prepare_quantum_job(device, source_module, entry_point, image_uri, job_name, code_location,
role_arn, hyperparameters, input_data, instance_config, distribution,
stopping_condition, output_data_config, copy_checkpoints_from_job, checkpoint_config, tags)
algo_spec = args[:algo_spec]
token = args[:token]
dev_conf = args[:dev_conf]
inst_conf = args[:inst_conf]
job_name = args[:job_name]
out_conf = args[:out_conf]
role_arn = args[:role_arn]
params = args[:params]
response = BRAKET.create_job(algo_spec, token, dev_conf, inst_conf, job_name, out_conf, role_arn, params)
job = AwsQuantumJob(response["jobArn"])
wait_until_complete && logs(job, wait=true)
return job
end
AwsQuantumJob(device::BraketDevice, source_module::String; kwargs...) = AwsQuantumJob(convert(String, device), source_module; kwargs...)
AwsQuantumJob(device::String, source_module::String; kwargs...) = AwsQuantumJob(device, source_module, JobsOptions(; kwargs...))
AwsQuantumJob(device::BraketDevice, source_module::String; kwargs...) = AwsQuantumJob(convert(String, device), source_module, JobsOptions(; kwargs...))
15 changes: 15 additions & 0 deletions src/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,27 @@ const _GET_DEVICES_ORDER_BY_KEYS = Set(("arn", "name", "type", "provider_name",
@enum AwsDeviceType SIMULATOR QPU
const AwsDeviceTypeDict = Dict("SIMULATOR"=>SIMULATOR, "QPU"=>QPU)

"""
BraketDevice
An abstract type representing one of the devices available on Amazon Braket, which will automatically
generate its ARN when passed to the appropriate function.
# Examples
```jldoctest
julia> d = Braket.SV1()
julia> arn(d)
"arn:aws:braket:::device/quantum-simulator/amazon/sv1"
```
"""
abstract type BraketDevice end
for provider in (:AmazonDevice, :_XanaduDevice, :_DWaveDevice, :OQCDevice, :QuEraDevice, :IonQDevice, :RigettiDevice)
@eval begin
abstract type $provider <: BraketDevice end
end
end
arn(d::BraketDevice) = convert(String, d)

for (d, d_arn) in zip((:SV1, :DM1, :TN1), ("sv1", "dm1", "tn1"))
@eval begin
Expand Down
Loading

0 comments on commit db0e405

Please sign in to comment.