diff --git a/README.md b/README.md index 0f27706..82da858 100644 --- a/README.md +++ b/README.md @@ -28,10 +28,6 @@ Use `add_predictor`: ```julia y = Omelette.add_predictor(model, predictor, x) ``` -or: -```julia -Omelette.add_predictor!(model, predictor, x, y) -``` ### LinearRegression @@ -54,6 +50,7 @@ predictor = Omelette.LogisticRegression(model_glm) ## Other constraints ### UnivariateNormalDistribution + ```julia using JuMP, Omelette model = Model(); diff --git a/src/Omelette.jl b/src/Omelette.jl index 7ba5c0e..f6944a2 100644 --- a/src/Omelette.jl +++ b/src/Omelette.jl @@ -12,45 +12,16 @@ import MathOptInterface as MOI """ abstract type AbstractPredictor end +An abstract type representig different types of prediction models. + ## Methods All subtypes must implement: - * `_add_predictor_inner` - * `Base.size` + * `add_predictor` """ abstract type AbstractPredictor end -Base.size(x::AbstractPredictor, i::Int) = size(x)[i] - -""" - add_predictor!( - model::JuMP.Model, - predictor::AbstractPredictor, - x::Vector{JuMP.VariableRef}, - y::Vector{JuMP.VariableRef}, - )::Nothing - -Add the constraint `predictor(x) .== y` to the optimization model `model`. -""" -function add_predictor!( - model::JuMP.Model, - predictor::AbstractPredictor, - x::Vector{JuMP.VariableRef}, - y::Vector{JuMP.VariableRef}, -) - output_n, input_n = size(predictor) - if length(x) != input_n - msg = "Input vector x is length $(length(x)), expected $input_n" - throw(DimensionMismatch(msg)) - elseif length(y) != output_n - msg = "Output vector y is length $(length(y)), expected $output_n" - throw(DimensionMismatch(msg)) - end - _add_predictor_inner(model, predictor, x, y) - return nothing -end - """ add_predictor( model::JuMP.Model, @@ -58,18 +29,32 @@ end x::Vector{JuMP.VariableRef}, )::Vector{JuMP.VariableRef} -Return an expression for `predictor(x)` in terms of variables in the -optimization model `model`. +Return a `Vector{JuMP.VariableRef}` representing `y` such that +`y = predictor(x)`. + +## Example + +```jldoctest +julia> using JuMP, Omelette + +julia> model = Model(); + +julia> @variable(model, x[1:2]); + +julia> f = Omelette.LinearRegression([2.0, 3.0]) +Omelette.LinearRegression([2.0 3.0]) + +julia> y = Omelette.add_predictor(model, f, x) +1-element Vector{VariableRef}: + omelette_y[1] + +julia> print(model) + Feasibility + Subject to + 2 x[1] + 3 x[2] - omelette_y[1] = 0 +``` """ -function add_predictor( - model::JuMP.Model, - predictor::AbstractPredictor, - x::Vector{JuMP.VariableRef}, -) - y = JuMP.@variable(model, [1:size(predictor, 1)], base_name = "omelette_y") - add_predictor!(model, predictor, x, y) - return y -end +function add_predictor end for file in readdir(joinpath(@__DIR__, "models"); join = true) if endswith(file, ".jl") diff --git a/src/models/LinearRegression.jl b/src/models/LinearRegression.jl index 235f890..a96ff06 100644 --- a/src/models/LinearRegression.jl +++ b/src/models/LinearRegression.jl @@ -42,14 +42,13 @@ function LinearRegression(parameters::Vector{Float64}) return LinearRegression(reshape(parameters, 1, length(parameters))) end -Base.size(f::LinearRegression) = size(f.parameters) - -function _add_predictor_inner( +function add_predictor( model::JuMP.Model, predictor::LinearRegression, x::Vector{JuMP.VariableRef}, - y::Vector{JuMP.VariableRef}, ) + m = size(predictor.parameters, 1) + y = JuMP.@variable(model, [1:m], base_name = "omelette_y") JuMP.@constraint(model, predictor.parameters * x .== y) - return + return y end diff --git a/src/models/LogisticRegression.jl b/src/models/LogisticRegression.jl index 4998eae..9e8d86f 100644 --- a/src/models/LogisticRegression.jl +++ b/src/models/LogisticRegression.jl @@ -44,12 +44,13 @@ end Base.size(f::LogisticRegression) = size(f.parameters) -function _add_predictor_inner( +function add_predictor( model::JuMP.Model, predictor::LogisticRegression, x::Vector{JuMP.VariableRef}, - y::Vector{JuMP.VariableRef}, ) + m = size(predictor.parameters, 1) + y = JuMP.@variable(model, [1:m], base_name = "omelette_y") JuMP.@constraint(model, 1 ./ (1 .+ exp.(-predictor.parameters * x)) .== y) - return + return y end diff --git a/test/test_LinearRegression.jl b/test/test_LinearRegression.jl index b19b075..bce234b 100644 --- a/test/test_LinearRegression.jl +++ b/test/test_LinearRegression.jl @@ -24,9 +24,8 @@ end function test_LinearRegression() model = Model() @variable(model, x[1:2]) - @variable(model, y[1:1]) f = Omelette.LinearRegression([2.0, 3.0]) - Omelette.add_predictor!(model, f, x, y) + y = Omelette.add_predictor(model, f, x) cons = all_constraints(model; include_variable_in_set_constraints = false) obj = constraint_object(only(cons)) @test obj.set == MOI.EqualTo(0.0) @@ -34,20 +33,6 @@ function test_LinearRegression() return end -function test_LinearRegression_dimension_mismatch() - model = Model() - @variable(model, x[1:3]) - @variable(model, y[1:2]) - f = Omelette.LinearRegression([2.0, 3.0]) - @test size(f) == (1, 2) - @test_throws DimensionMismatch Omelette.add_predictor!(model, f, x, y[1:1]) - @test_throws DimensionMismatch Omelette.add_predictor!(model, f, x[1:2], y) - g = Omelette.LinearRegression([2.0 3.0; 4.0 5.0; 6.0 7.0]) - @test size(g) == (3, 2) - @test_throws DimensionMismatch Omelette.add_predictor!(model, g, x, y) - return -end - function test_LinearRegression_GLM() num_features = 2 num_observations = 10 diff --git a/test/test_LogisticRegression.jl b/test/test_LogisticRegression.jl index b976724..38d7315 100644 --- a/test/test_LogisticRegression.jl +++ b/test/test_LogisticRegression.jl @@ -24,9 +24,8 @@ end function test_LogisticRegression() model = Model() @variable(model, x[1:2]) - @variable(model, y[1:1]) f = Omelette.LogisticRegression([2.0, 3.0]) - Omelette.add_predictor!(model, f, x, y) + y = Omelette.add_predictor(model, f, x) cons = all_constraints(model; include_variable_in_set_constraints = false) obj = constraint_object(only(cons)) @test obj.set == MOI.EqualTo(0.0) @@ -35,20 +34,6 @@ function test_LogisticRegression() return end -function test_LogisticRegression_dimension_mismatch() - model = Model() - @variable(model, x[1:3]) - @variable(model, y[1:2]) - f = Omelette.LogisticRegression([2.0, 3.0]) - @test size(f) == (1, 2) - @test_throws DimensionMismatch Omelette.add_predictor!(model, f, x, y[1:1]) - @test_throws DimensionMismatch Omelette.add_predictor!(model, f, x[1:2], y) - g = Omelette.LogisticRegression([2.0 3.0; 4.0 5.0; 6.0 7.0]) - @test size(g) == (3, 2) - @test_throws DimensionMismatch Omelette.add_predictor!(model, g, x, y) - return -end - function test_LogisticRegression_GLM() num_features = 2 num_observations = 10