Skip to content

Commit

Permalink
feature: support serde pickle v4 (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
WingCode authored Aug 28, 2024
1 parent 43df9a6 commit f5ec984
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down Expand Up @@ -58,6 +59,7 @@ Markdown = "=0.7.5"
Mocking = "=0.8.1"
NamedTupleTools = "=0.14.3"
OrderedCollections = "=1.6.3"
Pickle = "0.3.5"
Pkg = "1.6"
Random = "1.6"
SparseArrays = "1.6"
Expand Down
7 changes: 5 additions & 2 deletions src/jobs.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using Base64
using Pickle

const JOB_DEFAULT_RESULTS_POLL_TIMEOUT = 864000
const JOB_DEFAULT_RESULTS_POLL_INTERVAL = 5
const JOB_TERMINAL_STATES = ["COMPLETED", "FAILED", "CANCELLED"]
Expand Down Expand Up @@ -67,13 +70,13 @@ function get_hyperparameters()
end

function serialize_values(data_dictionary::Dict{String, Any}, data_format::PersistedJobDataFormat)
data_format == pickled_v4 && throw(ArgumentError("pickling data not yet supported!"))
data_format == pickled_v4 && return Dict(k => base64encode(Pickle.stores(v)) for (k, v) in data_dictionary)
return data_dictionary
end

function deserialize_values(data_dictionary::Dict{String, Any}, data_format::PersistedJobDataFormat)
data_format == plaintext && return data_dictionary
throw(ArgumentError("unpickling results not yet supported!"))
return Dict(k => Pickle.loads(base64decode(v)) for (k, v) in data_dictionary)
end
deserialize_values(data_dictionary::Dict{String, Any}, data_format::String) = deserialize_values(data_dictionary, PersistedJobDataFormatDict[data_format])

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Mocking = "78c3b35d-d492-501b-9361-3d52fe80e533"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
19 changes: 15 additions & 4 deletions test/jobs.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
using Braket, Test, Mocking, Random, Dates, Tar, JSON3
using Braket, Pickle, Base64, Test, Mocking, Random, Dates, Tar, JSON3
Mocking.activate()

Base.parse(d::Dict) = d

dev_arn = "arn:aws:braket:::device/quantum-simulator/amazon/sv1"

@testset "Jobs" begin
@testset "deserialization errors" begin
@test_throws ArgumentError Braket.deserialize_values(Dict{String, Any}(), Braket.pickled_v4)
@testset "serialization pickled_v4" begin
data_dictionary = Dict{String, Any}("key1" => "value1", "key2" => "value2")
data_format = Braket.pickled_v4
result = Braket.serialize_values(data_dictionary, data_format)
@test result == Dict{String, Any}("key1" => base64encode(Pickle.stores("value1")), "key2" => base64encode(Pickle.stores("value2")))
end
@testset "deserialization pickled_v4" begin
data_dictionary = Dict{String, Any}("key1" => base64encode(Pickle.stores("value1")), "key2" => base64encode(Pickle.stores("value2")))
data_format = Braket.pickled_v4
result = Braket.deserialize_values(data_dictionary, data_format)
@test result == Dict{String, Any}("key1" => "value1", "key2" => "value2")
end
@testset "deserialization" begin
@test Dict{Any,Any}() == Braket.deserialize_values(Dict{String,Any}(), Braket.pickled_v4)
mktempdir() do d
job = Braket.AwsQuantumJob("arn:fake")
@test Braket._read_and_deserialize_results(job, d) == []
pjd = Braket.PersistedJobData(Braket.header_dict[Braket.PersistedJobData], Dict{String, Any}(), Braket.pickled_v4)
write(joinpath(d, Braket.RESULTS_FILENAME), JSON3.write(pjd))
@test_throws ArgumentError Braket._read_and_deserialize_results(job, d)
end
end
@testset "logs" begin
Expand Down

0 comments on commit f5ec984

Please sign in to comment.