From 40326fe8724d88488b3406ad0e1b48472d030aeb Mon Sep 17 00:00:00 2001 From: Stuart Daines Date: Tue, 21 May 2024 17:13:23 +0100 Subject: [PATCH] OutputMemory add user_data Dict PALEOmodel.OutputWriters.OutputMemory can now store key-value pairs in a 'user_dict' field, which is written and read from netcdf output. This is a Julia Dict, and can be used to store arbitrary metadata, eg to label file output to reconstruct a model grid. It can either be supplied when the OutputMemory is created, or modified later. NB: available types are restricted to those that are compatible with NetCDF attribute types, ie Float64, Int64, String, Vector{Float64}, Vector{Int64}, Vector{String} See OutputMemory docstring for details. --- src/OutputWriters.jl | 55 ++++++++++++++++++++++++++++-------- test/runoutputwritertests.jl | 25 +++++++++++++++- 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/src/OutputWriters.jl b/src/OutputWriters.jl index e97644b..e5f8df9 100644 --- a/src/OutputWriters.jl +++ b/src/OutputWriters.jl @@ -480,30 +480,43 @@ end # OutputMemory ########################################## +const UserDataTypes = Union{Float64, Int64, String, Vector{Float64}, Vector{Int64}, Vector{String}} + """ - OutputMemory + OutputMemory(; user_data=Dict{String, UserDataTypes}()) In-memory container for model output, organized by model Domains. Implements the [`PALEOmodel.AbstractOutputWriter`](@ref) interface, with additional methods [`save_netcdf`](@ref) and [`load_netcdf!`](@ref) to save and load data. - # Implementation -Field `domains::Dict{String, OutputMemoryDomain}` contains per-Domain model output. +- Field `domains::Dict{String, OutputMemoryDomain}` contains per-Domain model output. +- Field `user_data::Dict{String, UserDataTypes}` contains optional user data + NB: + - available types are restricted to those that are compatible with NetCDF attribute types, + ie Float64, Int64, String, Vector{Float64}, Vector{Int64}, Vector{String} + - Vectors with a single element are read back from netcdf as scalars, + see https://alexander-barth.github.io/NCDatasets.jl/dev/issues/#Corner-cases """ struct OutputMemory <: PALEOmodel.AbstractOutputWriter - domains::Dict{String, OutputMemoryDomain} + user_data::Dict{String, UserDataTypes} end -function OutputMemory() - return OutputMemory(Dict{String, OutputMemoryDomain}()) + +const default_user_data=Dict{String, UserDataTypes}( + "title"=>"PALEO (exo)Earth system model output", + "source"=>"PALEOmodel https://github.com/PALEOtoolkit/PALEOmodel.jl", +) + +function OutputMemory(; user_data=default_user_data) + return OutputMemory(Dict{String, OutputMemoryDomain}(), user_data) end "create from collection of OutputMemoryDomain" -function OutputMemory(output_memory_domains::Union{Vector, Tuple}) - om = OutputMemory(Dict(om.name => om for om in output_memory_domains)) +function OutputMemory(output_memory_domains::Union{Vector, Tuple}; user_data=default_user_data) + om = OutputMemory(Dict(om.name => om for om in output_memory_domains), user_data) return om end @@ -754,7 +767,7 @@ end "compact form" function Base.show(io::IO, output::OutputMemory) - print(io, "OutputMemory(domains=", keys(output.domains), ")") + print(io, "OutputMemory(domains=", keys(output.domains), ", user_data=", output.user_data, ")") end @@ -801,7 +814,6 @@ Save to `filename` in netcdf4 format (NB: filename must either have no extension """ function save_netcdf( output::OutputMemory, filename; - additional_attributes::AbstractVector{<:Pair{<:AbstractString, <:Any}} = Pair{String, String}[], check_ext::Bool=true, add_coordinates::Bool=false, ) @@ -819,11 +831,17 @@ function save_netcdf( @info "saving to $filename ..." NCDatasets.NCDataset(filename, "c") do nc_dataset - nc_dataset.attrib["title"] = "PALEO (exo)Earth system model output" - nc_dataset.attrib["source"] = "PALEOmodel https://github.com/PALEOtoolkit/PALEOmodel.jl" nc_dataset.attrib["PALEO_netcdf_version"] = "0.1.0" nc_dataset.attrib["PALEO_domains"] = join([k for (k, v) in output.domains], " ") + for (k, v) in output.user_data + if k in ("PALEO_netcdf_version", "PALEO_domains") + @warn "ignoring reserved user_data key $k" + continue + end + nc_dataset.attrib[k] = v + end + for (domname, dom) in output.domains dsg = NCDatasets.defGroup(nc_dataset, dom.name; attrib=[]) @@ -935,6 +953,19 @@ function load_netcdf!(output::OutputMemory, filename; check_ext=true) !ismissing(paleo_netcdf_version) || error("not a PALEO netcdf output file ? (key PALEO_netcdf_version not present)") paleo_netcdf_version == "0.1.0" || error("unsupported PALEO_netcdf_version $paleo_netcdf_version") + for (k, v) in nc_dataset.attrib + if k in ("PALEO_netcdf_version", "PALEO_domains") + continue + end + # workaround for https://github.com/Alexander-Barth/NCDatasets.jl/issues/258 + if v isa Int32 + v = Int64(v) + elseif v isa Vector{Int32} + v = Int64.(v) + end + output.user_data[k] = v + end + for (domainname, dsg) in nc_dataset.group coords_record = dsg.attrib["coords_record"] nrecs = dsg.dim[coords_record] diff --git a/test/runoutputwritertests.jl b/test/runoutputwritertests.jl index 31e252a..3bc7886 100644 --- a/test/runoutputwritertests.jl +++ b/test/runoutputwritertests.jl @@ -90,7 +90,16 @@ end all_values.global.O .= [4e19] PALEOmodel.OutputWriters.add_record!(output, model, modeldata, 0.0) - tmpfile = tempname(; cleanup=true) + tmpfile = tempname(; cleanup=true) + + output.user_data["testString"] = "hello" + output.user_data["testInt64"] = 42 + output.user_data["testFloat64"] = 42.0 + output.user_data["testVecString"] = ["hello", "world"] + output.user_data["testVecInt64"] = [42, 43] + output.user_data["testVecFloat64_1"] = [42.0] + output.user_data["testVecFloat64"] = [42.0, 43.0] + PALEOmodel.OutputWriters.save_netcdf(output, tmpfile) load_output = PALEOmodel.OutputWriters.load_netcdf!(PALEOmodel.OutputWriters.OutputMemory(), tmpfile) @@ -98,6 +107,20 @@ end O_array = PALEOmodel.get_array(load_output, "global.O") @test O_array.values == [2e19, 4e19] + function test_user_key_type_value(k, v) + lv = load_output.user_data[k] + @test typeof(lv) == typeof(v) + @test lv == v + end + + test_user_key_type_value("testString", "hello") + test_user_key_type_value("testInt64", 42) + test_user_key_type_value("testFloat64", 42.0) + test_user_key_type_value("testVecString", ["hello", "world"]) + test_user_key_type_value("testVecInt64", [42, 43]) + @test_broken load_output.user_data["testVecFloat64_1"] == [42.0] # returned as a scalar + test_user_key_type_value("testVecFloat64", [42.0, 43.0]) + end @testset "DataFrameCreate" begin