diff --git a/ckiwi/ckiwi.cpp b/ckiwi/ckiwi.cpp index 07bbe3f..56c232c 100644 --- a/ckiwi/ckiwi.cpp +++ b/ckiwi/ckiwi.cpp @@ -2,6 +2,7 @@ #include +#include #include #include #include @@ -181,7 +182,7 @@ void kiwi_var_set_value(KiwiVar* var, double value) { void kiwi_expression_retain(KiwiExpression* expr) { if (lk_unlikely(!expr)) return; - for (auto* t = expr->terms_; t != expr->terms_ + expr->term_count; ++t) { + for (auto* t = expr->terms_; t != expr->terms_ + std::max(expr->term_count, 0); ++t) { retain_unmanaged(t->var); } expr->owner = expr; @@ -192,7 +193,7 @@ void kiwi_expression_destroy(KiwiExpression* expr) { return; if (expr->owner == expr) { - for (auto* t = expr->terms_; t != expr->terms_ + expr->term_count; ++t) { + for (auto* t = expr->terms_; t != expr->terms_ + std::max(expr->term_count, 0); ++t) { release_unmanaged(t->var); } } else { @@ -200,6 +201,51 @@ void kiwi_expression_destroy(KiwiExpression* expr) { } } +void kiwi_expression_add_term( + const KiwiExpression* expr, + KiwiVar* var, + double coefficient, + KiwiExpression* out +) { + if (lk_unlikely(!expr || expr->term_count == INT_MAX || expr->term_count < 0)) { + out->term_count = 0; + return; + } + + out->owner = out; + out->term_count = expr->term_count + 1; + out->constant = expr->constant; + + auto* d = out->terms_; + for (auto* s = expr->terms_; s != expr->terms_ + expr->term_count; ++s, ++d) { + d->var = retain_unmanaged(s->var); + d->coefficient = s->coefficient; + } + d->var = retain_unmanaged(var); + d->coefficient = coefficient; +} + +void kiwi_expression_set_constant( + const KiwiExpression* expr, + double constant, + KiwiExpression* out +) { + if (lk_unlikely(!expr || expr->term_count < 0)) { + out->term_count = 0; + return; + } + + out->owner = out; + out->term_count = expr->term_count; + out->constant = constant; + + auto* d = out->terms_; + for (auto* s = expr->terms_; s != expr->terms_ + expr->term_count; ++s, ++d) { + d->var = retain_unmanaged(s->var); + d->coefficient = s->coefficient; + } +} + KiwiConstraint* kiwi_constraint_new( const KiwiExpression* lhs, const KiwiExpression* rhs, diff --git a/ckiwi/ckiwi.h b/ckiwi/ckiwi.h index 15d2213..e3e914e 100644 --- a/ckiwi/ckiwi.h +++ b/ckiwi/ckiwi.h @@ -91,6 +91,14 @@ LJKIWI_EXP void kiwi_var_set_value(KiwiVar* var, double value); LJKIWI_EXP void kiwi_expression_retain(KiwiExpression* expr); LJKIWI_EXP void kiwi_expression_destroy(KiwiExpression* expr); +LJKIWI_EXP void kiwi_expression_add_term( + const KiwiExpression* expr, + KiwiVar* var, + double coefficient, + KiwiExpression* out +); +LJKIWI_EXP void +kiwi_expression_set_constant(const KiwiExpression* expr, double constant, KiwiExpression* out); LJKIWI_EXP KiwiConstraint* kiwi_constraint_new( const KiwiExpression* lhs, diff --git a/kiwi.lua b/kiwi.lua index 7c03db2..97fb437 100644 --- a/kiwi.lua +++ b/kiwi.lua @@ -63,7 +63,8 @@ typedef struct KiwiExpression { void kiwi_expression_retain(KiwiExpression* expr); void kiwi_expression_destroy(KiwiExpression* expr); - +void kiwi_expression_add_term(const KiwiExpression* expr, KiwiVar* var, double coeff, KiwiExpression* out); +void kiwi_expression_set_constant(const KiwiExpression* expr, double constant, KiwiExpression* out); ]]) if RUST then @@ -95,9 +96,6 @@ struct KiwiVar { }; void kiwi_var_free(KiwiVar* var); - -void kiwi_expression_add_term(const KiwiExpression* expr, KiwiVar* var, double coeff, KiwiExpression* out); -void kiwi_expression_set_constant(const KiwiExpression* expr, double constant, KiwiExpression* out); ]]) else ffi.cdef([[ @@ -298,29 +296,14 @@ else end end -local add_expr_term -if RUST then - ---@param expr kiwi.Expression - ---@param var kiwi.Var - ---@param coeff number? - ---@nodiscard - function add_expr_term(expr, var, coeff) - local ret = ffi_new(Expression, expr.term_count + 1) - ljkiwi.kiwi_expression_add_term(expr, var, coeff or 1.0, ret) - return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] - end -else - function add_expr_term(expr, var, coeff) - local ret = ffi_new(Expression, expr.term_count + 1) --[[@as kiwi.Expression]] - ffi_copy(ret.terms_, expr.terms_, SIZEOF_TERM * expr.term_count) - local dt = ret.terms_[expr.term_count] - dt.var = var - dt.coefficient = coeff or 1.0 - ret.constant = expr.constant - ret.term_count = expr.term_count + 1 - ljkiwi.kiwi_expression_retain(ret) - return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] - end +---@param expr kiwi.Expression +---@param var kiwi.Var +---@param coeff number? +---@nodiscard +local function add_expr_term(expr, var, coeff) + local ret = ffi_new(Expression, expr.term_count + 1) + ljkiwi.kiwi_expression_add_term(expr, var, coeff or 1.0, ret) + return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] end ---@param constant number @@ -738,25 +721,13 @@ do return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] end - local new_expr_constant - if RUST then - ---@param expr kiwi.Expression - ---@param constant number - ---@nodiscard - function new_expr_constant(expr, constant) - local ret = ffi_new(Expression, expr.term_count) - ljkiwi.kiwi_expression_set_constant(expr, constant, ret) - return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] - end - else - function new_expr_constant(expr, constant) - local ret = ffi_new(Expression, expr.term_count) --[[@as kiwi.Expression]] - ffi_copy(ret.terms_, expr.terms_, SIZEOF_TERM * expr.term_count) - ret.constant = constant - ret.term_count = expr.term_count - ljkiwi.kiwi_expression_retain(ret) - return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] - end + ---@param expr kiwi.Expression + ---@param constant number + ---@nodiscard + local function new_expr_constant(expr, constant) + local ret = ffi_new(Expression, expr.term_count) + ljkiwi.kiwi_expression_set_constant(expr, constant, ret) + return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] end ---@return number diff --git a/rjkiwi/src/expr.rs b/rjkiwi/src/expr.rs index 839b524..3eb90cf 100644 --- a/rjkiwi/src/expr.rs +++ b/rjkiwi/src/expr.rs @@ -40,7 +40,7 @@ impl KiwiExpression { } unsafe fn try_from_raw<'a>(e: *const KiwiExpressionPtr) -> Option<&'a Self> { - if e.is_null() { + if e.is_null() || (*e).term_count < 0 { None } else { Some(Self::from_raw(e)) @@ -56,10 +56,15 @@ pub unsafe extern "C" fn kiwi_expression_add_term( out: *mut KiwiExpressionPtr, ) { let Some(e) = KiwiExpression::try_from_raw(e) else { + (*out).term_count = 0; return; }; + if e.terms_.len() >= c_int::MAX as usize { + (*out).term_count = 0; + return; + } - let n_terms = (e.terms_.len() + 1).min(c_int::MAX as usize); + let n_terms = e.terms_.len() + 1; let out = core::slice::from_raw_parts_mut(out as *mut (), n_terms) as *mut [()] as *mut KiwiExpression; @@ -85,6 +90,7 @@ pub unsafe extern "C" fn kiwi_expression_set_constant( out: *mut KiwiExpressionPtr, ) { let Some(e) = KiwiExpression::try_from_raw(e) else { + (*out).term_count = 0; return; }; diff --git a/rjkiwi/src/lib.rs b/rjkiwi/src/lib.rs index 71a8e3c..88ee656 100644 --- a/rjkiwi/src/lib.rs +++ b/rjkiwi/src/lib.rs @@ -7,7 +7,7 @@ pub mod solver; mod util; pub mod var; -mod mem { +pub mod mem { use std::{ ffi::{c_char, CString}, ptr::NonNull, @@ -20,3 +20,10 @@ mod mem { } } } + +pub mod ffi { + pub use crate::expr::*; + pub use crate::mem::*; + pub use crate::solver::*; + pub use crate::var::*; +}