-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add package skeleton and a basic implementation of LinearRegression (#2)
- Loading branch information
Showing
14 changed files
with
336 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Configuration file for JuliaFormatter.jl | ||
# For more information, see: https://domluna.github.io/JuliaFormatter.jl/stable/config/ | ||
|
||
always_for_in = true | ||
always_use_return = true | ||
margin = 80 | ||
remove_extra_newlines = true | ||
separate_kwargs_with_semicolon = true | ||
short_to_long_function_def = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
name: TagBot | ||
on: | ||
issue_comment: | ||
types: | ||
- created | ||
workflow_dispatch: | ||
jobs: | ||
TagBot: | ||
if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: JuliaRegistries/TagBot@v1 | ||
with: | ||
token: ${{ secrets.GITHUB_TOKEN }} | ||
ssh: ${{ secrets.DOCUMENTER_KEY }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
name: aqua-lint | ||
on: | ||
push: | ||
branches: | ||
- master | ||
pull_request: | ||
types: [opened, synchronize, reopened] | ||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: julia-actions/setup-julia@latest | ||
with: | ||
version: '1' | ||
- uses: actions/checkout@v4 | ||
- name: Aqua | ||
shell: julia --color=yes {0} | ||
run: | | ||
using Pkg | ||
Pkg.add(PackageSpec(name="Aqua")) | ||
Pkg.develop(PackageSpec(path=pwd())) | ||
using Omelette, Aqua | ||
Aqua.test_all(Omelette) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
name: CI | ||
on: | ||
push: | ||
branches: | ||
- master | ||
- release-* | ||
pull_request: | ||
types: [opened, synchronize, reopened] | ||
# needed to allow julia-actions/cache to delete old caches that it has created | ||
permissions: | ||
actions: write | ||
contents: read | ||
jobs: | ||
test: | ||
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
# Since Omelette doesn't have binary dependencies, only test on a subset of | ||
# possible platforms. | ||
include: | ||
- version: '1' # The latest point-release (Linux) | ||
os: ubuntu-latest | ||
arch: x64 | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: julia-actions/setup-julia@v2 | ||
with: | ||
version: ${{ matrix.version }} | ||
arch: ${{ matrix.arch }} | ||
- uses: julia-actions/cache@v1 | ||
- uses: julia-actions/julia-buildpkg@v1 | ||
- uses: julia-actions/julia-runtest@v1 | ||
with: | ||
depwarn: error | ||
- uses: julia-actions/julia-processcoverage@v1 | ||
- uses: codecov/codecov-action@v4 | ||
with: | ||
file: lcov.info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
name: format-check | ||
on: | ||
push: | ||
branches: | ||
- master | ||
- release-* | ||
pull_request: | ||
types: [opened, synchronize, reopened] | ||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: julia-actions/setup-julia@latest | ||
with: | ||
version: '1' | ||
- uses: actions/checkout@v4 | ||
- name: Format check | ||
shell: julia --color=yes {0} | ||
run: | | ||
using Pkg | ||
# If you update the version, also update the style guide docs. | ||
Pkg.add(PackageSpec(name="JuliaFormatter", version="1")) | ||
using JuliaFormatter | ||
format("."; verbose = true) | ||
out = String(read(Cmd(`git diff`))) | ||
if isempty(out) | ||
exit(0) | ||
end | ||
@error "Some files have not been formatted !!!" | ||
write(stdout, out) | ||
exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Manifest.toml | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
Copyright (c) 2024: Oscar Dowson and contributors | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
name = "Omelette" | ||
uuid = "e52c2cb8-508e-4e12-9dd2-9c4755b60e73" | ||
authors = ["odow <[email protected]>"] | ||
version = "0.1.0" | ||
|
||
[deps] | ||
JuMP = "4076af6c-e467-56ae-b986-b466b2749572" | ||
|
||
[compat] | ||
JuMP = "1" | ||
julia = "1.6" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,23 @@ | ||
# jump-ml | ||
# Omelette | ||
|
||
_If you can come up with a better name, please open an issue._ | ||
|
||
Omelette is a [JuMP](https://jump.dev) extension for embedding common types of | ||
AI, machine learning, and statistical learning models into a JuMP optimization | ||
model. | ||
|
||
## License | ||
|
||
Omelette.jl is licensed under the [MIT license](https://github.com/lanl-ansi/jump-ml/blob/main/LICENSE.md) | ||
|
||
## Getting help | ||
|
||
This package is under active development. For help, questions, comments, and | ||
suggestions, please open a GitHub issue. | ||
|
||
## Inspiration | ||
|
||
This project is inspired by two existing projects: | ||
|
||
* [OMLT](https://github.com/cog-imperial/OMLT) | ||
* [gurobi-machinelearning](https://github.com/Gurobi/gurobi-machinelearning) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (c) 2024: Oscar Dowson and contributors | ||
# | ||
# Use of this source code is governed by an MIT-style license that can be found | ||
# in the LICENSE.md file or at https://opensource.org/licenses/MIT. | ||
|
||
module Omelette | ||
|
||
import JuMP | ||
|
||
""" | ||
abstract type AbstractModel end | ||
## Methods | ||
All subtypes must implement: | ||
* `add_model_internal` | ||
* `Base.size` | ||
""" | ||
abstract type AbstractModel end | ||
|
||
""" | ||
add_model( | ||
opt_model::JuMP.Model, | ||
ml_model::AbstractModel, | ||
x::Vector{JuMP.VariableRef}, | ||
y::Vector{JuMP.VariableRef}, | ||
) | ||
Add the constraint `ml_model(x) == y` to the optimization model `opt_model`. | ||
## Input | ||
## Output | ||
* `::Nothing` | ||
## Examples | ||
TODO | ||
""" | ||
function add_model( | ||
opt_model::JuMP.Model, | ||
ml_model::AbstractModel, | ||
x::Vector{JuMP.VariableRef}, | ||
y::Vector{JuMP.VariableRef}, | ||
) | ||
output_n, input_n = size(ml_model) | ||
if length(x) != input_n | ||
msg = "Input vector x is length $(length(x)), expected $input_n" | ||
throw(DimensionMismatch(msg)) | ||
elseif length(y) != output_n | ||
msg = "Output vector y is length $(length(y)), expected $output_n" | ||
throw(DimensionMismatch(msg)) | ||
end | ||
_add_model_inner(opt_model, ml_model, x, y) | ||
return | ||
end | ||
|
||
for file in readdir(joinpath(@__DIR__, "models"); join = true) | ||
if endswith(file, ".jl") | ||
include(file) | ||
end | ||
end | ||
|
||
end # module Omelette |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) 2024: Oscar Dowson and contributors | ||
# | ||
# Use of this source code is governed by an MIT-style license that can be found | ||
# in the LICENSE.md file or at https://opensource.org/licenses/MIT. | ||
|
||
struct LinearRegression <: AbstractModel | ||
parameters::Matrix{Float64} | ||
end | ||
|
||
function LinearRegression(parameters::Vector{Float64}) | ||
return LinearRegression(reshape(parameters, 1, length(parameters))) | ||
end | ||
|
||
Base.size(f::LinearRegression) = size(f.parameters) | ||
|
||
function _add_model_inner( | ||
opt_model::JuMP.Model, | ||
ml_model::LinearRegression, | ||
x::Vector{JuMP.VariableRef}, | ||
y::Vector{JuMP.VariableRef}, | ||
) | ||
JuMP.@constraint(opt_model, ml_model.parameters * x .== y) | ||
return | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[deps] | ||
JuMP = "4076af6c-e467-56ae-b986-b466b2749572" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
[compat] | ||
Test = "<0.0.1, 1.6" | ||
julia = "1.6" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) 2024: Oscar Dowson and contributors | ||
# | ||
# Use of this source code is governed by an MIT-style license that can be found | ||
# in the LICENSE.md file or at https://opensource.org/licenses/MIT. | ||
|
||
module LinearRegressionTests | ||
|
||
using Test | ||
using JuMP | ||
import Omelette | ||
|
||
function runtests() | ||
for name in names(@__MODULE__; all = true) | ||
if startswith("$name", "test_") | ||
@testset "$name" begin | ||
getfield(@__MODULE__, name)() | ||
end | ||
end | ||
end | ||
return | ||
end | ||
|
||
function test_LinearRegression() | ||
model = Model() | ||
@variable(model, x[1:2]) | ||
@variable(model, y[1:1]) | ||
f = Omelette.LinearRegression([2.0, 3.0]) | ||
Omelette.add_model(model, f, x, y) | ||
cons = all_constraints(model; include_variable_in_set_constraints = false) | ||
obj = constraint_object(only(cons)) | ||
@test obj.set == MOI.EqualTo(0.0) | ||
@test isequal_canonical(obj.func, 2.0 * x[1] + 3.0 * x[2] - y[1]) | ||
return | ||
end | ||
|
||
function test_LinearRegression_dimension_mismatch() | ||
model = Model() | ||
@variable(model, x[1:3]) | ||
@variable(model, y[1:2]) | ||
f = Omelette.LinearRegression([2.0, 3.0]) | ||
@test size(f) == (1, 2) | ||
@test_throws DimensionMismatch Omelette.add_model(model, f, x, y[1:1]) | ||
@test_throws DimensionMismatch Omelette.add_model(model, f, x[1:2], y) | ||
g = Omelette.LinearRegression([2.0 3.0; 4.0 5.0; 6.0 7.0]) | ||
@test size(g) == (3, 2) | ||
@test_throws DimensionMismatch Omelette.add_model(model, g, x, y) | ||
return | ||
end | ||
|
||
end | ||
|
||
LinearRegressionTests.runtests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright (c) 2024: Oscar Dowson and contributors | ||
# | ||
# Use of this source code is governed by an MIT-style license that can be found | ||
# in the LICENSE.md file or at https://opensource.org/licenses/MIT. | ||
|
||
using Test | ||
|
||
for file in readdir(joinpath(@__DIR__, "models")) | ||
if startswith(file, "test_") && endswith(file, ".jl") | ||
@testset "$file" begin | ||
include(joinpath(@__DIR__, "models", file)) | ||
end | ||
end | ||
end |