Skip to content

Commit

Permalink
Add UnivariateNormalDistribution (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored May 27, 2024
1 parent 2788981 commit c3ef77a
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 0 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ authors = ["odow <[email protected]>"]
version = "0.1.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"

[weakdeps]
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
Expand All @@ -13,6 +15,8 @@ GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
OmeletteGLMExt = "GLM"

[compat]
Distributions = "0.25"
GLM = "1.9"
JuMP = "1"
MathOptInterface = "1"
julia = "1.9"
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,17 @@ X, Y = rand(10, 2), rand(Bool, 10)
model_glm = GLM.glm(X, Y, GLM.Bernoulli())
predictor = Omelette.LogisticRegression(model_glm)
```

## Other constraints

### UnivariateNormalDistribution
```julia
using JuMP, Omelette
model = Model();
@variable(model, 0 <= x <= 5);
f = Omelette.UnivariateNormalDistribution(;
mean = x -> only(x),
covariance = x -> 1.0,
);
Omelette.add_constraint(model, f, [x], MOI.Interval(0.5, Inf), 0.95);
```
2 changes: 2 additions & 0 deletions src/Omelette.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

module Omelette

import Distributions
import JuMP
import MathOptInterface as MOI

"""
abstract type AbstractPredictor end
Expand Down
105 changes: 105 additions & 0 deletions src/models/UnivariateNormalDistribution.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2024: Oscar Dowson and contributors
#
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

"""
UnivariateNormalDistribution(; mean::Function, std_dev::Function)
A univariate Normal distribution, represented by the functions `mean(x::Vector)`
and `std_dev(x::Vector)`.
## Example
```jldoctest
julia> import Omelette
julia> Omelette.UnivariateNormalDistribution(;
mean = x -> only(x),
std_dev = x -> 1.0,
)
UnivariateNormalDistribution(mean, std_dev)
```
"""
struct UnivariateNormalDistribution{F,G}
mean::F
std_dev::G

function UnivariateNormalDistribution(; mean::Function, std_dev::Function)
return new{typeof(mean),typeof(std_dev)}(mean, std_dev)
end
end

function Base.show(io::IO, x::UnivariateNormalDistribution)
return print(io, "UnivariateNormalDistribution(mean, std_dev)")
end

"""
add_constraint(
model::JuMP.Model,
f::UnivariateNormalDistribution,
set::MOI.Interval,
β::Float64,
)
Add the constraint:
```math
\\mathbb{P}(f(x) \\in [l, u]) \\ge β
```
where \$f(x)~\\mathcal{N}(\\mu, \\sigma)\$ is a normally distributed random
variable given by the `UnivariateNormalDistribution`.
If both `l` and `u` are finite, then the probability mass is equally
distributed, so that each side of the constraint holds with `(1 + β) / 2`.
## Examples
```jldoctest
julia> using JuMP, Omelette
julia> model = Model();
julia> @variable(model, 0 <= x <= 5);
julia> f = Omelette.UnivariateNormalDistribution(;
mean = x -> only(x),
std_dev = x -> 1.0,
);
julia> Omelette.add_constraint(model, f, [x], MOI.Interval(0.5, Inf), 0.95);
julia> print(model)
Feasibility
Subject to
x ≥ 2.1448536269514715
x ≥ 0
x ≤ 5
```
"""
function add_constraint(
model::JuMP.Model,
N::UnivariateNormalDistribution,
x::Vector{JuMP.VariableRef},
set::MOI.Interval,
β::Float64,
)
@assert β >= 0.5
if isfinite(set.upper) && isfinite(set.lower)
# Dual-sided chance constraint. In this case, we want β to be the joint
# probabiltiy, so take an equal probabiltiy each side.
β = (1 + β) / 2
end
if isfinite(set.upper)
# P(f(x) ≤ u) ≥ β
# => μ(x) + Φ⁻¹(β) * σ <= u
λ = Distributions.invlogcdf(Distributions.Normal(0, 1), log(β))
JuMP.@constraint(model, N.mean(x) + λ * N.std_dev(x) <= set.upper)
end
if isfinite(set.lower)
# P(f(x) ≥ l) ≥ β
# => μ(x) + Φ⁻¹(1 - β) * σ >= l
λ = Distributions.invlogcdf(Distributions.Normal(0, 1), log(1 - β))
JuMP.@constraint(model, N.mean(x) + λ * N.std_dev(x) >= set.lower)
end
return
end
59 changes: 59 additions & 0 deletions test/test_UnivariateNormalDistribution.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2024: Oscar Dowson and contributors
#
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

module UnivariateNormalDistributionTests

using JuMP
using Test

import Ipopt
import Omelette

is_test(x) = startswith(string(x), "test_")

function runtests()
@testset "$name" for name in filter(is_test, names(@__MODULE__; all = true))
getfield(@__MODULE__, name)()
end
return
end

function test_normal_lower_limit()
model = Model(Ipopt.Optimizer)
set_silent(model)
@variable(model, 0 <= x <= 5)
@objective(model, Min, x)
f = Omelette.UnivariateNormalDistribution(;
mean = x -> only(x),
std_dev = x -> 1.0,
)
Omelette.add_constraint(model, f, [x], MOI.Interval(0.5, Inf), 0.95)
optimize!(model)
@test is_solved_and_feasible(model)
# μ: Distributions.invlogcdf(Distributions.Normal(μ, 1.0), log(0.05)) = 0.5
@test isapprox(value(x), 2.1448536; atol = 1e-4)
return
end

function test_normal_upper_limit()
model = Model(Ipopt.Optimizer)
@variable(model, -5 <= x <= 5)
@objective(model, Max, x)
f = Omelette.UnivariateNormalDistribution(;
mean = x -> only(x),
std_dev = x -> 1.0,
)
Omelette.add_constraint(model, f, [x], MOI.Interval(-Inf, 0.5), 0.95)
set_silent(model)
optimize!(model)
@test is_solved_and_feasible(model)
# μ: Distributions.invlogcdf(Distributions.Normal(μ, 1.0), log(0.95)) = 0.5
@test isapprox(value(x), -1.1448536; atol = 1e-4)
return
end

end

UnivariateNormalDistributionTests.runtests()

0 comments on commit c3ef77a

Please sign in to comment.