Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return NaN for negative ModeResult variance estimates #2471

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

frankier
Copy link

Here's a modified example that gets negative estimates for variance of some parameters (coefficients_versicolor[3]):

using Turing
using RDatasets
using StatsPlots
using MLDataUtils: shuffleobs, splitobs, rescale!
using NNlib: softmax
using FillArrays
using LinearAlgebra
using Random
Random.seed!(0);

using Optim
using StatsBase

data = RDatasets.dataset("datasets", "iris");
data[rand(1:size(data, 1), 20), :]
species = ["setosa", "versicolor", "virginica"]
data[!, :Species_index] = indexin(data[!, :Species], species)
data[rand(1:size(data, 1), 20), [:Species, :Species_index]]
trainset, testset = splitobs(shuffleobs(data), 0.5)
features = [:SepalLength, :SepalWidth, :PetalLength, :PetalWidth]
target = :Species_index

train_features = Matrix(trainset[!, features])
test_features = Matrix(testset[!, features])
train_target = trainset[!, target]
test_target = testset[!, target]

μ, σ = rescale!(train_features; obsdim=1)
rescale!(test_features, μ, σ; obsdim=1);

@model function logistic_regression(x, y, σ)
    n = size(x, 1)
    length(y) == n ||
        throw(DimensionMismatch("number of observations in `x` and `y` is not equal"))

    # Priors of intercepts and coefficients.
    intercept_versicolor ~ Normal(0, σ)
    intercept_virginica ~ Normal(0, σ)
    coefficients_versicolor ~ MvNormal(Zeros(4), σ^2 * I)
    coefficients_virginica ~ MvNormal(Zeros(4), σ^2 * I)

    # Compute the likelihood of the observations.
    values_versicolor = intercept_versicolor .+ x * coefficients_versicolor
    values_virginica = intercept_virginica .+ x * coefficients_virginica
    for i in 1:n
        # the 0 corresponds to the base category `setosa`
        v = softmax([0, values_versicolor[i], values_virginica[i]])
        y[i] ~ Categorical(v)
    end
end;

model = logistic_regression(train_features, train_target, 1)
mle_estimate = Optim.optimize(model, MLE())
println(coeftable(mle_estimate))

Without this PR, this will throw a DomainError in coeftable when calling getting the stderr of coefficients_versicolor[3].

@frankier
Copy link
Author

This is related to #2048

I don't fully agree with the conclusion that there is nothing to fix in Turing.jl here.

Propagating a NaN makes it easier to inspect the coeftable and see that something has gone wrong with the optimisation process.

Copy link

codecov bot commented Jan 20, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 85.02%. Comparing base (24d5556) to head (d0050f8).

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #2471   +/-   ##
=======================================
  Coverage   85.01%   85.02%           
=======================================
  Files          21       21           
  Lines        1582     1583    +1     
=======================================
+ Hits         1345     1346    +1     
  Misses        237      237           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coveralls
Copy link

Pull Request Test Coverage Report for Build 12867285375

Details

  • 1 of 1 (100.0%) changed or added relevant line in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.02%) to 76.474%

Totals Coverage Status
Change from base Build 12786420948: 0.02%
Covered Lines: 1206
Relevant Lines: 1577

💛 - Coveralls

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @frankier. It looks like a nice improvement!

@mhauru
Copy link
Member

mhauru commented Jan 21, 2025

Thanks @frankier, I agree that the current situation where coeftable fails with DomainError is not optimal, and this is an improvement. I wonder if we could be even more explicit though, and in our method for coeftable, catch the DomainError, print out a warning explaining that the solution seems to have negative variance and thus you should be very suspicious of your result, and then return a table with stderr as NaN. @frankier, as a user, do you think that would be helpful? This would also save us introducing a new dependency that we only use on a single line.

@frankier
Copy link
Author

Yes, I think on balance that would be better.

I think there is also the possibility of getting a SingularException in inv, which I guess also indicates model identifiability (and thus optimization) problems. So I guess it's better to catch these, aggregate them and report how it failed alongside the table.

I'll update this PR to work this way soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants