diff --git a/ckiwi/.clang-format b/.clang-format similarity index 98% rename from ckiwi/.clang-format rename to .clang-format index 5357c2c..76a9a41 100644 --- a/ckiwi/.clang-format +++ b/.clang-format @@ -27,7 +27,7 @@ BreakBeforeTernaryOperators: true BreakConstructorInitializers: AfterColon BreakInheritanceList: AfterColon BreakStringLiterals: false -ColumnLimit: 90 +ColumnLimit: 98 CompactNamespaces: false ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerIndentWidth: 4 @@ -65,7 +65,7 @@ PointerAlignment: Left ReferenceAlignment: Left # New in v13. int &name ==> int& name ReflowComments: false SeparateDefinitionBlocks: Always # New in v14. -SortIncludes: true +SortIncludes: false SortUsingDeclarations: true SpaceAfterCStyleCast: false SpaceAfterLogicalNot: false diff --git a/.editorconfig b/.editorconfig index 7e2158e..0c08aae 100644 --- a/.editorconfig +++ b/.editorconfig @@ -7,5 +7,4 @@ insert_final_newline = true [{*.lua,*.rockspec,.luacov}] indent_style = space indent_size = 3 -call_parentheses = nosingletable -max_line_length = 98 +max_line_length = 105 diff --git a/.github/workflows/busted.yml b/.github/workflows/busted.yml new file mode 100644 index 0000000..ae8e397 --- /dev/null +++ b/.github/workflows/busted.yml @@ -0,0 +1,37 @@ +name: Busted + +on: [push, pull_request] + +jobs: + busted: + strategy: + fail-fast: false + matrix: + lua_version: ["luajit-openresty", "luajit-2.1.0-beta3", "luajit-git"] + + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup ‘lua’ + uses: jkl1337/gh-actions-lua@master + with: + luaVersion: ${{ matrix.lua_version }} + - name: Setup ‘luarocks’ + uses: jkl1337/gh-actions-luarocks@master + - name: Setup dependencies + run: | + luarocks install busted + luarocks install luacov-coveralls + - name: Build C library + run: | + luarocks make --no-install + - name: Run busted tests + run: busted -c -v + - name: Report test coverage + if: success() + continue-on-error: true + run: luacov-coveralls -e .luarocks -e spec + env: + COVERALLS_REPO_TOKEN: ${{ github.token }} diff --git a/.gitignore b/.gitignore index 5b0714b..98b74ee 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ /lua /lua_modules /.luarocks +*.pch +*.gch *.so *.o .cache/ +compile_commands.json diff --git a/.luarc.json b/.luarc.json index 5e855b1..73dc269 100644 --- a/.luarc.json +++ b/.luarc.json @@ -8,6 +8,6 @@ "lua_modules/share/lua/5.1/?.lua", "lua_modules/share/lua/5.1/?/init.lua" ], - "workspace.library": ["lua_modules/share/lua/5.1"], + "workspace.library": ["${3rd}/busted/library", "${3rd}/luassert/library"], "workspace.checkThirdParty": false } diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..7bcc92e --- /dev/null +++ b/LICENSE @@ -0,0 +1,7 @@ +Copyright 2024 John Luebs + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..b1a3035 --- /dev/null +++ b/Makefile @@ -0,0 +1,93 @@ +-include config.mk + +CC := $(CROSS)gcc +CP := cp +RM := rm +LIBFLAG := -shared +LIB_EXT := so +LUA_INCDIR := /usr/include + +SRCDIR := . + +OPTFLAG := -O2 +CCFLAGS += $(OPTFLAG) -fPIC -Wall -fvisibility=hidden -Wformat=2 -Wconversion -Wimplicit-fallthrough + +SANITIZE_FLAGS := -fstrict-flex-arrays -fsanitize=undefined -fsanitize=address +LTO_FLAGS := -flto=auto + +ifdef SANITIZE + CCFLAGS += $(SANITIZE_FLAGS) +endif +ifdef LTO + CCFLAGS += $(LTO_FLAGS) +endif + +override CPPFLAGS += -I$(SRCDIR) -I$(SRCDIR)/kiwi -I$(LUA_INCDIR) +override CXXFLAGS += -std=c++14 -fno-rtti $(CCFLAGS) +override CFLAGS += -std=c99 $(CCFLAGS) + +ifneq ($(filter %gcc,$(CC)),) + CXX := $(patsubst %gcc,%g++,$(CC)) + PCH := ljkiwi.hpp.gch +else + ifneq ($(filter %clang,$(CC)),) + CXX := $(patsubst %clang,%clang++,$(CC)) + override CXXFLAGS += -pedantic -Wno-c99-extensions + PCH := ljkiwi.hpp.pch +endif +endif + +ifdef LUA +LUA_VERSION ?= $(lastword $(shell $(LUA) -e "print(_VERSION)")) +endif + +ifndef LUA_VERSION +LJKIWI_CKIWI := 1 +else + ifeq ($(LUA_VERSION),5.1) + LJKIWI_CKIWI := 1 + endif +endif + +KIWI_LIB := AssocVector.h constraint.h debug.h errors.h expression.h kiwi.h maptype.h \ + row.h shareddata.h solver.h solverimpl.h strength.h symbol.h symbolics.h term.h \ + util.h variable.h version.h + +OBJS := luakiwi.o +ifdef LJKIWI_CKIWI + OBJS += ckiwi.o +endif + +vpath %.cpp $(SRCDIR)/ckiwi $(SRCDIR)/luakiwi +vpath %.h $(SRCDIR)/ckiwi $(SRCDIR)/luakiwi $(SRCDIR)/kiwi/kiwi + +all: ljkiwi.$(LIB_EXT) + +install: + $(CP) -f ljkiwi.$(LIB_EXT) $(INST_LIBDIR)/ljkiwi.$(LIB_EXT) + $(CP) -f kiwi.lua $(INST_LUADIR)/kiwi.lua + +clean: + $(RM) -f ljkiwi.$(LIB_EXT) $(OBJS) $(PCH) + + +ljkiwi.hpp.gch: $(KIWI_LIB) +ckiwi.o: $(PCH) ckiwi.cpp ckiwi.h $(KIWI_LIB) +luakiwi.o: $(PCH) luakiwi-int.h luacompat.h $(KIWI_LIB) + +ljkiwi.$(LIB_EXT): $(OBJS) + $(CXX) $(CCFLAGS) $(LIBFLAG) -o $@ $(OBJS) + +%.hpp.gch: %.hpp + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -x c++-header -o $@ $< + +%.hpp.pch: %.hpp + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -x c++-header -o $@ $< + +%.o: %.c + $(CC) $(CPPFLAGS) $(CFLAGS) -c -o $@ $< + +%.o: %.cpp + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c -o $@ $< + +.PHONY: all install clean diff --git a/README.md b/README.md new file mode 100644 index 0000000..8d2258f --- /dev/null +++ b/README.md @@ -0,0 +1,94 @@ +ljkiwi - Free LuaJIT FFI and Lua C API kiwi (Cassowary derived) constraint solver. + +[![CI](https://github.com/jkl1337/ljkiwi/actions/workflows/busted.yml/badge.svg)](https://github.com/jkl1337/ljkiwi/actions/workflows/busted.yml) +[![Coverage Status](https://coveralls.io/repos/github/jkl1337/ljkiwi/badge.svg?branch=master)](https://coveralls.io/github/jkl1337/ljkiwi?branch=master) +[![luarocks](https://img.shields.io/luarocks/v/jkl/kiwi)](https://luarocks.org/modules/jkl/kiwi) + +# Introduction + +Kiwi is a reasonably efficient C++ implementation of the Cassowary constraint solving algorithm. It is an implementation of the algorithm as described in the paper ["The Cassowary Linear Arithmetic Constraint Solving Algorithm"](http://www.cs.washington.edu/research/constraints/cassowary/techreports/cassowaryTR.pdf) by Greg J. Badros and Alan Borning. The Kiwi implementation is not based on the original C++ implementation, but is a ground-up reimplementation with performance 10x to 500x faster in typical use. +Cassowary constraint solving is a technique that is particularly well suited to user interface layout. It is the algorithm Apple uses for iOS and OS X Auto Layout. + +There are a few Lua implementations or attempts. The SILE typesetting system has a pure Lua implementation of the original Cassowary code, which appears to be correct but is quite slow. There are two extant Lua ports of Kiwi, one that is based on a C rewrite of Kiwi. However testing of these was not encouraging with either segfaults or incorrect results. +Since the C++ Kiwi library is well tested and widely used it was simpler to provide a LuaJIT FFI wrapper. There is also a Lua C API binding with support for 5.1 through 5.4. +This package has no dependencies other than a supported C++14 compiler to compile the included Kiwi library and a small C wrapper. + +The Lua API has a pure Lua expression builder. There is of course some overhead to this, however in most cases expression building is infrequent and the underlying structures can be reused. + +The wrapper is quite close to the Kiwi C++/Python port with a few naming changes. + +## Example + +```lua +local kiwi = require("kiwi") +local Var = kiwi.Var + +local Button = setmetatable({}, { + __call = function(_, identifier) + return setmetatable({ + left = Var(identifier .. " left"), + width = Var(identifier .. " width"), + }, { + __tostring = function(self) + return "Button(" .. self.left:value() .. ", " .. self.width:value() .. ")" + end, + }) + end, +}) + +local b1 = Button("b1") +local b2 = Button("b2") + +local left_edge = Var("left") +local right_edge = Var("width") + +local STRONG = kiwi.Strength.STRONG + +-- stylua: ignore start +local constraints = { + left_edge :eq(0.0), + -- two buttons are the same width + b1.width :eq(b2.width), + -- button1 starts 50 from the left margin + b1.left :eq(left_edge + 50), + -- button2 ends 50 from the right margin + right_edge :eq(b2.left + b2.width + 50), + -- button2 starts at least 100 from the end of button1. This is the "elastic" constraint + b2.left :ge(b1.left + b1.width + 100), + -- button1 has a minimum width of 87 + b1.width :ge(87), + -- button1 has a preferred width of 87 + b1.width :eq(87, STRONG), + -- button2 has minimum width of 113 + b2.width :ge(113), + -- button2 has a preferred width of 113 + b2.width :eq(113, STRONG), +} +-- stylua: ignore end + +local solver = kiwi.Solver() + +for _, c in ipairs(constraints) do + solver:add_constraint(c) +end + +solver:update_vars() + +print(b1) -- Button(50, 113) +print(b2) -- Button(263, 113) +print(left_edge:value()) -- 0 +print(right_edge:value()) -- 426 + +solver:add_edit_var(right_edge, STRONG) +solver:suggest_value(right_edge, 500) +solver:update_vars() +print(b1) -- Button(50, 113) +print(b2) -- Button(337, 113) +print(right_edge:value()) -- 500 + +``` + +In addition to the expression builder there is a convenience constraints submodule with: `pair_ratio`, `pair`, and `single` to allow efficient construction of the most common simple expression types for GUI layout. + +## Documentation +WIP - However the API is fully annotated and will work with lua-language-server. Documentation can also be generated with lua-language-server. diff --git a/ckiwi/ckiwi.cpp b/ckiwi/ckiwi.cpp new file mode 100644 index 0000000..13596df --- /dev/null +++ b/ckiwi/ckiwi.cpp @@ -0,0 +1,366 @@ +#include "ljkiwi.hpp" +#include "ckiwi.h" + +#include + +#include +#include +#include +#include + +#if defined(__GNUC__) && !defined(LJKIWI_NO_BUILTIN) + #define lk_likely(x) (__builtin_expect(((x) != 0), 1)) + #define lk_unlikely(x) (__builtin_expect(((x) != 0), 0)) +#else + #define lk_likely(x) (x) + #define lk_unlikely(x) (x) +#endif + +namespace { + +using namespace kiwi; + +const KiwiErr* new_error(const KiwiErr* base, const std::exception& ex) { + if (!std::strcmp(ex.what(), base->message)) + return base; + + const auto msg_n = std::strlen(ex.what()) + 1; + + auto* mem = static_cast(std::malloc(sizeof(KiwiErr) + msg_n)); + if (!mem) { + return base; + } + + const auto* err = new (mem) KiwiErr {base->kind, mem + sizeof(KiwiErr), true}; + std::memcpy(const_cast(err->message), ex.what(), msg_n); + return err; +} + +static const constexpr KiwiErr kKiwiErrUnhandledCxxException { + KiwiErrUnknown, + "An unhandled C++ exception occurred."}; + +static const constexpr KiwiErr kKiwiErrNullObjectArg0 { + KiwiErrNullObject, + "null object passed as argument #0 (self)"}; + +static const constexpr KiwiErr kKiwiErrNullObjectArg1 { + KiwiErrNullObject, + "null object passed as argument #1"}; + +template +const KiwiErr* wrap_err(F&& f) { + try { + f(); + } catch (const UnsatisfiableConstraint& ex) { + static const constexpr KiwiErr err { + KiwiErrUnsatisfiableConstraint, + "The constraint cannot be satisfied."}; + return &err; + } catch (const UnknownConstraint& ex) { + static const constexpr KiwiErr err { + KiwiErrUnknownConstraint, + "The constraint has not been added to the solver."}; + return &err; + + } catch (const DuplicateConstraint& ex) { + static const constexpr KiwiErr err { + KiwiErrDuplicateConstraint, + "The constraint has already been added to the solver."}; + return &err; + + } catch (const UnknownEditVariable& ex) { + static const constexpr KiwiErr err { + KiwiErrUnknownEditVariable, + "The edit variable has not been added to the solver."}; + return &err; + + } catch (const DuplicateEditVariable& ex) { + static const constexpr KiwiErr err { + KiwiErrDuplicateEditVariable, + "The edit variable has already been added to the solver."}; + return &err; + + } catch (const BadRequiredStrength& ex) { + static const constexpr KiwiErr err { + KiwiErrBadRequiredStrength, + "A required strength cannot be used in this context."}; + return &err; + + } catch (const InternalSolverError& ex) { + static const constexpr KiwiErr base { + KiwiErrInternalSolverError, + "An internal solver error occurred."}; + return new_error(&base, ex); + } catch (std::bad_alloc&) { + static const constexpr KiwiErr err {KiwiErrAlloc, "A memory allocation failed."}; + return &err; + } catch (const std::exception& ex) { + return new_error(&kKiwiErrUnhandledCxxException, ex); + } catch (...) { + return &kKiwiErrUnhandledCxxException; + } + return nullptr; +} + +template +const KiwiErr* wrap_err(P self, F&& f) { + if (lk_unlikely(!self)) { + return &kKiwiErrNullObjectArg0; + } + return wrap_err([&]() { f(self->solver); }); +} + +template +const KiwiErr* wrap_err(P* self, R* item, F&& f) { + if (lk_unlikely(!self)) { + return &kKiwiErrNullObjectArg0; + } else if (lk_unlikely(!item)) { + return &kKiwiErrNullObjectArg1; + } + return wrap_err([&]() { f(self->solver, item); }); +} + +template +T* make_unmanaged(Args... args) { + auto* o = new T(std::forward(args)...); + o->m_refcount = 1; + return o; +} + +template +void release_unmanaged(T* p) { + if (lk_likely(p)) { + if (--p->m_refcount == 0) + delete p; + } +} + +template +T* retain_unmanaged(T* p) { + if (lk_likely(p)) + p->m_refcount++; + return p; +} + +} // namespace + +extern "C" { + +KiwiVar* kiwi_var_construct(const char* name) { + return make_unmanaged(lk_likely(name) ? name : ""); +} + +void kiwi_var_release(KiwiVar* var) { + release_unmanaged(var); +} + +void kiwi_var_retain(KiwiVar* var) { + retain_unmanaged(var); +} + +const char* kiwi_var_name(const KiwiVar* var) { + return lk_likely(var) ? var->name().c_str() : "()"; +} + +void kiwi_var_set_name(KiwiVar* var, const char* name) { + if (lk_likely(var && name)) + var->setName(name); +} + +double kiwi_var_value(const KiwiVar* var) { + return lk_likely(var) ? var->value() : std::numeric_limits::quiet_NaN(); +} + +void kiwi_var_set_value(KiwiVar* var, double value) { + if (lk_likely(var)) + var->setValue(value); +} + +void kiwi_expression_retain(KiwiExpression* expr) { + if (lk_unlikely(!expr)) + return; + for (auto* t = expr->terms_; t != expr->terms_ + expr->term_count; ++t) { + retain_unmanaged(t->var); + } +} + +void kiwi_expression_destroy(KiwiExpression* expr) { + if (lk_unlikely(!expr)) + return; + + if (expr->owner) { + release_unmanaged(expr->owner); + } else { + for (auto* t = expr->terms_; t != expr->terms_ + expr->term_count; ++t) { + release_unmanaged(t->var); + } + } +} + +KiwiConstraint* kiwi_constraint_construct( + const KiwiExpression* lhs, + const KiwiExpression* rhs, + enum KiwiRelOp op, + double strength +) { + if (strength < 0.0) { + strength = kiwi::strength::required; + } + + std::vector terms; + terms.reserve(static_cast( + (lhs && lhs->term_count > 0 ? lhs->term_count : 0) + + (rhs && rhs->term_count > 0 ? rhs->term_count : 0) + )); + + if (lhs) { + for (int i = 0; i < lhs->term_count; ++i) { + const auto& t = lhs->terms_[i]; + if (t.var) + terms.emplace_back(Variable(t.var), t.coefficient); + } + } + if (rhs) { + for (int i = 0; i < rhs->term_count; ++i) { + const auto& t = rhs->terms_[i]; + if (t.var) + terms.emplace_back(Variable(t.var), -t.coefficient); + } + } + + return make_unmanaged( + Expression(std::move(terms), (lhs ? lhs->constant : 0.0) - (rhs ? rhs->constant : 0.0)), + static_cast(op), + strength + ); +} + +void kiwi_constraint_release(KiwiConstraint* c) { + release_unmanaged(c); +} + +void kiwi_constraint_retain(KiwiConstraint* c) { + retain_unmanaged(c); +} + +double kiwi_constraint_strength(const KiwiConstraint* c) { + return lk_likely(c) ? c->strength() : std::numeric_limits::quiet_NaN(); +} + +enum KiwiRelOp kiwi_constraint_op(const KiwiConstraint* c) { + return lk_likely(c) ? static_cast(c->op()) : KiwiRelOp::KIWI_OP_EQ; +} + +bool kiwi_constraint_violated(const KiwiConstraint* c) { + return lk_likely(c) ? c->violated() : false; +} + +int kiwi_constraint_expression(KiwiConstraint* c, KiwiExpression* out, int out_size) { + if (lk_unlikely(!c)) + return 0; + + const auto& expr = c->expression(); + const auto& terms = expr.terms(); + int n = terms.size() < INT_MAX ? static_cast(terms.size()) : INT_MAX; + if (!out || out_size < n) + return n; + + for (int i = 0; i < n; ++i) { + const auto& t = terms[static_cast(i)]; + out->terms_[i].var = const_cast(t.variable()).ptr(); + out->terms_[i].coefficient = t.coefficient(); + } + out->constant = expr.constant(); + out->term_count = n; + out->owner = retain_unmanaged(c); + + return n; +} + +struct KiwiSolver { + unsigned error_mask; + Solver solver; +}; + +KiwiSolver* kiwi_solver_construct(unsigned error_mask) { + return new KiwiSolver {error_mask}; +} + +void kiwi_solver_destroy(KiwiSolver* s) { + if (lk_likely(s)) + delete s; +} + +unsigned kiwi_solver_get_error_mask(const KiwiSolver* s) { + return lk_likely(s) ? s->error_mask : 0; +} + +void kiwi_solver_set_error_mask(KiwiSolver* s, unsigned mask) { + if (lk_likely(s)) + s->error_mask = mask; +} + +const KiwiErr* kiwi_solver_add_constraint(KiwiSolver* s, KiwiConstraint* constraint) { + return wrap_err(s, constraint, [](auto&& s, auto&& c) { s.addConstraint(Constraint(c)); }); +} + +const KiwiErr* kiwi_solver_remove_constraint(KiwiSolver* s, KiwiConstraint* constraint) { + return wrap_err(s, constraint, [](auto&& s, auto&& c) { s.removeConstraint(Constraint(c)); }); +} + +bool kiwi_solver_has_constraint(const KiwiSolver* s, KiwiConstraint* constraint) { + if (lk_unlikely(!s || !constraint)) + return 0; + return s->solver.hasConstraint(Constraint(constraint)); +} + +const KiwiErr* kiwi_solver_add_edit_var(KiwiSolver* s, KiwiVar* var, double strength) { + return wrap_err(s, var, [strength](auto& s, auto&& v) { + s.addEditVariable(Variable(v), strength); + }); +} + +const KiwiErr* kiwi_solver_remove_edit_var(KiwiSolver* s, KiwiVar* var) { + return wrap_err(s, var, [](auto&& s, auto&& v) { s.removeEditVariable(Variable(v)); }); +} + +bool kiwi_solver_has_edit_var(const KiwiSolver* s, KiwiVar* var) { + if (lk_unlikely(!s || !var)) + return 0; + return s->solver.hasEditVariable(Variable(var)); +} + +const KiwiErr* kiwi_solver_suggest_value(KiwiSolver* s, KiwiVar* var, double value) { + return wrap_err(s, var, [value](auto&& s, auto&& v) { s.suggestValue(Variable(v), value); }); +} + +void kiwi_solver_update_vars(KiwiSolver* s) { + if (lk_likely(s)) + s->solver.updateVariables(); +} + +void kiwi_solver_reset(KiwiSolver* s) { + if (lk_likely(s)) + s->solver.reset(); +} + +void kiwi_solver_dump(const KiwiSolver* s) { + if (lk_likely(s)) + s->solver.dump(); +} + +char* kiwi_solver_dumps(const KiwiSolver* s) { + if (lk_unlikely(!s)) + return nullptr; + + const auto& str = s->solver.dumps(); + const auto buf_size = str.size() + 1; + auto* buf = static_cast(std::malloc(buf_size)); + if (!buf) + return nullptr; + std::memcpy(buf, str.c_str(), str.size() + 1); + return buf; +} + +} // extern "C" diff --git a/ckiwi/ckiwi.h b/ckiwi/ckiwi.h new file mode 100644 index 0000000..80a61e3 --- /dev/null +++ b/ckiwi/ckiwi.h @@ -0,0 +1,133 @@ +#ifndef LJKIWI_CKIWI_H_ +#define LJKIWI_CKIWI_H_ + +#if !defined(_MSC_VER) || _MSC_VER >= 1900 + #undef LJKIWI_USE_FAM_1 +#else + #define LJKIWI_USE_FAM_1 +#endif + +#ifdef __cplusplus + +namespace kiwi { +class VariableData; +class Constraint; +} // namespace kiwi + +typedef kiwi::VariableData KiwiVar; +typedef kiwi::ConstraintData KiwiConstraint; + +extern "C" { + +#else +typedef struct KiwiVar KiwiVar; +typedef struct KiwiConstraint KiwiConstraint; + +#endif + +#if __GNUC__ + #pragma GCC visibility push(default) + #define LJKIWI_DATA_EXPORT __attribute__((visibility("default"))) +#endif + +// LuaJIT start +enum KiwiErrKind { + KiwiErrNone, + KiwiErrUnsatisfiableConstraint = 1, + KiwiErrUnknownConstraint, + KiwiErrDuplicateConstraint, + KiwiErrUnknownEditVariable, + KiwiErrDuplicateEditVariable, + KiwiErrBadRequiredStrength, + KiwiErrInternalSolverError, + KiwiErrAlloc, + KiwiErrNullObject, + KiwiErrUnknown, +}; + +enum KiwiRelOp { KIWI_OP_LE, KIWI_OP_GE, KIWI_OP_EQ }; + +typedef struct KiwiTerm { + KiwiVar* var; + double coefficient; +} KiwiTerm; + +typedef struct KiwiExpression { + double constant; + int term_count; + KiwiConstraint* owner; + +#if defined(LJKIWI_LUAJIT_DEF) + KiwiTerm terms_[?]; +#elif defined(LJKIWI_USE_FAM_1) + KiwiTerm terms_[1]; // LuaJIT: struct KiwiTerm terms_[?]; +#else + KiwiTerm terms_[]; +#endif + +} KiwiExpression; + +typedef struct KiwiErr { + enum KiwiErrKind kind; + const char* message; + bool must_free; +} KiwiErr; + +struct KiwiSolver; + +KiwiVar* kiwi_var_construct(const char* name); +void kiwi_var_release(KiwiVar* var); +void kiwi_var_retain(KiwiVar* var); + +const char* kiwi_var_name(const KiwiVar* var); +void kiwi_var_set_name(KiwiVar* var, const char* name); +double kiwi_var_value(const KiwiVar* var); +void kiwi_var_set_value(KiwiVar* var, double value); + +void kiwi_expression_retain(KiwiExpression* expr); +void kiwi_expression_destroy(KiwiExpression* expr); + +KiwiConstraint* kiwi_constraint_construct( + const KiwiExpression* lhs, + const KiwiExpression* rhs, + enum KiwiRelOp op, + double strength +); +void kiwi_constraint_release(KiwiConstraint* c); +void kiwi_constraint_retain(KiwiConstraint* c); + +double kiwi_constraint_strength(const KiwiConstraint* c); +enum KiwiRelOp kiwi_constraint_op(const KiwiConstraint* c); +bool kiwi_constraint_violated(const KiwiConstraint* c); +int kiwi_constraint_expression(KiwiConstraint* c, KiwiExpression* out, int out_size); + +KiwiSolver* kiwi_solver_construct(unsigned error_mask); +void kiwi_solver_destroy(KiwiSolver* s); +unsigned kiwi_solver_get_error_mask(const KiwiSolver* s); +void kiwi_solver_set_error_mask(KiwiSolver* s, unsigned mask); + +const KiwiErr* kiwi_solver_add_constraint(KiwiSolver* s, KiwiConstraint* constraint); +const KiwiErr* kiwi_solver_remove_constraint(KiwiSolver* s, KiwiConstraint* constraint); +bool kiwi_solver_has_constraint(const KiwiSolver* s, KiwiConstraint* constraint); +const KiwiErr* kiwi_solver_add_edit_var(KiwiSolver* s, KiwiVar* var, double strength); +const KiwiErr* kiwi_solver_remove_edit_var(KiwiSolver* s, KiwiVar* var); +bool kiwi_solver_has_edit_var(const KiwiSolver* s, KiwiVar* var); +const KiwiErr* kiwi_solver_suggest_value(KiwiSolver* s, KiwiVar* var, double value); +void kiwi_solver_update_vars(KiwiSolver* sp); +void kiwi_solver_reset(KiwiSolver* sp); +void kiwi_solver_dump(const KiwiSolver* sp); +char* kiwi_solver_dumps(const KiwiSolver* sp); +// LuaJIT end + +#if __GNUC__ + #pragma GCC visibility pop +#endif + +#ifdef __cplusplus +} // extern "C" +#endif + +// Local Variables: +// mode: c++ +// End: +#endif // LJKIWI_CKIWI_H_ diff --git a/kiwi-scm-1.rockspec b/kiwi-scm-1.rockspec new file mode 100644 index 0000000..749b688 --- /dev/null +++ b/kiwi-scm-1.rockspec @@ -0,0 +1,36 @@ +rockspec_format = "3.0" +package = "kiwi" +version = "scm-1" +source = { + url = "git+https://github.com/jkl1337/ljkiwi", +} +description = { + summary = "LuaJIT FFI and Lua binding for the Kiwi constraint solver.", + detailed = [[ + kiwi is a LuaJIT FFI and Lua binding for the Kiwi constraint solver. Kiwi is a fast + implementation of the Cassowary constraint solving algorithm. kiwi provides + reasonably efficient bindings using the LuaJIT FFI and convential Lua C bindings.]], + license = "MIT", + issues_url = "https://github.com/jkl1337/ljkiwi/issues", + maintainer = "John Luebs", +} +dependencies = { + "lua >= 5.1", +} + +build = { + type = "make", + build_variables = { + LUA = "$(LUA)", + CFLAGS = "$(CFLAGS)", + LUA_INCDIR = "$(LUA_INCDIR)", + LIBFLAG = "$(LIBFLAG)", + LIB_EXT = "$(LIB_EXTENSION)", + OBJ_EXT = "$(OBJ_EXTENSION)", + }, + install_variables = { + INST_LIBDIR = "$(LIBDIR)", + INST_LUADIR = "$(LUADIR)", + LIB_EXT = "$(LIB_EXTENSION)", + }, +} diff --git a/kiwi.lua b/kiwi.lua new file mode 100644 index 0000000..5c578aa --- /dev/null +++ b/kiwi.lua @@ -0,0 +1,1178 @@ +-- kiwi.lua - LuaJIT FFI bindings with C API fallback to kiwi constraint solver. + +local ffi +do + local ffi_loader = package.preload["ffi"] + if ffi_loader == nil then + return require("ljkiwi") + end + ffi = ffi_loader() --[[@as ffilib]] +end + +local kiwi = {} + +local ljkiwi +do + local cpath, err = package.searchpath("ljkiwi", package.cpath) + if cpath == nil then + error("kiwi dynamic library 'ljkiwi' not found\n" .. err) + end + ljkiwi = ffi.load(cpath) +end +kiwi.ljkiwi = ljkiwi + +ffi.cdef([[ +void free(void *); + +typedef struct KiwiVar KiwiVar; +typedef struct KiwiConstraint KiwiConstraint; +typedef struct KiwiSolver KiwiSolver;]]) + +ffi.cdef([[ +enum KiwiErrKind { + KiwiErrNone, + KiwiErrUnsatisfiableConstraint = 1, + KiwiErrUnknownConstraint, + KiwiErrDuplicateConstraint, + KiwiErrUnknownEditVariable, + KiwiErrDuplicateEditVariable, + KiwiErrBadRequiredStrength, + KiwiErrInternalSolverError, + KiwiErrAlloc, + KiwiErrNullObject, + KiwiErrUnknown, +}; + +enum KiwiRelOp { LE, GE, EQ }; + +typedef struct KiwiTerm { + KiwiVar* var; + double coefficient; +} KiwiTerm; + +typedef struct KiwiExpression { + double constant; + int term_count; + KiwiConstraint* owner; + + KiwiTerm terms_[?]; +} KiwiExpression; + +typedef struct KiwiErr { + enum KiwiErrKind kind; + const char* message; + bool must_free; +} KiwiErr; + +struct KiwiSolver; + +KiwiVar* kiwi_var_construct(const char* name); +void kiwi_var_release(KiwiVar* var); +void kiwi_var_retain(KiwiVar* var); + +const char* kiwi_var_name(const KiwiVar* var); +void kiwi_var_set_name(KiwiVar* var, const char* name); +double kiwi_var_value(const KiwiVar* var); +void kiwi_var_set_value(KiwiVar* var, double value); + +void kiwi_expression_retain(KiwiExpression* expr); +void kiwi_expression_destroy(KiwiExpression* expr); + +KiwiConstraint* kiwi_constraint_construct( + const KiwiExpression* lhs, + const KiwiExpression* rhs, + enum KiwiRelOp op, + double strength +); +void kiwi_constraint_release(KiwiConstraint* c); +void kiwi_constraint_retain(KiwiConstraint* c); + +double kiwi_constraint_strength(const KiwiConstraint* c); +enum KiwiRelOp kiwi_constraint_op(const KiwiConstraint* c); +bool kiwi_constraint_violated(const KiwiConstraint* c); +int kiwi_constraint_expression(KiwiConstraint* c, KiwiExpression* out, int out_size); + +KiwiSolver* kiwi_solver_construct(unsigned error_mask); +void kiwi_solver_destroy(KiwiSolver* s); +unsigned kiwi_solver_get_error_mask(const KiwiSolver* s); +void kiwi_solver_set_error_mask(KiwiSolver* s, unsigned mask); + +const KiwiErr* kiwi_solver_add_constraint(KiwiSolver* s, KiwiConstraint* constraint); +const KiwiErr* kiwi_solver_remove_constraint(KiwiSolver* s, KiwiConstraint* constraint); +bool kiwi_solver_has_constraint(const KiwiSolver* s, KiwiConstraint* constraint); +const KiwiErr* kiwi_solver_add_edit_var(KiwiSolver* s, KiwiVar* var, double strength); +const KiwiErr* kiwi_solver_remove_edit_var(KiwiSolver* s, KiwiVar* var); +bool kiwi_solver_has_edit_var(const KiwiSolver* s, KiwiVar* var); +const KiwiErr* kiwi_solver_suggest_value(KiwiSolver* s, KiwiVar* var, double value); +void kiwi_solver_update_vars(KiwiSolver* sp); +void kiwi_solver_reset(KiwiSolver* sp); +void kiwi_solver_dump(const KiwiSolver* sp); +char* kiwi_solver_dumps(const KiwiSolver* sp); +]]) + +local strformat = string.format +local ffi_copy, ffi_gc, ffi_istype, ffi_new, ffi_string = + ffi.copy, ffi.gc, ffi.istype, ffi.new, ffi.string + +local concat = table.concat +local has_table_new, new_tab = pcall(require, "table.new") +if not has_table_new or type(new_tab) ~= "function" then + new_tab = function(_, _) + return {} + end +end + +---@alias kiwi.ErrKind +---| '"KiwiErrNone"' # No error. +---| '"KiwiErrUnsatisfiableConstraint"' # The given constraint is required and cannot be satisfied. +---| '"KiwiErrUnknownConstraint"' # The given constraint has not been added to the solver. +---| '"KiwiErrDuplicateConstraint"' # The given constraint has already been added to the solver. +---| '"KiwiErrUnknownEditVariable"' # The given edit variable has not been added to the solver. +---| '"KiwiErrDuplicateEditVariable"' # The given edit variable has already been added to the solver. +---| '"KiwiErrBadRequiredStrength"' # The given strength is >= required. +---| '"KiwiErrInternalSolverError"' # An internal solver error occurred. +---| '"KiwiErrAlloc"' # A memory allocation error occurred. +---| '"KiwiErrNullObject"' # A method was invoked on a null or empty object. +---| '"KiwiErrUnknown"' # An unknown error occurred. +kiwi.ErrKind = ffi.typeof("enum KiwiErrKind") --[[@as kiwi.ErrKind]] + +---@alias kiwi.RelOp +---| '"LE"' # <= (less than or equal) +---| '"GE"' # >= (greater than or equal) +---| '"EQ"' # == (equal) +kiwi.RelOp = ffi.typeof("enum KiwiRelOp") + +kiwi.strength = { + REQUIRED = 1001001000.0, + STRONG = 1000000.0, + MEDIUM = 1000.0, + WEAK = 1.0, +} + +do + local function clamp(n) + return math.max(0, math.min(1000, n)) + end + + --- Create a custom constraint strength. + ---@param a number: Scale factor 1e6 + ---@param b number: Scale factor 1e3 + ---@param c number: Scale factor 1 + ---@param w? number: Weight + ---@return number + ---@nodiscard + function kiwi.strength.create(a, b, c, w) + w = w or 1.0 + return clamp(a * w) * 1000000.0 + clamp(b * w) * 1000.0 + clamp(c * w) + end +end + +local Var = ffi.typeof("struct KiwiVar") --[[@as kiwi.Var]] +kiwi.Var = Var + +function kiwi.is_var(o) + return ffi_istype(Var, o) +end + +local Term = ffi.typeof("struct KiwiTerm") --[[@as kiwi.Term]] +kiwi.Term = Term + +function kiwi.is_term(o) + return ffi_istype(Term, o) +end + +local Expression = ffi.typeof("struct KiwiExpression") --[[@as kiwi.Expression]] +kiwi.Expression = Expression + +function kiwi.is_expression(o) + return ffi_istype(Expression, o) +end + +local Constraint = ffi.typeof("struct KiwiConstraint") --[[@as kiwi.Constraint]] +kiwi.Constraint = Constraint + +function kiwi.is_constraint(o) + return ffi_istype(Constraint, o) +end + +---@param expr kiwi.Expression +---@param var kiwi.Var +---@param coeff number? +---@nodiscard +local function add_expr_term(expr, var, coeff) + local ret = ffi_gc(ffi_new(Expression, expr.term_count + 1), ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + for i = 0, expr.term_count - 1 do + local st = expr.terms_[i] --[[@as kiwi.Term]] + local dt = ret.terms_[i] --[[@as kiwi.Term]] + dt.var = st.var + dt.coefficient = st.coefficient + end + local dt = ret.terms_[expr.term_count] + dt.var = var + dt.coefficient = coeff or 1.0 + ret.constant = expr.constant + ret.term_count = expr.term_count + 1 + ljkiwi.kiwi_expression_retain(ret) + return ret +end + +---@param constant number +---@param var kiwi.Var +---@param coeff number? +---@nodiscard +local function new_expr_one(constant, var, coeff) + local ret = ffi_gc(ffi_new(Expression, 1), ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + local dt = ret.terms_[0] + dt.var = var + dt.coefficient = coeff or 1.0 + ret.constant = constant + ret.term_count = 1 + ljkiwi.kiwi_var_retain(var) + return ret +end + +---@param constant number +---@param var1 kiwi.Var +---@param var2 kiwi.Var +---@param coeff1 number? +---@param coeff2 number? +---@nodiscard +local function new_expr_pair(constant, var1, var2, coeff1, coeff2) + local ret = ffi_gc(ffi_new(Expression, 2), ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + local dt = ret.terms_[0] + dt.var = var1 + dt.coefficient = coeff1 or 1.0 + dt = ret.terms_[1] + dt.var = var2 + dt.coefficient = coeff2 or 1.0 + ret.constant = constant + ret.term_count = 2 + ljkiwi.kiwi_expression_retain(ret) + return ret +end + +local function typename(o) + if ffi.istype(Var, o) then + return "Var" + elseif ffi.istype(Term, o) then + return "Term" + elseif ffi.istype(Expression, o) then + return "Expression" + elseif ffi.istype(Constraint, o) then + return "Constraint" + else + return type(o) + end +end + +local function op_error(a, b, op) + --stylua: ignore + -- level 3 works for arithmetic without TCO (no return), and for rel with TCO forced (explicit return) + error(strformat( + "invalid operand type for '%s' %.40s('%.99s') and %.40s('%.99s')", + op, typename(a), tostring(a), typename(b), tostring(b)), 3) +end + +local Strength = kiwi.strength +local REQUIRED = Strength.REQUIRED + +local OP_NAMES = { + LE = "<=", + GE = ">=", + EQ = "==", +} + +local SIZEOF_TERM = ffi.sizeof(Term) --[[@as integer]] + +local tmpexpr = ffi_new(Expression, 2) --[[@as kiwi.Expression]] +local tmpexpr_r = ffi_new(Expression, 1) --[[@as kiwi.Expression]] + +local function toexpr(o, temp) + if ffi_istype(Expression, o) then + return o --[[@as kiwi.Expression]] + elseif type(o) == "number" then + temp.constant = o + temp.term_count = 0 + return temp + end + temp.constant = 0 + temp.term_count = 1 + local t = temp.terms_[0] + + if ffi_istype(Var, o) then + t.var = o --[[@as kiwi.Var]] + t.coefficient = 1.0 + elseif ffi_istype(Term, o) then + ffi_copy(t, o, SIZEOF_TERM) + else + return nil + end + return temp +end + +---@param lhs kiwi.Expression|kiwi.Term|kiwi.Var|number +---@param rhs kiwi.Expression|kiwi.Term|kiwi.Var|number +---@param op kiwi.RelOp +---@param strength? number +---@nodiscard +local function rel(lhs, rhs, op, strength) + local el = toexpr(lhs, tmpexpr) + local er = toexpr(rhs, tmpexpr_r) + if el == nil or er == nil then + op_error(lhs, rhs, OP_NAMES[op]) + end + + return ffi_gc( + ljkiwi.kiwi_constraint_construct(el, er, op, strength or REQUIRED), + ljkiwi.kiwi_constraint_release + ) --[[@as kiwi.Constraint]] +end + +--- Define a constraint with expressions as `a <= b`. +---@param lhs kiwi.Expression|kiwi.Term|kiwi.Var|number +---@param rhs kiwi.Expression|kiwi.Term|kiwi.Var|number +---@param strength? number +---@nodiscard +function kiwi.le(lhs, rhs, strength) + return rel(lhs, rhs, "LE", strength) +end + +--- Define a constraint with expressions as `a >= b`. +---@param lhs kiwi.Expression|kiwi.Term|kiwi.Var|number +---@param rhs kiwi.Expression|kiwi.Term|kiwi.Var|number +---@param strength? number +---@nodiscard +function kiwi.ge(lhs, rhs, strength) + return rel(lhs, rhs, "GE", strength) +end + +--- Define a constraint with expressions as `a == b`. +---@param lhs kiwi.Expression|kiwi.Term|kiwi.Var|number +---@param rhs kiwi.Expression|kiwi.Term|kiwi.Var|number +---@param strength? number +---@nodiscard +function kiwi.eq(lhs, rhs, strength) + return rel(lhs, rhs, "EQ", strength) +end + +do + --- Variables are the values the constraint solver calculates. + ---@class kiwi.Var: ffi.cdata* + ---@overload fun(name: string?): kiwi.Var + ---@operator mul(number): kiwi.Term + ---@operator div(number): kiwi.Term + ---@operator unm: kiwi.Term + ---@operator add(kiwi.Expression|kiwi.Term|kiwi.Var|number): kiwi.Expression + ---@operator sub(kiwi.Expression|kiwi.Term|kiwi.Var|number): kiwi.Expression + local Var_cls = { + le = kiwi.le, + ge = kiwi.ge, + eq = kiwi.eq, + + --- Change the name of the variable. + ---@type fun(self: kiwi.Var, name: string) + set_name = ljkiwi.kiwi_var_set_name, + + --- Get the current value of the variable. + ---@type fun(self: kiwi.Var): number + value = ljkiwi.kiwi_var_value, + + --- Set the value of the variable. + ---@type fun(self: kiwi.Var, value: number) + set = ljkiwi.kiwi_var_set_value, + } + + --- Get the name of the variable. + ---@return string + ---@nodiscard + function Var_cls:name() + return ffi_string(ljkiwi.kiwi_var_name(self)) + end + + --- Create a term from this variable. + ---@param coefficient number? + ---@return kiwi.Term + ---@nodiscard + function Var_cls:toterm(coefficient) + return Term(self, coefficient) + end + + --- Create a term from this variable. + ---@param coefficient number? + ---@param constant number? + ---@return kiwi.Expression + ---@nodiscard + function Var_cls:toexpr(coefficient, constant) + return new_expr_one(constant or 0.0, self, coefficient) + end + + local Var_mt = { + __index = Var_cls, + } + + function Var_mt:__new(name) + return ffi_gc(ljkiwi.kiwi_var_construct(name), ljkiwi.kiwi_var_release) + end + + function Var_mt.__mul(a, b) + if type(a) == "number" then + return Term(b, a) + elseif type(b) == "number" then + return Term(a, b) + end + op_error(a, b, "*") + end + + function Var_mt.__div(a, b) + if type(b) ~= "number" then + op_error(a, b, "/") + end + return Term(a, 1.0 / b) + end + + function Var_mt:__unm() + return Term(self, -1.0) + end + + function Var_mt.__add(a, b) + if ffi_istype(Var, b) then + if type(a) == "number" then + return new_expr_one(a, b) + else + return new_expr_pair(0.0, a, b) + end + elseif ffi_istype(Term, b) then + return new_expr_pair(0.0, a, b.var, 1.0, b.coefficient) + elseif ffi_istype(Expression, b) then + return add_expr_term(b, a) + elseif type(b) == "number" then + return new_expr_one(b, a) + end + op_error(a, b, "+") + end + + function Var_mt.__sub(a, b) + return a + -b + end + + function Var_mt:__tostring() + return self:name() .. "(" .. self:value() .. ")" + end + + ffi.metatype(Var, Var_mt) +end + +do + --- Terms are the components of an expression. + --- Each term is a variable multiplied by a constant coefficient (default 1.0). + ---@class kiwi.Term: ffi.cdata* + ---@overload fun(var: kiwi.Var, coefficient: number?): kiwi.Term + ---@field var kiwi.Var + ---@field coefficient number + ---@operator mul(number): kiwi.Term + ---@operator div(number): kiwi.Term + ---@operator unm: kiwi.Term + ---@operator add(kiwi.Expression|kiwi.Term|kiwi.Var|number): kiwi.Expression + ---@operator sub(kiwi.Expression|kiwi.Term|kiwi.Var|number): kiwi.Expression + local Term_cls = { + le = kiwi.le, + ge = kiwi.ge, + eq = kiwi.eq, + } + + ---@return number + ---@nodiscard + function Term_cls:value() + return self.coefficient * self.var:value() + end + + --- Create an expression from this term. + ---@param constant number? + ---@return kiwi.Expression + function Term_cls:toexpr(constant) + return new_expr_one(constant or 0.0, self.var, self.coefficient) + end + + local Term_mt = { __index = Term_cls } + + local function term_gc(term) + ljkiwi.kiwi_var_release(term.var) + end + + function Term_mt.__new(T, var, coefficient) + local t = ffi_new(T, var, coefficient or 1.0) + ljkiwi.kiwi_var_retain(var) + return ffi_gc(t, term_gc) + end + + function Term_mt.__mul(a, b) + if type(b) == "number" then + return Term(a.var, a.coefficient * b) + elseif type(a) == "number" then + return Term(b.var, b.coefficient * a) + end + op_error(a, b, "*") + end + + function Term_mt.__div(a, b) + if type(b) ~= "number" then + op_error(a, b, "/") + end + return Term(a.var, a.coefficient / b) + end + + function Term_mt:__unm() + return Term(self.var, -self.coefficient) + end + + function Term_mt.__add(a, b) + if ffi_istype(Var, b) then + return new_expr_pair(0.0, a.var, b, a.coefficient) + elseif ffi_istype(Term, b) then + if type(a) == "number" then + return new_expr_one(a, b.var, b.coefficient) + else + return new_expr_pair(0.0, a.var, b.var, a.coefficient, b.coefficient) + end + elseif ffi_istype(Expression, b) then + return add_expr_term(b, a.var, a.coefficient) + elseif type(b) == "number" then + return new_expr_one(b, a.var, a.coefficient) + end + op_error(a, b, "+") + end + + function Term_mt.__sub(a, b) + return Term_mt.__add(a, -b) + end + + function Term_mt:__tostring() + return tostring(self.coefficient) .. " " .. self.var:name() + end + + ffi.metatype(Term, Term_mt) +end + +do + --- Expressions are a sum of terms with an added constant. + ---@class kiwi.Expression: ffi.cdata* + ---@overload fun(constant: number, ...: kiwi.Term): kiwi.Expression + ---@field constant number + ---@field package owner ffi.cdata* + ---@field package term_count number + ---@field package terms_ ffi.cdata* + ---@operator mul(number): kiwi.Expression + ---@operator div(number): kiwi.Expression + ---@operator unm: kiwi.Expression + ---@operator add(kiwi.Expression|kiwi.Term|kiwi.Var|number): kiwi.Expression + ---@operator sub(kiwi.Expression|kiwi.Term|kiwi.Var|number): kiwi.Expression + local Expression_cls = { + le = kiwi.le, + ge = kiwi.ge, + eq = kiwi.eq, + } + + ---@param expr kiwi.Expression + ---@param constant number + ---@nodiscard + local function mul_expr_coeff(expr, constant) + local ret = ffi_gc(ffi_new(Expression, expr.term_count), ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + for i = 0, expr.term_count - 1 do + local st = expr.terms_[i] --[[@as kiwi.Term]] + local dt = ret.terms_[i] --[[@as kiwi.Term]] + dt.var = st.var + dt.coefficient = st.coefficient * constant + end + ret.constant = expr.constant * constant + ret.term_count = expr.term_count + ljkiwi.kiwi_expression_retain(ret) + return ret + end + + ---@param a kiwi.Expression + ---@param b kiwi.Expression + ---@nodiscard + local function add_expr_expr(a, b) + local a_count = a.term_count + local b_count = b.term_count + local ret = ffi_gc(ffi_new(Expression, a_count + b_count), ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + + for i = 0, a_count - 1 do + local dt = ret.terms_[i] --[[@as kiwi.Term]] + local st = a.terms_[i] --[[@as kiwi.Term]] + dt.var = st.var + dt.coefficient = st.coefficient + end + for i = 0, b_count - 1 do + local dt = ret.terms_[a_count + i] --[[@as kiwi.Term]] + local st = b.terms_[i] --[[@as kiwi.Term]] + dt.var = st.var + dt.coefficient = st.coefficient + end + ret.constant = a.constant + b.constant + ret.term_count = a_count + b_count + ljkiwi.kiwi_expression_retain(ret) + return ret + end + + ---@param expr kiwi.Expression + ---@param constant number + ---@nodiscard + local function new_expr_constant(expr, constant) + local ret = ffi_gc(ffi_new(Expression, expr.term_count), ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + + for i = 0, expr.term_count - 1 do + local dt = ret.terms_[i] --[[@as kiwi.Term]] + local st = expr.terms_[i] --[[@as kiwi.Term]] + dt.var = st.var + dt.coefficient = st.coefficient + end + ret.constant = constant + ret.term_count = expr.term_count + ljkiwi.kiwi_expression_retain(ret) + return ret + end + + ---@return number + ---@nodiscard + function Expression_cls:value() + local sum = self.constant + for i = 0, self.term_count - 1 do + local t = self.terms_[i] + sum = sum + t.var:value() * t.coefficient + end + return sum + end + + ---@return kiwi.Term[] + ---@nodiscard + function Expression_cls:terms() + local terms = new_tab(self.term_count, 0) + for i = 0, self.term_count - 1 do + local t = self.terms_[i] --[[@as kiwi.Term]] + terms[i + 1] = Term(t.var, t.coefficient) + end + return terms + end + + ---@return kiwi.Expression + ---@nodiscard + function Expression_cls:copy() + return new_expr_constant(self, self.constant) + end + + local Expression_mt = { + __index = Expression_cls, + } + + function Expression_mt:__new(constant, ...) + local term_count = select("#", ...) + local e = ffi_gc(ffi_new(self, term_count), ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + e.term_count = term_count + e.constant = constant + for i = 1, term_count do + local t = select(i, ...) + local dt = e.terms_[i - 1] --[[@as kiwi.Term]] + dt.var = t.var + dt.coefficient = t.coefficient + end + ljkiwi.kiwi_expression_retain(e) + return e + end + + function Expression_mt.__mul(a, b) + if type(a) == "number" then + return mul_expr_coeff(b, a) + elseif type(b) == "number" then + return mul_expr_coeff(a, b) + end + op_error(a, b, "*") + end + + function Expression_mt.__div(a, b) + if type(b) ~= "number" then + op_error(a, b, "/") + end + return mul_expr_coeff(a, 1.0 / b) + end + + function Expression_mt:__unm() + return mul_expr_coeff(self, -1.0) + end + + function Expression_mt.__add(a, b) + if ffi_istype(Var, b) then + return add_expr_term(a, b) + elseif ffi_istype(Expression, b) then + if type(a) == "number" then + return new_expr_constant(b, a + b.constant) + else + return add_expr_expr(a, b) + end + elseif ffi_istype(Term, b) then + return add_expr_term(a, b.var, b.coefficient) + elseif type(b) == "number" then + return new_expr_constant(a, a.constant + b) + end + op_error(a, b, "+") + end + + function Expression_mt.__sub(a, b) + return Expression_mt.__add(a, -b) + end + + function Expression_mt:__tostring() + local tab = new_tab(self.term_count + 1, 0) + for i = 0, self.term_count - 1 do + local t = self.terms_[i] + tab[i + 1] = tostring(t.coefficient) .. " " .. t.var:name() + end + tab[self.term_count + 1] = self.constant + return concat(tab, " + ") + end + + ffi.metatype(Expression, Expression_mt) +end + +do + --- A constraint is a linear inequality or equality with associated strength. + --- Constraints can be built with arbitrary left and right hand expressions. But + --- ultimately they all have the form `expression [op] 0`. + ---@class kiwi.Constraint: ffi.cdata* + ---@overload fun(lhs: kiwi.Expression?, rhs: kiwi.Expression?, op: kiwi.RelOp?, strength: number?): kiwi.Constraint + local Constraint_cls = { + --- The strength of the constraint. + ---@type fun(self: kiwi.Constraint): number + strength = ljkiwi.kiwi_constraint_strength, + + --- The relational operator of the constraint. + ---@type fun(self: kiwi.Constraint): kiwi.RelOp + op = ljkiwi.kiwi_constraint_op, + + --- Whether the constraint is violated in the current solution. + ---@type fun(self: kiwi.Constraint): boolean + violated = ljkiwi.kiwi_constraint_violated, + } + + --- The reduced expression defining the constraint. + ---@return kiwi.Expression + ---@nodiscard + function Constraint_cls:expression() + local SZ = 7 + local expr = ffi_new(Expression, SZ) --[[@as kiwi.Expression]] + local n = ljkiwi.kiwi_constraint_expression(self, expr, SZ) + if n > SZ then + expr = ffi_new(Expression, n) --[[@as kiwi.Expression]] + n = ljkiwi.kiwi_constraint_expression(self, expr, n) + end + return ffi_gc(expr, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + end + + --- Add the constraint to the solver. + --- Raises: + --- KiwiErrDuplicateConstraint: The given constraint has already been added to the solver. + --- KiwiErrUnsatisfiableConstraint: The given constraint is required and cannot be satisfied. + ---@param solver kiwi.Solver + ---@return kiwi.Constraint + function Constraint_cls:add_to(solver) + solver:add_constraint(self) + return self + end + + --- Remove the constraint from the solver. + --- Raises: + --- KiwiErrUnknownConstraint: The given constraint has not been added to the solver. + ---@param solver kiwi.Solver + ---@return kiwi.Constraint + function Constraint_cls:remove_from(solver) + solver:remove_constraint(self) + return self + end + + local Constraint_mt = { + __index = Constraint_cls, + } + + function Constraint_mt:__new(lhs, rhs, op, strength) + return ffi_gc( + ljkiwi.kiwi_constraint_construct(lhs, rhs, op or "EQ", strength or REQUIRED), + ljkiwi.kiwi_constraint_release + ) + end + + local OPS = { [0] = "<=", ">=", "==" } + local STRENGTH_NAMES = { + [Strength.REQUIRED] = "required", + [Strength.STRONG] = "strong", + [Strength.MEDIUM] = "medium", + [Strength.WEAK] = "weak", + } + + function Constraint_mt:__tostring() + local strength = self:strength() + local strength_str = STRENGTH_NAMES[strength] or tostring(strength) + local op = OPS[tonumber(self:op())] + return strformat("%s %s 0 | %s", tostring(self:expression()), op, strength_str) + end + + ffi.metatype(Constraint, Constraint_mt) +end + +do + local constraints = {} + kiwi.constraints = constraints + + --- Create a constraint between a pair of variables with ratio. + --- The constraint is of the form `left [op|==] coeff right + [constant|0.0]`. + ---@param left kiwi.Var + ---@param coeff number right side term coefficient + ---@param right kiwi.Var + ---@param constant number? constant (default 0.0) + ---@param op kiwi.RelOp? relational operator (default "EQ") + ---@param strength number? strength (default REQUIRED) + ---@return kiwi.Constraint + ---@nodiscard + function constraints.pair_ratio(left, coeff, right, constant, op, strength) + assert(ffi_istype(Var, left) and ffi_istype(Var, right)) + local dt = tmpexpr.terms_[0] + dt.var = left + dt.coefficient = 1.0 + dt = tmpexpr.terms_[1] + dt.var = right + dt.coefficient = -coeff + tmpexpr.constant = constant ~= nil and constant or 0 + tmpexpr.term_count = 2 + + return ffi_gc( + ljkiwi.kiwi_constraint_construct(tmpexpr, nil, op or "EQ", strength or REQUIRED), + ljkiwi.kiwi_constraint_release + ) --[[@as kiwi.Constraint]] + end + + local pair_ratio = constraints.pair_ratio + + --- Create a constraint between a pair of variables with ratio. + --- The constraint is of the form `left [op|==] right + [constant|0.0]`. + ---@param left kiwi.Var + ---@param right kiwi.Var + ---@param constant number? constant (default 0.0) + ---@param op kiwi.RelOp? relational operator (default "EQ") + ---@param strength number? strength (default REQUIRED) + ---@return kiwi.Constraint + ---@nodiscard + function constraints.pair(left, right, constant, op, strength) + return pair_ratio(left, 1.0, right, constant, op, strength) + end + + --- Create a single term constraint + --- The constraint is of the form `var [op|==] [constant|0.0]`. + ---@param var kiwi.Var + ---@param constant number? constant (default 0.0) + ---@param op kiwi.RelOp? relational operator (default "EQ") + ---@param strength number? strength (default REQUIRED) + ---@return kiwi.Constraint + ---@nodiscard + function constraints.single(var, constant, op, strength) + assert(ffi_istype(Var, var)) + tmpexpr.constant = -(constant or 0) + tmpexpr.term_count = 1 + local t = tmpexpr.terms_[0] + t.var = var + t.coefficient = 1.0 + + return ffi_gc( + ljkiwi.kiwi_constraint_construct(tmpexpr, nil, op or "EQ", strength or REQUIRED), + ljkiwi.kiwi_constraint_release + ) --[[@as kiwi.Constraint]] + end +end + +do + local bit = require("bit") + local band, bor, lshift = bit.band, bit.bor, bit.lshift + local C = ffi.C + + --- Produce a custom error raise mask + --- Error kinds specified in the mask will not cause a lua + --- error to be raised. + ---@param kinds (kiwi.ErrKind|integer)[] + ---@param invert boolean? + ---@return integer + function kiwi.error_mask(kinds, invert) + local mask = 0 + for _, k in ipairs(kinds) do + mask = bor(mask, lshift(1, kiwi.ErrKind(k))) + end + return invert and bit.bnot(mask) or mask + end + + kiwi.ERROR_MASK_ALL = 0xFFFF + --- an error mask that raises errors only for fatal conditions + kiwi.ERROR_MASK_NON_FATAL = bit.bnot(kiwi.error_mask({ + "KiwiErrInternalSolverError", + "KiwiErrAlloc", + "KiwiErrNullObject", + "KiwiErrUnknown", + })) + + ---@class kiwi.KiwiErr: ffi.cdata* + ---@field package kind kiwi.ErrKind + ---@field package message ffi.cdata* + ---@field package must_free boolean + ---@overload fun(): kiwi.KiwiErr + local KiwiErr = ffi.typeof("struct KiwiErr") --[[@as kiwi.KiwiErr]] + + local Error_mt = { + ---@param self kiwi.Error + ---@return string + __tostring = function(self) + return strformat("%s: (%s, %s)", self.message, tostring(self.solver), tostring(self.item)) + end, + } + + ---@class kiwi.Error + ---@field kind kiwi.ErrKind + ---@field message string + ---@field solver kiwi.Solver? + ---@field item any? + kiwi.Error = Error_mt + + function kiwi.is_error(o) + return type(o) == "table" and getmetatable(o) == Error_mt + end + + ---@param kind kiwi.ErrKind + ---@param message string + ---@param solver kiwi.Solver + ---@param item any + ---@return kiwi.Error + local function new_error(kind, message, solver, item) + return setmetatable({ + kind = kind, + message = message, + solver = solver, + item = item, + }, Error_mt) + end + + ---@generic T + ---@param f fun(solver: kiwi.Solver, item: T, ...): kiwi.KiwiErr? + ---@param solver kiwi.Solver + ---@param item T + ---@return T, kiwi.Error? + local function try_solver(f, solver, item, ...) + local err = f(solver, item, ...) + if err ~= nil then + local kind = err.kind + local message = err.message ~= nil and ffi_string(err.message) or "" + if err.must_free then + C.free(err) + end + local errdata = new_error(kind, message, solver, item) + local error_mask = ljkiwi.kiwi_solver_get_error_mask(solver) + return item, + band(error_mask, lshift(1, kind --[[@as integer]])) == 0 and error(errdata) + or errdata + end + return item + end + ---@class kiwi.Solver: ffi.cdata* + ---@overload fun(error_mask: (integer|(kiwi.ErrKind|integer)[] )?): kiwi.Solver + local Solver_cls = { + --- Test whether a constraint is in the solver. + ---@type fun(self: kiwi.Solver, constraint: kiwi.Constraint): boolean + has_constraint = ljkiwi.kiwi_solver_has_constraint, + + --- Test whether an edit variable has been added to the solver. + ---@type fun(self: kiwi.Solver, var: kiwi.Var): boolean + has_edit_var = ljkiwi.kiwi_solver_has_edit_var, + + --- Update the values of the external solver variables. + ---@type fun(self: kiwi.Solver) + update_vars = ljkiwi.kiwi_solver_update_vars, + + --- Reset the solver to the empty starting conditions. + --- + --- This method resets the internal solver state to the empty starting + --- condition, as if no constraints or edit variables have been added. + --- This can be faster than deleting the solver and creating a new one + --- when the entire system must change, since it can avoid unecessary + --- heap (de)allocations. + ---@type fun(self: kiwi.Solver) + reset = ljkiwi.kiwi_solver_reset, + + --- Dump a representation of the solver to stdout. + ---@type fun(self: kiwi.Solver) + dump = ljkiwi.kiwi_solver_dump, + } + + --- Sets the error mask for the solver. + ---@param mask integer|(kiwi.ErrKind|integer)[] the mask value or an array of kinds + ---@param invert boolean? whether to invert the mask if an array was passed for mask + function Solver_cls:set_error_mask(mask, invert) + if type(mask) == "table" then + mask = kiwi.error_mask(mask, invert) + end + ljkiwi.kiwi_solver_set_error_mask(self, mask) + end + + ---@generic T + ---@param solver kiwi.Solver + ---@param items T|T[] + ---@param f fun(solver: kiwi.Solver, item: T, ...): kiwi.KiwiErr? + ---@return T|T[], kiwi.Error? + local function add_remove_items(solver, items, f, ...) + for _, item in ipairs(items) do + local _, err = try_solver(f, solver, item, ...) + if err ~= nil then + return items, err + end + end + return items + end + + --- Add a constraint to the solver. + --- Errors: + --- KiwiErrDuplicateConstraint + --- KiwiErrUnsatisfiableConstraint + ---@param constraint kiwi.Constraint + ---@return kiwi.Constraint constraint, kiwi.Error? + function Solver_cls:add_constraint(constraint) + return try_solver(ljkiwi.kiwi_solver_add_constraint, self, constraint) + end + + --- Add constraints to the solver. + --- Errors: + --- KiwiErrDuplicateConstraint + --- KiwiErrUnsatisfiableConstraint + ---@param constraints kiwi.Constraint[] + ---@return kiwi.Constraint[] constraints, kiwi.Error? + function Solver_cls:add_constraints(constraints) + return add_remove_items(self, constraints, ljkiwi.kiwi_solver_add_constraint) + end + + --- Remove a constraint from the solver. + --- Errors: + --- KiwiErrUnknownConstraint + ---@param constraint kiwi.Constraint + ---@return kiwi.Constraint constraint, kiwi.Error? + function Solver_cls:remove_constraint(constraint) + return try_solver(ljkiwi.kiwi_solver_remove_constraint, self, constraint) + end + + --- Remove constraints from the solver. + --- Errors: + --- KiwiErrUnknownConstraint + ---@param constraints kiwi.Constraint[] + ---@return kiwi.Constraint[] constraints, kiwi.Error? + function Solver_cls:remove_constraints(constraints) + return add_remove_items(self, constraints, ljkiwi.kiwi_solver_remove_constraint) + end + + --- Add an edit variables to the solver. + --- + --- This method should be called before the `suggestValue` method is + --- used to supply a suggested value for the given edit variable. + --- Errors: + --- KiwiErrDuplicateEditVariable + --- KiwiErrBadRequiredStrength: The given strength is >= required. + ---@param var kiwi.Var the variable to add as an edit variable + ---@param strength number the strength of the edit variable (must be less than `Strength.REQUIRED`) + ---@return kiwi.Var var, kiwi.Error? + function Solver_cls:add_edit_var(var, strength) + return try_solver(ljkiwi.kiwi_solver_add_edit_var, self, var, strength) + end + + --- Add edit variables to the solver. + --- + --- This method should be called before the `suggestValue` method is + --- used to supply a suggested value for the given edit variable. + --- Errors: + --- KiwiErrDuplicateEditVariable + --- KiwiErrBadRequiredStrength: The given strength is >= required. + ---@param vars kiwi.Var[] the variables to add as an edit variable + ---@param strength number the strength of the edit variables (must be less than `Strength.REQUIRED`) + ---@return kiwi.Var[] vars, kiwi.Error? + function Solver_cls:add_edit_vars(vars, strength) + return add_remove_items(self, vars, ljkiwi.kiwi_solver_add_edit_var, strength) + end + + --- Remove an edit variable from the solver. + --- Raises: + --- KiwiErrUnknownEditVariable + ---@param var kiwi.Var the edit variable to remove + ---@return kiwi.Var var, kiwi.Error? + function Solver_cls:remove_edit_var(var) + return try_solver(ljkiwi.kiwi_solver_remove_edit_var, self, var) + end + + --- Removes edit variables from the solver. + --- Raises: + --- KiwiErrUnknownEditVariable + ---@param vars kiwi.Var[] the edit variables to remove + ---@return kiwi.Var[] vars, kiwi.Error? + function Solver_cls:remove_edit_vars(vars) + return add_remove_items(self, vars, ljkiwi.kiwi_solver_remove_edit_var) + end + + --- Suggest a value for the given edit variable. + --- This method should be used after an edit variable has been added to the solver in order + --- to suggest the value for that variable. After all suggestions have been made, + --- the `update_vars` methods can be used to update the values of the external solver variables. + --- Raises: + --- KiwiErrUnknownEditVariable + ---@param var kiwi.Var the edit variable to suggest a value for + ---@param value number the suggested value + ---@return kiwi.Var var, kiwi.Error? + function Solver_cls:suggest_value(var, value) + return try_solver(ljkiwi.kiwi_solver_suggest_value, self, var, value) + end + + --- Suggest values for the given edit variables. + --- Convenience wrapper of `suggest_value` that takes tables of `kiwi.Var` and number pairs. + --- Raises: + --- KiwiErrUnknownEditVariable: The given edit variable has not been added to the solver. + ---@param vars kiwi.Var[] edit variables to suggest + ---@param values number[] suggested values + ---@return kiwi.Var[] vars, number[] values, kiwi.Error? + function Solver_cls:suggest_values(vars, values) + for i, var in ipairs(vars) do + local _, err = try_solver(ljkiwi.kiwi_solver_suggest_value, self, var, values[i]) + if err ~= nil then + return vars, values, err + end + end + return vars, values + end + + --- Dump a representation of the solver to a string. + ---@return string + ---@nodiscard + function Solver_cls:dumps() + local cs = ljkiwi.kiwi_solver_dumps(self) + local s = ffi_string(cs) + C.free(cs) + return s + end + + local Solver_mt = { + __index = Solver_cls, + } + + function Solver_mt:__new(error_mask) + if type(error_mask) == "table" then + error_mask = kiwi.error_mask(error_mask) + end + + return ffi_gc(ljkiwi.kiwi_solver_construct(error_mask or 0), ljkiwi.kiwi_solver_destroy) --[[@as kiwi.Constraint]] + end + + local Solver = ffi.metatype(ffi.typeof("struct KiwiSolver"), Solver_mt) --[[@as kiwi.Solver]] + kiwi.Solver = Solver + + function kiwi.is_solver(s) + return ffi_istype(Solver, s) + end +end + +return kiwi diff --git a/kiwi/kiwi/solver.h b/kiwi/kiwi/solver.h index 8ff2dbb..3e510f9 100644 --- a/kiwi/kiwi/solver.h +++ b/kiwi/kiwi/solver.h @@ -145,7 +145,7 @@ class Solver /* Dump a representation of the solver internals to stdout. */ - void dump() + void dump() const { debug::dump( m_impl ); } @@ -153,7 +153,7 @@ class Solver /* Dump a representation of the solver internals to a stream. */ - void dump( std::ostream& out ) + void dump( std::ostream& out ) const { debug::dump( m_impl, out ); } @@ -161,7 +161,7 @@ class Solver /* Dump a representation of the solver internals to a string. */ - std::string dumps() + std::string dumps() const { return debug::dumps( m_impl ); } diff --git a/kiwi/kiwi/variable.h b/kiwi/kiwi/variable.h index 0120a4f..662831a 100644 --- a/kiwi/kiwi/variable.h +++ b/kiwi/kiwi/variable.h @@ -64,7 +64,7 @@ class Variable void setValue(double value) { m_data->setValue(value); } // operator== is used for symbolics - bool equals(const Variable &other) + bool equals(const Variable &other) const { return m_data == other.m_data; } diff --git a/ljkiwi.hpp b/ljkiwi.hpp new file mode 100644 index 0000000..d222c8d --- /dev/null +++ b/ljkiwi.hpp @@ -0,0 +1 @@ +#include diff --git a/luakiwi/luacompat.h b/luakiwi/luacompat.h new file mode 100644 index 0000000..aed16f4 --- /dev/null +++ b/luakiwi/luacompat.h @@ -0,0 +1,155 @@ +#ifndef LJKIWI_LUACOMPAT_H_ +#define LJKIWI_LUACOMPAT_H_ + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#include +#include +#include + +#if defined(LUA_VERSION_NUM) && LUA_VERSION_NUM == 501 && defined(__GNUC__) + #define LJKIWI_LJ_COMPAT_ATTR __attribute__((weak, visibility("default"))) +#else + #define LJKIWI_LJ_COMPAT_ATTR static +#endif + +#if !defined(LUA_VERSION_NUM) || LUA_VERSION_NUM == 501 + + #define LUA_OPADD 0 + #define LUA_OPSUB 1 + #define LUA_OPMUL 2 + #define LUA_OPDIV 3 + #define LUA_OPMOD 4 + #define LUA_OPPOW 5 + #define LUA_OPUNM 6 + +static int lua_absindex(lua_State* L, int i) { + if (i < 0 && i > LUA_REGISTRYINDEX) + i += lua_gettop(L) + 1; + return i; +} + +LJKIWI_LJ_COMPAT_ATTR lua_Number lua_tonumberx(lua_State* L, int i, int* isnum) { + lua_Number n = lua_tonumber(L, i); + if (isnum != NULL) { + *isnum = (n != 0 || lua_isnumber(L, i)); + } + return n; +} + +LJKIWI_LJ_COMPAT_ATTR lua_Integer lua_tointegerx(lua_State* L, int i, int* isnum) { + int ok = 0; + lua_Number n = lua_tonumberx(L, i, &ok); + if (ok) { + if (n == (lua_Integer)n) { + if (isnum) + *isnum = 1; + return (lua_Integer)n; + } + } + if (isnum) + *isnum = 0; + return 0; +} + +static const char* luaL_tolstring(lua_State* L, int idx, size_t* len) { + if (!luaL_callmeta(L, idx, "__tostring")) { + int t = lua_type(L, idx), tt = 0; + char const* name = NULL; + switch (t) { + case LUA_TNIL: + lua_pushliteral(L, "nil"); + break; + case LUA_TSTRING: + case LUA_TNUMBER: + lua_pushvalue(L, idx); + break; + case LUA_TBOOLEAN: + if (lua_toboolean(L, idx)) + lua_pushliteral(L, "true"); + else + lua_pushliteral(L, "false"); + break; + default: + tt = luaL_getmetafield(L, idx, "__name"); + name = (tt == LUA_TSTRING) ? lua_tostring(L, -1) : lua_typename(L, t); + lua_pushfstring(L, "%s: %p", name, lua_topointer(L, idx)); + if (tt != LUA_TNIL) + lua_replace(L, -2); + break; + } + } else { + if (!lua_isstring(L, -1)) + luaL_error(L, "'__tostring' must return a string"); + } + return lua_tolstring(L, -1, len); +} + +#endif /* LUA_VERSION_NUM == 501 */ + +#if defined(LUA_VERSION_NUM) && LUA_VERSION_NUM <= 502 + +static void compat_reverse(lua_State* L, int a, int b) { + for (; a < b; ++a, --b) { + lua_pushvalue(L, a); + lua_pushvalue(L, b); + lua_replace(L, a); + lua_replace(L, b); + } +} + +static void lua_rotate(lua_State* L, int idx, int n) { + int n_elems = 0; + idx = lua_absindex(L, idx); + n_elems = lua_gettop(L) - idx + 1; + if (n < 0) + n += n_elems; + if (n > 0 && n < n_elems) { + luaL_checkstack(L, 2, "not enough stack slots available"); + n = n_elems - n; + compat_reverse(L, idx, idx + n - 1); + compat_reverse(L, idx + n, idx + n_elems - 1); + compat_reverse(L, idx, idx + n_elems - 1); + } +} + +static int lua_geti(lua_State* L, int index, lua_Integer i) { + index = lua_absindex(L, index); + lua_pushinteger(L, i); + lua_gettable(L, index); + return lua_type(L, -1); +} + +#endif /* LUA_VERSION_NUM <= 502 */ + +#if defined(LUA_VERSION_NUM) && LUA_VERSION_NUM <= 503 +static int luaL_typeerror(lua_State* L, int arg, const char* tname) { + const char* msg; + const char* typearg; /* name for the type of the actual argument */ + if (luaL_getmetafield(L, arg, "__name") == LUA_TSTRING) + typearg = lua_tostring(L, -1); /* use the given type name */ + else if (lua_type(L, arg) == LUA_TLIGHTUSERDATA) + typearg = "light userdata"; /* special name for messages */ + else + typearg = luaL_typename(L, arg); /* standard name */ + msg = lua_pushfstring(L, "%s expected, got %s", tname, typearg); + return luaL_argerror(L, arg, msg); +} + +#endif /* LUA_VERSION_NUM <= 503 */ + +#if !defined(luaL_newlibtable) + #define luaL_newlibtable(L, l) lua_createtable(L, 0, sizeof(l) / sizeof((l)[0]) - 1) +#endif + +#if !defined(luaL_checkversion) + #define luaL_checkversion(L) ((void)0) +#endif + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // LJKIWI_LUACOMPAT_H_ diff --git a/luakiwi/luakiwi-int.h b/luakiwi/luakiwi-int.h new file mode 100644 index 0000000..09adfcd --- /dev/null +++ b/luakiwi/luakiwi-int.h @@ -0,0 +1,301 @@ +#ifndef LUAKIWI_INT_H_ +#define LUAKIWI_INT_H_ + +#include + +#include +#include + +#include "luacompat.h" + +#if defined(__GNUC__) && !defined(LJKIWI_NO_BUILTIN) + #define lk_likely(x) (__builtin_expect(((x) != 0), 1)) + #define lk_unlikely(x) (__builtin_expect(((x) != 0), 0)) +#else + #define lk_likely(x) (x) + #define lk_unlikely(x) (x) +#endif + +namespace { + +using namespace kiwi; + +// Lua 5.1 compatibility for missing lua_arith. +inline void compat_arith_unm(lua_State* L) { +#if defined(LUA_VERSION_NUM) && LUA_VERSION_NUM == 501 + int isnum; + lua_Number n = lua_tonumberx(L, -1, &isnum); + if (isnum) { + lua_pop(L, 1); + lua_pushnumber(L, -n); + } else { + if (!luaL_callmeta(L, -1, "__unm")) + luaL_error(L, "attempt to perform arithmetic on a %s value", luaL_typename(L, -1)); + lua_replace(L, -2); + } +#else + lua_arith(L, LUA_OPUNM); +#endif +} + +// This version supports placeholders. +inline void setfuncs(lua_State* L, const luaL_Reg* l, int nup) { + luaL_checkstack(L, nup, "too many upvalues"); + for (; l->name != NULL; l++) { /* fill the table with given functions */ + if (l->func == NULL) /* place holder? */ + lua_pushboolean(L, 0); + else { + for (int i = 0; i < nup; i++) /* copy upvalues to the top */ + lua_pushvalue(L, -nup); + lua_pushcclosure(L, l->func, nup); /* closure with those upvalues */ + } + lua_setfield(L, -(nup + 2), l->name); + } + lua_pop(L, nup); /* remove upvalues */ +} + +template +constexpr int array_count(T (&)[N]) { + return static_cast(N); +} + +inline void newlib(lua_State* L, const luaL_Reg* l) { + lua_newtable(L); + setfuncs(L, l, 0); +} + +enum KiwiErrKind { + KiwiErrNone, + KiwiErrUnsatisfiableConstraint = 1, + KiwiErrUnknownConstraint, + KiwiErrDuplicateConstraint, + KiwiErrUnknownEditVariable, + KiwiErrDuplicateEditVariable, + KiwiErrBadRequiredStrength, + KiwiErrInternalSolverError, + KiwiErrAlloc, + KiwiErrNullObject, + KiwiErrUnknown, +}; + +struct KiwiTerm { + VariableData* var; + double coefficient; +}; + +struct KiwiExpression { + double constant; + int term_count; + ConstraintData* owner; + +#if !defined(_MSC_VER) || _MSC_VER >= 1900 + KiwiTerm terms[]; + + static constexpr std::size_t sz(int count) { + return sizeof(KiwiExpression) + sizeof(KiwiTerm) * (count > 0 ? count : 0); + } +#else + KiwiTerm terms[1]; + + static constexpr std::size_t sz(int count) { + return sizeof(KiwiExpression) + sizeof(KiwiTerm) * (count > 1 ? count - 1 : 0); + } +#endif + + KiwiExpression() = delete; + KiwiExpression(const KiwiExpression&) = delete; + KiwiExpression& operator=(const KiwiExpression&) = delete; + ~KiwiExpression() = delete; +}; + +// This mechanism was initially designed for LuaJIT FFI. +struct KiwiErr { + enum KiwiErrKind kind; + const char* message; + bool must_delete; +}; + +struct KiwiSolver { + unsigned error_mask; + Solver solver; +}; + +inline const KiwiErr* new_error(const KiwiErr* base, const std::exception& ex) { + if (!std::strcmp(ex.what(), base->message)) + return base; + + const auto msg_n = std::strlen(ex.what()) + 1; + + auto* mem = static_cast(::operator new(sizeof(KiwiErr) + msg_n, std::nothrow)); + if (!mem) { + return base; + } + auto* msg = mem + sizeof(KiwiErr); + std::memcpy(msg, ex.what(), msg_n); + return new (mem) KiwiErr {base->kind, msg, true}; +} + +template +inline const KiwiErr* wrap_err(F&& f) { + static const constexpr KiwiErr kKiwiErrUnhandledCxxException { + KiwiErrUnknown, + "An unhandled C++ exception occurred."}; + + try { + f(); + } catch (const UnsatisfiableConstraint&) { + static const constexpr KiwiErr err { + KiwiErrUnsatisfiableConstraint, + "The constraint cannot be satisfied."}; + return &err; + } catch (const UnknownConstraint&) { + static const constexpr KiwiErr err { + KiwiErrUnknownConstraint, + "The constraint has not been added to the solver."}; + return &err; + + } catch (const DuplicateConstraint&) { + static const constexpr KiwiErr err { + KiwiErrDuplicateConstraint, + "The constraint has already been added to the solver."}; + return &err; + + } catch (const UnknownEditVariable&) { + static const constexpr KiwiErr err { + KiwiErrUnknownEditVariable, + "The edit variable has not been added to the solver."}; + return &err; + + } catch (const DuplicateEditVariable&) { + static const constexpr KiwiErr err { + KiwiErrDuplicateEditVariable, + "The edit variable has already been added to the solver."}; + return &err; + + } catch (const BadRequiredStrength&) { + static const constexpr KiwiErr err { + KiwiErrBadRequiredStrength, + "A required strength cannot be used in this context."}; + return &err; + + } catch (const InternalSolverError& ex) { + static const constexpr KiwiErr base { + KiwiErrInternalSolverError, + "An internal solver error occurred."}; + return new_error(&base, ex); + } catch (std::bad_alloc&) { + static const constexpr KiwiErr err {KiwiErrAlloc, "A memory allocation failed."}; + return &err; + } catch (const std::exception& ex) { + return new_error(&kKiwiErrUnhandledCxxException, ex); + } catch (...) { + return &kKiwiErrUnhandledCxxException; + } + return nullptr; +} + +template +inline const KiwiErr* wrap_err(P&& s, F&& f) { + return wrap_err([&]() { f(s); }); +} + +template +inline const KiwiErr* wrap_err(P&& s, R&& ref, F&& f) { + return wrap_err([&]() { f(s, ref); }); +} + +template +inline T* make_unmanaged(Args... args) { + auto* o = new T(std::forward(args)...); + o->m_refcount = 1; + return o; +} + +template +inline void release_unmanaged(T* p) { + if (lk_likely(p)) { + if (--p->m_refcount == 0) + delete p; + } +} + +template +inline T* retain_unmanaged(T* p) { + if (lk_likely(p)) + p->m_refcount++; + return p; +} + +inline ConstraintData* kiwi_constraint_new( + const KiwiExpression* lhs, + const KiwiExpression* rhs, + RelationalOperator op, + double strength +) { + if (strength < 0.0) { + strength = kiwi::strength::required; + } + + std::vector terms; + terms.reserve(static_cast( + (lhs && lhs->term_count > 0 ? lhs->term_count : 0) + + (rhs && rhs->term_count > 0 ? rhs->term_count : 0) + )); + + if (lhs) { + for (int i = 0; i < lhs->term_count; ++i) { + const auto& t = lhs->terms[i]; + terms.emplace_back(Variable(t.var), t.coefficient); + } + } + if (rhs) { + for (int i = 0; i < rhs->term_count; ++i) { + const auto& t = rhs->terms[i]; + terms.emplace_back(Variable(t.var), -t.coefficient); + } + } + + return make_unmanaged( + Expression(std::move(terms), (lhs ? lhs->constant : 0.0) - (rhs ? rhs->constant : 0.0)), + static_cast(op), + strength + ); +} + +inline const KiwiErr* kiwi_solver_add_constraint(Solver& s, ConstraintData* constraint) { + return wrap_err(s, constraint, [](auto&& solver, auto&& c) { + solver.addConstraint(Constraint(c)); + }); +} + +inline const KiwiErr* kiwi_solver_remove_constraint(Solver& s, ConstraintData* constraint) { + return wrap_err(s, constraint, [](auto&& solver, auto&& c) { + solver.removeConstraint(Constraint(c)); + }); +} + +inline const KiwiErr* kiwi_solver_add_edit_var(Solver& s, VariableData* var, double strength) { + return wrap_err(s, var, [strength](auto&& solver, auto&& v) { + solver.addEditVariable(Variable(v), strength); + }); +} + +inline const KiwiErr* kiwi_solver_remove_edit_var(Solver& s, VariableData* var) { + return wrap_err(s, var, [](auto&& solver, auto&& v) { + solver.removeEditVariable(Variable(v)); + }); +} + +inline const KiwiErr* kiwi_solver_suggest_value(Solver& s, VariableData* var, double value) { + return wrap_err(s, var, [value](auto&& solver, auto&& v) { + solver.suggestValue(Variable(v), value); + }); +} + +} // namespace + +// Local Variables: +// mode: c++ +// End: + +#endif // LUAKIWI_INT_H_ diff --git a/luakiwi/luakiwi.cpp b/luakiwi/luakiwi.cpp new file mode 100644 index 0000000..198b966 --- /dev/null +++ b/luakiwi/luakiwi.cpp @@ -0,0 +1,1688 @@ +#include "ljkiwi.hpp" +#include +#include +#include + +#include "luacompat.h" +#include "luakiwi-int.h" + +namespace { + +// Note some of the internal functions do not bother cleaning up the stack, they +// are marked with accordingly. + +enum TypeId { NOTYPE, VAR = 1, TERM, EXPR, CONSTRAINT, SOLVER, ERROR, NUMBER }; + +const int ERR_KIND_TAB = NUMBER + 1; +const int VAR_SUB_FN = ERR_KIND_TAB + 1; +const int CONTEXT_TAB_MAX = VAR_SUB_FN + 1; + +constexpr const char* const lkiwi_error_kinds[] = { + "KiwiErrNone", + "KiwiErrUnsatisfiableConstraint", + "KiwiErrUnknownConstraint", + "KiwiErrDuplicateConstraint", + "KiwiErrUnknownEditVariable", + "KiwiErrDuplicateEditVariable", + "KiwiErrBadRequiredStrength", + "KiwiErrInternalSolverError", + "KiwiErrAlloc", + "KiwiErrNullObject", + "KiwiErrUnknown", +}; + +const double STRENGTH_REQUIRED = 1001001000.0; +const double STRENGTH_STRONG = 1000000.0; +const double STRENGTH_MEDIUM = 1000.0; +const double STRENGTH_WEAK = 1.0; + +kiwi::RelationalOperator get_op_opt(lua_State* L, int idx) { + size_t opn; + const char* op = luaL_optlstring(L, idx, "EQ", &opn); + + if (opn == 2) { + if (op[0] == 'E' && op[1] == 'Q') { + return kiwi::OP_EQ; + } else if (op[0] == 'L' && op[1] == 'E') { + return kiwi::OP_LE; + } else if (op[0] == 'G' && op[1] == 'E') { + return kiwi::OP_GE; + } + } + luaL_argerror(L, idx, "invalid operator"); + return kiwi::OP_EQ; +} + +inline void push_type(lua_State* L, int type_id) { + lua_rawgeti(L, lua_upvalueindex(1), type_id); +} + +// stack disposition: dirty +inline int is_udata_obj(lua_State* L, int type_id) { + int result = 0; + if (lua_isuserdata(L, 1) && lua_getmetatable(L, 1)) { + push_type(L, type_id); + result = lua_rawequal(L, -1, -2); + } + lua_pushboolean(L, result); + return 1; +} + +// get typename, copy the stack string to tidx, helpful when using +// with buffers. +const char* lk_typename(lua_State* L, int idx, int tidx) { + const char* ret = 0; + if (lua_getmetatable(L, idx)) { + lua_getfield(L, -1, "__name"); + ret = lua_tolstring(L, -1, 0); + lua_replace(L, tidx); + lua_pop(L, 1); + } + + return ret ? ret : luaL_typename(L, idx); +} + +// never returns +int op_error(lua_State* L, const char* op, int lidx, int ridx) { + luaL_Buffer buf; + size_t len; + const char* str; + + // scratch space for strings + lua_pushnil(L); + int stridx = lua_gettop(L); + + luaL_buffinit(L, &buf); + lua_pushfstring(L, "invalid operand type for '%s' %s('", op, lk_typename(L, lidx, stridx)); + luaL_addvalue(&buf); + + str = luaL_tolstring(L, lidx, &len); + lua_replace(L, stridx); + luaL_addlstring(&buf, str, len < 100 ? len : 100); + + lua_pushfstring(L, "') and %s('", lk_typename(L, ridx, stridx)); + luaL_addvalue(&buf); + + str = luaL_tolstring(L, ridx, &len); + lua_replace(L, stridx); + luaL_addlstring(&buf, str, len < 100 ? len : 100); + + luaL_addstring(&buf, "')"); + luaL_pushresult(&buf); + lua_error(L); + return 0; +} + +void check_arg_error(lua_State* L, int idx, int have_mt) { + lua_pushstring(L, "__name"); + lua_rawget(L, -2); + // TODO: simplify this. This is a bit of a hack to deal with missing args. + // Also these error messages are funky when idx is negative. + int top = lua_gettop(L); + if (idx > 0 && top <= 2 + have_mt) { + lua_pushnil(L); + lua_replace(L, idx); + } + luaL_typeerror(L, idx < 0 ? top + idx - have_mt - 2 + 1 : idx, lua_tostring(L, -1)); +} + +inline void* check_arg(lua_State* L, int idx, int type_id) { + void* udp = lua_touserdata(L, idx); + int have_mt = lua_getmetatable(L, idx); + push_type(L, type_id); + + if (lk_unlikely(!udp || !have_mt || !lua_rawequal(L, -1, -2))) + check_arg_error(L, idx, have_mt); + + lua_pop(L, 2); + return udp; +} + +inline void* try_type(lua_State* L, int idx, TypeId type_id) { + void* p = lua_touserdata(L, idx); + if (!p || !lua_getmetatable(L, idx)) + return 0; + push_type(L, type_id); + return lua_rawequal(L, -1, -2) ? p : 0; +} + +inline VariableData* try_var(lua_State* L, int idx) { + return *static_cast(try_type(L, idx, VAR)); +} + +inline KiwiTerm* try_term(lua_State* L, int idx) { + return static_cast(try_type(L, idx, TERM)); +} + +inline KiwiExpression* try_expr(lua_State* L, int idx) { + return static_cast(try_type(L, idx, EXPR)); +} + +// method to test types for expression functions +// stack disposition: dirty +inline void* try_arg(lua_State* L, int idx, TypeId* type_id, double* num) { + void* p = lua_touserdata(L, idx); + if (!p || !lua_getmetatable(L, idx)) { + int isnum; + *num = lua_tonumberx(L, idx, &isnum); + if (isnum) { + *type_id = NUMBER; + } else + *type_id = NOTYPE; + return 0; + } + + push_type(L, EXPR); + if (lua_rawequal(L, -1, -2)) { + *type_id = EXPR; + return p; + } + push_type(L, VAR); + if (lua_rawequal(L, -1, -3)) { + *type_id = VAR; + return p; + } + push_type(L, TERM); + if (lua_rawequal(L, -1, -4)) { + *type_id = TERM; + return p; + } + *type_id = NOTYPE; + return 0; +} + +inline VariableData* get_var(lua_State* L, int idx) { + return *static_cast(check_arg(L, idx, VAR)); +} + +inline KiwiTerm* get_term(lua_State* L, int idx) { + return static_cast(check_arg(L, idx, TERM)); +} + +inline KiwiExpression* get_expr(lua_State* L, int idx) { + return static_cast(check_arg(L, idx, EXPR)); +} + +inline KiwiExpression* get_expr_opt(lua_State* L, int idx) { + if (lua_isnoneornil(L, idx)) { + return 0; + } + return static_cast(check_arg(L, idx, EXPR)); +} + +inline ConstraintData* get_constraint(lua_State* L, int idx) { + return *static_cast(check_arg(L, idx, CONSTRAINT)); +} + +inline KiwiSolver* get_solver(lua_State* L, int idx) { + return static_cast(check_arg(L, idx, SOLVER)); +} + +// note this expects the 2nd upvalue to have the variable weak table +VariableData* var_new(lua_State* L, VariableData* var) { + *static_cast(lua_newuserdata(L, sizeof(VariableData*))) = var; + + push_type(L, VAR); + lua_setmetatable(L, -2); + +#if defined(LUA_VERSION_NUM) && LUA_VERSION_NUM == 501 + // a true compatibility shim has performance implications here + lua_pushlightuserdata(L, var); + lua_pushvalue(L, -2); + lua_rawset(L, lua_upvalueindex(2)); +#else + lua_pushvalue(L, -1); + lua_rawsetp(L, lua_upvalueindex(2), var); +#endif + return var; +} + +KiwiTerm* term_new(lua_State* L) { + auto* term = static_cast(lua_newuserdata(L, sizeof(KiwiTerm))); + push_type(L, TERM); + lua_setmetatable(L, -2); + return term; +} + +inline KiwiExpression* expr_new(lua_State* L, int nterms) { + auto* expr = static_cast(lua_newuserdata(L, KiwiExpression::sz(nterms))); + expr->owner = nullptr; + push_type(L, EXPR); + lua_setmetatable(L, -2); + return expr; +} + +inline ConstraintData* constraint_new( + lua_State* L, + const KiwiExpression* lhs, + const KiwiExpression* rhs, + kiwi::RelationalOperator op, + double strength +) { + auto** c = static_cast(lua_newuserdata(L, sizeof(ConstraintData*))); + *c = kiwi_constraint_new(lhs, rhs, op, strength); + + push_type(L, CONSTRAINT); + lua_setmetatable(L, -2); + return *c; +} + +// stack disposition: dirty +KiwiExpression* toexpr(lua_State* L, int idx, KiwiExpression* temp) { + void* ud = lua_touserdata(L, idx); + + if (!ud) { + int isnum; + temp->constant = lua_tonumberx(L, idx, &isnum); + temp->term_count = 0; + return isnum ? temp : 0; + } + if (!lua_getmetatable(L, idx)) + return 0; + + push_type(L, EXPR); + if (lua_rawequal(L, -1, -2)) { + return static_cast(ud); + } + + temp->constant = 0; + temp->term_count = 1; + push_type(L, VAR); + if (lua_rawequal(L, -1, -3)) { + temp->terms[0].var = *static_cast(ud); + temp->terms[0].coefficient = 1.0; + return temp; + } + push_type(L, TERM); + if (lua_rawequal(L, -1, -4)) { + temp->terms[0] = *static_cast(ud); + return temp; + } + return 0; +} + +int relop(lua_State* L, kiwi::RelationalOperator op, const char opdisp[2]) { + alignas(KiwiExpression) unsigned char tmpl[KiwiExpression::sz(1)]; + alignas(KiwiExpression) unsigned char tmpr[KiwiExpression::sz(1)]; + double strength = luaL_optnumber(L, 3, STRENGTH_REQUIRED); + const auto* lhs = toexpr(L, 1, reinterpret_cast(tmpl)); + const auto* rhs = toexpr(L, 2, reinterpret_cast(tmpr)); + + if (!lhs || !rhs) { + op_error(L, opdisp, 1, 2); + } + + constraint_new(L, lhs, rhs, op, strength); + return 1; +} + +int lkiwi_eq(lua_State* L) { + return relop(L, kiwi::OP_EQ, "=="); +} + +int lkiwi_le(lua_State* L) { + return relop(L, kiwi::OP_LE, "<="); +} + +int lkiwi_ge(lua_State* L) { + return relop(L, kiwi::OP_GE, ">="); +} + +inline int push_expr_one(lua_State* L, double constant, const KiwiTerm* term) { + auto* expr = expr_new(L, 1); + expr->constant = constant; + expr->term_count = 1; + expr->terms[0].coefficient = term->coefficient; + expr->terms[0].var = retain_unmanaged(term->var); + return 1; +} + +inline int push_expr_pair(lua_State* L, double constant, const KiwiTerm* ta, const KiwiTerm* tb) { + auto* e = expr_new(L, 2); + e->constant = constant; + e->term_count = 2; + e->terms[0].coefficient = ta->coefficient; + e->terms[0].var = retain_unmanaged(ta->var); + e->terms[1].coefficient = tb->coefficient; + e->terms[1].var = retain_unmanaged(tb->var); + return 1; +} + +inline int +push_expr_var_term(lua_State* L, double constant, VariableData* var, const KiwiTerm* t) { + auto* e = expr_new(L, 2); + e->constant = constant; + e->term_count = 2; + e->terms[0].coefficient = 1.0; + e->terms[0].var = retain_unmanaged(var); + e->terms[1].coefficient = t->coefficient; + e->terms[1].var = retain_unmanaged(t->var); + return 1; +} + +int push_add_expr_term(lua_State* L, const KiwiExpression* expr, const KiwiTerm* t) { + auto* e = expr_new(L, expr->term_count + 1); + e->constant = expr->constant; + e->term_count = expr->term_count + 1; + int i = 0; + for (; i < expr->term_count; ++i) { + e->terms[i].coefficient = expr->terms[i].coefficient; + e->terms[i].var = retain_unmanaged(expr->terms[i].var); + } + e->terms[i].coefficient = t->coefficient; + e->terms[i].var = retain_unmanaged(t->var); + return 1; +} + +int lkiwi_var_m_add(lua_State* L) { + TypeId type_id_b; + double num = 0.0; + void* arg_b = try_arg(L, 2, &type_id_b, &num); + + if (type_id_b == VAR) { + int isnum_a; + num = lua_tonumberx(L, 1, &isnum_a); + if (isnum_a) { + const KiwiTerm t {*static_cast(arg_b), 1.0}; + return push_expr_one(L, num, &t); + } + } + + auto* var_a = try_var(L, 1); + if (var_a) { + switch (type_id_b) { + case VAR: { + const KiwiTerm ta {var_a, 1.0}, tb {*static_cast(arg_b), 1.0}; + return push_expr_pair(L, 0.0, &ta, &tb); + } + case TERM: + return push_expr_var_term(L, 0.0, var_a, (static_cast(arg_b))); + case EXPR: { + const KiwiTerm t {var_a, 1.0}; + return push_add_expr_term(L, static_cast(arg_b), &t); + } + case NUMBER: { + const KiwiTerm t {var_a, 1.0}; + return push_expr_one(L, num, &t); + } + default: + break; + } + } + return op_error(L, "+", 1, 2); +} + +int lkiwi_var_m_sub(lua_State* L) { + lua_settop(L, 2); +#if defined(LUA_VERSION_NUM) && LUA_VERSION_NUM == 501 + lua_rawgeti(L, lua_upvalueindex(1), VAR_SUB_FN); + lua_insert(L, 1); + lua_call(L, 2, 1); +#else + lua_arith(L, LUA_OPUNM); + lua_arith(L, LUA_OPADD); +#endif + return 1; +} + +int lkiwi_var_m_mul(lua_State* L) { + int isnum, varidx = 2; + double num = lua_tonumberx(L, 1, &isnum); + + if (!isnum) { + varidx = 1; + num = lua_tonumberx(L, 2, &isnum); + } + + if (isnum) { + auto* var = try_var(L, varidx); + if (var) { + auto* term = term_new(L); + term->var = retain_unmanaged(var); + term->coefficient = num; + return 1; + } + } + return op_error(L, "*", 1, 2); +} + +int lkiwi_var_m_div(lua_State* L) { + auto* var = try_var(L, 1); + int isnum; + double num = lua_tonumberx(L, 2, &isnum); + if (!var || !isnum) { + return op_error(L, "/", 1, 2); + } + auto* term = term_new(L); + term->var = retain_unmanaged(var); + term->coefficient = 1.0 / num; + return 1; +} + +int lkiwi_var_m_unm(lua_State* L) { + auto* term = term_new(L); + term->var = retain_unmanaged(get_var(L, 1)); + term->coefficient = -1.0; + return 1; +} + +int lkiwi_var_m_eq(lua_State* L) { + lua_pushboolean(L, get_var(L, 1) == get_var(L, 2)); + return 1; +} + +int lkiwi_var_m_tostring(lua_State* L) { + auto* var = get_var(L, 1); + lua_pushfstring(L, "%s(%f)", var->name().c_str(), var->value()); + return 1; +} + +int lkiwi_var_m_gc(lua_State* L) { + release_unmanaged(get_var(L, 1)); + return 0; +} + +int lkiwi_var_set_name(lua_State* L) { + auto* var = get_var(L, 1); + const char* name = luaL_checkstring(L, 2); + var->setName(name); + return 0; +} + +int lkiwi_var_name(lua_State* L) { + lua_pushstring(L, get_var(L, 1)->name().c_str()); + return 1; +} + +int lkiwi_var_set(lua_State* L) { + auto* var = get_var(L, 1); + const double value = luaL_checknumber(L, 2); + var->setValue(value); + return 0; +} + +int lkiwi_var_value(lua_State* L) { + lua_pushnumber(L, get_var(L, 1)->value()); + return 1; +} + +int lkiwi_var_toterm(lua_State* L) { + auto* var = get_var(L, 1); + double coefficient = luaL_optnumber(L, 2, 1.0); + auto* term = term_new(L); + + term->var = retain_unmanaged(var); + term->coefficient = coefficient; + + return 1; +} + +int lkiwi_var_toexpr(lua_State* L) { + const KiwiTerm t {get_var(L, 1), 1.0}; + return push_expr_one(L, 0.0, &t); +} + +constexpr const struct luaL_Reg kiwi_var_m[] = { + {"__add", lkiwi_var_m_add}, + {"__sub", lkiwi_var_m_sub}, + {"__mul", lkiwi_var_m_mul}, + {"__div", lkiwi_var_m_div}, + {"__unm", lkiwi_var_m_unm}, + {"__eq", lkiwi_var_m_eq}, + {"__tostring", lkiwi_var_m_tostring}, + {"__gc", lkiwi_var_m_gc}, + {"name", lkiwi_var_name}, + {"set_name", lkiwi_var_set_name}, + {"value", lkiwi_var_value}, + {"set", lkiwi_var_set}, + {"toterm", lkiwi_var_toterm}, + {"toexpr", lkiwi_var_toexpr}, + {"eq", lkiwi_eq}, + {"le", lkiwi_le}, + {"ge", lkiwi_ge}, + {0, 0}}; + +int lkiwi_var_new(lua_State* L) { + const char* name = luaL_optstring(L, 1, ""); + var_new(L, make_unmanaged(name)); + return 1; +} + +int lkiwi_term_m_add(lua_State* L) { + TypeId type_id_b; + double num = 0.0; + void* arg_b = try_arg(L, 2, &type_id_b, &num); + + if (type_id_b == TERM) { + int isnum_a; + num = lua_tonumberx(L, 1, &isnum_a); + if (isnum_a) { + return push_expr_one(L, num, (const KiwiTerm*)arg_b); + } + } + + const auto* term_a = try_term(L, 1); + if (term_a) { + switch (type_id_b) { + case TERM: + return push_expr_pair(L, 0.0, term_a, static_cast(arg_b)); + case VAR: { + const KiwiTerm term_b {*static_cast(arg_b), 1.0}; + return push_expr_pair(L, 0.0, term_a, &term_b); + } + case EXPR: + return push_add_expr_term(L, static_cast(arg_b), term_a); + case NUMBER: + return push_expr_one(L, num, term_a); + default: + break; + } + } + return op_error(L, "+", 1, 2); +} + +int lkiwi_term_m_sub(lua_State* L) { + lua_settop(L, 2); + compat_arith_unm(L); + lkiwi_term_m_add(L); + return 1; +} + +int lkiwi_term_m_mul(lua_State* L) { + int isnum, termidx = 2; + double num = lua_tonumberx(L, 1, &isnum); + + if (!isnum) { + termidx = 1; + num = lua_tonumberx(L, 2, &isnum); + } + + if (isnum) { + const auto* term = try_term(L, termidx); + if (term) { + auto* ret = term_new(L); + ret->var = retain_unmanaged(term->var); + ret->coefficient = term->coefficient * num; + return 1; + } + } + return op_error(L, "*", 1, 2); +} + +int lkiwi_term_m_div(lua_State* L) { + const KiwiTerm* term = try_term(L, 1); + int isnum; + double num = lua_tonumberx(L, 2, &isnum); + if (!term || !isnum) { + return op_error(L, "/", 1, 2); + } + auto* ret = term_new(L); + ret->var = retain_unmanaged(term->var); + ret->coefficient = term->coefficient / num; + return 1; +} + +int lkiwi_term_m_unm(lua_State* L) { + const auto* term = get_term(L, 1); + auto* ret = term_new(L); + ret->var = retain_unmanaged(term->var); + ret->coefficient = -term->coefficient; + return 1; +} + +int lkiwi_term_toexpr(lua_State* L) { + return push_expr_one(L, 0.0, get_term(L, 1)); +} + +int lkiwi_term_value(lua_State* L) { + const auto* term = get_term(L, 1); + lua_pushnumber(L, term->var->value() * term->coefficient); + return 1; +} + +int lkiwi_term_m_tostring(lua_State* L) { + const auto* term = get_term(L, 1); + lua_pushfstring(L, "%f %s", term->coefficient, term->var->name().c_str()); + return 1; +} + +int lkiwi_term_m_gc(lua_State* L) { + release_unmanaged(get_term(L, 1)->var); + return 0; +} + +int lkiwi_term_m_index(lua_State* L) { + const auto* term = get_term(L, 1); + size_t len; + const char* k = lua_tolstring(L, 2, &len); + if (len == 3 && memcmp("var", k, len) == 0) { +#if defined(LUA_VERSION_NUM) && LUA_VERSION_NUM == 501 + lua_pushlightuserdata(L, term->var); + lua_rawget(L, lua_upvalueindex(2)); +#else + lua_rawgetp(L, lua_upvalueindex(2), term->var); +#endif + if (lua_isnil(L, -1)) + var_new(L, term->var); + return 1; + } else if (len == 11 && memcmp("coefficient", k, len) == 0) { + lua_pushnumber(L, term->coefficient); + return 1; + } + lua_getmetatable(L, 1); + lua_pushvalue(L, 2); + lua_rawget(L, -2); + if (lua_isnil(L, -1)) { + luaL_error(L, "kiwi.Term has no member named '%s'", k); + } + return 1; +} + +constexpr const struct luaL_Reg kiwi_term_m[] = { + {"__add", lkiwi_term_m_add}, + {"__sub", lkiwi_term_m_sub}, + {"__mul", lkiwi_term_m_mul}, + {"__div", lkiwi_term_m_div}, + {"__unm", lkiwi_term_m_unm}, + {"__tostring", lkiwi_term_m_tostring}, + {"__gc", lkiwi_term_m_gc}, + {"__index", 0}, + {"toexpr", lkiwi_term_toexpr}, + {"value", lkiwi_term_value}, + {"eq", lkiwi_eq}, + {"le", lkiwi_le}, + {"ge", lkiwi_ge}, + {0, 0}}; + +int lkiwi_term_new(lua_State* L) { + auto* var = get_var(L, 1); + double coefficient = luaL_optnumber(L, 2, 1.0); + auto* term = term_new(L); + term->var = retain_unmanaged(var); + term->coefficient = coefficient; + return 1; +} + +int push_expr_constant(lua_State* L, const KiwiExpression* expr, double constant) { + auto* ne = expr_new(L, expr->term_count); + for (int i = 0; i < expr->term_count; i++) { + ne->terms[i].var = retain_unmanaged(expr->terms[i].var); + ne->terms[i].coefficient = expr->terms[i].coefficient; + } + ne->constant = constant; + ne->term_count = expr->term_count; + return 1; +} + +int push_mul_expr_coeff(lua_State* L, const KiwiExpression* expr, double coeff) { + auto* ne = expr_new(L, expr->term_count); + ne->constant = expr->constant * coeff; + ne->term_count = expr->term_count; + for (int i = 0; i < expr->term_count; i++) { + ne->terms[i].var = retain_unmanaged(expr->terms[i].var); + ne->terms[i].coefficient = expr->terms[i].coefficient * coeff; + } + return 1; +} + +int push_add_expr_expr(lua_State* L, const KiwiExpression* a, const KiwiExpression* b) { + int na = a->term_count, nb = b->term_count; + + auto* ne = expr_new(L, na + nb); + ne->constant = a->constant + b->constant; + ne->term_count = na + nb; + + for (int i = 0; i < na; i++) { + ne->terms[i].var = retain_unmanaged(a->terms[i].var); + ne->terms[i].coefficient = a->terms[i].coefficient; + } + for (int i = 0; i < nb; i++) { + ne->terms[i + na].var = retain_unmanaged(b->terms[i].var); + ne->terms[i + na].coefficient = b->terms[i].coefficient; + } + return 1; +} + +int lkiwi_expr_m_add(lua_State* L) { + TypeId type_id_b; + double num = 0.0; + void* arg_b = try_arg(L, 2, &type_id_b, &num); + + if (type_id_b == EXPR) { + int isnum_a; + num = lua_tonumberx(L, 1, &isnum_a); + if (isnum_a) { + auto* expr_b = static_cast(arg_b); + return push_expr_constant(L, expr_b, num + expr_b->constant); + } + } + + const auto* expr_a = try_expr(L, 1); + if (expr_a) { + switch (type_id_b) { + case EXPR: + return push_add_expr_expr(L, expr_a, static_cast(arg_b)); + case TERM: + return push_add_expr_term(L, expr_a, static_cast(arg_b)); + case VAR: { + const KiwiTerm term_b {*static_cast(arg_b), 1.0}; + return push_add_expr_term(L, expr_a, &term_b); + } + case NUMBER: + return push_expr_constant(L, expr_a, num + expr_a->constant); + default: + break; + } + } + return op_error(L, "+", 1, 2); +} + +int lkiwi_expr_m_sub(lua_State* L) { + lua_settop(L, 2); + compat_arith_unm(L); + lkiwi_expr_m_add(L); + return 1; +} + +int lkiwi_expr_m_mul(lua_State* L) { + int isnum, expridx = 2; + double num = lua_tonumberx(L, 1, &isnum); + + if (!isnum) { + expridx = 1; + num = lua_tonumberx(L, 2, &isnum); + } + + if (isnum) { + const auto* expr = try_expr(L, expridx); + if (expr) + return push_mul_expr_coeff(L, expr, num); + } + return op_error(L, "*", 1, 2); +} + +int lkiwi_expr_m_div(lua_State* L) { + const auto* expr = try_expr(L, 1); + int isnum; + double num = lua_tonumberx(L, 2, &isnum); + if (!expr || !isnum) { + return op_error(L, "/", 1, 2); + } + return push_mul_expr_coeff(L, expr, 1.0 / num); +} + +int lkiwi_expr_m_unm(lua_State* L) { + const auto* expr = get_expr(L, 1); + return push_mul_expr_coeff(L, expr, -1.0); +} + +int lkiwi_expr_value(lua_State* L) { + const auto* expr = get_expr(L, 1); + double sum = expr->constant; + for (int i = 0; i < expr->term_count; i++) { + const auto* t = &expr->terms[i]; + sum += t->var->value() * t->coefficient; + } + lua_pushnumber(L, sum); + return 1; +} + +int lkiwi_expr_terms(lua_State* L) { + const auto* expr = get_expr(L, 1); + lua_createtable(L, expr->term_count, 0); + for (int i = 0; i < expr->term_count; i++) { + const auto* t = &expr->terms[i]; + auto* new_term = term_new(L); + new_term->var = retain_unmanaged(t->var); + new_term->coefficient = t->coefficient; + lua_rawseti(L, -2, i + 1); + } + return 1; +} + +int lkiwi_expr_copy(lua_State* L) { + auto* expr = get_expr(L, 1); + return push_expr_constant(L, expr, expr->constant); +} + +int lkiwi_expr_m_tostring(lua_State* L) { + const auto* expr = get_expr(L, 1); + luaL_Buffer buf; + luaL_buffinit(L, &buf); + + for (int i = 0; i < expr->term_count; i++) { + const auto* t = &expr->terms[i]; + lua_pushfstring(L, "%f %s", t->coefficient, t->var->name().c_str()); + luaL_addvalue(&buf); + luaL_addstring(&buf, " + "); + } + + lua_pushfstring(L, "%f", expr->constant); + luaL_addvalue(&buf); + luaL_pushresult(&buf); + + return 1; +} + +int lkiwi_expr_m_gc(lua_State* L) { + const auto* expr = get_expr(L, 1); + if (expr->owner) { + release_unmanaged(expr->owner); + } else { + for (auto* t = expr->terms; t != expr->terms + expr->term_count; ++t) { + release_unmanaged(t->var); + } + } + return 0; +} + +int lkiwi_expr_m_index(lua_State* L) { + const auto* expr = get_expr(L, 1); + size_t len; + const char* k = lua_tolstring(L, 2, &len); + if (len == 8 && memcmp("constant", k, len) == 0) { + lua_pushnumber(L, expr->constant); + return 1; + } + lua_getmetatable(L, 1); + lua_pushvalue(L, 2); + lua_rawget(L, -2); + if (lua_isnil(L, -1)) { + luaL_error(L, "kiwi.Expression has no member named '%s'", k); + } + return 1; +} + +constexpr const struct luaL_Reg kiwi_expr_m[] = { + {"__add", lkiwi_expr_m_add}, + {"__sub", lkiwi_expr_m_sub}, + {"__mul", lkiwi_expr_m_mul}, + {"__div", lkiwi_expr_m_div}, + {"__unm", lkiwi_expr_m_unm}, + {"__tostring", lkiwi_expr_m_tostring}, + {"__gc", lkiwi_expr_m_gc}, + {"__index", lkiwi_expr_m_index}, + {"value", lkiwi_expr_value}, + {"terms", lkiwi_expr_terms}, + {"copy", lkiwi_expr_copy}, + {"eq", lkiwi_eq}, + {"le", lkiwi_le}, + {"ge", lkiwi_ge}, + {0, 0}}; + +int lkiwi_expr_new(lua_State* L) { + int nterms = lua_gettop(L) - 1; + lua_Number constant = luaL_checknumber(L, 1); + + auto* expr = expr_new(L, nterms); + expr->constant = constant; + expr->term_count = nterms; + + for (int i = 0; i < nterms; i++) { + const auto* term = get_term(L, i + 2); + expr->terms[i].var = retain_unmanaged(term->var); + expr->terms[i].coefficient = term->coefficient; + } + return 1; +} + +int lkiwi_constraint_strength(lua_State* L) { + lua_pushnumber(L, get_constraint(L, 1)->strength()); + return 1; +} + +int lkiwi_constraint_op(lua_State* L) { + auto op = get_constraint(L, 1)->op(); + const char* opstr = "??"; + switch (op) { + case kiwi::OP_LE: + opstr = "LE"; + break; + case kiwi::OP_GE: + opstr = "GE"; + break; + case kiwi::OP_EQ: + opstr = "EQ"; + break; + } + lua_pushlstring(L, opstr, 2); + return 1; +} + +int lkiwi_constraint_violated(lua_State* L) { + lua_pushboolean(L, get_constraint(L, 1)->violated()); + return 1; +} + +int lkiwi_constraint_expression(lua_State* L) { + auto* c = get_constraint(L, 1); + const auto& expr = c->expression(); + const auto& terms = expr.terms(); + const auto term_count = static_cast(terms.size() > INT_MAX ? INT_MAX : terms.size()); + + auto* ne = expr_new(L, term_count); + ne->owner = retain_unmanaged(c); + ne->constant = expr.constant(); + ne->term_count = term_count; + + for (int i = 0; i < term_count; ++i) { + const auto& t = terms[static_cast(i)]; + ne->terms[i].var = const_cast(t.variable()).ptr(); + ne->terms[i].coefficient = t.coefficient(); + } + return 1; +} + +int lkiwi_constraint_m_tostring(lua_State* L) { + const auto& c = *get_constraint(L, 1); + + luaL_Buffer buf; + luaL_buffinit(L, &buf); + const char* oppart = " ?? 0 | "; + switch (c.op()) { + case kiwi::OP_LE: + oppart = " <= 0 | "; + break; + case kiwi::OP_GE: + oppart = " >= 0 | "; + break; + case kiwi::OP_EQ: + oppart = " == 0 | "; + break; + } + + const auto& expr = c.expression(); + + for (const auto& t : expr.terms()) { + lua_pushfstring(L, "%f %s", t.coefficient(), t.variable().name().c_str()); + luaL_addvalue(&buf); + luaL_addstring(&buf, " + "); + } + + lua_pushfstring(L, "%f", expr.constant()); + luaL_addvalue(&buf); + + luaL_addlstring(&buf, oppart, 8); + const char* strength_name = 0; + const double strength = c.strength(); + + if (strength == STRENGTH_REQUIRED) { + strength_name = "required"; + } else if (strength == STRENGTH_STRONG) { + strength_name = "strong"; + } else if (strength == STRENGTH_MEDIUM) { + strength_name = "medium"; + } else if (strength == STRENGTH_WEAK) { + strength_name = "weak"; + } + + if (strength_name) { + luaL_addstring(&buf, strength_name); + } else { + lua_pushfstring(L, "%f", strength); + luaL_addvalue(&buf); + } + luaL_pushresult(&buf); + + return 1; +} + +int lkiwi_constraint_m_gc(lua_State* L) { + release_unmanaged(get_constraint(L, 1)); + return 0; +} + +int lkiwi_solver_add_constraint(lua_State* L); +int lkiwi_solver_remove_constraint(lua_State* L); + +int lkiwi_constraint_add_to(lua_State* L) { + lua_settop(L, 2); + lua_rotate(L, 1, 1); + lkiwi_solver_add_constraint(L); + lua_settop(L, 2); + return 1; +} + +int lkiwi_constraint_remove_from(lua_State* L) { + lua_settop(L, 2); + lua_rotate(L, 1, 1); + lkiwi_solver_remove_constraint(L); + lua_settop(L, 2); + return 1; +} + +constexpr const struct luaL_Reg kiwi_constraint_m[] = { + {"__tostring", lkiwi_constraint_m_tostring}, + {"__gc", lkiwi_constraint_m_gc}, + {"strength", lkiwi_constraint_strength}, + {"op", lkiwi_constraint_op}, + {"violated", lkiwi_constraint_violated}, + {"expression", lkiwi_constraint_expression}, + {"add_to", lkiwi_constraint_add_to}, + {"remove_from", lkiwi_constraint_remove_from}, + {0, 0}}; + +int lkiwi_constraint_new(lua_State* L) { + const auto* lhs = get_expr_opt(L, 1); + const auto* rhs = get_expr_opt(L, 2); + const auto op = get_op_opt(L, 3); + double strength = luaL_optnumber(L, 4, STRENGTH_REQUIRED); + + constraint_new(L, lhs, rhs, op, strength); + return 1; +} + +int push_pair_constraint( + lua_State* L, + VariableData* left, + double coeff, + VariableData* right, + double constant, + kiwi::RelationalOperator op, + double strength +) { + alignas(KiwiExpression) unsigned char expr_buf[KiwiExpression::sz(2)]; + auto* expr = reinterpret_cast(&expr_buf); + expr->constant = constant; + expr->term_count = 2; + expr->terms[0].var = left; + expr->terms[0].coefficient = 1.0; + expr->terms[1].var = right; + expr->terms[1].coefficient = -coeff; + constraint_new(L, expr, 0, op, strength); + return 1; +} + +int lkiwi_constraints_pair_ratio(lua_State* L) { + return push_pair_constraint( + L, + get_var(L, 1), + luaL_checknumber(L, 2), + get_var(L, 3), + luaL_optnumber(L, 4, 0.0), + get_op_opt(L, 5), + luaL_optnumber(L, 6, STRENGTH_REQUIRED) + ); +} + +int lkiwi_constraints_pair(lua_State* L) { + return push_pair_constraint( + L, + get_var(L, 1), + 1.0, + get_var(L, 2), + luaL_optnumber(L, 3, 0.0), + get_op_opt(L, 4), + luaL_optnumber(L, 4, STRENGTH_REQUIRED) + ); +} + +int lkiwi_constraints_single(lua_State* L) { + alignas(KiwiExpression) unsigned char expr_buf[KiwiExpression::sz(1)]; + auto* expr = reinterpret_cast(&expr_buf); + expr->term_count = 1; + expr->terms[0].var = get_var(L, 1); + expr->terms[0].coefficient = 1.0; + expr->constant = luaL_optnumber(L, 2, 0.0); + + constraint_new(L, expr, 0, get_op_opt(L, 3), luaL_optnumber(L, 4, STRENGTH_REQUIRED)); + return 1; +} + +constexpr const struct luaL_Reg lkiwi_constraints[] = { + {"pair_ratio", lkiwi_constraints_pair_ratio}, + {"pair", lkiwi_constraints_pair}, + {"single", lkiwi_constraints_single}, + {0, 0}}; + +void lkiwi_mod_constraints_new(lua_State* L, int ctx_i) { + luaL_newlibtable(L, lkiwi_constraints); + lua_pushvalue(L, ctx_i); + setfuncs(L, lkiwi_constraints, 1); +} + +/* kiwi.Error */ + +void error_new(lua_State* L, const KiwiErr* err, int solver_absi, int item_absi) { + lua_createtable(L, 0, 4); + push_type(L, ERROR); + lua_setmetatable(L, -2); + + lua_pushstring(L, lkiwi_error_kinds[err->kind < KiwiErrUnknown ? err->kind : KiwiErrUnknown]); + lua_setfield(L, -2, "kind"); + + lua_pushstring(L, err->message); + lua_setfield(L, -2, "message"); + + if (solver_absi) { + lua_pushvalue(L, solver_absi); + lua_setfield(L, -2, "solver"); + } + if (item_absi) { + lua_pushvalue(L, item_absi); + lua_setfield(L, -2, "item"); + } + + if (err->must_delete) { + delete const_cast(err); + } +} + +int lkiwi_error_m_tostring(lua_State* L) { + luaL_Buffer buf; + luaL_buffinit(L, &buf); + + lua_getfield(L, 1, "message"); + luaL_addvalue(&buf); + + lua_getfield(L, 1, "solver"); + lua_pushfstring(L, ": (kiwi.Solver(%p), ", get_solver(L, -1)); + lua_remove(L, -2); // remove solver + luaL_addvalue(&buf); + + lua_getfield(L, 1, "item"); + luaL_tolstring(L, -1, 0); + lua_remove(L, -2); // remove item + luaL_addvalue(&buf); + luaL_addstring(&buf, ")"); + luaL_pushresult(&buf); + + return 1; +} + +constexpr const struct luaL_Reg lkiwi_error_m[] = { + {"__tostring", lkiwi_error_m_tostring}, + {0, 0}}; + +int lkiwi_error_mask(lua_State* L) { + int invert = lua_toboolean(L, 2); + + if (lua_type(L, 1) == LUA_TSTRING) { + luaL_typeerror(L, 1, "indexable"); + } + + lua_rawgeti(L, lua_upvalueindex(1), ERR_KIND_TAB); + + unsigned mask = 0; + for (int n = 1; lua_geti(L, 1, n) != LUA_TNIL; ++n) { + int isnum; + auto shift = lua_tointegerx(L, -1, &isnum); + if (!isnum) { + lua_rawget(L, -2 /* err_kind table */); + shift = lua_tointegerx(L, -1, &isnum); + if (!isnum) { + luaL_error(L, "unknown error kind at index %d: %s", n, luaL_tolstring(L, -2, 0)); + } + } + mask |= 1 << shift; + lua_pop(L, 1); + } + lua_pushinteger(L, invert ? ~mask : mask); + return 1; +} + +int lkiwi_solver_handle_err(lua_State* L, const KiwiErr* err, const KiwiSolver* solver) { + /* This assumes solver is at index 1 */ + lua_settop(L, 2); + if (err) { + error_new(L, err, 1, 2); + unsigned error_mask = solver->error_mask; + if (error_mask & (1 << err->kind)) { + return 2; + } else { + lua_error(L); + } + } + return 1; +} + +int lkiwi_solver_add_constraint(lua_State* L) { + auto* self = get_solver(L, 1); + auto* c = get_constraint(L, 2); + auto* err = kiwi_solver_add_constraint(self->solver, c); + return lkiwi_solver_handle_err(L, err, self); +} + +int lkiwi_solver_remove_constraint(lua_State* L) { + auto* self = get_solver(L, 1); + auto* c = get_constraint(L, 2); + auto* err = kiwi_solver_remove_constraint(self->solver, c); + return lkiwi_solver_handle_err(L, err, self); +} + +int lkiwi_solver_add_edit_var(lua_State* L) { + auto* self = get_solver(L, 1); + auto* var = get_var(L, 2); + double strength = luaL_checknumber(L, 3); + auto* err = kiwi_solver_add_edit_var(self->solver, var, strength); + return lkiwi_solver_handle_err(L, err, self); +} + +int lkiwi_solver_remove_edit_var(lua_State* L) { + auto* self = get_solver(L, 1); + auto* var = get_var(L, 2); + auto* err = kiwi_solver_remove_edit_var(self->solver, var); + return lkiwi_solver_handle_err(L, err, self); +} + +int lkiwi_solver_suggest_value(lua_State* L) { + auto* self = get_solver(L, 1); + auto* var = get_var(L, 2); + double value = luaL_checknumber(L, 3); + auto* err = kiwi_solver_suggest_value(self->solver, var, value); + return lkiwi_solver_handle_err(L, err, self); +} + +int lkiwi_solver_update_vars(lua_State* L) { + get_solver(L, 1)->solver.updateVariables(); + return 0; +} + +int lkiwi_solver_reset(lua_State* L) { + get_solver(L, 1)->solver.reset(); + return 0; +} + +int lkiwi_solver_has_constraint(lua_State* L) { + auto* s = get_solver(L, 1); + auto* c = get_constraint(L, 2); + lua_pushboolean(L, s->solver.hasConstraint(Constraint(c))); + return 1; +} + +int lkiwi_solver_has_edit_var(lua_State* L) { + auto* s = get_solver(L, 1); + auto* var = get_var(L, 2); + lua_pushboolean(L, s->solver.hasEditVariable(Variable(var))); + return 1; +} + +int lkiwi_solver_dump(lua_State* L) { + get_solver(L, 1)->solver.dump(); + return 0; +} + +int lkiwi_solver_dumps(lua_State* L) { + const auto& s = get_solver(L, 1)->solver.dumps(); + lua_pushlstring(L, s.data(), s.length()); + return 1; +} + +template +int lkiwi_add_remove_tab(lua_State* L, F&& fn) { + auto* solver = get_solver(L, 1); + int narg = lua_gettop(L); + + // block this particularly obnoxious case which is always a bug + if (lua_type(L, 2) == LUA_TSTRING) { + luaL_typeerror(L, 2, "indexable"); + } + for (int i = 1; lua_geti(L, 2, i) != LUA_TNIL; ++i) { + const KiwiErr* err = fn(L, solver); + if (err) { + error_new(L, err, 1, narg + 1 /* item_absi */); + const auto error_mask = solver->error_mask; + if (error_mask & (1 << err->kind)) { + lua_replace(L, 3); + lua_settop(L, 3); + return 2; + } else { + lua_error(L); + } + } + lua_pop(L, 1); + } + lua_settop(L, 2); + return 1; +} + +int lkiwi_solver_add_constraints(lua_State* L) { + return lkiwi_add_remove_tab(L, [](lua_State* L, KiwiSolver* s) { + return kiwi_solver_add_constraint(s->solver, get_constraint(L, -1)); + }); +} + +int lkiwi_solver_remove_constraints(lua_State* L) { + return lkiwi_add_remove_tab(L, [](lua_State* L, KiwiSolver* s) { + return kiwi_solver_add_constraint(s->solver, get_constraint(L, -1)); + }); +} + +int lkiwi_solver_add_edit_vars(lua_State* L) { + double strength = luaL_checknumber(L, 3); + return lkiwi_add_remove_tab(L, [strength](lua_State* L, KiwiSolver* s) { + return kiwi_solver_add_edit_var(s->solver, get_var(L, -1), strength); + }); +} + +int lkiwi_solver_remove_edit_vars(lua_State* L) { + return lkiwi_add_remove_tab(L, [](lua_State* L, KiwiSolver* s) { + return kiwi_solver_remove_edit_var(s->solver, get_var(L, -1)); + }); +} + +int lkiwi_solver_suggest_values(lua_State* L) { + auto* self = get_solver(L, 1); + int narg = lua_gettop(L); + + // catch this obnoxious case which is always a bug + if (lua_type(L, 2) == LUA_TSTRING) { + luaL_typeerror(L, 2, "indexable"); + } + if (lua_type(L, 3) == LUA_TSTRING) { + luaL_typeerror(L, 3, "indexable"); + } + + for (int i = 1; lua_geti(L, 2, i) != LUA_TNIL; ++i) { + auto* var = get_var(L, -1); + + lua_geti(L, 3, i); + double value = luaL_checknumber(L, -1); + + const KiwiErr* err = kiwi_solver_suggest_value(self->solver, var, value); + if (err) { + error_new(L, err, 1, narg + 1 /* item_absi */); + unsigned error_mask = self->error_mask; + if (error_mask & (1 << err->kind)) { + lua_replace(L, 4); + lua_settop(L, 4); + return 3; + } else { + lua_error(L); + } + } + lua_pop(L, 2); + } + lua_settop(L, 3); + return 2; +} + +int lkiwi_solver_set_error_mask(lua_State* L) { + auto* solver = get_solver(L, 1); + + lua_Integer error_mask; + if (lua_istable(L, 2)) { + lua_settop(L, 3); + lua_rotate(L, 1, -1); + lkiwi_error_mask(L); + error_mask = lua_tointeger(L, -1); + } else { + error_mask = luaL_checkinteger(L, 2); + } + + solver->error_mask = static_cast(error_mask); + return 0; +} + +int lkiwi_solver_m_tostring(lua_State* L) { + lua_pushfstring(L, "kiwi.Solver(%p)", get_solver(L, 1)); + return 1; +} + +int lkiwi_solver_m_gc(lua_State* L) { + get_solver(L, 1)->~KiwiSolver(); + return 0; +} + +constexpr const struct luaL_Reg kiwi_solver_m[] = { + {"add_constraint", lkiwi_solver_add_constraint}, + {"add_constraints", lkiwi_solver_add_constraints}, + {"remove_constraint", lkiwi_solver_remove_constraint}, + {"remove_constraints", lkiwi_solver_remove_constraints}, + {"add_edit_var", lkiwi_solver_add_edit_var}, + {"add_edit_vars", lkiwi_solver_add_edit_vars}, + {"remove_edit_var", lkiwi_solver_remove_edit_var}, + {"remove_edit_vars", lkiwi_solver_remove_edit_vars}, + {"suggest_value", lkiwi_solver_suggest_value}, + {"suggest_values", lkiwi_solver_suggest_values}, + {"update_vars", lkiwi_solver_update_vars}, + {"reset", lkiwi_solver_reset}, + {"has_constraint", lkiwi_solver_has_constraint}, + {"has_edit_var", lkiwi_solver_has_edit_var}, + {"dump", lkiwi_solver_dump}, + {"dumps", lkiwi_solver_dumps}, + {"set_error_mask", lkiwi_solver_set_error_mask}, + {"__tostring", lkiwi_solver_m_tostring}, + {"__gc", lkiwi_solver_m_gc}, + {0, 0}}; + +int lkiwi_solver_new(lua_State* L) { + lua_Integer error_mask; + if (lua_istable(L, 1)) { + lkiwi_error_mask(L); + error_mask = lua_tointeger(L, -1); + } else { + error_mask = luaL_optinteger(L, 1, 0); + } + + new (lua_newuserdata(L, sizeof(KiwiSolver))) KiwiSolver {static_cast(error_mask)}; + push_type(L, SOLVER); + lua_setmetatable(L, -2); + return 1; +} + +inline double clamp(double n) { + return fmax(0.0, fmin(1000, n)); +} + +int lkiwi_strength_create(lua_State* L) { + const double a = luaL_checknumber(L, 1); + const double b = luaL_checknumber(L, 2); + const double c = luaL_checknumber(L, 3); + const double w = luaL_optnumber(L, 4, 1.0); + + const double result = clamp(a * w) * 1000000.0 + clamp(b * w) * 1000.0 + clamp(c * w); + lua_pushnumber(L, result); + return 1; +} + +constexpr const struct luaL_Reg lkiwi_strength[] = {{"create", lkiwi_strength_create}, {0, 0}}; + +void lkiwi_mod_strength_new(lua_State* L) { + newlib(L, lkiwi_strength); + + lua_pushnumber(L, STRENGTH_REQUIRED); + lua_setfield(L, -2, "REQUIRED"); + + lua_pushnumber(L, STRENGTH_STRONG); + lua_setfield(L, -2, "STRONG"); + + lua_pushnumber(L, STRENGTH_MEDIUM); + lua_setfield(L, -2, "MEDIUM"); + + lua_pushnumber(L, STRENGTH_WEAK); + lua_setfield(L, -2, "WEAK"); +} + +int lkiwi_is_var(lua_State* L) { + return is_udata_obj(L, VAR); +} + +int lkiwi_is_term(lua_State* L) { + return is_udata_obj(L, TERM); +} + +int lkiwi_is_expression(lua_State* L) { + return is_udata_obj(L, EXPR); +} + +int lkiwi_is_constraint(lua_State* L) { + return is_udata_obj(L, CONSTRAINT); +} + +int lkiwi_is_solver(lua_State* L) { + return is_udata_obj(L, SOLVER); +} + +int lkiwi_is_error(lua_State* L) { + int result = 0; + if (lua_getmetatable(L, 1)) { + push_type(L, ERROR); + result = lua_rawequal(L, -1, -2); + lua_pop(L, 2); + } + lua_pushboolean(L, result); + return 1; +} + +constexpr const struct luaL_Reg lkiwi[] = { + {"Var", 0}, + {"is_var", lkiwi_is_var}, + {"Term", lkiwi_term_new}, + {"is_term", lkiwi_is_term}, + {"Expression", lkiwi_expr_new}, + {"is_expression", lkiwi_is_expression}, + {"Constraint", lkiwi_constraint_new}, + {"is_constraint", lkiwi_is_constraint}, + {"Solver", lkiwi_solver_new}, + {"is_solver", lkiwi_is_solver}, + {"error_mask", lkiwi_error_mask}, + {"is_error", lkiwi_is_error}, + {"eq", lkiwi_eq}, + {"le", lkiwi_le}, + {"ge", lkiwi_ge}, + {0, 0}}; + +int no_member_mt_index(lua_State* L) { + luaL_error(L, "attempt to access non-existent member '%s'", lua_tostring(L, 2)); + return 0; +} + +void no_member_mt_new(lua_State* L) { + lua_createtable(L, 0, 1); + lua_pushcfunction(L, no_member_mt_index); + lua_setfield(L, -2, "__index"); +} + +void register_type_n( + lua_State* L, + const char* name, + int context_absi, + int type_id, + const luaL_Reg* m, + size_t mcnt +) { + lua_createtable(L, 0, static_cast(mcnt + 2)); + lua_pushvalue(L, -2); // no_member_mt + lua_setmetatable(L, -2); + lua_pushstring(L, name); + lua_setfield(L, -2, "__name"); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + + /* set type_tab udata as upvalue */ + lua_pushvalue(L, context_absi); + setfuncs(L, m, 1); + + lua_rawseti(L, context_absi, type_id); +} + +template +constexpr inline void register_type( + lua_State* L, + const char* name, + int context_absi, + int type_id, + const luaL_Reg (&m)[N] +) { + register_type_n(L, name, context_absi, type_id, m, N); +} + +#if defined(LUA_VERSION_NUM) && LUA_VERSION_NUM == 501 +void compat_init(lua_State* L, int context_absi) { + static const char var_sub_code[] = + "local a,b=...\n" + "return a + -b"; + + if (luaL_loadbuffer(L, var_sub_code, sizeof(var_sub_code) - 1, "=kiwi internal")) + lua_error(L); + + lua_rawseti(L, context_absi, VAR_SUB_FN); +} +#else +void compat_init(lua_State*, int) {} +#endif /* Lua 5.1 */ + +} // namespace + +#if defined __GNUC__ && (!defined _WIN32 || defined __CYGWIN__) + #define LJKIWI_EXPORT __attribute__((__visibility__("default"))) +#endif + +extern "C" LJKIWI_EXPORT int luaopen_ljkiwi(lua_State* L) { + luaL_checkversion(L); + + /* context table */ + lua_createtable(L, 0, CONTEXT_TAB_MAX); + int ctx_i = lua_gettop(L); + + compat_init(L, ctx_i); + + no_member_mt_new(L); + register_type(L, "kiwi.Var", ctx_i, VAR, kiwi_var_m); + register_type(L, "kiwi.Term", ctx_i, TERM, kiwi_term_m); + register_type(L, "kiwi.Expression", ctx_i, EXPR, kiwi_expr_m); + register_type(L, "kiwi.Constraint", ctx_i, CONSTRAINT, kiwi_constraint_m); + register_type(L, "kiwi.Solver", ctx_i, SOLVER, kiwi_solver_m); + register_type(L, "kiwi.Error", ctx_i, ERROR, lkiwi_error_m); + + lua_createtable(L, 0, array_count(lkiwi) + 6); + lua_pushvalue(L, ctx_i); + setfuncs(L, lkiwi, 1); + + /* var weak table */ + /* set as upvalue for selected functions */ + lua_createtable(L, 0, 0); + lua_createtable(L, 0, 1); + lua_pushstring(L, "v"); + lua_setfield(L, -2, "__mode"); + lua_setmetatable(L, -2); + + lua_pushvalue(L, ctx_i); + lua_pushvalue(L, -2); + lua_pushcclosure(L, lkiwi_var_new, 2); + lua_setfield(L, -3, "Var"); + + lua_rawgeti(L, ctx_i, TERM); + lua_pushvalue(L, ctx_i); + lua_pushvalue(L, -3); + lua_pushcclosure(L, lkiwi_term_m_index, 2); + lua_setfield(L, -2, "__index"); + lua_pop(L, 2); // TERM mt and var weak table + + /* ErrKind table */ + /* TODO: implement __call metamethod for these */ + lua_createtable(L, array_count(lkiwi_error_kinds) + 1, array_count(lkiwi_error_kinds)); + for (int i = 0; i < array_count(lkiwi_error_kinds); i++) { + lua_pushstring(L, lkiwi_error_kinds[i]); + lua_pushvalue(L, -1); + lua_rawseti(L, -3, i); + lua_pushinteger(L, i); + lua_rawset(L, -3); + } + + lua_pushvalue(L, -1); + lua_rawseti(L, ctx_i, ERR_KIND_TAB); + lua_setfield(L, -2, "ErrKind"); + + lua_rawgeti(L, ctx_i, ERROR); + lua_setfield(L, -2, "Error"); + + lua_pushinteger(L, 0xFFFF); + lua_setfield(L, -2, "ERROR_MASK_ALL"); + + lua_pushinteger( + L, + ~((1 << KiwiErrInternalSolverError) | (1 << KiwiErrAlloc) | (1 << KiwiErrNullObject) + | (1 << KiwiErrUnknown)) + ); + lua_setfield(L, -2, "ERROR_MASK_NON_FATAL"); + + lkiwi_mod_strength_new(L); + lua_setfield(L, -2, "strength"); + + lkiwi_mod_constraints_new(L, ctx_i); + lua_setfield(L, -2, "constraints"); + + return 1; +} diff --git a/spec/constraint_spec.lua b/spec/constraint_spec.lua new file mode 100644 index 0000000..6b9f7a1 --- /dev/null +++ b/spec/constraint_spec.lua @@ -0,0 +1,125 @@ +expose("module", function() + require("kiwi") +end) + +describe("Constraint", function() + local kiwi = require("kiwi") + local LUA_VERSION = tonumber(_VERSION:match("%d+%.%d+")) + + describe("construction", function() + local v, lhs + before_each(function() + v = kiwi.Var("foo") + lhs = v + 1 + end) + + it("has correct type", function() + assert.True(kiwi.is_constraint(kiwi.Constraint())) + assert.False(kiwi.is_constraint(v)) + end) + + it("default op and strength", function() + local c = kiwi.Constraint(lhs) + assert.equal("EQ", c:op()) + assert.equal(kiwi.strength.REQUIRED, c:strength()) + end) + + it("configure op", function() + local c = kiwi.Constraint(lhs, nil, "LE") + assert.equal("LE", c:op()) + end) + it("configure strength", function() + local c = kiwi.Constraint(lhs, nil, "GE", kiwi.strength.STRONG) + assert.equal(kiwi.strength.STRONG, c:strength()) + end) + + -- TODO: standardize formatting + it("formats well", function() + local c = kiwi.Constraint(lhs) + if LUA_VERSION <= 5.2 then + assert.equal("1 foo + 1 == 0 | required", tostring(c)) + else + assert.equal("1.0 foo + 1.0 == 0 | required", tostring(c)) + end + + c = kiwi.Constraint(lhs * 2, nil, "GE", kiwi.strength.STRONG) + if LUA_VERSION <= 5.2 then + assert.equal("2 foo + 2 >= 0 | strong", tostring(c)) + else + assert.equal("2.0 foo + 2.0 >= 0 | strong", tostring(c)) + end + + c = kiwi.Constraint(lhs / 2, nil, "LE", kiwi.strength.MEDIUM) + assert.equal("0.5 foo + 0.5 <= 0 | medium", tostring(c)) + + c = kiwi.Constraint(lhs, kiwi.Expression(3), "GE", kiwi.strength.WEAK) + if LUA_VERSION <= 5.2 then + assert.equal("1 foo + -2 >= 0 | weak", tostring(c)) + else + assert.equal("1.0 foo + -2.0 >= 0 | weak", tostring(c)) + end + end) + + it("rejects invalid args", function() + assert.error(function() + local _ = kiwi.Constraint(1) + end) + assert.error(function() + local _ = kiwi.Constraint(lhs, 1) + end) + assert.error(function() + local _ = kiwi.Constraint("") + end) + assert.error(function() + local _ = kiwi.Constraint(lhs, "") + end) + assert.error(function() + local _ = kiwi.Constraint(lhs, nil, "foo") + end) + assert.error(function() + local _ = kiwi.Constraint(lhs, nil, "LE", "foo") + end) + end) + it("combines lhs and rhs", function() + local v2 = kiwi.Var("bar") + local rhs = kiwi.Expression(3, 5 * v2, 3 * v) + local c = kiwi.Constraint(lhs, rhs) + + local e = c:expression() + local t = e:terms() + assert.equal(2, #t) + if t[1].var ~= v then + t[1], t[2] = t[2], t[1] + end + assert.equal(v, t[1].var) + assert.equal(-2.0, t[1].coefficient) + assert.equal(v2, t[2].var) + assert.equal(-5.0, t[2].coefficient) + assert.equal(-2.0, e.constant) + end) + end) + + describe("method", function() + local c, v + + before_each(function() + v = kiwi.Var("foo") + c = kiwi.Constraint(2 * v + 1) + end) + + it("violated", function() + assert.True(c:violated()) + v:set(-0.5) + assert.False(c:violated()) + end) + + it("add/remove constraint", function() + local s = kiwi.Solver() + c:add_to(s) + assert.True(s:has_constraint(c)) + + c:remove_from(s) + assert.False(s:has_constraint(c)) + end) + end) +end) diff --git a/spec/solver_spec.lua b/spec/solver_spec.lua new file mode 100644 index 0000000..8b68847 --- /dev/null +++ b/spec/solver_spec.lua @@ -0,0 +1,335 @@ +expose("module", function() + require("kiwi") +end) + +describe("solver", function() + local kiwi = require("kiwi") + ---@type kiwi.Solver + local solver + + before_each(function() + solver = kiwi.Solver() + end) + + it("should create a solver", function() + assert.True(kiwi.is_solver(solver)) + assert.False(kiwi.is_solver(kiwi.Term(kiwi.Var("v1")))) + end) + + describe("edit variables", function() + local v1, v2, v3 + before_each(function() + v1 = kiwi.Var("foo") + v2 = kiwi.Var("bar") + v3 = kiwi.Var("baz") + end) + + describe("add_edit_var", function() + it("should add a variable", function() + solver:add_edit_var(v1, kiwi.strength.STRONG) + assert.True(solver:has_edit_var(v1)) + end) + + it("should return the argument", function() + assert.equal(v1, solver:add_edit_var(v1, kiwi.strength.STRONG)) + end) + + it("should error on incorrect type", function() + assert.error(function() + solver:add_edit_var("", kiwi.strength.STRONG) ---@diagnostic disable-line: param-type-mismatch + end) + assert.error(function() + solver:add_edit_var(v1, "") ---@diagnostic disable-line: param-type-mismatch + end) + end) + + it("should require a strength argument", function() + assert.error(function() + solver:add_edit_var(v1) ---@diagnostic disable-line: missing-parameter + end) + end) + + it("should error on duplicate variable", function() + solver:add_edit_var(v1, kiwi.strength.STRONG) + local _, err = pcall(function() + return solver:add_edit_var(v1, kiwi.strength.STRONG) + end) + assert.True(kiwi.is_error(err)) + assert.True(kiwi.is_solver(err.solver)) + assert.equal(v1, err.item) + assert.equal("KiwiErrDuplicateEditVariable", err.kind) + assert.equal("The edit variable has already been added to the solver.", err.message) + end) + + it("should error on invalid strength", function() + local _, err = pcall(function() + return solver:add_edit_var(v1, kiwi.strength.REQUIRED) + end) + assert.True(kiwi.is_error(err)) + assert.True(kiwi.is_solver(err.solver)) + assert.equal(v1, err.item) + assert.equal("KiwiErrBadRequiredStrength", err.kind) + assert.equal("A required strength cannot be used in this context.", err.message) + end) + + it("should return errors for duplicate variables", function() + solver:set_error_mask({ "KiwiErrDuplicateEditVariable", "KiwiErrBadRequiredStrength" }) + local ret, err = solver:add_edit_var(v1, kiwi.strength.STRONG) + assert.Nil(err) + + ret, err = solver:add_edit_var(v1, kiwi.strength.STRONG) + + assert.equal(v1, ret) + assert.True(kiwi.is_error(err)) + ---@diagnostic disable: need-check-nil + assert.True(kiwi.is_solver(err.solver)) + assert.equal(v1, err.item) + assert.equal("KiwiErrDuplicateEditVariable", err.kind) + assert.equal("The edit variable has already been added to the solver.", err.message) + ---@diagnostic enable: need-check-nil + end) + + it("should return errors for invalid strength", function() + solver:set_error_mask({ "KiwiErrDuplicateEditVariable", "KiwiErrBadRequiredStrength" }) + + ---@diagnostic disable: need-check-nil + local ret, err = solver:add_edit_var(v2, kiwi.strength.REQUIRED) + assert.equal(v2, ret) + assert.True(kiwi.is_error(err)) + assert.True(kiwi.is_solver(err.solver)) + assert.equal(v2, err.item) + assert.equal("KiwiErrBadRequiredStrength", err.kind) + assert.equal("A required strength cannot be used in this context.", err.message) + ---@diagnostic enable: need-check-nil + end) + + it("tolerates a nil self", function() + assert.error(function() + kiwi.Solver.add_edit_var(nil, v1, kiwi.strength.STRONG) ---@diagnostic disable-line: param-type-mismatch + end) + end) + + it("tolerates a nil var", function() + assert.error(function() + solver:add_edit_var(nil, kiwi.strength.STRONG) ---@diagnostic disable-line: param-type-mismatch + end) + end) + end) + + describe("add_edit_vars", function() + it("should add variables", function() + solver:add_edit_vars({ v1, v2 }, kiwi.strength.STRONG) + assert.True(solver:has_edit_var(v1)) + assert.True(solver:has_edit_var(v2)) + assert.False(solver:has_edit_var(v3)) + end) + + it("should return the argument", function() + local arg = { v1, v2, v3 } + assert.equal(arg, solver:add_edit_vars(arg, kiwi.strength.STRONG)) + end) + + it("should error on incorrect type", function() + assert.error(function() + solver:add_edit_vars(v1, kiwi.strength.STRONG) ---@diagnostic disable-line: param-type-mismatch + end) + assert.error(function() + solver:add_edit_vars("", kiwi.strength.STRONG) ---@diagnostic disable-line: param-type-mismatch + end) + assert.error(function() + solver:add_edit_vars(v1, "") ---@diagnostic disable-line: param-type-mismatch + end) + end) + + it("should require a strength argument", function() + assert.error(function() + solver:add_edit_vars({ v1, v2 }) ---@diagnostic disable-line: missing-parameter + end) + end) + + it("should error on duplicate variable", function() + local _, err = pcall(function() + return solver:add_edit_vars({ v1, v2, v3, v2, v3 }, kiwi.strength.STRONG) + end) + assert.True(kiwi.is_error(err)) + assert.True(kiwi.is_solver(err.solver)) + assert.equal(v2, err.item) + assert.equal("KiwiErrDuplicateEditVariable", err.kind) + assert.equal("The edit variable has already been added to the solver.", err.message) + end) + + it("should error on invalid strength", function() + local _, err = pcall(function() + return solver:add_edit_vars({ v1, v2 }, kiwi.strength.REQUIRED) + end) + assert.True(kiwi.is_error(err)) + assert.True(kiwi.is_solver(err.solver)) + assert.equal(v1, err.item) + assert.equal("KiwiErrBadRequiredStrength", err.kind) + assert.equal("A required strength cannot be used in this context.", err.message) + end) + + it("should return errors for duplicate variables", function() + solver:set_error_mask({ "KiwiErrDuplicateEditVariable", "KiwiErrBadRequiredStrength" }) + local ret, err = solver:add_edit_vars({ v1, v2, v3 }, kiwi.strength.STRONG) + assert.Nil(err) + + local arg = { v1, v2, v3 } + ret, err = solver:add_edit_vars(arg, kiwi.strength.STRONG) + assert.equal(arg, ret) + assert.True(kiwi.is_error(err)) + ---@diagnostic disable: need-check-nil + assert.True(kiwi.is_solver(err.solver)) + assert.equal(v1, err.item) + assert.equal("KiwiErrDuplicateEditVariable", err.kind) + assert.equal("The edit variable has already been added to the solver.", err.message) + ---@diagnostic enable: need-check-nil + end) + + it("should return errors for invalid strength", function() + solver:set_error_mask({ "KiwiErrDuplicateEditVariable", "KiwiErrBadRequiredStrength" }) + arg = { v2, v3 } + local ret, err = solver:add_edit_vars(arg, kiwi.strength.REQUIRED) + assert.equal(arg, ret) + assert.True(kiwi.is_error(err)) + ---@diagnostic disable: need-check-nil + assert.True(kiwi.is_solver(err.solver)) + assert.equal(v2, err.item) + assert.equal("KiwiErrBadRequiredStrength", err.kind) + assert.equal("A required strength cannot be used in this context.", err.message) + ---@diagnostic enable: need-check-nil + end) + + it("tolerates a nil self", function() + assert.has_error(function() + kiwi.Solver.add_edit_vars(nil, { v1, v2 }, kiwi.strength.STRONG) ---@diagnostic disable-line: param-type-mismatch + end) + end) + end) + + describe("remove_edit_var", function() + it("should remove a variable", function() + solver:add_edit_vars({ v1, v2, v3 }, kiwi.strength.STRONG) + assert.True(solver:has_edit_var(v2)) + solver:remove_edit_var(v2) + assert.True(solver:has_edit_var(v1)) + assert.False(solver:has_edit_var(v2)) + assert.True(solver:has_edit_var(v3)) + end) + + it("should return the argument", function() + solver:add_edit_var(v1, kiwi.strength.STRONG) + assert.equal(v1, solver:remove_edit_var(v1)) + end) + + it("should error on incorrect type", function() + assert.error(function() + solver:remove_edit_var("") ---@diagnostic disable-line: param-type-mismatch + end) + assert.error(function() + solver:remove_edit_var({ v1 }) ---@diagnostic disable-line: param-type-mismatch + end) + end) + + it("should error on unknown variable", function() + solver:add_edit_var(v1, kiwi.strength.STRONG) + local _, err = pcall(function() + return solver:remove_edit_var(v2) + end) + assert.True(kiwi.is_error(err)) + assert.True(kiwi.is_solver(err.solver)) + assert.equal(v2, err.item) + assert.equal("KiwiErrUnknownEditVariable", err.kind) + assert.equal("The edit variable has not been added to the solver.", err.message) + end) + + it("should return errors if requested", function() + solver:set_error_mask({ "KiwiErrDuplicateEditVariable", "KiwiErrUnknownEditVariable" }) + + local ret, err = solver:remove_edit_var(v1) + + assert.equal(v1, ret) + assert.True(kiwi.is_error(err)) + ---@diagnostic disable: need-check-nil + assert.True(kiwi.is_solver(err.solver)) + assert.equal(v1, err.item) + assert.equal("KiwiErrUnknownEditVariable", err.kind) + assert.equal("The edit variable has not been added to the solver.", err.message) + ---@diagnostic enable: need-check-nil + end) + + it("tolerates a nil self", function() + assert.has_error(function() + kiwi.Solver.remove_edit_var(nil, v1) ---@diagnostic disable-line: param-type-mismatch + end) + end) + + it("tolerates a nil var", function() + assert.has_error(function() + solver:remove_edit_var(nil) ---@diagnostic disable-line: param-type-mismatch + end) + end) + end) + + describe("remove_edit_vars", function() + it("should remove variables", function() + solver:add_edit_vars({ v1, v2, v3 }, kiwi.strength.STRONG) + assert.True(solver:has_edit_var(v2)) + assert.True(solver:has_edit_var(v3)) + + solver:remove_edit_vars({ v2, v3 }) + assert.False(solver:has_edit_var(v2)) + assert.False(solver:has_edit_var(v3)) + end) + + it("should return the argument", function() + local arg = { v1, v2, v3 } + solver:add_edit_vars(arg, kiwi.strength.STRONG) + assert.equal(arg, solver:remove_edit_vars(arg)) + end) + + it("should error on incorrect type", function() + assert.error(function() + solver:remove_edit_vars(v1) ---@diagnostic disable-line: param-type-mismatch + end) + assert.error(function() + solver:remove_edit_vars("") ---@diagnostic disable-line: param-type-mismatch + end) + end) + + it("should error on unknown variables", function() + local _, err = pcall(function() + return solver:remove_edit_vars({ v2, v1 }) + end) + assert.True(kiwi.is_error(err)) + assert.True(kiwi.is_solver(err.solver)) + assert.equal(v2, err.item) + assert.equal("KiwiErrUnknownEditVariable", err.kind) + assert.equal("The edit variable has not been added to the solver.", err.message) + end) + + it("should return errors for unknown variables", function() + solver:set_error_mask({ "KiwiErrDuplicateEditVariable", "KiwiErrUnknownEditVariable" }) + local ret, err = solver:add_edit_vars({ v1, v2 }, kiwi.strength.STRONG) + assert.Nil(err) + + local arg = { v1, v2, v3 } + ret, err = solver:remove_edit_vars(arg) + assert.equal(arg, ret) + assert.True(kiwi.is_error(err)) + ---@diagnostic disable: need-check-nil + assert.True(kiwi.is_solver(err.solver)) + assert.equal(v3, err.item) + assert.equal("KiwiErrUnknownEditVariable", err.kind) + assert.equal("The edit variable has not been added to the solver.", err.message) + ---@diagnostic enable: need-check-nil + end) + + it("tolerates a nil self", function() + assert.has_error(function() + kiwi.Solver.remove_edit_vars(nil, { v1, v2 }) ---@diagnostic disable-line: param-type-mismatch + end) + end) + end) + end) +end) diff --git a/spec/var_spec.lua b/spec/var_spec.lua new file mode 100644 index 0000000..3bd693a --- /dev/null +++ b/spec/var_spec.lua @@ -0,0 +1,186 @@ +expose("module", function() + require("kiwi") +end) + +describe("Var", function() + local kiwi = require("kiwi") + + it("construction", function() + assert.True(kiwi.is_var(kiwi.Var())) + assert.False(kiwi.is_var(kiwi.Constraint())) + + assert.error(function() + kiwi.Var({}) + end) + end) + + describe("method", function() + local v + + before_each(function() + v = kiwi.Var("goo") + end) + + it("has settable name", function() + assert.equal("goo", v:name()) + v:set_name("Δ") + assert.equal("Δ", v:name()) + assert.error(function() + v:set_name({}) + end) + end) + + it("has a initial value of 0.0", function() + assert.equal(0.0, v:value()) + end) + + it("has a settable value", function() + v:set(47.0) + assert.equal(47.0, v:value()) + end) + + it("neg", function() + local neg = -v --[[@as kiwi.Term]] + assert.True(kiwi.is_term(neg)) + assert.equal(v, neg.var) + assert.equal(-1.0, neg.coefficient) + end) + + describe("bin op", function() + local v2 + before_each(function() + v2 = kiwi.Var("foo") + end) + + it("mul", function() + for _, prod in ipairs({ v * 2.0, 2 * v }) do + assert.True(kiwi.is_term(prod)) + assert.equal(v, prod.var) + assert.equal(2.0, prod.coefficient) + end + + assert.error(function() + local _ = v * v2 + end) + end) + + it("div", function() + local quot = v / 2.0 + assert.True(kiwi.is_term(quot)) + assert.equal(v, quot.var) + assert.equal(0.5, quot.coefficient) + + assert.error(function() + local _ = v / v2 + end) + end) + + it("add", function() + for _, sum in ipairs({ v + 2.0, 2 + v }) do + assert.True(kiwi.is_expression(sum)) + assert.equal(2.0, sum.constant) + + local terms = sum:terms() + assert.equal(1, #terms) + assert.equal(1.0, terms[1].coefficient) + assert.equal(v, terms[1].var) + end + + local sum = v + v2 + assert.True(kiwi.is_expression(sum)) + assert.equal(0, sum.constant) + local terms = sum:terms() + assert.equal(2, #terms) + assert.equal(v, terms[1].var) + assert.equal(1.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(1.0, terms[2].coefficient) + + assert.error(function() + local _ = v + "foo" + end) + assert.error(function() + local _ = v + {} + end) + end) + + it("sub", function() + local constants = { -2, 2 } + for i, diff in ipairs({ v - 2.0, 2 - v }) do + local constant = constants[i] + assert.True(kiwi.is_expression(diff)) + assert.equal(constant, diff.constant) + + local terms = diff:terms() + assert.equal(1, #terms) + assert.equal(v, terms[1].var) + assert.equal(constant < 0 and 1 or -1, terms[1].coefficient) + end + + local diff = v - v2 + assert.True(kiwi.is_expression(diff)) + assert.equal(0, diff.constant) + local terms = diff:terms() + assert.equal(2, #terms) + assert.equal(v, terms[1].var) + assert.equal(1.0, terms[1].coefficient) + assert.equal(v2, terms[2].var) + assert.equal(-1.0, terms[2].coefficient) + + assert.error(function() + local _ = v - "foo" + end) + assert.error(function() + local _ = v - {} + end) + end) + + it("constraint var op expr", function() + local ops = { "LE", "EQ", "GE" } + for i, meth in ipairs({ "le", "eq", "ge" }) do + local c = v[meth](v, v2 + 1) + assert.True(kiwi.is_constraint(c)) + + local e = c:expression() + local t = e:terms() + assert.equal(2, #t) + + -- order can be randomized due to use of map + if t[1].var ~= v then + t[1], t[2] = t[2], t[1] + end + assert.equal(v, t[1].var) + assert.equal(1.0, t[1].coefficient) + assert.equal(v2, t[2].var) + assert.equal(-1.0, t[2].coefficient) + + assert.equal(-1, e.constant) + assert.equal(ops[i], c:op()) + assert.equal(kiwi.strength.REQUIRED, c:strength()) + end + end) + + it("constraint var op var", function() + for i, meth in ipairs({ "le", "eq", "ge" }) do + local c = v[meth](v, v2) + assert.True(kiwi.is_constraint(c)) + + local e = c:expression() + local t = e:terms() + assert.equal(2, #t) + + -- order can be randomized due to use of map + if t[1].var ~= v then + t[1], t[2] = t[2], t[1] + end + assert.equal(v, t[1].var) + assert.equal(1.0, t[1].coefficient) + assert.equal(v2, t[2].var) + assert.equal(-1.0, t[2].coefficient) + + assert.equal(0, e.constant) + end + end) + end) + end) +end)