Skip to content

Commit

Permalink
Add some more unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jkl1337 committed Feb 26, 2024
1 parent 2b76ba9 commit f68c24d
Show file tree
Hide file tree
Showing 6 changed files with 491 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/busted.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
LD_PRELOAD: ${{ matrix.os == 'ubuntu-latest' && '/usr/lib/x86_64-linux-gnu/libasan.so.6:/usr/lib/x86_64-linux-gnu/libstdc++.so.6:/usr/lib/x86_64-linux-gnu/libubsan.so.1' || '' }}

- name: Run gcov
if: success() && ${{ startsWith(matrix.os, 'ubuntu-') }}
if: success() && startsWith(matrix.os, 'ubuntu-')
run: |
gcov -p -b -s"$(pwd)" -r *.gcda
Expand Down
2 changes: 1 addition & 1 deletion ckiwi/ckiwi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ bool kiwi_solver_has_constraint(const KiwiSolver* s, KiwiConstraint* constraint)
}

const KiwiErr* kiwi_solver_add_edit_var(KiwiSolver* s, KiwiVar* var, double strength) {
return wrap_err(s, var, [strength](auto& s, auto&& v) {
return wrap_err(s, var, [strength](auto&& s, auto&& v) {
s.addEditVariable(Variable(v), strength);
});
}
Expand Down
4 changes: 2 additions & 2 deletions luakiwi/luakiwi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ KiwiTerm* term_new(lua_State* L) {

inline KiwiExpression* expr_new(lua_State* L, int nterms) {
auto* expr = static_cast<KiwiExpression*>(lua_newuserdata(L, KiwiExpression::sz(nterms)));
expr->term_count = 0;
expr->owner = nullptr;
push_type(L, EXPR);
lua_setmetatable(L, -2);
Expand Down Expand Up @@ -929,9 +930,8 @@ int lkiwi_expr_new(lua_State* L) {

auto* expr = expr_new(L, nterms);
expr->constant = constant;
expr->term_count = nterms;

for (int i = 0; i < nterms; i++) {
for (int i = 0; i < nterms; ++i, ++expr->term_count) {
const auto* term = get_term(L, i + 2);
expr->terms[i].var = retain_unmanaged(term->var);
expr->terms[i].coefficient = term->coefficient;
Expand Down
240 changes: 240 additions & 0 deletions spec/expression_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
expose("module", function()
require("kiwi")
end)

describe("Expression", function()
local kiwi = require("kiwi")
local LUA_VERSION = tonumber(_VERSION:match("%d+%.%d+"))

it("construction", function()
local v = kiwi.Var("foo")
local v2 = kiwi.Var("bar")
local v3 = kiwi.Var("aux")
local e1 = kiwi.Expression(0, v * 1, v2 * 2, v3 * 3)
local e2 = kiwi.Expression(10, v * 1, v2 * 2, v3 * 3)

local constants = { 0, 10 }
for i, e in ipairs({ e1, e2 }) do
assert.equal(constants[i], e.constant)
local terms = e:terms()
assert.equal(3, #terms)
assert.equal(v, terms[1].var)
assert.equal(1.0, terms[1].coefficient)
assert.equal(v2, terms[2].var)
assert.equal(2.0, terms[2].coefficient)
assert.equal(v3, terms[3].var)
assert.equal(3.0, terms[3].coefficient)
end

if LUA_VERSION <= 5.2 then
assert.equal("1 foo + 2 bar + 3 aux + 10", tostring(e2))
else
assert.equal("1.0 foo + 2.0 bar + 3.0 aux + 10.0", tostring(e2))
end

assert.error(function()
kiwi.Expression(0, 0, v2 * 2, v3 * 3)
end)
end)

describe("method", function()
local v, t, e
before_each(function()
v = kiwi.Var("foo")
v:set(42)
t = kiwi.Term(v, 10)
e = t + 5
end)

it("has value", function()
v:set(42)
assert.equal(425, e:value())
v:set(87)
assert.equal(875, e:value())
end)

it("can be copied", function()
local e2 = e:copy()
assert.equal(e.constant, e2.constant)
local t1, t2 = e:terms(), e2:terms()
assert.equal(#t1, #t2)
for i = 1, #t1 do
assert.equal(t1[i].var, t2[i].var)
assert.equal(t1[i].coefficient, t2[i].coefficient)
end
end)

it("neg", function()
local neg = -e --[[@as kiwi.Expression]]
assert.True(kiwi.is_expression(neg))
local terms = neg:terms()
assert.equal(1, #terms)
assert.equal(v, terms[1].var)
assert.equal(-10.0, terms[1].coefficient)
assert.equal(-5, neg.constant)
end)

describe("bin op", function()
local v2, t2, e2
before_each(function()
v2 = kiwi.Var("bar")
t2 = kiwi.Term(v2)
e2 = v2 - 10
end)

it("mul", function()
for _, prod in ipairs({ e * 2.0, 2 * e }) do
assert.True(kiwi.is_expression(prod))
local terms = prod:terms()
assert.equal(1, #terms)
assert.equal(v, terms[1].var)
assert.equal(20.0, terms[1].coefficient)
assert.equal(10, prod.constant)
end

assert.error(function()
local _ = e * v
end)
end)

it("div", function()
local quot = e / 2.0
assert.True(kiwi.is_expression(quot))
local terms = quot:terms()
assert.equal(1, #terms)
assert.equal(v, terms[1].var)
assert.equal(5.0, terms[1].coefficient)
assert.equal(2.5, quot.constant)

assert.error(function()
local _ = e / v2
end)
end)

it("add", function()
for _, sum in ipairs({ e + 2.0, 2 + e }) do
assert.True(kiwi.is_expression(sum))
assert.equal(7.0, sum.constant)

local terms = sum:terms()
assert.equal(1, #terms)
assert.equal(10.0, terms[1].coefficient)
assert.equal(v, terms[1].var)
end

local sum = e + v2
assert.True(kiwi.is_expression(sum))
assert.equal(5, sum.constant)
local terms = sum:terms()
assert.equal(2, #terms)
assert.equal(v, terms[1].var)
assert.equal(10.0, terms[1].coefficient)
assert.equal(v2, terms[2].var)
assert.equal(1.0, terms[2].coefficient)

sum = e + t2
assert.True(kiwi.is_expression(sum))
assert.equal(5, sum.constant)
terms = sum:terms()
assert.equal(2, #terms)
assert.equal(v, terms[1].var)
assert.equal(10.0, terms[1].coefficient)
assert.equal(v2, terms[2].var)
assert.equal(1.0, terms[2].coefficient)

sum = e + e2
assert.True(kiwi.is_expression(sum))
assert.equal(-5, sum.constant)
terms = sum:terms()
assert.equal(2, #terms)
assert.equal(v, terms[1].var)
assert.equal(10.0, terms[1].coefficient)
assert.equal(v2, terms[2].var)
assert.equal(1.0, terms[2].coefficient)

assert.error(function()
local _ = t + "foo"
end)
assert.error(function()
local _ = t + {}
end)
end)

it("sub", function()
local constants = { 3, -3 }
for i, diff in ipairs({ e - 2.0, 2 - e }) do
local constant = constants[i]
assert.True(kiwi.is_expression(diff))
assert.equal(constant, diff.constant)

local terms = diff:terms()
assert.equal(1, #terms)
assert.equal(v, terms[1].var)
assert.equal(constant < 0 and -10.0 or 10.0, terms[1].coefficient)
end

local diff = e - v2
assert.True(kiwi.is_expression(diff))
assert.equal(5, diff.constant)
local terms = diff:terms()
assert.equal(2, #terms)
assert.equal(v, terms[1].var)
assert.equal(10.0, terms[1].coefficient)
assert.equal(v2, terms[2].var)
assert.equal(-1.0, terms[2].coefficient)

diff = e - t2
assert.True(kiwi.is_expression(diff))
assert.equal(5, diff.constant)
terms = diff:terms()
assert.equal(2, #terms)
assert.equal(v, terms[1].var)
assert.equal(10.0, terms[1].coefficient)
assert.equal(v2, terms[2].var)
assert.equal(-1.0, terms[2].coefficient)

diff = e - e2
assert.True(kiwi.is_expression(diff))
assert.equal(15, diff.constant)
terms = diff:terms()
assert.equal(2, #terms)
assert.equal(v, terms[1].var)
assert.equal(10.0, terms[1].coefficient)
assert.equal(v2, terms[2].var)
assert.equal(-1.0, terms[2].coefficient)

assert.error(function()
local _ = e - "foo"
end)
assert.error(function()
local _ = e - {}
end)
end)

it("constraint expr op expr", function()
local ops = { "LE", "EQ", "GE" }
for i, meth in ipairs({ "le", "eq", "ge" }) do
local c = e[meth](e, e2)
assert.True(kiwi.is_constraint(c))

local expr = c:expression()
local terms = expr:terms()
assert.equal(2, #terms)

-- order can be randomized due to use of map
if terms[1].var ~= v then
terms[1], terms[2] = terms[2], terms[1]
end
assert.equal(v, terms[1].var)
assert.equal(10.0, terms[1].coefficient)
assert.equal(v2, terms[2].var)
assert.equal(-1.0, terms[2].coefficient)

assert.equal(15, expr.constant)
assert.equal(ops[i], c:op())
assert.equal(kiwi.strength.REQUIRED, c:strength())
end
end)
end)
end)
end)
Loading

0 comments on commit f68c24d

Please sign in to comment.