Skip to content

Commit

Permalink
Merge pull request #71 from phajy/main
Browse files Browse the repository at this point in the history
Add data_mask to InjectiveData
  • Loading branch information
fjebaker authored Nov 24, 2023
2 parents 1941d43 + 56f0ba4 commit f851af6
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
24 changes: 17 additions & 7 deletions src/datasets/injectivedata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ struct InjectiveData{V} <: AbstractDataset
domain_variance::Union{Nothing,V}
codomain_variance::Union{Nothing,V}
name::String
data_mask::BitVector
end

function InjectiveData(
Expand All @@ -13,8 +14,9 @@ function InjectiveData(
domain_variance = nothing,
codomain_variance = nothing,
name = "[no-name]",
data_mask = BitVector(fill(true, size(codomain)))
)
InjectiveData(domain, codomain, domain_variance, codomain_variance, name)
InjectiveData(domain, codomain, domain_variance, codomain_variance, name, data_mask)
end

supports_contiguosly_binned(::Type{<:InjectiveData}) = true
Expand All @@ -27,24 +29,32 @@ function make_model_domain(::ContiguouslyBinned, dataset::InjectiveData)
push!(domain, domain[end] + Δ)
domain
end
make_objective(::ContiguouslyBinned, dataset::InjectiveData) = dataset.codomain
make_objective(::ContiguouslyBinned, dataset::InjectiveData) = dataset.codomain[dataset.data_mask]

make_model_domain(::OneToOne, dataset::InjectiveData) = dataset.domain
make_objective(::OneToOne, dataset::InjectiveData) = dataset.codomain
make_objective(::OneToOne, dataset::InjectiveData) = dataset.codomain[dataset.data_mask]

function make_objective_variance(
::AbstractDataLayout,
dataset::InjectiveData{V},
)::V where {V}
if !isnothing(dataset.domain_variance)
dataset.codomain_variance
if !isnothing(dataset.codomain_variance)
dataset.codomain_variance[dataset.data_mask]
else
# todo: i dunno just something
1e-8 .* dataset.codomain
1e-8 .* dataset.codomain[dataset.data_mask]
end
end

objective_transformer(::AbstractDataLayout, dataset::InjectiveData) = _DEFAULT_TRANSFORMER()
function objective_transformer(::AbstractDataLayout, dataset::InjectiveData)
function _transformer!!(domain, objective)
@views objective[dataset.data_mask]
end
function _transformer!!(output, domain, objective)
@. output = objective[dataset.data_mask]
end
_transformer!!
end

make_label(dataset::InjectiveData) = dataset.name

Expand Down
17 changes: 17 additions & 0 deletions test/fitting/test-fit-simple-dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,20 @@ prob = FittingProblem(model => data)

result = fit(prob, LevenbergMarquadt())
@test result.u[2] 6.1 atol = 0.1

# fitting a contiguously binned dataset with some masked bins
x = 10 .^ collect(range(-1, 2, 10))
y = x .^ -2.0
y_err = 0.1 .* y
# introduce some bogus data points to ignore
y[2:5] .= 2.0
data = InjectiveData(x, y, codomain_variance=y_err)
# mask out the bogus data points
data.data_mask[2:5] .= false
model = XS_PowerLaw(K=FitParam(1.0E-5), a=FitParam(2.0))
prob = FittingProblem(model => data)
@test SpectralFitting.common_support(model, data) isa SpectralFitting.ContiguouslyBinned
result = fit(prob, LevenbergMarquadt())
@test result.u[1] 2.55 atol = 0.01
@test result.u[2] 3.0 atol = 0.05
# note best fit photon index, u[2] should be 3 not 2 becuase y contains bin integrated values not the density

0 comments on commit f851af6

Please sign in to comment.