Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OutputMemory add user_data Dict #95

Merged
merged 1 commit into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions src/OutputWriters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -480,30 +480,43 @@ end
# OutputMemory
##########################################

const UserDataTypes = Union{Float64, Int64, String, Vector{Float64}, Vector{Int64}, Vector{String}}

"""
OutputMemory
OutputMemory(; user_data=Dict{String, UserDataTypes}())
In-memory container for model output, organized by model Domains.
Implements the [`PALEOmodel.AbstractOutputWriter`](@ref) interface, with additional methods
[`save_netcdf`](@ref) and [`load_netcdf!`](@ref) to save and load data.
# Implementation
Field `domains::Dict{String, OutputMemoryDomain}` contains per-Domain model output.
- Field `domains::Dict{String, OutputMemoryDomain}` contains per-Domain model output.
- Field `user_data::Dict{String, UserDataTypes}` contains optional user data
NB:
- available types are restricted to those that are compatible with NetCDF attribute types,
ie Float64, Int64, String, Vector{Float64}, Vector{Int64}, Vector{String}
- Vectors with a single element are read back from netcdf as scalars,
see https://alexander-barth.github.io/NCDatasets.jl/dev/issues/#Corner-cases
"""
struct OutputMemory <: PALEOmodel.AbstractOutputWriter

domains::Dict{String, OutputMemoryDomain}
user_data::Dict{String, UserDataTypes}
end

function OutputMemory()
return OutputMemory(Dict{String, OutputMemoryDomain}())

const default_user_data=Dict{String, UserDataTypes}(
"title"=>"PALEO (exo)Earth system model output",
"source"=>"PALEOmodel https://github.com/PALEOtoolkit/PALEOmodel.jl",
)

function OutputMemory(; user_data=default_user_data)
return OutputMemory(Dict{String, OutputMemoryDomain}(), user_data)
end

"create from collection of OutputMemoryDomain"
function OutputMemory(output_memory_domains::Union{Vector, Tuple})
om = OutputMemory(Dict(om.name => om for om in output_memory_domains))
function OutputMemory(output_memory_domains::Union{Vector, Tuple}; user_data=default_user_data)
om = OutputMemory(Dict(om.name => om for om in output_memory_domains), user_data)
return om
end

Expand Down Expand Up @@ -754,7 +767,7 @@ end

"compact form"
function Base.show(io::IO, output::OutputMemory)
print(io, "OutputMemory(domains=", keys(output.domains), ")")
print(io, "OutputMemory(domains=", keys(output.domains), ", user_data=", output.user_data, ")")
end


Expand Down Expand Up @@ -801,7 +814,6 @@ Save to `filename` in netcdf4 format (NB: filename must either have no extension
"""
function save_netcdf(
output::OutputMemory, filename;
additional_attributes::AbstractVector{<:Pair{<:AbstractString, <:Any}} = Pair{String, String}[],
check_ext::Bool=true,
add_coordinates::Bool=false,
)
Expand All @@ -819,11 +831,17 @@ function save_netcdf(
@info "saving to $filename ..."

NCDatasets.NCDataset(filename, "c") do nc_dataset
nc_dataset.attrib["title"] = "PALEO (exo)Earth system model output"
nc_dataset.attrib["source"] = "PALEOmodel https://github.com/PALEOtoolkit/PALEOmodel.jl"
nc_dataset.attrib["PALEO_netcdf_version"] = "0.1.0"
nc_dataset.attrib["PALEO_domains"] = join([k for (k, v) in output.domains], " ")

for (k, v) in output.user_data
if k in ("PALEO_netcdf_version", "PALEO_domains")
@warn "ignoring reserved user_data key $k"
continue
end
nc_dataset.attrib[k] = v
end

for (domname, dom) in output.domains

dsg = NCDatasets.defGroup(nc_dataset, dom.name; attrib=[])
Expand Down Expand Up @@ -935,6 +953,19 @@ function load_netcdf!(output::OutputMemory, filename; check_ext=true)
!ismissing(paleo_netcdf_version) || error("not a PALEO netcdf output file ? (key PALEO_netcdf_version not present)")
paleo_netcdf_version == "0.1.0" || error("unsupported PALEO_netcdf_version $paleo_netcdf_version")

for (k, v) in nc_dataset.attrib
if k in ("PALEO_netcdf_version", "PALEO_domains")
continue
end
# workaround for https://github.com/Alexander-Barth/NCDatasets.jl/issues/258
if v isa Int32
v = Int64(v)
elseif v isa Vector{Int32}
v = Int64.(v)
end
output.user_data[k] = v
end

for (domainname, dsg) in nc_dataset.group
coords_record = dsg.attrib["coords_record"]
nrecs = dsg.dim[coords_record]
Expand Down
25 changes: 24 additions & 1 deletion test/runoutputwritertests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,37 @@ end
all_values.global.O .= [4e19]
PALEOmodel.OutputWriters.add_record!(output, model, modeldata, 0.0)

tmpfile = tempname(; cleanup=true)
tmpfile = tempname(; cleanup=true)

output.user_data["testString"] = "hello"
output.user_data["testInt64"] = 42
output.user_data["testFloat64"] = 42.0
output.user_data["testVecString"] = ["hello", "world"]
output.user_data["testVecInt64"] = [42, 43]
output.user_data["testVecFloat64_1"] = [42.0]
output.user_data["testVecFloat64"] = [42.0, 43.0]

PALEOmodel.OutputWriters.save_netcdf(output, tmpfile)

load_output = PALEOmodel.OutputWriters.load_netcdf!(PALEOmodel.OutputWriters.OutputMemory(), tmpfile)

O_array = PALEOmodel.get_array(load_output, "global.O")
@test O_array.values == [2e19, 4e19]

function test_user_key_type_value(k, v)
lv = load_output.user_data[k]
@test typeof(lv) == typeof(v)
@test lv == v
end

test_user_key_type_value("testString", "hello")
test_user_key_type_value("testInt64", 42)
test_user_key_type_value("testFloat64", 42.0)
test_user_key_type_value("testVecString", ["hello", "world"])
test_user_key_type_value("testVecInt64", [42, 43])
@test_broken load_output.user_data["testVecFloat64_1"] == [42.0] # returned as a scalar
test_user_key_type_value("testVecFloat64", [42.0, 43.0])

end

@testset "DataFrameCreate" begin
Expand Down
Loading