Skip to content

Commit

Permalink
Merge pull request #95 from PALEOtoolkit/user_data
Browse files Browse the repository at this point in the history
OutputMemory add user_data Dict
  • Loading branch information
sjdaines authored May 21, 2024
2 parents 1f4722e + 40326fe commit 9109f57
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 13 deletions.
55 changes: 43 additions & 12 deletions src/OutputWriters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -480,30 +480,43 @@ end
# OutputMemory
##########################################

const UserDataTypes = Union{Float64, Int64, String, Vector{Float64}, Vector{Int64}, Vector{String}}

"""
OutputMemory
OutputMemory(; user_data=Dict{String, UserDataTypes}())
In-memory container for model output, organized by model Domains.
Implements the [`PALEOmodel.AbstractOutputWriter`](@ref) interface, with additional methods
[`save_netcdf`](@ref) and [`load_netcdf!`](@ref) to save and load data.
# Implementation
Field `domains::Dict{String, OutputMemoryDomain}` contains per-Domain model output.
- Field `domains::Dict{String, OutputMemoryDomain}` contains per-Domain model output.
- Field `user_data::Dict{String, UserDataTypes}` contains optional user data
NB:
- available types are restricted to those that are compatible with NetCDF attribute types,
ie Float64, Int64, String, Vector{Float64}, Vector{Int64}, Vector{String}
- Vectors with a single element are read back from netcdf as scalars,
see https://alexander-barth.github.io/NCDatasets.jl/dev/issues/#Corner-cases
"""
struct OutputMemory <: PALEOmodel.AbstractOutputWriter

domains::Dict{String, OutputMemoryDomain}
user_data::Dict{String, UserDataTypes}
end

function OutputMemory()
return OutputMemory(Dict{String, OutputMemoryDomain}())

const default_user_data=Dict{String, UserDataTypes}(
"title"=>"PALEO (exo)Earth system model output",
"source"=>"PALEOmodel https://github.com/PALEOtoolkit/PALEOmodel.jl",
)

function OutputMemory(; user_data=default_user_data)
return OutputMemory(Dict{String, OutputMemoryDomain}(), user_data)
end

"create from collection of OutputMemoryDomain"
function OutputMemory(output_memory_domains::Union{Vector, Tuple})
om = OutputMemory(Dict(om.name => om for om in output_memory_domains))
function OutputMemory(output_memory_domains::Union{Vector, Tuple}; user_data=default_user_data)
om = OutputMemory(Dict(om.name => om for om in output_memory_domains), user_data)
return om
end

Expand Down Expand Up @@ -754,7 +767,7 @@ end

"compact form"
function Base.show(io::IO, output::OutputMemory)
print(io, "OutputMemory(domains=", keys(output.domains), ")")
print(io, "OutputMemory(domains=", keys(output.domains), ", user_data=", output.user_data, ")")
end


Expand Down Expand Up @@ -801,7 +814,6 @@ Save to `filename` in netcdf4 format (NB: filename must either have no extension
"""
function save_netcdf(
output::OutputMemory, filename;
additional_attributes::AbstractVector{<:Pair{<:AbstractString, <:Any}} = Pair{String, String}[],
check_ext::Bool=true,
add_coordinates::Bool=false,
)
Expand All @@ -819,11 +831,17 @@ function save_netcdf(
@info "saving to $filename ..."

NCDatasets.NCDataset(filename, "c") do nc_dataset
nc_dataset.attrib["title"] = "PALEO (exo)Earth system model output"
nc_dataset.attrib["source"] = "PALEOmodel https://github.com/PALEOtoolkit/PALEOmodel.jl"
nc_dataset.attrib["PALEO_netcdf_version"] = "0.1.0"
nc_dataset.attrib["PALEO_domains"] = join([k for (k, v) in output.domains], " ")

for (k, v) in output.user_data
if k in ("PALEO_netcdf_version", "PALEO_domains")
@warn "ignoring reserved user_data key $k"
continue
end
nc_dataset.attrib[k] = v
end

for (domname, dom) in output.domains

dsg = NCDatasets.defGroup(nc_dataset, dom.name; attrib=[])
Expand Down Expand Up @@ -935,6 +953,19 @@ function load_netcdf!(output::OutputMemory, filename; check_ext=true)
!ismissing(paleo_netcdf_version) || error("not a PALEO netcdf output file ? (key PALEO_netcdf_version not present)")
paleo_netcdf_version == "0.1.0" || error("unsupported PALEO_netcdf_version $paleo_netcdf_version")

for (k, v) in nc_dataset.attrib
if k in ("PALEO_netcdf_version", "PALEO_domains")
continue
end
# workaround for https://github.com/Alexander-Barth/NCDatasets.jl/issues/258
if v isa Int32
v = Int64(v)
elseif v isa Vector{Int32}
v = Int64.(v)
end
output.user_data[k] = v
end

for (domainname, dsg) in nc_dataset.group
coords_record = dsg.attrib["coords_record"]
nrecs = dsg.dim[coords_record]
Expand Down
25 changes: 24 additions & 1 deletion test/runoutputwritertests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,37 @@ end
all_values.global.O .= [4e19]
PALEOmodel.OutputWriters.add_record!(output, model, modeldata, 0.0)

tmpfile = tempname(; cleanup=true)
tmpfile = tempname(; cleanup=true)

output.user_data["testString"] = "hello"
output.user_data["testInt64"] = 42
output.user_data["testFloat64"] = 42.0
output.user_data["testVecString"] = ["hello", "world"]
output.user_data["testVecInt64"] = [42, 43]
output.user_data["testVecFloat64_1"] = [42.0]
output.user_data["testVecFloat64"] = [42.0, 43.0]

PALEOmodel.OutputWriters.save_netcdf(output, tmpfile)

load_output = PALEOmodel.OutputWriters.load_netcdf!(PALEOmodel.OutputWriters.OutputMemory(), tmpfile)

O_array = PALEOmodel.get_array(load_output, "global.O")
@test O_array.values == [2e19, 4e19]

function test_user_key_type_value(k, v)
lv = load_output.user_data[k]
@test typeof(lv) == typeof(v)
@test lv == v
end

test_user_key_type_value("testString", "hello")
test_user_key_type_value("testInt64", 42)
test_user_key_type_value("testFloat64", 42.0)
test_user_key_type_value("testVecString", ["hello", "world"])
test_user_key_type_value("testVecInt64", [42, 43])
@test_broken load_output.user_data["testVecFloat64_1"] == [42.0] # returned as a scalar
test_user_key_type_value("testVecFloat64", [42.0, 43.0])

end

@testset "DataFrameCreate" begin
Expand Down

0 comments on commit 9109f57

Please sign in to comment.