Skip to content

Commit

Permalink
Merge pull request #80 from fjebaker/fergus/fix-79
Browse files Browse the repository at this point in the history
Fix `mul!` problems and bump Julia CI version
  • Loading branch information
fjebaker authored Feb 22, 2024
2 parents 7198b88 + cf2918f commit 0dc8ea7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: ['1.9']
julia-version: ['1.10']
os: [ubuntu-latest]

steps:
Expand Down
2 changes: 1 addition & 1 deletion src/abstract-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ function make_diff_parameter_cache(

# embed current parameter values inside of the dual cache
# else all frozens will be zero
get_tmp(diffcache, ForwardDiff.Dual(1.0)) .= vals
get_tmp(diffcache, ForwardDiff.Dual(one(eltype(vals)))) .= vals

ParameterCache(free_mask, diffcache)
end
27 changes: 22 additions & 5 deletions src/fitting/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ _invoke_and_transform!(cache::AbstractFittingCache, domain, params) =
error("Not implemented for $(typeof(cache))")

# one of these for each (mulit)model / data pair
struct SpectralCache{M,O,T,P,TransformerType} <: AbstractFittingCache
struct SpectralCache{M,O,T,K,P,TransformerType} <: AbstractFittingCache
model::M
model_output::O
calculated_objective::T
output_cache::K
parameter_cache::P
transfomer!!::TransformerType
function SpectralCache(
Expand All @@ -19,15 +20,29 @@ struct SpectralCache{M,O,T,P,TransformerType} <: AbstractFittingCache
param_diff_cache_size = nothing,
) where {M,XfmT}
model_output = DiffCache(construct_objective_cache(layout, model, domain))
calc_obj = similar(objective)
calc_obj .= 0
# fix for https://github.com/fjebaker/SpectralFitting.jl/issues/79
# output must be a vector but can only give matrix to `mul!`, so we need to
# unfortunately duplicate the array to ensure we have both types
calc_obj = zeros(eltype(objective), (length(objective), 1))
calc_obj_cache = DiffCache(calc_obj)
# vector chache
output = similar(objective)
output .= 0
output_cache = DiffCache(output)
param_cache =
make_diff_parameter_cache(model; param_diff_cache_size = param_diff_cache_size)
new{M,typeof(model_output),typeof(calc_obj_cache),typeof(param_cache),XfmT}(
new{
M,
typeof(model_output),
typeof(calc_obj_cache),
typeof(output_cache),
typeof(param_cache),
XfmT,
}(
model,
model_output,
calc_obj_cache,
output_cache,
param_cache,
transformer,
)
Expand Down Expand Up @@ -60,7 +75,9 @@ function _invoke_and_transform!(cache::SpectralCache, domain, params)
output = invokemodel!(model_output, domain, cache.model, parameters)
cache.transfomer!!(calc_obj, domain, output)

calc_obj
output_vector = get_tmp(cache.output_cache, params)
output_vector .= calc_obj
output_vector
end

struct FittingConfig{ImplType,CacheType,P,D,O}
Expand Down

0 comments on commit 0dc8ea7

Please sign in to comment.