Skip to content

Commit

Permalink
Merge pull request #3351 from CliMA/gb/automatic_restart
Browse files Browse the repository at this point in the history
Reorganize restart tests
  • Loading branch information
Sbozzolo authored Oct 3, 2024
2 parents df2b7ee + 17c0c81 commit d4cb583
Showing 1 changed file with 143 additions and 102 deletions.
245 changes: 143 additions & 102 deletions test/restart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import ClimaCore.Spaces: AbstractSpace
import ClimaComms
pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends
import Logging
import NCDatasets
using Test

import Random
Expand All @@ -35,10 +36,6 @@ ClimaComms.init(comms_ctx)
# different.
#
# For this reason, we don't use Test but just print to screen the differences.
# However, we still have to return an exit code with failure in case of the
# comparison fails. So, we have this global `SUCCESS` bool that is updated by
# the result of tests.
const SUCCESS::Base.RefValue{Bool} = Ref(true)

"""
_error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1)))
Expand Down Expand Up @@ -130,6 +127,17 @@ function _compare(v1::T, v2::T; name, ignore) where {T <: Number}
return print_maybe(v1 === v2, "$name differs: $v1 vs $v2")
end

# We ignore NCDatasets. They contain a lot of state-ful information
function _compare(
pass,
v1::T,
v2::T;
name,
ignore,
) where {T <: NCDatasets.NCDataset}
return pass
end

function _compare(
v1::T,
v2::T;
Expand All @@ -143,6 +151,17 @@ function _compare(pass, v1::T, v2::T; name, ignore) where {T <: AbstractData}
return pass && _compare(parent(v1), parent(v2); name, ignore)
end

# Handle views
function _compare(
pass,
v1::SubArray{FT},
v2::SubArray{FT};
name,
ignore,
) where {FT <: AbstractFloat}
return pass && _compare(collect(v1), collect(v2); name, ignore)
end

function _compare(
v1::AbstractArray{FT},
v2::AbstractArray{FT};
Expand All @@ -167,6 +186,118 @@ end
# Disable all the @info statements that are produced when creating a simulation
Logging.disable_logging(Logging.Info)


"""
test_restart(test_dict; job_id, comms_ctx, more_ignore = Symbol[])
Test if the restarts are consistent for a simulation defined by the `test_dict` config.
`more_ignore` is a Vector of Symbols that identifies config-specific keys that
have to be ignored when reading a simulation.
"""
function test_restart(test_dict; job_id, comms_ctx, more_ignore = Symbol[])
println("job_id = $(job_id)")

local_success = Ref(true)

config = CA.AtmosConfig(test_dict; job_id, comms_ctx)

simulation = CA.get_simulation(config)
CA.solve_atmos!(simulation)

# Check re-importing the same state
restart_dir = simulation.output_dir
@test isfile(joinpath(restart_dir), "day0.3.hdf5")

# Reset random seed for RRTMGP
Random.seed!(1234)

ClimaComms.iamroot(comms_ctx) && println(" just reading data")
config_should_be_same = CA.AtmosConfig(
merge(test_dict, Dict("detect_restart_file" => true));
job_id,
comms_ctx,
)

simulation_restarted = CA.get_simulation(config_should_be_same)

local_success[] &= compare(
simulation.integrator.u,
simulation_restarted.integrator.u;
name = "integrator.u",
)
local_success[] &= compare(
axes(simulation.integrator.u.c),
axes(simulation_restarted.integrator.u.c);
name = "space",
)
local_success[] &= compare(
simulation.integrator.p,
simulation_restarted.integrator.p;
name = "integrator.p",
ignore = Set([
:ghost_buffer,
:hyperdiffusion_ghost_buffer,
:scratch,
:output_dir,
:ghost_buffer,
# Computed in tendencies (which are not computed in this case)
:hyperdiff,
:precipitation,
# rc is some CUDA/CuArray internal object that we don't care about
:rc,
# DataHandlers contains caches, so they are stateful
:data_handler,
# Config-specific
more_ignore...,
]),
)

# Check re-importing from previous state and advancing one step
ClimaComms.iamroot(comms_ctx) && println(" reading and simulating")
# Reset random seed for RRTMGP
Random.seed!(1234)

restart_file = joinpath(simulation.output_dir, "day0.2.hdf5")
@test isfile(joinpath(restart_dir), "day0.2.hdf5")
# Restart from specific file
config2 = CA.AtmosConfig(
merge(test_dict, Dict("restart_file" => restart_file));
job_id,
comms_ctx,
)

simulation_restarted2 = CA.get_simulation(config2)
CA.fill_with_nans!(simulation_restarted2.integrator.p)

CA.solve_atmos!(simulation_restarted2)
local_success[] &= compare(
simulation.integrator.u,
simulation_restarted2.integrator.u;
name = "integrator.u",
)
local_success[] &= compare(
simulation.integrator.p,
simulation_restarted2.integrator.p;
name = "integrator.p",
ignore = Set([
:scratch,
:output_dir,
:ghost_buffer,
:hyperdiffusion_ghost_buffer,
:data_handler,
:rc,
]),
)

return local_success[]
end

# Let's prepare the test_dicts. TESTING is a Vector of NamedTuples, each element
# has a test_dict, a job_id, and a more_ignore

TESTING = Any[]

if comms_ctx isa ClimaComms.SingletonCommsContext
configurations = ["sphere", "box", "column"]
else
Expand All @@ -180,13 +311,13 @@ for configuration in configurations
topography = "Earth"
turbconv_models = [nothing, "diagnostic_edmfx"]
# turbconv_models = ["prognostic_edmfx"]
radiations = [nothing]
radiations = [nothing, "gray"]
else
moistures = ["equil"]
precips = ["1M"]
topography = "NoWarp"
turbconv_models = ["diagnostic_edmfx"]
radiations = [nothing]
radiations = [nothing, "gray"]
end

for turbconv_mode in turbconv_models
Expand All @@ -200,9 +331,6 @@ for configuration in configurations
end
end

println(
"config = $configuration $moisture $precip $topography $radiation",
)
# The `enable_bubble` case is broken for ClimaCore < 0.14.6, so we
# hard-code this to be always false for those versions
bubble = pkgversion(ClimaCore) > v"0.14.5"
Expand All @@ -211,9 +339,8 @@ for configuration in configurations
output_loc =
ClimaComms.iamroot(comms_ctx) ? mktempdir(pwd()) : ""
output_loc = ClimaComms.bcast(comms_ctx, output_loc)
ClimaComms.barrier(comms_ctx)

job_id = "restart"
job_id = "$(configuration)_$(moisture)_$(precip)_$(topography)_$(radiation)"
test_dict = Dict(
"test_dycore_consistency" => true, # We will add NaNs to the cache, just to make sure
"check_nan_every" => 3,
Expand All @@ -240,103 +367,17 @@ for configuration in configurations
)
more_ignore = Symbol[]

config = CA.AtmosConfig(test_dict; job_id, comms_ctx)

simulation = CA.get_simulation(config)
CA.solve_atmos!(simulation)

# Check re-importing the same state
restart_dir = simulation.output_dir
@test isfile(joinpath(restart_dir), "day0.3.hdf5")

# Reset random seed for RRTMGP
Random.seed!(1234)

println(" just reading data")
if turbconv_mode == "prognostic_edmf"
more_ignore = [:ᶠnh_pressure₃ʲs]
end

config_should_be_same = CA.AtmosConfig(
merge(test_dict, Dict("detect_restart_file" => true));
job_id,
comms_ctx,
)

simulation_restarted =
CA.get_simulation(config_should_be_same)

SUCCESS[] &= compare(
simulation.integrator.u,
simulation_restarted.integrator.u;
name = "integrator.u",
)
SUCCESS[] &= compare(
axes(simulation.integrator.u.c),
axes(simulation_restarted.integrator.u.c);
name = "space",
)
SUCCESS[] &= compare(
simulation.integrator.p,
simulation_restarted.integrator.p;
name = "integrator.p",
ignore = Set([
:ghost_buffer,
:hyperdiffusion_ghost_buffer,
:scratch,
:output_dir,
:ghost_buffer,
# Computed in tendencies (which are not computed in this case)
:hyperdiff,
:precipitation,
# rc is some CUDA/CuArray internal object that we don't care about
:rc,
# Config-specific
more_ignore...,
]),
)

# Check re-importing from previous state and advancing one step
println(" reading and simulating")
# Reset random seed for RRTMGP
Random.seed!(1234)

restart_file =
joinpath(simulation.output_dir, "day0.2.hdf5")
@test isfile(joinpath(restart_dir), "day0.2.hdf5")
# Restart from specific file
config2 = CA.AtmosConfig(
merge(test_dict, Dict("restart_file" => restart_file));
job_id,
comms_ctx,
)

simulation_restarted2 = CA.get_simulation(config2)
CA.fill_with_nans!(simulation_restarted2.integrator.p)

CA.solve_atmos!(simulation_restarted2)
SUCCESS[] &= compare(
simulation.integrator.u,
simulation_restarted2.integrator.u;
name = "integrator.u",
)
SUCCESS[] &= compare(
simulation.integrator.p,
simulation_restarted2.integrator.p;
name = "integrator.p",
ignore = Set([
:scratch,
:output_dir,
:ghost_buffer,
:hyperdiffusion_ghost_buffer,
:rc,
]),
)
push!(TESTING, (; test_dict, job_id, more_ignore))
end
end
end
end
end

# Ensure that we have the correct exit code
@test SUCCESS[]
@test all(
@time test_restart(t.test_dict; comms_ctx, t.job_id, t.more_ignore) for
t in TESTING
)

0 comments on commit d4cb583

Please sign in to comment.