diff --git a/.github/workflows/busted.yml b/.github/workflows/busted.yml index bc57a55..efef0bd 100644 --- a/.github/workflows/busted.yml +++ b/.github/workflows/busted.yml @@ -53,7 +53,7 @@ jobs: LD_PRELOAD: ${{ matrix.os == 'ubuntu-latest' && '/usr/lib/x86_64-linux-gnu/libasan.so.6:/usr/lib/x86_64-linux-gnu/libstdc++.so.6:/usr/lib/x86_64-linux-gnu/libubsan.so.1' || '' }} - name: Run gcov - if: success() && ${{ startsWith(matrix.os, 'ubuntu-') }} + if: success() && startsWith(matrix.os, 'ubuntu-') run: | gcov -p -b -s"$(pwd)" -r *.gcda diff --git a/ckiwi/ckiwi.cpp b/ckiwi/ckiwi.cpp index be6c013..410656a 100644 --- a/ckiwi/ckiwi.cpp +++ b/ckiwi/ckiwi.cpp @@ -326,7 +326,7 @@ bool kiwi_solver_has_constraint(const KiwiSolver* s, KiwiConstraint* constraint) } const KiwiErr* kiwi_solver_add_edit_var(KiwiSolver* s, KiwiVar* var, double strength) { - return wrap_err(s, var, [strength](auto& s, auto&& v) { + return wrap_err(s, var, [strength](auto&& s, auto&& v) { s.addEditVariable(Variable(v), strength); }); } diff --git a/luakiwi/luakiwi.cpp b/luakiwi/luakiwi.cpp index bf63335..89ffa8a 100644 --- a/luakiwi/luakiwi.cpp +++ b/luakiwi/luakiwi.cpp @@ -248,6 +248,7 @@ KiwiTerm* term_new(lua_State* L) { inline KiwiExpression* expr_new(lua_State* L, int nterms) { auto* expr = static_cast(lua_newuserdata(L, KiwiExpression::sz(nterms))); + expr->term_count = 0; expr->owner = nullptr; push_type(L, EXPR); lua_setmetatable(L, -2); @@ -929,9 +930,8 @@ int lkiwi_expr_new(lua_State* L) { auto* expr = expr_new(L, nterms); expr->constant = constant; - expr->term_count = nterms; - for (int i = 0; i < nterms; i++) { + for (int i = 0; i < nterms; ++i, ++expr->term_count) { const auto* term = get_term(L, i + 2); expr->terms[i].var = retain_unmanaged(term->var); expr->terms[i].coefficient = term->coefficient; diff --git a/spec/expression_spec.lua b/spec/expression_spec.lua new file mode 100644 index 0000000..89812ac --- /dev/null +++ b/spec/expression_spec.lua @@ -0,0 +1,240 @@ +expose("module", function() + require("kiwi") +end) + +describe("Expression", function() + local kiwi = require("kiwi") + local LUA_VERSION = tonumber(_VERSION:match("%d+%.%d+")) + + it("construction", function() + local v = kiwi.Var("foo") + local v2 = kiwi.Var("bar") + local v3 = kiwi.Var("aux") + local e1 = kiwi.Expression(0, v * 1, v2 * 2, v3 * 3) + local e2 = kiwi.Expression(10, v * 1, v2 * 2, v3 * 3) + + local constants = { 0, 10 } + for i, e in ipairs({ e1, e2 }) do + assert.equal(constants[i], e.constant) + local terms = e:terms() + assert.equal(3, #terms) + assert.equal(v, terms[1].var) + assert.equal(1.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(2.0, terms[2].coefficient) + assert.equal(v3, terms[3].var) + assert.equal(3.0, terms[3].coefficient) + end + + if LUA_VERSION <= 5.2 then + assert.equal("1 foo + 2 bar + 3 aux + 10", tostring(e2)) + else + assert.equal("1.0 foo + 2.0 bar + 3.0 aux + 10.0", tostring(e2)) + end + + assert.error(function() + kiwi.Expression(0, 0, v2 * 2, v3 * 3) + end) + end) + + describe("method", function() + local v, t, e + before_each(function() + v = kiwi.Var("foo") + v:set(42) + t = kiwi.Term(v, 10) + e = t + 5 + end) + + it("has value", function() + v:set(42) + assert.equal(425, e:value()) + v:set(87) + assert.equal(875, e:value()) + end) + + it("can be copied", function() + local e2 = e:copy() + assert.equal(e.constant, e2.constant) + local t1, t2 = e:terms(), e2:terms() + assert.equal(#t1, #t2) + for i = 1, #t1 do + assert.equal(t1[i].var, t2[i].var) + assert.equal(t1[i].coefficient, t2[i].coefficient) + end + end) + + it("neg", function() + local neg = -e --[[@as kiwi.Expression]] + assert.True(kiwi.is_expression(neg)) + local terms = neg:terms() + assert.equal(1, #terms) + assert.equal(v, terms[1].var) + assert.equal(-10.0, terms[1].coefficient) + assert.equal(-5, neg.constant) + end) + + describe("bin op", function() + local v2, t2, e2 + before_each(function() + v2 = kiwi.Var("bar") + t2 = kiwi.Term(v2) + e2 = v2 - 10 + end) + + it("mul", function() + for _, prod in ipairs({ e * 2.0, 2 * e }) do + assert.True(kiwi.is_expression(prod)) + local terms = prod:terms() + assert.equal(1, #terms) + assert.equal(v, terms[1].var) + assert.equal(20.0, terms[1].coefficient) + assert.equal(10, prod.constant) + end + + assert.error(function() + local _ = e * v + end) + end) + + it("div", function() + local quot = e / 2.0 + assert.True(kiwi.is_expression(quot)) + local terms = quot:terms() + assert.equal(1, #terms) + assert.equal(v, terms[1].var) + assert.equal(5.0, terms[1].coefficient) + assert.equal(2.5, quot.constant) + + assert.error(function() + local _ = e / v2 + end) + end) + + it("add", function() + for _, sum in ipairs({ e + 2.0, 2 + e }) do + assert.True(kiwi.is_expression(sum)) + assert.equal(7.0, sum.constant) + + local terms = sum:terms() + assert.equal(1, #terms) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v, terms[1].var) + end + + local sum = e + v2 + assert.True(kiwi.is_expression(sum)) + assert.equal(5, sum.constant) + local terms = sum:terms() + assert.equal(2, #terms) + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(1.0, terms[2].coefficient) + + sum = e + t2 + assert.True(kiwi.is_expression(sum)) + assert.equal(5, sum.constant) + terms = sum:terms() + assert.equal(2, #terms) + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(1.0, terms[2].coefficient) + + sum = e + e2 + assert.True(kiwi.is_expression(sum)) + assert.equal(-5, sum.constant) + terms = sum:terms() + assert.equal(2, #terms) + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(1.0, terms[2].coefficient) + + assert.error(function() + local _ = t + "foo" + end) + assert.error(function() + local _ = t + {} + end) + end) + + it("sub", function() + local constants = { 3, -3 } + for i, diff in ipairs({ e - 2.0, 2 - e }) do + local constant = constants[i] + assert.True(kiwi.is_expression(diff)) + assert.equal(constant, diff.constant) + + local terms = diff:terms() + assert.equal(1, #terms) + assert.equal(v, terms[1].var) + assert.equal(constant < 0 and -10.0 or 10.0, terms[1].coefficient) + end + + local diff = e - v2 + assert.True(kiwi.is_expression(diff)) + assert.equal(5, diff.constant) + local terms = diff:terms() + assert.equal(2, #terms) + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(-1.0, terms[2].coefficient) + + diff = e - t2 + assert.True(kiwi.is_expression(diff)) + assert.equal(5, diff.constant) + terms = diff:terms() + assert.equal(2, #terms) + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(-1.0, terms[2].coefficient) + + diff = e - e2 + assert.True(kiwi.is_expression(diff)) + assert.equal(15, diff.constant) + terms = diff:terms() + assert.equal(2, #terms) + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(-1.0, terms[2].coefficient) + + assert.error(function() + local _ = e - "foo" + end) + assert.error(function() + local _ = e - {} + end) + end) + + it("constraint expr op expr", function() + local ops = { "LE", "EQ", "GE" } + for i, meth in ipairs({ "le", "eq", "ge" }) do + local c = e[meth](e, e2) + assert.True(kiwi.is_constraint(c)) + + local expr = c:expression() + local terms = expr:terms() + assert.equal(2, #terms) + + -- order can be randomized due to use of map + if terms[1].var ~= v then + terms[1], terms[2] = terms[2], terms[1] + end + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(-1.0, terms[2].coefficient) + + assert.equal(15, expr.constant) + assert.equal(ops[i], c:op()) + assert.equal(kiwi.strength.REQUIRED, c:strength()) + end + end) + end) + end) +end) diff --git a/spec/term_spec.lua b/spec/term_spec.lua new file mode 100644 index 0000000..5919de2 --- /dev/null +++ b/spec/term_spec.lua @@ -0,0 +1,245 @@ +expose("module", function() + require("kiwi") +end) + +describe("Term", function() + local kiwi = require("kiwi") + local LUA_VERSION = tonumber(_VERSION:match("%d+%.%d+")) + + it("construction", function() + local v = kiwi.Var("foo") + local t = kiwi.Term(v) + assert.equal(v, t.var) + assert.equal(1.0, t.coefficient) + + t = kiwi.Term(v, 100) + assert.equal(v, t.var) + assert.equal(100, t.coefficient) + + if LUA_VERSION <= 5.2 then + assert.equal("100 foo", tostring(t)) + else + assert.equal("100.0 foo", tostring(t)) + end + + assert.error(function() + kiwi.Term("") + end) + end) + + describe("method", function() + local v, v2, t, t2 + + before_each(function() + v = kiwi.Var("foo") + t = kiwi.Term(v, 10) + end) + + it("has value", function() + v:set(42) + assert.equal(420, t:value()) + v:set(87) + assert.equal(870, t:value()) + end) + + it("has toexpr", function() + local e = t:toexpr() + assert.True(kiwi.is_expression(e)) + assert.equal(0, e.constant) + local terms = e:terms() + assert.equal(1, #terms) + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + end) + + it("neg", function() + local neg = -t --[[@as kiwi.Term]] + assert.True(kiwi.is_term(neg)) + assert.equal(v, neg.var) + assert.equal(-10, neg.coefficient) + end) + + describe("bin op", function() + before_each(function() + v2 = kiwi.Var("bar") + t2 = kiwi.Term(v2) + end) + + it("mul", function() + for _, prod in ipairs({ t * 2.0, 2 * t }) do + assert.True(kiwi.is_term(prod)) + assert.equal(v, prod.var) + assert.equal(20, prod.coefficient) + end + + assert.error(function() + local _ = t * v + end) + end) + + it("div", function() + local quot = t / 2.0 + assert.True(kiwi.is_term(quot)) + assert.equal(v, quot.var) + assert.equal(5.0, quot.coefficient) + + assert.error(function() + local _ = v / v2 + end) + end) + + it("add", function() + for _, sum in ipairs({ t + 2.0, 2 + t }) do + assert.True(kiwi.is_expression(sum)) + assert.equal(2.0, sum.constant) + + local terms = sum:terms() + assert.equal(1, #terms) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v, terms[1].var) + end + + local sum = t + v2 + assert.True(kiwi.is_expression(sum)) + assert.equal(0, sum.constant) + local terms = sum:terms() + assert.equal(2, #terms) + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(1.0, terms[2].coefficient) + + sum = t + t2 + assert.True(kiwi.is_expression(sum)) + assert.equal(0, sum.constant) + terms = sum:terms() + assert.equal(2, #terms) + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(1.0, terms[2].coefficient) + + local t3 = kiwi.Term(v2, 20) + sum = t3 + sum + assert.True(kiwi.is_expression(sum)) + assert.equal(0, sum.constant) + terms = sum:terms() + assert.equal(3, #terms) + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(1.0, terms[2].coefficient) + assert.equal(v2, terms[3].var) + assert.equal(20.0, terms[3].coefficient) + + assert.error(function() + local _ = t + "foo" + end) + assert.error(function() + local _ = t + {} + end) + end) + + it("sub", function() + local constants = { -2, 2 } + for i, diff in ipairs({ t - 2.0, 2 - t }) do + local constant = constants[i] + assert.True(kiwi.is_expression(diff)) + assert.equal(constant, diff.constant) + + local terms = diff:terms() + assert.equal(1, #terms) + assert.equal(v, terms[1].var) + assert.equal(constant < 0 and 10.0 or -10.0, terms[1].coefficient) + end + + local diff = t - v2 + assert.True(kiwi.is_expression(diff)) + assert.equal(0, diff.constant) + local terms = diff:terms() + assert.equal(2, #terms) + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(-1.0, terms[2].coefficient) + + diff = t - t2 + assert.True(kiwi.is_expression(diff)) + assert.equal(0, diff.constant) + terms = diff:terms() + assert.equal(2, #terms) + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(-1.0, terms[2].coefficient) + + local t3 = kiwi.Term(v2, 20) + diff = t3 - diff + assert.True(kiwi.is_expression(diff)) + assert.equal(0, diff.constant) + terms = diff:terms() + assert.equal(3, #terms) + assert.equal(v, terms[1].var) + assert.equal(-10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(1.0, terms[2].coefficient) + assert.equal(v2, terms[3].var) + assert.equal(20.0, terms[3].coefficient) + + assert.error(function() + local _ = t - "foo" + end) + assert.error(function() + local _ = t - {} + end) + end) + + it("constraint term op expr", function() + local ops = { "LE", "EQ", "GE" } + for i, meth in ipairs({ "le", "eq", "ge" }) do + local c = t[meth](t, v2 + 1) + assert.True(kiwi.is_constraint(c)) + + local e = c:expression() + local terms = e:terms() + assert.equal(2, #terms) + + -- order can be randomized due to use of map + if terms[1].var ~= v then + terms[1], terms[2] = terms[2], terms[1] + end + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(-1.0, terms[2].coefficient) + + assert.equal(-1, e.constant) + assert.equal(ops[i], c:op()) + assert.equal(kiwi.strength.REQUIRED, c:strength()) + end + end) + + it("constraint term op term", function() + for i, meth in ipairs({ "le", "eq", "ge" }) do + local c = t[meth](t, t2) + assert.True(kiwi.is_constraint(c)) + + local e = c:expression() + local terms = e:terms() + assert.equal(2, #terms) + + -- order can be randomized due to use of map + if terms[1].var ~= v then + terms[1], terms[2] = terms[2], terms[1] + end + assert.equal(v, terms[1].var) + assert.equal(10.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(-1.0, terms[2].coefficient) + + assert.equal(0, e.constant) + end + end) + end) + end) +end) diff --git a/spec/var_spec.lua b/spec/var_spec.lua index 3bd693a..efc830f 100644 --- a/spec/var_spec.lua +++ b/spec/var_spec.lua @@ -127,6 +127,8 @@ describe("Var", function() assert.equal(v2, terms[2].var) assert.equal(-1.0, terms[2].coefficient) + -- TODO: terms and expressions + assert.error(function() local _ = v - "foo" end)