From db0e405a1046573d2226f0152587d1822969bc4d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <67932820+kshyatt-aws@users.noreply.github.com> Date: Thu, 16 Nov 2023 11:30:32 -0500 Subject: [PATCH] feature: Add support for job decorator macro (#67) * feature: Initial implementation of macro for hybrid job launch --- Project.toml | 19 +- PyBraket/Project.toml | 4 +- PyBraket/test/runtests.jl | 6 +- docs/src/device.md | 1 + docs/src/jobs.md | 1 + src/Braket.jl | 3 +- src/aws_jobs.jl | 136 +++---- src/device.jl | 15 + src/job_macro.jl | 363 ++++++++++++++++++ src/jobs.jl | 64 +++ src/local_jobs.jl | 106 ++--- src/utils.jl | 20 +- test/integ_tests/JobProject.toml | 3 + test/integ_tests/create_quantum_job.jl | 2 +- test/integ_tests/job_macro.jl | 50 +++ .../job_test_submodule/Project.toml | 2 + .../job_test_submodule_file.jl | 8 + .../job_test_submodule/requirements.txt | 1 + test/integ_tests/job_test_script.jl | 4 + test/integ_tests/requirements.txt | 2 + test/integ_tests/runtests.jl | 1 + test/job_macro.jl | 46 +++ test/local_jobs.jl | 50 ++- test/runtests.jl | 5 +- 24 files changed, 739 insertions(+), 173 deletions(-) create mode 100644 src/job_macro.jl create mode 100644 test/integ_tests/JobProject.toml create mode 100644 test/integ_tests/job_macro.jl create mode 100644 test/integ_tests/job_test_module/job_test_submodule/Project.toml create mode 100644 test/integ_tests/job_test_module/job_test_submodule/job_test_submodule_file.jl create mode 100644 test/integ_tests/job_test_module/job_test_submodule/requirements.txt create mode 100644 test/integ_tests/job_test_script.jl create mode 100644 test/integ_tests/requirements.txt create mode 100644 test/job_macro.jl diff --git a/Project.toml b/Project.toml index 0b5a6aaa..613a2340 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ AWSS3 = "1c724243-ef5b-51ab-93f4-b0a88ac62a95" AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +CodeTracking = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" @@ -18,6 +19,7 @@ Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -32,22 +34,37 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] AWS = "=1.90.3" AWSS3 = "=0.11.2" -Aqua = "=0.6" +Aqua = "=0.8" +Base64 = "1.6" AxisArrays = "=0.4.7" CSV = "=0.10.11" +CodeTracking = "=1.3.5" Compat = "=4.10.0" DataStructures = "=0.18.15" +Dates = "1.6" DecFP = "=1.3.2" Distributions = "=0.25.76" +Distributed = "1.6" +Downloads = "1" Graphs = "=1.9.0" HTTP = "=1.10.0" +InteractiveUtils = "1.6" +JLD2 = "=0.4.38" JSON3 = "=1.13.2" +LinearAlgebra = "1.6" +Logging = "1.6" Markdown = "=0.7.5" Mocking = "=0.7.6" NamedTupleTools = "=0.14.3" OrderedCollections = "=1.6.2" +Pkg = "1.6" +Random = "1.6" +Statistics = "1.6" +SparseArrays = "1.6" StructTypes = "=1.10.0" Tar = "1.9.3" +Test = "1.6" +UUIDs = "1.6" julia = "1.6" [extras] diff --git a/PyBraket/Project.toml b/PyBraket/Project.toml index a14fd15e..186c94a2 100644 --- a/PyBraket/Project.toml +++ b/PyBraket/Project.toml @@ -13,13 +13,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" [compat] -Aqua = "=0.6" +Aqua = "=0.8" Braket = "=0.7.6" CondaPkg = "=0.2.21" DataStructures = "=0.18.15" +LinearAlgebra = "1.6" PythonCall = "=0.9.14" Statistics = "1" StructTypes = "=1.10.0" +Test = "1.6" julia = "1.6" [extras] diff --git a/PyBraket/test/runtests.jl b/PyBraket/test/runtests.jl index 1b95d3eb..ef624576 100644 --- a/PyBraket/test/runtests.jl +++ b/PyBraket/test/runtests.jl @@ -1,9 +1,7 @@ using Test, Aqua, Braket, Braket.AWS, PyBraket -withenv("JULIA_CONDAPKG_VERBOSITY"=>"-1") do - Aqua.test_all(PyBraket, ambiguities=false, unbound_args=false, piracy=false) - Aqua.test_ambiguities(PyBraket) -end +Aqua.test_all(PyBraket, ambiguities=false, unbound_args=false, piracies=false, persistent_tasks=false) +Aqua.test_ambiguities(PyBraket) function set_aws_creds(test_type) if test_type == "unit" diff --git a/docs/src/device.md b/docs/src/device.md index ff3a8617..f1bab3a6 100644 --- a/docs/src/device.md +++ b/docs/src/device.md @@ -3,6 +3,7 @@ ```@docs Device AwsDevice +Braket.BraketDevice isavailable search_devices get_devices diff --git a/docs/src/jobs.md b/docs/src/jobs.md index e3b8e837..6905b716 100644 --- a/docs/src/jobs.md +++ b/docs/src/jobs.md @@ -14,4 +14,5 @@ log_metric metrics logs download_result +@hybrid_job ``` diff --git a/src/Braket.jl b/src/Braket.jl index 607a7ec5..a56e810a 100644 --- a/src/Braket.jl +++ b/src/Braket.jl @@ -6,7 +6,7 @@ export Tracker, simulator_tasks_cost, qpu_tasks_cost export arn, cancel, state, result, results, name, download_result, id, ir, isavailable, search_devices, get_devices export provider_name, properties, type export apply_gate_noise!, apply -export logs, log_metric, metrics +export logs, log_metric, metrics, @hybrid_job export depth, qubit_count, qubits, ir, IRType, OpenQASMSerializationProperties export OpenQasmProgram export QueueDepthInfo, QueueType, Normal, Priority, queue_depth, queue_position @@ -133,6 +133,7 @@ include("device.jl") include("gate_applicators.jl") include("noise_applicators.jl") include("jobs.jl") +include("job_macro.jl") include("aws_jobs.jl") include("local_jobs.jl") include("task.jl") diff --git a/src/aws_jobs.jl b/src/aws_jobs.jl index e67cb453..2e3bc9e1 100644 --- a/src/aws_jobs.jl +++ b/src/aws_jobs.jl @@ -186,13 +186,15 @@ function _parse_query_results(results::Vector, metric_type::String=TIMESTAMP, st end mergewith!(vcat, metrics_dict, parsed_metrics) end + @debug "results: $results, metrics_dict: $metrics_dict" p = sortperm(metrics_dict[sortby]) expanded_metrics_dict = Dict{String, Vector}(sortby=>metrics_dict[sortby][p]) for k in filter(k->k!=sortby, keys(metrics_dict)) expanded_metrics_dict[k] = [] for ind in expanded_metrics_dict[sortby] ix = findfirst(local_ind -> local_ind == ind, first.(metrics_dict[k])) - push!(expanded_metrics_dict[k], isnothing(ix) ? missing : last(metrics_dict[k][ix])) + val = popat!(metrics_dict[k], ix, (missing,)) + push!(expanded_metrics_dict[k], last(val)) end end return expanded_metrics_dict @@ -309,12 +311,6 @@ function download_result(j::AwsQuantumJob; extract_to=pwd(), poll_timeout_second return "" end -function deserialize_values(data_dictionary::Dict{String, Any}, data_format::PersistedJobDataFormat) - data_format == plaintext && return data_dictionary - throw(ArgumentError("unpickling results not yet supported!")) -end -deserialize_values(data_dictionary::Dict{String, Any}, data_format::String) = deserialize_values(data_dictionary, PersistedJobDataFormatDict[data_format]) - function _read_and_deserialize_results(j::AwsQuantumJob, loc::String) try open(joinpath(loc, RESULTS_FILENAME), "r") do f @@ -409,10 +405,12 @@ function _tar_and_upload(source_module_path::String, code_location::String) end function _process_input_data(input_data::Dict, job_name::String) + @debug "_process_input_data Input data: $input_data" + processed_input_data = Dict{String, Any}() for (k, v) in filter(x->!(x.second isa S3DataSourceConfig), input_data) - input_data[k] = _process_channel(v, job_name, k) + processed_input_data[k] = _process_channel(v, job_name, k) end - return [merge(Dict("channelName"=>k), v.config) for (k,v) in input_data] + return [merge(Dict("channelName"=>k), v.config) for (k,v) in processed_input_data] end _process_input_data(input_data, job_name::String) = _process_input_data(Dict{String, Any}("input_data"=>input_data), job_name) @@ -420,8 +418,10 @@ function _process_channel(loc::String, job_name::String, channel_name::String) is_s3_uri(loc) && return S3DataSourceConfig(loc) loc_name = splitdir(loc)[2] s3_prefix = construct_s3_uri(default_bucket(), "jobs", job_name, "data", channel_name, loc_name) + @debug "Uploading input data for channel $channel_name from $loc to s3 $s3_prefix with loc_name $loc_name" upload_local_data(loc, s3_prefix) - return S3DataSourceConfig(s3_prefix) + suffixed_prefix = isdir(loc) ? s3_prefix * "/" : s3_prefix + return S3DataSourceConfig(suffixed_prefix) end function _get_default_jobs_role() @@ -430,70 +430,72 @@ function _get_default_jobs_role() response = IAM.list_roles(params, aws_config=AWS.AWSConfig(creds=global_conf.credentials, region="us-east-1", output=global_conf.output)) roles = response["ListRolesResult"]["Roles"]["member"] for role in roles - startswith("AmazonBraketJobsExecutionRole", role["RoleName"]) && return role["Arn"] + startswith(role["RoleName"], "AmazonBraketJobsExecutionRole") && return role["Arn"] end throw(ErrorException("No default jobs roles found. Please create a role using the Amazon Braket console or supply a custom role.")) end -function prepare_quantum_job( - device::String, - source_module::String, - entry_point::String, - image_uri::String, - job_name::String, - code_location::String, - role_arn::String, - hyperparameters::Dict{String, <:Any}, - input_data::Union{String, Dict}, - instance_config::InstanceConfig, - distribution::String, - stopping_condition::StoppingCondition, - output_data_config::OutputDataConfig, - copy_checkpoints_from_job::String, - checkpoint_config::CheckpointConfig, - tags::Dict{String, String}, - ) - hyperparams = Dict(zip(keys(hyperparameters), map(string, values(hyperparameters)))) - input_data_list = _process_input_data(input_data, job_name) +function prepare_quantum_job(device::String, source_module::String, j_opts::JobsOptions) + hyperparams = Dict(zip(keys(j_opts.hyperparameters), map(string, values(j_opts.hyperparameters)))) + @debug "Job input data: $(j_opts.input_data)" + @debug "\n\n" + input_data_list = _process_input_data(j_opts.input_data, j_opts.job_name) + entry_point = j_opts.entry_point if is_s3_uri(source_module) - _process_s3_source_module(source_module, entry_point, code_location) + _process_s3_source_module(source_module, j_opts.entry_point, j_opts.code_location) else - entry_point = _process_local_source_module(source_module, entry_point, code_location) + entry_point = _process_local_source_module(source_module, j_opts.entry_point, j_opts.code_location) end algo_spec = Dict("scriptModeConfig"=>OrderedDict("entryPoint"=>entry_point, - "s3Uri"=>code_location*"/source.tar.gz", + "s3Uri"=>j_opts.code_location*"/source.tar.gz", "compressionType"=>"GZIP")) - !isempty(image_uri) && setindex!(algo_spec, Dict("uri"=>image_uri), "containerImage") - if !isempty(copy_checkpoints_from_job) - checkpoints_to_copy = get_job(copy_checkpoints_from_job)["checkpointConfig"]["s3Uri"] - copy_s3_directory(checkpoints_to_copy, checkpoint_config.s3Uri) + !isempty(j_opts.image_uri) && setindex!(algo_spec, Dict("uri"=>j_opts.image_uri), "containerImage") + if !isempty(j_opts.copy_checkpoints_from_job) + checkpoints_to_copy = get_job(j_opts.copy_checkpoints_from_job)["checkpointConfig"]["s3Uri"] + copy_s3_directory(checkpoints_to_copy, j_opts.checkpoint_config.s3Uri) end - if distribution == "data_parallel" + if j_opts.distribution == "data_parallel" merge!(hyperparams, Dict("sagemaker_distributed_dataparallel_enabled"=>"true", - "sagemaker_instance_type"=>instance_config.instanceType)) + "sagemaker_instance_type"=>j_opts.instance_config.instanceType)) end params = OrderedDict( - "checkpointConfig"=>Dict(checkpoint_config), + "checkpointConfig"=>Dict(j_opts.checkpoint_config), "hyperParameters"=>hyperparams, "inputDataConfig"=>input_data_list, - "stoppingCondition"=>Dict(stopping_condition), - "tags"=>tags, + "stoppingCondition"=>Dict(j_opts.stopping_condition), + "tags"=>j_opts.tags, ) token = string(uuid1()) dev_conf = Dict(DeviceConfig(device)) - inst_conf = Dict(instance_config) - out_conf = Dict(output_data_config) - return (algo_spec=algo_spec, token=token, dev_conf=dev_conf, inst_conf=inst_conf, job_name=job_name, out_conf=out_conf, role_arn=role_arn, params=params) + inst_conf = Dict(j_opts.instance_config) + out_conf = Dict(j_opts.output_data_config) + return (algo_spec=algo_spec, token=token, dev_conf=dev_conf, inst_conf=inst_conf, job_name=j_opts.job_name, out_conf=out_conf, role_arn=j_opts.role_arn, params=params) +end + +function AwsQuantumJob(device::String, source_module::String, job_opts::JobsOptions) + args = prepare_quantum_job(device, source_module, job_opts) + algo_spec = args[:algo_spec] + token = args[:token] + dev_conf = args[:dev_conf] + inst_conf = args[:inst_conf] + job_name = args[:job_name] + out_conf = args[:out_conf] + role_arn = args[:role_arn] + params = args[:params] + response = BRAKET.create_job(algo_spec, token, dev_conf, inst_conf, job_name, out_conf, role_arn, params) + job = AwsQuantumJob(response["jobArn"]) + job_opts.wait_until_complete && logs(job, wait=true) + return job end """ - AwsQuantumJob(device::String, source_module::String; kwargs...) + AwsQuantumJob(device::Union{String, BraketDevice}, source_module::String; kwargs...) Create and launch an `AwsQuantumJob` which will use device `device` (a managed simulator, a QPU, or an [embedded simulator](https://docs.aws.amazon.com/braket/latest/developerguide/pennylane-embedded-simulators.html)) and will run the code (either a single file, or a Julia package, or a Python module) located at `source_module`. The keyword arguments -`kwargs` control the launch configuration of the job. +`kwargs` control the launch configuration of the job. `device` can be either the device's ARN as a `String`, or a [`BraketDevice`](@ref). # Keyword Arguments - `entry_point::String` - the function to run in `source_module` if `source_module` is a Python module/Julia package. Defaults to an empty string, in which case @@ -531,39 +533,5 @@ and will run the code (either a single file, or a Julia package, or a Python mod The default is `CheckpointConfig("/opt/jobs/checkpoints", "s3://{default_bucket_name}/jobs/{job_name}/checkpoints")`. - `tags::Dict{String, String}` - specifies the key-value pairs for tagging this job. """ -function AwsQuantumJob( - device::String, - source_module::String; - entry_point::String="", - image_uri::String="", - job_name::String=_generate_default_job_name(image_uri), - code_location::String=construct_s3_uri(default_bucket(), "jobs", job_name, "script"), - role_arn::String=get(ENV, "BRAKET_JOBS_ROLE_ARN", _get_default_jobs_role()), - wait_until_complete::Bool=false, - hyperparameters::Dict{String, <:Any}=Dict{String, Any}(), - input_data::Union{String, Dict} = Dict(), - instance_config::InstanceConfig = InstanceConfig(), - distribution::String="", - stopping_condition::StoppingCondition = StoppingCondition(), - output_data_config::OutputDataConfig = OutputDataConfig(job_name=job_name), - copy_checkpoints_from_job::String="", - checkpoint_config::CheckpointConfig = CheckpointConfig(job_name), - tags::Dict{String, String}=Dict{String, String}(), - ) - args = prepare_quantum_job(device, source_module, entry_point, image_uri, job_name, code_location, - role_arn, hyperparameters, input_data, instance_config, distribution, - stopping_condition, output_data_config, copy_checkpoints_from_job, checkpoint_config, tags) - algo_spec = args[:algo_spec] - token = args[:token] - dev_conf = args[:dev_conf] - inst_conf = args[:inst_conf] - job_name = args[:job_name] - out_conf = args[:out_conf] - role_arn = args[:role_arn] - params = args[:params] - response = BRAKET.create_job(algo_spec, token, dev_conf, inst_conf, job_name, out_conf, role_arn, params) - job = AwsQuantumJob(response["jobArn"]) - wait_until_complete && logs(job, wait=true) - return job -end -AwsQuantumJob(device::BraketDevice, source_module::String; kwargs...) = AwsQuantumJob(convert(String, device), source_module; kwargs...) +AwsQuantumJob(device::String, source_module::String; kwargs...) = AwsQuantumJob(device, source_module, JobsOptions(; kwargs...)) +AwsQuantumJob(device::BraketDevice, source_module::String; kwargs...) = AwsQuantumJob(convert(String, device), source_module, JobsOptions(; kwargs...)) diff --git a/src/device.jl b/src/device.jl index 0b87687f..3b402bce 100644 --- a/src/device.jl +++ b/src/device.jl @@ -8,12 +8,27 @@ const _GET_DEVICES_ORDER_BY_KEYS = Set(("arn", "name", "type", "provider_name", @enum AwsDeviceType SIMULATOR QPU const AwsDeviceTypeDict = Dict("SIMULATOR"=>SIMULATOR, "QPU"=>QPU) +""" + BraketDevice + +An abstract type representing one of the devices available on Amazon Braket, which will automatically +generate its ARN when passed to the appropriate function. + +# Examples +```jldoctest +julia> d = Braket.SV1() + +julia> arn(d) +"arn:aws:braket:::device/quantum-simulator/amazon/sv1" +``` +""" abstract type BraketDevice end for provider in (:AmazonDevice, :_XanaduDevice, :_DWaveDevice, :OQCDevice, :QuEraDevice, :IonQDevice, :RigettiDevice) @eval begin abstract type $provider <: BraketDevice end end end +arn(d::BraketDevice) = convert(String, d) for (d, d_arn) in zip((:SV1, :DM1, :TN1), ("sv1", "dm1", "tn1")) @eval begin diff --git a/src/job_macro.jl b/src/job_macro.jl new file mode 100644 index 00000000..41229c05 --- /dev/null +++ b/src/job_macro.jl @@ -0,0 +1,363 @@ +using CodeTracking, JLD2 + +"""Sanitize forbidden characters from hyperparameter strings""" +function _sanitize(hyperparameter::String) + # replace forbidden characters with close matches + # , not technically forbidden, but to avoid mismatched parens + if VERSION >= v"1.7" + sanitized = replace(hyperparameter, "\n"=>" ", "\$"=>"?", "("=>"{", "&"=>"+", "`"=>"'", ")"=>"}") + else + sanitized = deepcopy(hyperparameter) + for pat in ["\n"=>" ", "\$"=>"?", "("=>"{", "&"=>"+", "`"=>"'", ")"=>"}"] + sanitized = replace(sanitized, pat) + end + end + # max allowed length for a hyperparameter is 2500 + # show as much as possible, including the final 20 characters + length(sanitized) > 2500 && return "$(sanitized[1:2500-23])...$(sanitized[end-19:end])" + return sanitized +end + + +function _kwarg_to_string(kw) + val = kw[2] isa String ? "\""*kw[2]*"\"" : string(kw[2]) + return string(kw[1])*"="*val +end + +""" +Captures the arguments and keyword arguments of +`entry_point` as hyperparameters for the Job. +""" +function _log_hyperparameters(f_args, f_kwargs) + # dummy function, as we are now using the input data JLD2 file for this + hyperparams = Dict{String, String}() + #sanitized_hyperparameters = Dict{String, String}(name=>_sanitize(param) for (name, param) in hyperparams) + return hyperparams +end + +function matches(prefix::String) + possible_paths = readdir(dirname(abspath(prefix))) + possible_paths = isabspath(prefix) ? [joinpath(dirname(abspath(prefix)), path) for path in possible_paths] : possible_paths + prefixed_paths = filter(path->startswith(path, prefix), possible_paths) + @debug "Possible paths: $possible_paths" + @debug "Prefix: $prefix, Prefixed paths: $prefixed_paths" + @debug "\n\n" + return prefixed_paths +end +is_prefix(path::String) = length(matches(path)) > 1 || !ispath(path) +is_prefix(path) = false +function _process_input_data(input_data::Union{String, Dict}) + isempty(input_data) && (input_data = Dict()) + input_data isa Dict || (input_data = Dict("input"=>input_data)) + prefix_channels = Set{String}() + directory_channels = Set{String}() + file_channels = Set{String}() + @debug "All input data channels: $(collect(keys(input_data)))" + for (channel, data) in input_data + @debug "Channel $channel, data $data" + if is_s3_uri(string(data)) + channel_arg = channel != "input" ? "channel=$channel" : "" + @warn "Input data channels mapped to an S3 source will not be available in the working directory. Use `get_input_data_dir($channel_arg)` to read input data from S3 source inside the job container." + elseif is_prefix(data) + @debug "prefix channel" + union!(prefix_channels, [channel]) + elseif isdir(data) + @debug "dir channel" + union!(directory_channels, [channel]) + else + @debug "file channel" + union!(file_channels, [channel]) + end + end + @debug "Generating prefix matches" + for channel in prefix_channels + @debug "Channel: $channel" + @debug "Data: $(input_data[channel])" + @debug "Matches: $( matches(input_data[channel]) )" + end + prefix_matches = Dict(channel=>matches(input_data[channel]) for channel in prefix_channels) + prefix_matches_str = "{" * join(["$k: $v" for (k,v) in prefix_matches] , ", ") * "}" + @debug "Prefix matches: $prefix_matches" + @debug "Prefix matches string: $prefix_matches_str" + + prefix_channels_str = "{" * join(prefix_channels, ", ") * "}" + directory_channels_str = "{" * join(["\"" * d * "\"" for d in directory_channels], ", ") * "}" + @debug "Directory channels: $directory_channels" + @debug "Directory channels string: $directory_channels_str" + input_data_items = [(channel, relpath(input_data[channel])) for channel in filter(ch->ch∈union(prefix_channels, directory_channels, file_channels), collect(keys(input_data)))] + @debug "Input data items: $input_data_items" + @debug "\n\n" + input_data_items_str = "[" * join(string.(input_data_items), ", ") * "]" + return """ + from pathlib import Path + from braket.jobs import get_input_data_dir + + + \"\"\"Create symlink from input_link_path to input_data_path.\"\"\" + def make_link(input_link_path, input_data_path, links): + input_link_path.parent.mkdir(parents=True, exist_ok=True) + input_link_path.symlink_to(input_data_path) + print(input_link_path, '->', input_data_path) + links[input_link_path] = input_data_path + + + def link_input(): + links = {} + dirs = set() + # map of data sources to lists of matched local files + prefix_matches = $prefix_matches_str + + for channel, data in $input_data_items_str: + + if channel in $prefix_channels_str: + # link all matched files + for input_link_name in prefix_matches[channel]: + input_link_path = Path(input_link_name) + input_data_path = Path(get_input_data_dir(channel)) / input_link_path.name + make_link(input_link_path, input_data_path, links) + + else: + input_link_path = Path(data) + if channel in $directory_channels_str: + # link directory source directly to input channel directory + input_data_path = Path(get_input_data_dir(channel)) + else: + # link file source to file within input channel directory + input_data_path = Path(get_input_data_dir(channel), Path(data).name) + + make_link(input_link_path, input_data_path, links) + + return links + + + def clean_links(links): + for link, target in links.items(): + if link.is_symlink and link.readlink() == target: + link.unlink() + + if link.is_relative_to(Path()): + for dir in link.parents[:-1]: + try: + dir.rmdir() + except: + # directory not empty + pass + """ +end + +function _serialize_function(f_name::String, f_source::String, included_pkgs::Union{String, Vector{String}}="", included_jl_files::Union{String, Vector{String}}="") + using_list = isempty(included_pkgs) ? "JLD2, Braket" : join(vcat(included_pkgs, ["JLD2", "Braket"]), ", ") + included_jl_files = included_jl_files isa String ? [included_jl_files] : included_jl_files + return """ +import os +import json +from juliacall import Main as jl +from juliacall import Pkg as jlPkg +from braket.jobs import get_results_dir, save_job_result +from braket.jobs_data import PersistedJobDataFormat + +jlPkg.activate(".") +jlPkg.instantiate() +jl.seval(f'using $using_list') + +input_file_dir = get_input_data_dir("jl_include_files") +for fi in os.listdir(input_file_dir): + full_path = input_file_dir + '/' + fi + jl.seval(f'include("{full_path}")') + +# set working directory to results dir +results_dir = get_results_dir() +os.chdir(results_dir) + +# create symlinks to input data +links = link_input() + +def main(): + result = None + # load and run serialized entry point + hyperparams = {} + hp_file = os.environ.get("AMZN_BRAKET_HP_FILE") + if hp_file: + with open(hp_file, "r") as f: + hyperparams = json.load(f) + hyperparams = hyperparams or {} + try: + jl.seval('$f_source\\n') + load_str_loc = get_input_data_dir("jl_args") + '/job_f_args.jld2' + load_str = f'j_args = load("{load_str_loc}")' + jl.seval(load_str) + jl_func_str = f'$f_name(j_args["jl_args"]...; j_args["jl_kwargs"]...)' + result = jl.seval(jl_func_str) + except Exception as e: + print('An exception occured running the Julia code: ', e, flush=True) + raise e + finally: + clean_links(links) + if result is not None: + save_job_result(result, data_format=PersistedJobDataFormat.PICKLED_V4) + return result + +if __name__ == "__main__": + main() +""" +end + +function parse_macro_args(args) + has_device = length(args) != 0 && occursin("=", string(args[1])) + device = has_device ? string(args[1]) : "" + raw_kwargs = has_device ? args : args[2:end] + return device, raw_kwargs +end + +function _process_call_args(args) + n_arguments = length(args) + code = quote end + + splatted_args = [Meta.isexpr(arg, :(...)) for arg in args] + new_args = [splatted_args[a_ix] ? args[a_ix].args[1] : args[a_ix] for a_ix in 1:n_arguments] + new_kwargs = filter(arg->Meta.isexpr(arg, :kw), new_args) + new_args = filter(arg->!Meta.isexpr(arg, :kw), new_args) + # handle kwargs + new_kwargs = [(arg.args[1], arg.args[2]) for arg in new_kwargs] + + # match arguments with variables + vars = [gensym() for v_ix in 1:length(new_args)] + for v_ix in 1:length(new_args) + push!(code.args, :($(vars[v_ix]) = $(new_args[v_ix]))) + end + kw_vars = [gensym() for v_ix in 1:length(new_kwargs)] + for v_ix in 1:length(new_kwargs) + push!(code.args, :($(kw_vars[v_ix]) = $(new_kwargs[v_ix]))) + end + # convert the arguments + # while keeping the original arguments alive + var_expressions = [splatted_args[v_ix] ? Expr(:(...), vars[v_ix]) : vars[v_ix] for v_ix in 1:length(new_args)] + kw_var_expressions = [kw_vars[v_ix] for v_ix in 1:length(new_kwargs)] + return code, vars, var_expressions, kw_var_expressions +end + +function jobify_f(f, job_f_types, job_f_arguments, job_f_kwargs, device; jl_dependencies="", py_dependencies="", as_local=false, include_modules="", using_jl_pkgs="", include_jl_files="", job_opts_kwargs...) + mktempdir(pwd(), prefix="decorator_job_") do temp_path + j_opts = Braket.JobsOptions(; job_opts_kwargs...) + entry_point_file = joinpath(temp_path, "entry_point.py") + # create JLD2 file with function arguments and kwargs + save(joinpath(temp_path, "job_f_args.jld2"), "jl_args", job_f_arguments, "jl_kwargs", Dict(kw[1]=>kw[2] for kw in job_f_kwargs)) + + included_jl_files_vec = include_jl_files isa String ? [include_jl_files] : include_jl_files + jl_files_dict = Dict{String, String}() + if !isempty(include_jl_files) + mkdir(joinpath(temp_path, "jl_input_files")) + for fi in included_jl_files_vec + dest_fi = isabspath(fi) ? basename(fi) : fi + cp(fi, joinpath(temp_path, "jl_input_files", dest_fi)) + end + jl_files_dict["jl_include_files"] = joinpath(temp_path, "jl_input_files") + end + raw_input_data = Dict{String, Any}() + if j_opts.input_data isa String + raw_input_data = Dict("input"=>j_opts.input_data, "jl_args"=>joinpath(temp_path, "job_f_args.jld2")) + else + raw_input_data = merge(j_opts.input_data, Dict("jl_args"=>joinpath(temp_path, "job_f_args.jld2"))) + end + merge!(raw_input_data , jl_files_dict) + j_opts.input_data = raw_input_data + input_data = _process_input_data(raw_input_data) + + f_source = code_string(f, job_f_types) + if isempty(f_source) + t = precompile(f, job_f_types) + f_source = code_string(f, job_f_types) + end + isempty(f_source) && error("no method instance for $f found with types $job_f_types") + f_source = String(escape_string(f_source)) + serialized_f = _serialize_function(string(Symbol(f)), f_source, using_jl_pkgs, include_jl_files) + file_contents = join((input_data, serialized_f), "\n") + write(entry_point_file, file_contents) + + if !isempty(py_dependencies) + cp(py_dependencies, joinpath(temp_path, "requirements.txt")) + else + write(joinpath(temp_path, "requirements.txt"), "juliacall") + end + if !isempty(jl_dependencies) + cp(jl_dependencies, joinpath(temp_path, "Project.toml")) + else + write(joinpath(temp_path, "Project.toml"), "[deps]\nBraket = \"19504a0f-b47d-4348-9127-acc6cc69ef67\"\nJLD2 = \"033835bb-8acc-5ee8-8aae-3f567f8a3819\"\n") + end + device = isempty(device) ? "local:none/none" : string(device) + hyperparams = _log_hyperparameters(job_f_arguments, job_f_kwargs) + j_opts.hyperparameters = hyperparams + j_opts.entry_point = "$(relpath(temp_path)).entry_point" + T = as_local ? LocalQuantumJob : AwsQuantumJob + return T(device, relpath(temp_path), j_opts) + end +end + +""" + @hybrid_job [device] [job_creation_kwargs] job_function(args...; kwargs..) + +Run `job_function` inside an [Amazon Braket Job](https://docs.aws.amazon.com/braket/latest/developerguide/braket-jobs.html), launching +the job with creation arguments defined by `job_creation_kwargs`, and reserving device `device` (may be empty, in which case +`local:local/none` is used). `device` should be either a [valid AWS device ARN](https://docs.aws.amazon.com/braket/latest/developerguide/braket-devices.html) +or use the format `local:/` (see the developer guide on [embedded simulators](https://docs.aws.amazon.com/braket/latest/developerguide/pennylane-embedded-simulators.html)). + +Valid job creation keyword arguments are: + - `jl_dependencies::String`: a path to a `Project.toml` containing the Julia packages needed to run `job_function`. Can be `""` (default). + - `py_dependencies::String`: a path to a `requirements.txt` containing the Python packages needed to run `job_function`. Can be `""` (default). + - `as_local::Bool`: whether to run the job in [local mode](https://docs.aws.amazon.com/braket/latest/developerguide/braket-jobs-local-mode.html). Default is `false`, running as a Hybrid, non-local Job. + - `include_modules`: unused but reserved argument. + - `using_jl_pkgs::Union{String, Vector{String}}`: Julia packages to load with `using [pkgs]` before the `job_function` is called within the job. + - `include_jl_files::Union{String, Vector{String}}`: path(s) to Julia file(s) to load with `include(file)` before `job_function` is called within the job. + - creation arguments for [`AwsQuantumJob`](@ref) + +Currently, `args` and `kwargs` to `job_function` must be serializable by `JLD2.jl`. `job_function` must be a Julia function, not Python. + +!!! note + The paths to include files and dependencies are resolved from the *call location* of this macro - to ensure your paths will resolve correctly, use absolute, not relative, paths. + +# Examples +```julia +function my_job_func(a, b::Int; c=0, d::Float64=1.0, kwargs...) + Braket.save_job_result(job_helper()) + py_reqs = read(joinpath(Braket.get_input_data_dir(), "requirements.txt"), String) + hyperparameters = Braket.get_hyperparameters() + write("test/output_file.txt", "hello") + return 0 +end + +py_deps = joinpath(@__DIR__, "requirements.txt") +jl_deps = joinpath(@__DIR__, "JobProject.toml") +input_data = joinpath(@__DIR__, "requirements") +include_jl_files = joinpath(@__DIR__, "job_test_script.jl") + +j = @hybrid_job Braket.SV1() wait_until_complete=true as_local=false include_modules="job_test_script" using_jl_pkgs="LinearAlgebra" include_jl_files=include_jl_files py_dependencies=py_deps jl_dependencies=jl_deps input_data=input_data my_job_func(MyStruct(), 2, d=5.0, extra_kwarg="extra_value") +``` +""" +macro hybrid_job(args...) + # peel apart `args` + entry_point = args[end] + Meta.isexpr(entry_point, :call) || throw(ArgumentError("final argument to @hybrid_job must be a function call")) + device, jobify_kwargs = parse_macro_args(args[1:end-1]) + f = entry_point.args[1] + f_args = entry_point.args[2:end] + # need to transform f to launch as a Job + # and transform its arguments to be properly passed to the new call + code, vars, var_expressions, kw_var_expressions = _process_call_args(f_args) + @gensym job_f job_f_args job_f_kwargs job_f_types wrapped_f + # now build up the actual call + push!(code.args, + quote + $job_f_args = ($(var_expressions...),) + $job_f_kwargs = ($(kw_var_expressions...),) + $job_f_types = tuple(map(Core.Typeof, $job_f_args)...) + $wrapped_f = $jobify_f($f, $job_f_types, $job_f_args, $job_f_kwargs, $device; $(jobify_kwargs...)) + $wrapped_f + end + ) + # use this let block to avoid leaking out of scope + return esc(quote + let + $code + end + end) +end diff --git a/src/jobs.jl b/src/jobs.jl index 5ce35dbc..e8bca6eb 100644 --- a/src/jobs.jl +++ b/src/jobs.jl @@ -52,3 +52,67 @@ function retrieve_image(f::Framework, config::AWSConfig) end return string(registry) * ".dkr.ecr.$aws_region.amazonaws.com/$tag" end + +function get_input_data_dir(channel::String="input") + input_dir = get(ENV, "AMZN_BRAKET_INPUT_DIR", ".") + return input_dir == "." ? input_dir : joinpath(input_dir, channel) +end +get_job_name() = get(ENV, "AMZN_BRAKET_JOB_NAME", "") +get_job_device_arn() = get(ENV, "AMZN_BRAKET_DEVICE_ARN", "local:none/none") +get_results_dir() = get(ENV, "AMZN_BRAKET_JOB_RESULTS_DIR", ".") +get_checkpoint_dir() = get(ENV, "AMZN_BRAKET_CHECKPOINT_DIR", ".") +function get_hyperparameters() + haskey(ENV, "AMZN_BRAKET_HP_FILE") || return Dict{String, Any}() + return JSON3.read(read(ENV["AMZN_BRAKET_HP_FILE"], String), Dict{String, Any}) +end + +function serialize_values(data_dictionary::Dict{String, Any}, data_format::PersistedJobDataFormat) + data_format == pickled_v4 && throw(ArgumentError("pickling data not yet supported!")) + return data_dictionary +end + +function deserialize_values(data_dictionary::Dict{String, Any}, data_format::PersistedJobDataFormat) + data_format == plaintext && return data_dictionary + throw(ArgumentError("unpickling results not yet supported!")) +end +deserialize_values(data_dictionary::Dict{String, Any}, data_format::String) = deserialize_values(data_dictionary, PersistedJobDataFormatDict[data_format]) + +function _load_persisted_data(filename::String="") + isempty(filename) && (filename = joinpath(get_results_dir(), "results.json")) + try + return JSON3.read(read(filename, String), Dict{String, Any}) + catch + return PersistedJobData(header_dict[PersistedJobData], Dict{String, Any}(), plaintext) + end +end + +function save_job_result(result_data::Dict{String, Any}, data_format::PersistedJobDataFormat=plaintext) + # can't handle pickled data yet + current_persisted_data = _load_persisted_data() + current_results = deserialize_values(current_persisted_data.dataDictionary, current_persisted_data.dataFormat) + updated_results = merge(current_results, result_data) + result_path = joinpath(get_results_dir(), "results.json") + serialized_data = serialize_values(updated_results, data_format) + persisted_data = PersistedJobData(header_dict[PersistedJobData], serialized_data, data_format) + write(result_path, JSON3.write(persisted_data)) + return +end +save_job_result(result_data, data_format::PersistedJobDataFormat=plaintext) = save_job_result(Dict{String, Any}("result"=>result_data), data_format) + +Base.@kwdef mutable struct JobsOptions + entry_point::String="" + image_uri::String="" + job_name::String=_generate_default_job_name(image_uri) + code_location::String=construct_s3_uri(default_bucket(), "jobs", job_name, "script") + role_arn::String=get(ENV, "BRAKET_JOBS_ROLE_ARN", _get_default_jobs_role()) + wait_until_complete::Bool=false + hyperparameters::Dict{String, <:Any}=Dict{String, Any}() + input_data::Union{String, Dict} = Dict() + instance_config::InstanceConfig = InstanceConfig() + distribution::String="" + stopping_condition::StoppingCondition = StoppingCondition() + output_data_config::OutputDataConfig = OutputDataConfig(job_name=job_name) + copy_checkpoints_from_job::String="" + checkpoint_config::CheckpointConfig = CheckpointConfig(job_name) + tags::Dict{String, String}=Dict{String, String}() +end diff --git a/src/local_jobs.jl b/src/local_jobs.jl index 7436ec13..e1208573 100644 --- a/src/local_jobs.jl +++ b/src/local_jobs.jl @@ -70,6 +70,7 @@ function download_input_data(config::AWSConfig, download_dir, input_data::Dict{S s3_uri_prefix = input_data["dataSource"]["s3DataSource"]["s3Uri"] bucket, prefix = parse_s3_uri(s3_uri_prefix) s3_keys = collect(s3_list_keys(bucket, prefix)) + @debug "Channel name: $channel_name, S3 URI: $s3_uri_prefix, S3 keys: $s3_keys, bucket: $bucket, prefix: $prefix, all bucket keys: $(collect(s3_list_keys(bucket)))" top_level = is_s3_dir(prefix, s3_keys) ? prefix : dirname(prefix) top_level = isempty(top_level) ? prefix : top_level found_item = false @@ -80,10 +81,12 @@ function download_input_data(config::AWSConfig, download_dir, input_data::Dict{S end for s3_key in s3_keys relative_key = relpath(s3_key, top_level) + relative_key = relative_key == "." ? basename(prefix) : relative_key download_path = joinpath(download_dir, channel_name, relative_key) if !endswith(s3_key, "/") + @debug "Getting file from S3: bucket $bucket, s3_key $s3_key, top level $top_level, relative_key $relative_key, download path $download_path" mkpath(dirname(download_path)) - s3_get_file(config, bucket, s3_key, joinpath(download_path, s3_key)) + s3_get_file(config, bucket, s3_key, download_path) found_item = true end end @@ -94,26 +97,33 @@ end function copy_input_data_list(c::LocalJobContainer, args) haskey(args[:params], "inputDataConfig") || return false input_data_list = args[:params]["inputDataConfig"] + @debug "Input data list for copy: $input_data_list" mktempdir() do temp_dir foreach(input_data->download_input_data(c.config, temp_dir, input_data), input_data_list) - copy_to_container!(c, temp_dir, "/opt/ml/input/data/") + # add dot to copy temp_dir's CONTENTS + copy_to_container!(c, temp_dir * "/.", "/opt/ml/input/data/") end return !isempty(input_data_list) end function setup_container!(c::LocalJobContainer, create_job_args) + @debug "Setting up container..." c_name = c.container_name # create expected paths for a Braket job to run + @debug "Setting up container: creating expected paths" proc_out, proc_err, code = capture_docker_cmd(`docker exec $c_name mkdir -p /opt/ml/model`) local_path = create_job_args[:params]["checkpointConfig"]["localPath"] proc_out, proc_err, code = capture_docker_cmd(`docker exec $c_name mkdir -p $local_path`) + @debug "Setting up container: creating environment variables" env_vars = Dict{String, String}() merge!(env_vars, get_env_creds(c.config)) script_mode = create_job_args[:algo_spec]["scriptModeConfig"] merge!(env_vars, get_env_script_mode_config(script_mode)) merge!(env_vars, get_env_defaults(c.config, create_job_args)) + @debug "Setting up container: copying hyperparameters" copy_hyperparameters(c, create_job_args) && merge!(env_vars, get_env_hyperparameters()) + @debug "Setting up container: copying input data list" copy_input_data_list(c, create_job_args) && merge!(env_vars, get_env_input_data()) c.env = env_vars return c @@ -122,17 +132,19 @@ end function run_local_job!(c::LocalJobContainer) code_path = c.container_code_path c_name = c.container_name + @debug "Running local job: capturing entry point command" entry_point_cmd = `docker exec $c_name printenv SAGEMAKER_PROGRAM` entry_program, err, code = capture_docker_cmd(entry_point_cmd) (isnothing(entry_program) || isempty(entry_program)) && throw(ErrorException("Start program not found. The specified container is not setup to run Braket Jobs. Please see setup instructions for creating your own containers.")) env_list = String.(reduce(vcat, ["-e", k*"="*v] for (k,v) in c.env)) cmd = Cmd(["docker", "exec", "-w", String(code_path), env_list..., String(c_name), "python", String(entry_program)]) + @debug "Running local job: running full entry point command" proc_out, proc_err, code = capture_docker_cmd(cmd) if code == 0 c.run_log *= proc_out else err_str = "Run local job process exited with code: $code" - println(proc_err) + c.run_log *= proc_out c.run_log *= err_str * proc_err end return c @@ -179,6 +191,7 @@ end function start_container!(c::LocalJobContainer, force_update::Bool) image_uri = c.image_uri get_image_name(image_uri) = capture_docker_cmd(`docker images -q $image_uri`)[1] + @debug "Acquiring docker image for container start" image_name = get_image_name(image_uri) if isempty(image_name) || isnothing(image_name) try @@ -196,6 +209,7 @@ function start_container!(c::LocalJobContainer, force_update::Bool) @warn "Unable to update $(c.image_uri) with error $e" end end + @debug "Launching container with docker run" container_name, err, code = capture_docker_cmd(`docker run -d --rm $image_name tail -f /dev/null`) code == 0 || throw(ErrorException(err)) c.container_name = container_name @@ -269,11 +283,43 @@ mutable struct LocalQuantumJob <: Job end end +function LocalQuantumJob( + device::String, + source_module::String, + j_opts::JobsOptions; + force_update::Bool=false, + config::AWSConfig=global_aws_config() + ) + image_uri = isempty(j_opts.image_uri) ? retrieve_image(BASE, config) : j_opts.image_uri + args = prepare_quantum_job(device, source_module, j_opts) + algo_spec = args[:algo_spec] + job_name = args[:job_name] + ispath(job_name) && throw(ErrorException("a local directory called $job_name already exists. Please use a different job name.")) + image_uri = haskey(algo_spec, "containerImage") ? algo_spec["containerImage"]["uri"] : retrieve_image(BASE, config) + + run_log = "" + let local_job_container=LocalJobContainer(image_uri, args, force_update=force_update) + local_job_container = run_local_job!(local_job_container) + # copy results out + copy_from_container!(local_job_container, "/opt/ml/model", job_name) + !ispath(job_name) && mkdir(job_name) + write(joinpath(job_name, "log.txt"), local_job_container.run_log) + if haskey(args, :params) && haskey(args[:params], "checkpointConfig") && haskey(args[:params]["checkpointConfig"], "localPath") + checkpoint_path = args[:params]["checkpointConfig"]["localPath"] + copy_from_container!(local_job_container, checkpoint_path, joinpath(job_name, "checkpoints")) + end + run_log = local_job_container.run_log + stop_container!(local_job_container) + end + return LocalQuantumJob("local:job/$job_name", run_log=run_log) +end + """ - LocalQuantumJob(device::String, source_module::String; kwargs...) + LocalQuantumJob(device::Union{String, BraketDevice}, source_module::String; kwargs...) Create and launch a `LocalQuantumJob` which will use device `device` (a managed simulator, a QPU, or an [embedded simulator](https://docs.aws.amazon.com/braket/latest/developerguide/pennylane-embedded-simulators.html)) -and will run the code (either a single file, or a Julia package, or a Python module) located at `source_module`. A *local* job +and will run the code (either a single file, or a Julia package, or a Python module) located at `source_module`. `device` can be either the device's ARN as a `String`, or a [`BraketDevice`](@ref). +A *local* job runs *locally* on your computational resource by launching the Job container locally using `docker`. The job will block until it completes, replicating the `wait_until_complete` behavior of [`AwsQuantumJob`](@ref). @@ -311,52 +357,7 @@ The keyword arguments `kwargs` control the launch configuration of the job. The default is `CheckpointConfig("/opt/jobs/checkpoints", "s3://{default_bucket_name}/jobs/{job_name}/checkpoints")`. - `tags::Dict{String, String}` - specifies the key-value pairs for tagging this job. """ -function LocalQuantumJob( - device::String, - source_module::String; - entry_point::String="", - image_uri::String="", - job_name::String=_generate_default_job_name(image_uri), - code_location::String=construct_s3_uri(default_bucket(), "jobs", job_name, "script"), - role_arn::String="", - wait_until_complete::Bool=false, - hyperparameters::Dict{String, <:Any}=Dict{String, Any}(), - input_data::Union{String, Dict} = Dict(), - instance_config::InstanceConfig = InstanceConfig(), - distribution::String="", - stopping_condition::StoppingCondition = StoppingCondition(), - output_data_config::OutputDataConfig = OutputDataConfig(job_name=job_name), - copy_checkpoints_from_job::String="", - checkpoint_config::CheckpointConfig = CheckpointConfig(job_name), - tags::Dict{String, String}=Dict{String, String}(), - force_update::Bool=false, - config::AWSConfig=global_aws_config() - ) - image_uri = isempty(image_uri) ? retrieve_image(BASE, config) : image_uri - args = prepare_quantum_job(device, source_module, entry_point, image_uri, job_name, code_location, - role_arn, hyperparameters, input_data, instance_config, distribution, - stopping_condition, output_data_config, copy_checkpoints_from_job, checkpoint_config, tags) - algo_spec = args[:algo_spec] - job_name = args[:job_name] - ispath(job_name) && throw(ErrorException("a local directory called $job_name already exists. Please use a different job name.")) - image_uri = haskey(algo_spec, "containerImage") ? algo_spec["containerImage"]["uri"] : retrieve_image(BASE, config) - - run_log = "" - let local_job_container=LocalJobContainer(image_uri, args, force_update=force_update) - local_job_container = run_local_job!(local_job_container) - # copy results out - copy_from_container!(local_job_container, "/opt/ml/model", job_name) - !ispath(job_name) && mkdir(job_name) - write(joinpath(job_name, "log.txt"), local_job_container.run_log) - if haskey(args, :params) && haskey(args[:params], "checkpointConfig") && haskey(args[:params]["checkpointConfig"], "localPath") - checkpoint_path = args[:params]["checkpointConfig"]["localPath"] - copy_from_container!(local_job_container, checkpoint_path, joinpath(job_name, "checkpoints")) - end - run_log = local_job_container.run_log - stop_container!(local_job_container) - end - return LocalQuantumJob("local:job/$job_name", run_log=run_log) -end +LocalQuantumJob(device::String, source_module::String; force_update::Bool=false, config::AWSConfig=global_aws_config(), kwargs...) = LocalQuantumJob(device, source_module, JobsOptions(; kwargs...); force_update=force_update, config=config) LocalQuantumJob(device::BraketDevice, source_module::String; kwargs...) = LocalQuantumJob(convert(String, device), source_module; kwargs...) """ @@ -417,7 +418,8 @@ Copy, extract, and deserialize the results of local job `j`. function result(j::LocalQuantumJob; kwargs...) try raw = read(joinpath(name(j), "results.json"), String) - persisted_data = parse_raw_schema(raw) + persisted_data = parse_raw_schema(raw) + @debug "Persisted data format: $(persisted_data.dataFormat)" deserialized_data = deserialize_values(persisted_data.dataDictionary, persisted_data.dataFormat) return deserialized_data catch diff --git a/src/utils.jl b/src/utils.jl index 237441e1..2d092742 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -42,23 +42,29 @@ end function upload_local_data(local_prefix::String, s3_prefix::String) base_dir = isabspath(local_prefix) ? joinpath(splitpath(local_prefix)[1:end-1]...) : pwd() relative_prefix = isabspath(local_prefix) ? relpath(local_prefix, base_dir) : local_prefix + @debug "Uploading local data with relative prefix $relative_prefix and base_dir $base_dir" isfile(relative_prefix) && return upload_to_s3(relative_prefix, s3_prefix) isdir(base_dir) || throw(ErrorException("uploading data $local_prefix to $s3_prefix failed!")) for (root, dirs, files) in walkdir(base_dir) - fns = String[] + fn_to_uri = Dict{String, String}() if root == base_dir fns = filter(x->startswith(x, relative_prefix), files) + for fn in fns + fn_to_uri[joinpath(base_dir, fn)] = replace(fn, relative_prefix=>s3_prefix) + end elseif startswith(relpath(root, base_dir), relative_prefix) fns = map(fn->joinpath(relpath(root, base_dir), fn), files) + for fn in fns + fn_to_uri[joinpath(base_dir, fn)] = replace(fn, relative_prefix=>s3_prefix) + end end - # need to fix s3 URIs on Windows - foreach(fns) do fn + for (fn, uri) in fn_to_uri + @debug "$fn, is file? $(isfile(fn))" if !Sys.iswindows() - upload_to_s3(fn, replace(fn, relative_prefix=>s3_prefix)) + @debug "Uploading $fn to S3 URI $uri" + upload_to_s3(fn, uri) else - s3_uri = replace(fn, relative_prefix=>s3_prefix) - s3_uri = replace(s3_uri, "\\"=>"/") - upload_to_s3(fn, s3_uri) + upload_to_s3(fn, replace(uri, "\\"=>"/")) end end end diff --git a/test/integ_tests/JobProject.toml b/test/integ_tests/JobProject.toml new file mode 100644 index 00000000..f342e61e --- /dev/null +++ b/test/integ_tests/JobProject.toml @@ -0,0 +1,3 @@ +[deps] +Braket = "19504a0f-b47d-4348-9127-acc6cc69ef67" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" diff --git a/test/integ_tests/create_quantum_job.jl b/test/integ_tests/create_quantum_job.jl index 62b90876..158fcc7c 100644 --- a/test/integ_tests/create_quantum_job.jl +++ b/test/integ_tests/create_quantum_job.jl @@ -54,7 +54,7 @@ using AWS, AWSS3, Braket, JSON3, Test @test occursin(data, log_data) end end - @test startswith(metadata(job)["failureReason"], "AlgorithmError: Job at job_test_script:start_here") + @test startswith(metadata(job)["failureReason"], "AlgorithmError: AssertionError, exit code: 1") end # Asserts the job is completed with the output, checkpoints, tasks and # script folder created in S3 for respective job. Validate the results are diff --git a/test/integ_tests/job_macro.jl b/test/integ_tests/job_macro.jl new file mode 100644 index 00000000..a9fec912 --- /dev/null +++ b/test/integ_tests/job_macro.jl @@ -0,0 +1,50 @@ +using AWS, AWSS3, Braket, JSON3, Test + +struct MyStruct + attribute::String +end +MyStruct() = MyStruct("value") +Base.show(io::IO, s::MyStruct) = print(io, "MyStruct($(s.attribute))") + +@testset "Job creation macro" begin + @testset "Local" for local_mode ∈ (true, false) + function my_job_func(a, b::Int; c=0, d::Float64=1.0, kwargs...) + Braket.save_job_result(job_helper()) + py_reqs = read(joinpath(Braket.get_input_data_dir(), "requirements.txt"), String) + @assert occursin("pytest", py_reqs) + @assert a.attribute == "value" + @assert b == 2 + @assert c == 0 + @assert d == 5 + @assert kwargs[:extra_kwarg] == "extra_value" + # ensure using LinearAlgebra worked + @assert size(triu(rand(2,2))) == (2,2) + + hyperparameters = Braket.get_hyperparameters() + @assert isempty(hyperparameters) + write("test/output_file.txt", "hello") + return 0 + end + + py_deps = joinpath(@__DIR__, "requirements.txt") + jl_deps = joinpath(@__DIR__, "JobProject.toml") + input_data = joinpath(@__DIR__, "requirements") + include_jl_files = joinpath(@__DIR__, "job_test_script.jl") + j = @hybrid_job Braket.SV1() wait_until_complete=true as_local=local_mode include_modules="job_test_script" using_jl_pkgs="LinearAlgebra" include_jl_files=include_jl_files py_dependencies=py_deps jl_dependencies=jl_deps input_data=input_data my_job_func(MyStruct(), 2, d=5.0, extra_kwarg="extra_value") + logs(j) + @test_broken result(j)["status"] == "SUCCESS" + try + j_res = download_result(j) + @test isdir(name(j)) + cd(name(j)) do + res_str = read(joinpath("test", "output_file.txt"), String) + @test res_str == "hello" + @test ispath("results.json") + @test ispath("test") + @test_broken ispath(joinpath("test", "integ_tests")) + end + finally + rm(name(j), recursive=true) + end + end +end diff --git a/test/integ_tests/job_test_module/job_test_submodule/Project.toml b/test/integ_tests/job_test_module/job_test_submodule/Project.toml new file mode 100644 index 00000000..fcc8998f --- /dev/null +++ b/test/integ_tests/job_test_module/job_test_submodule/Project.toml @@ -0,0 +1,2 @@ +[deps] +Braket = "19504a0f-b47d-4348-9127-acc6cc69ef67" diff --git a/test/integ_tests/job_test_module/job_test_submodule/job_test_submodule_file.jl b/test/integ_tests/job_test_module/job_test_submodule/job_test_submodule_file.jl new file mode 100644 index 00000000..aab5c8ba --- /dev/null +++ b/test/integ_tests/job_test_module/job_test_submodule/job_test_submodule_file.jl @@ -0,0 +1,8 @@ +module SubmoduleHelper + +function submodule_helper() + println("import successful!") + return Dict("status"=>"SUCCESS") +end + +end diff --git a/test/integ_tests/job_test_module/job_test_submodule/requirements.txt b/test/integ_tests/job_test_module/job_test_submodule/requirements.txt new file mode 100644 index 00000000..e079f8a6 --- /dev/null +++ b/test/integ_tests/job_test_module/job_test_submodule/requirements.txt @@ -0,0 +1 @@ +pytest diff --git a/test/integ_tests/job_test_script.jl b/test/integ_tests/job_test_script.jl new file mode 100644 index 00000000..96478a74 --- /dev/null +++ b/test/integ_tests/job_test_script.jl @@ -0,0 +1,4 @@ +function job_helper() + println("we did it!") + return Dict("status"=>"SUCCESS") +end diff --git a/test/integ_tests/requirements.txt b/test/integ_tests/requirements.txt new file mode 100644 index 00000000..1b563452 --- /dev/null +++ b/test/integ_tests/requirements.txt @@ -0,0 +1,2 @@ +pytest +juliacall diff --git a/test/integ_tests/runtests.jl b/test/integ_tests/runtests.jl index 86a5a47f..2897462a 100644 --- a/test/integ_tests/runtests.jl +++ b/test/integ_tests/runtests.jl @@ -6,6 +6,7 @@ s3_destination_folder = Braket.default_task_bucket() include("adjoint_gradient.jl") include("create_local_quantum_job.jl") include("create_quantum_job.jl") +include("job_macro.jl") include("cost_tracking.jl") include("device_creation.jl") include("queue_information.jl") diff --git a/test/job_macro.jl b/test/job_macro.jl new file mode 100644 index 00000000..d9460b17 --- /dev/null +++ b/test/job_macro.jl @@ -0,0 +1,46 @@ +using Braket, Test, Random, Dates, Tar, JSON3 +using Mocking +Mocking.activate() + +@testset "Job macro" begin + @testset "Macro defaults" begin + resp_dict = Dict("jobArn"=>"arn:job/fake", "GetCallerIdentityResult"=>Dict("Arn"=>"fake_arn", "Account"=>"000000"), "ListRolesResult"=>Dict("Roles"=>Dict("member"=>[Dict("RoleName"=>"AmazonBraketJobsExecutionRoleFake", "Arn"=>"fake_arn")]))) + function f(http_backend, request, response_stream) + if request.service == "s3" + xml_str = """ + + + + 2000:01:01T00:00:00 + "amazon-braket-fake_region-000000" + + + + "fake_name" + 000000 + + + """ + return Braket.AWS.Response(Braket.HTTP.Response(200, ["Content-Type"=>"application/xml"]), IOBuffer(xml_str)) + else + return Braket.AWS.Response(Braket.HTTP.Response(200, ["Content-Type"=>"application/json"]), IOBuffer(JSON3.write(resp_dict))) + end + end + req_patch = @patch Braket.AWS._http_request(http_backend, request::Braket.AWS.Request, response_stream::IO; kwargs...) = f(http_backend, request, response_stream) + apply(req_patch) do + function my_job_func(a, b; c) + println(2) + return 0 + end + ENV["AMZN_BRAKET_OUT_S3_BUCKET"] = "fake_bucket" + j = @hybrid_job my_job_func(0, 1, c=1) + delete!(ENV, "AMZN_BRAKET_OUT_S3_BUCKET") + @test arn(j) == "arn:job/fake" + end + end + @testset "Hyperparameter sanitization" for (hyperparameter, expected) in (("with\nnewline", "with newline"), + ("with weird chars: (&\$`)", "with weird chars: {+?'}"), + (repeat('?', 2600), repeat('?', 2477)*"..."*repeat('?', 20))) + @test Braket._sanitize(hyperparameter) == expected + end +end diff --git a/test/local_jobs.jl b/test/local_jobs.jl index f2d3200d..d6e08983 100644 --- a/test/local_jobs.jl +++ b/test/local_jobs.jl @@ -45,7 +45,12 @@ script_mode_dict = OrderedDict("s3Uri"=>"fake_uri", "entryPoint"=>"fake_entry", end @testset "Successful input data download" begin args = (algo_spec=Dict("scriptModeConfig"=>script_mode_dict), - params=Dict("hyperParameters"=>Dict("cool"=>"beans"), "checkpointConfig"=>Dict("localPath"=>"fake_local_path"), "inputDataConfig"=>[Dict("channelName"=>"fake_channel", "dataSource"=>Dict("s3DataSource"=>Dict("s3Uri"=>"s3://fake_bucket/fake_input")))]), + params=Dict( + "hyperParameters"=>Dict("cool"=>"beans"), + "checkpointConfig"=>Dict("localPath"=>"fake_local_path"), + "inputDataConfig"=>[Dict("channelName"=>"fake_channel", + "dataSource"=>Dict("s3DataSource"=>Dict("s3Uri"=>"s3://fake_bucket/fake_input")))] + ), job_name=job_name, out_conf=Dict("s3Path"=>"s3://fake_s3_bucket/fake_s3_path"), dev_conf=Dict("device"=>"fake_device") @@ -194,24 +199,29 @@ script_mode_dict = OrderedDict("s3Uri"=>"fake_uri", "entryPoint"=>"fake_entry", end @testset "run a facsimile LocalQuantumJob" begin function f(http_backend, request, response_stream) - xml_str = """ - - - fake_bucket - fake_channel - 205 - 1000 - false - - fake_input - 2009-10-12T17:50:30.000Z - "fba9dede5f27731c9771645a39863328" - 434234 - STANDARD - - - """ - return Braket.AWS.Response(Braket.HTTP.Response(200, ["Content-Type"=>"application/xml"]), IOBuffer(xml_str)) + resp_dict = Dict("ListRolesResult"=>Dict("Roles"=>Dict("member"=>[Dict("RoleName"=>"AmazonBraketJobsExecutionRolenoMemberHere", "Arn"=>"fake_job_arn")]))) + if request.service == "s3" + xml_str = """ + + + fake_bucket + fake_channel + 205 + 1000 + false + + fake_input + 2009-10-12T17:50:30.000Z + "fba9dede5f27731c9771645a39863328" + 434234 + STANDARD + + + """ + return Braket.AWS.Response(Braket.HTTP.Response(200, ["Content-Type"=>"application/xml"]), IOBuffer(xml_str)) + else + return Braket.AWS.Response(Braket.HTTP.Response(200, ["Content-Type"=>"application/json"]), IOBuffer(JSON3.write(resp_dict))) + end end req_patch = @patch Braket.AWS._http_request(http_backend, request::Braket.AWS.Request, response_stream::IO) = f(http_backend, request, response_stream) apply(req_patch) do @@ -321,7 +331,7 @@ script_mode_dict = OrderedDict("s3Uri"=>"fake_uri", "entryPoint"=>"fake_entry", Braket.capture_docker_cmd(c::Cmd) = ("braket_container.py", "sadness", 1) apply(req_patch) do ljc = Braket.run_local_job!(ljc) - @test ljc.run_log == "successsuccesssuccesssuccessRun local job process exited with code: 1sadness" + @test ljc.run_log == "successsuccesssuccesssuccessbraket_container.pyRun local job process exited with code: 1sadness" end @testset "errors in copying to/from container" begin ljc.run_log = "" diff --git a/test/runtests.jl b/test/runtests.jl index a770f1b0..d558e546 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,9 @@ using Pkg, Test, Aqua, Braket in_ci = tryparse(Bool, get(ENV, "BRAKET_CI", "false")) -Aqua.test_all(Braket, ambiguities=false, unbound_args=false, piracy=false, stale_deps=!in_ci, deps_compat=!in_ci) +Aqua.test_all(Braket, ambiguities=false, unbound_args=false, piracies=false, stale_deps=!in_ci, deps_compat=!in_ci, persistent_tasks=false) Aqua.test_ambiguities(Braket) -Aqua.test_piracy(Braket, treat_as_own=[Braket.DecFP.Dec128]) +Aqua.test_piracies(Braket, treat_as_own=[Braket.DecFP.Dec128]) const GROUP = get(ENV, "GROUP", "Braket-unit") @@ -62,6 +62,7 @@ for group in groups include("task.jl") include("task_batch.jl") include("local_jobs.jl") + include("job_macro.jl") include("jobs.jl") elseif test_type == "integ" include(joinpath(@__DIR__, "integ_tests", "runtests.jl"))