diff --git a/src/datasets/injectivedata.jl b/src/datasets/injectivedata.jl index ddae0685..758866f3 100644 --- a/src/datasets/injectivedata.jl +++ b/src/datasets/injectivedata.jl @@ -32,7 +32,7 @@ end make_objective(::ContiguouslyBinned, dataset::InjectiveData) = dataset.codomain[dataset.data_mask] make_model_domain(::OneToOne, dataset::InjectiveData) = dataset.domain[dataset.data_mask] -make_objective(::OneToOne, dataset::InjectiveData) = dataset.codomain[data_mask] +make_objective(::OneToOne, dataset::InjectiveData) = dataset.codomain[dataset.data_mask] function make_objective_variance( ::AbstractDataLayout, @@ -46,7 +46,15 @@ function make_objective_variance( end end -objective_transformer(::AbstractDataLayout, dataset::InjectiveData) = _DEFAULT_TRANSFORMER() +function objective_transformer(::AbstractDataLayout, dataset::InjectiveData) + function _transformer!!(domain, objective) + @views objective[dataset.data_mask] + end + function _transformer!!(output, domain, objective) + @. output = objective[dataset.data_mask] + end + _transformer!! +end make_label(dataset::InjectiveData) = dataset.name