diff --git a/Makefile b/Makefile index 87236e3..fa8c069 100644 --- a/Makefile +++ b/Makefile @@ -5,14 +5,14 @@ CFLAGS += -Wall -I$(SRCDIR)/kiwi LIBFLAG := -shared LIB_EXT := so -ifeq ($(findstring gcc, $(CC)), gcc) - CXX := $(subst gcc, g++, $(CC)) +ifeq ($(findstring gcc,$(CC)),gcc) + CXX := $(subst gcc,g++,$(CC)) CXXFLAGS += -std=c++14 ifneq ($(SANITIZE),) CFLAGS += -fsanitize=undefined -fsanitize=address endif else -ifeq ($(CC), clang) +ifeq ($(CC),clang) CXX := clang++ CXXFLAGS += -std=c++14 ifneq ($(SANITIZE),) diff --git a/kiwi.lua b/kiwi.lua index 9a60ab0..3a9e674 100644 --- a/kiwi.lua +++ b/kiwi.lua @@ -1,6 +1,15 @@ -local kiwi = {} +-- kiwi.lua - LuaJIT FFI bindings with C API fallback to kiwi constraint solver. + +local ffi +do + local ffi_loader = package.preload["ffi"] + if ffi_loader == nil then + return require("ckiwi") + end + ffi = ffi_loader() --[[@as ffilib]] +end -local ffi = require("ffi") +local kiwi = {} local ckiwi do @@ -101,7 +110,7 @@ local ffi_copy, ffi_gc, ffi_istype, ffi_new, ffi_string = local concat = table.concat local has_table_new, new_tab = pcall(require, "table.new") if not has_table_new or type(new_tab) ~= "function" then - new_tab = function() + new_tab = function(_, _) return {} end end @@ -532,7 +541,7 @@ end do --- Expressions are a sum of terms with an added constant. ---@class kiwi.Expression: ffi.cdata* - ---@overload fun(terms: kiwi.Term[], constant: number?): kiwi.Expression + ---@overload fun(constant: number, ...: kiwi.Term): kiwi.Expression ---@field constant number ---@field package term_count number ---@field package terms_ ffi.cdata* @@ -610,7 +619,8 @@ do function Expression_cls:value() local sum = self.constant for i = 0, self.term_count - 1 do - sum = sum + self.terms_[i]:value() + local t = self.terms_[i] + sum = sum + t.var:value() * t.coefficient end return sum end @@ -636,17 +646,16 @@ do __index = Expression_cls, } - function Expression_mt.__new(T, terms, constant) - local term_count = terms and #terms or 0 + function Expression_mt.__new(T, constant, ...) + local term_count = select("#", ...) local e = ffi_gc(ffi_new(T, term_count), ckiwi.kiwi_expression_del_vars) --[[@as kiwi.Expression]] e.term_count = term_count - e.constant = constant or 0.0 - if terms then - for i, t in ipairs(terms) do - local dt = e.terms_[i - 1] --[[@as kiwi.Term]] - dt.var = ckiwi.kiwi_var_clone(t.var) - dt.coefficient = t.coefficient - end + e.constant = constant + for i = 1, term_count do + local t = select(i, ...) + local dt = e.terms_[i - 1] --[[@as kiwi.Term]] + dt.var = ckiwi.kiwi_var_clone(t.var) + dt.coefficient = t.coefficient end return e end @@ -729,7 +738,7 @@ do ---@return kiwi.Expression ---@nodiscard function Constraint_cls:expression() - local SZ = 8 + local SZ = 7 -- 2**7 bytes on x64 local expr = ffi_new(Expression, SZ) --[[@as kiwi.Expression]] local n = ckiwi.kiwi_constraint_expression(self, expr, SZ) if n > SZ then @@ -865,7 +874,7 @@ do --- Produce a custom error raise mask --- Error kinds specified in the mask will not cause a lua --- error to be raised. - ---@param kinds (kiwi.ErrKind|number)[] + ---@param kinds (kiwi.ErrKind|integer)[] ---@param invert boolean? ---@return integer function kiwi.error_mask(kinds, invert) @@ -948,7 +957,7 @@ do end ---@class kiwi.Solver: ffi.cdata* ---@field package error_mask_ integer - ---@overload fun(error_mask: (integer|(kiwi.ErrKind|number)[] )?): kiwi.Solver + ---@overload fun(error_mask: (integer|(kiwi.ErrKind|integer)[] )?): kiwi.Solver local Solver_cls = { --- Test whether a constraint is in the solver. ---@type fun(self: kiwi.Solver, constraint: kiwi.Constraint): boolean @@ -978,7 +987,7 @@ do } --- Sets the error mask for the solver. - ---@param mask integer|(kiwi.ErrKind|number)[] the mask value or an array of kinds + ---@param mask integer|(kiwi.ErrKind|integer)[] the mask value or an array of kinds ---@param invert boolean? whether to invert the mask if an array was passed for mask function Solver_cls:set_error_mask(mask, invert) if type(mask) == "table" then diff --git a/spec/constraint_spec.lua b/spec/constraint_spec.lua index 137a228..61144c0 100644 --- a/spec/constraint_spec.lua +++ b/spec/constraint_spec.lua @@ -42,7 +42,7 @@ describe("Constraint", function() c = kiwi.Constraint(lhs / 2, nil, "LE", kiwi.strength.MEDIUM) assert.equal("0.5 foo + 0.5 <= 0 | medium", tostring(c)) - c = kiwi.Constraint(lhs, kiwi.Expression(nil, 3), "GE", kiwi.strength.WEAK) + c = kiwi.Constraint(lhs, kiwi.Expression(3), "GE", kiwi.strength.WEAK) assert.equal("1 foo + -2 >= 0 | weak", tostring(c)) end) @@ -68,7 +68,7 @@ describe("Constraint", function() end) it("combines lhs and rhs", function() local v2 = kiwi.Var("bar") - local rhs = kiwi.Expression({ 5 * v2, 3 * v }, 3) + local rhs = kiwi.Expression(3, 5 * v2, 3 * v) local c = kiwi.Constraint(lhs, rhs) local e = c:expression()