Skip to content

Commit

Permalink
Put some Rust specific binding functions in C++
Browse files Browse the repository at this point in the history
Make rust module with re-exports to allow combining with parent crates.
  • Loading branch information
jkl1337 committed Mar 9, 2024
1 parent 9fb7799 commit 27ac5df
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 51 deletions.
50 changes: 48 additions & 2 deletions ckiwi/ckiwi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <kiwi/kiwi.h>

#include <algorithm>
#include <climits>
#include <cstdlib>
#include <cstring>
Expand Down Expand Up @@ -181,7 +182,7 @@ void kiwi_var_set_value(KiwiVar* var, double value) {
void kiwi_expression_retain(KiwiExpression* expr) {
if (lk_unlikely(!expr))
return;
for (auto* t = expr->terms_; t != expr->terms_ + expr->term_count; ++t) {
for (auto* t = expr->terms_; t != expr->terms_ + std::max(expr->term_count, 0); ++t) {
retain_unmanaged(t->var);
}
expr->owner = expr;
Expand All @@ -192,14 +193,59 @@ void kiwi_expression_destroy(KiwiExpression* expr) {
return;

if (expr->owner == expr) {
for (auto* t = expr->terms_; t != expr->terms_ + expr->term_count; ++t) {
for (auto* t = expr->terms_; t != expr->terms_ + std::max(expr->term_count, 0); ++t) {
release_unmanaged(t->var);
}
} else {
release_unmanaged(static_cast<ConstraintData*>(expr->owner));
}
}

void kiwi_expression_add_term(
const KiwiExpression* expr,
KiwiVar* var,
double coefficient,
KiwiExpression* out
) {
if (lk_unlikely(!expr || expr->term_count == INT_MAX || expr->term_count < 0)) {
out->term_count = 0;
return;
}

out->owner = out;
out->term_count = expr->term_count + 1;
out->constant = expr->constant;

auto* d = out->terms_;
for (auto* s = expr->terms_; s != expr->terms_ + expr->term_count; ++s, ++d) {
d->var = retain_unmanaged(s->var);
d->coefficient = s->coefficient;
}
d->var = retain_unmanaged(var);
d->coefficient = coefficient;
}

void kiwi_expression_set_constant(
const KiwiExpression* expr,
double constant,
KiwiExpression* out
) {
if (lk_unlikely(!expr || expr->term_count < 0)) {
out->term_count = 0;
return;
}

out->owner = out;
out->term_count = expr->term_count;
out->constant = constant;

auto* d = out->terms_;
for (auto* s = expr->terms_; s != expr->terms_ + expr->term_count; ++s, ++d) {
d->var = retain_unmanaged(s->var);
d->coefficient = s->coefficient;
}
}

KiwiConstraint* kiwi_constraint_new(
const KiwiExpression* lhs,
const KiwiExpression* rhs,
Expand Down
8 changes: 8 additions & 0 deletions ckiwi/ckiwi.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ LJKIWI_EXP void kiwi_var_set_value(KiwiVar* var, double value);

LJKIWI_EXP void kiwi_expression_retain(KiwiExpression* expr);
LJKIWI_EXP void kiwi_expression_destroy(KiwiExpression* expr);
LJKIWI_EXP void kiwi_expression_add_term(
const KiwiExpression* expr,
KiwiVar* var,
double coefficient,
KiwiExpression* out
);
LJKIWI_EXP void
kiwi_expression_set_constant(const KiwiExpression* expr, double constant, KiwiExpression* out);

LJKIWI_EXP KiwiConstraint* kiwi_constraint_new(
const KiwiExpression* lhs,
Expand Down
63 changes: 17 additions & 46 deletions kiwi.lua
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ typedef struct KiwiExpression {
void kiwi_expression_retain(KiwiExpression* expr);
void kiwi_expression_destroy(KiwiExpression* expr);
void kiwi_expression_add_term(const KiwiExpression* expr, KiwiVar* var, double coeff, KiwiExpression* out);
void kiwi_expression_set_constant(const KiwiExpression* expr, double constant, KiwiExpression* out);
]])

if RUST then
Expand Down Expand Up @@ -95,9 +96,6 @@ struct KiwiVar {
};
void kiwi_var_free(KiwiVar* var);
void kiwi_expression_add_term(const KiwiExpression* expr, KiwiVar* var, double coeff, KiwiExpression* out);
void kiwi_expression_set_constant(const KiwiExpression* expr, double constant, KiwiExpression* out);
]])
else
ffi.cdef([[
Expand Down Expand Up @@ -298,29 +296,14 @@ else
end
end

local add_expr_term
if RUST then
---@param expr kiwi.Expression
---@param var kiwi.Var
---@param coeff number?
---@nodiscard
function add_expr_term(expr, var, coeff)
local ret = ffi_new(Expression, expr.term_count + 1)
ljkiwi.kiwi_expression_add_term(expr, var, coeff or 1.0, ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
end
else
function add_expr_term(expr, var, coeff)
local ret = ffi_new(Expression, expr.term_count + 1) --[[@as kiwi.Expression]]
ffi_copy(ret.terms_, expr.terms_, SIZEOF_TERM * expr.term_count)
local dt = ret.terms_[expr.term_count]
dt.var = var
dt.coefficient = coeff or 1.0
ret.constant = expr.constant
ret.term_count = expr.term_count + 1
ljkiwi.kiwi_expression_retain(ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
end
---@param expr kiwi.Expression
---@param var kiwi.Var
---@param coeff number?
---@nodiscard
local function add_expr_term(expr, var, coeff)
local ret = ffi_new(Expression, expr.term_count + 1)
ljkiwi.kiwi_expression_add_term(expr, var, coeff or 1.0, ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
end

---@param constant number
Expand Down Expand Up @@ -738,25 +721,13 @@ do
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
end

local new_expr_constant
if RUST then
---@param expr kiwi.Expression
---@param constant number
---@nodiscard
function new_expr_constant(expr, constant)
local ret = ffi_new(Expression, expr.term_count)
ljkiwi.kiwi_expression_set_constant(expr, constant, ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
end
else
function new_expr_constant(expr, constant)
local ret = ffi_new(Expression, expr.term_count) --[[@as kiwi.Expression]]
ffi_copy(ret.terms_, expr.terms_, SIZEOF_TERM * expr.term_count)
ret.constant = constant
ret.term_count = expr.term_count
ljkiwi.kiwi_expression_retain(ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
end
---@param expr kiwi.Expression
---@param constant number
---@nodiscard
local function new_expr_constant(expr, constant)
local ret = ffi_new(Expression, expr.term_count)
ljkiwi.kiwi_expression_set_constant(expr, constant, ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
end

---@return number
Expand Down
10 changes: 8 additions & 2 deletions rjkiwi/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl KiwiExpression {
}

unsafe fn try_from_raw<'a>(e: *const KiwiExpressionPtr) -> Option<&'a Self> {
if e.is_null() {
if e.is_null() || (*e).term_count < 0 {
None
} else {
Some(Self::from_raw(e))
Expand All @@ -56,10 +56,15 @@ pub unsafe extern "C" fn kiwi_expression_add_term(
out: *mut KiwiExpressionPtr,
) {
let Some(e) = KiwiExpression::try_from_raw(e) else {
(*out).term_count = 0;
return;
};
if e.terms_.len() >= c_int::MAX as usize {
(*out).term_count = 0;
return;
}

let n_terms = (e.terms_.len() + 1).min(c_int::MAX as usize);
let n_terms = e.terms_.len() + 1;

let out = core::slice::from_raw_parts_mut(out as *mut (), n_terms) as *mut [()]
as *mut KiwiExpression;
Expand All @@ -85,6 +90,7 @@ pub unsafe extern "C" fn kiwi_expression_set_constant(
out: *mut KiwiExpressionPtr,
) {
let Some(e) = KiwiExpression::try_from_raw(e) else {
(*out).term_count = 0;
return;
};

Expand Down
9 changes: 8 additions & 1 deletion rjkiwi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub mod solver;
mod util;
pub mod var;

mod mem {
pub mod mem {
use std::{
ffi::{c_char, CString},
ptr::NonNull,
Expand All @@ -20,3 +20,10 @@ mod mem {
}
}
}

pub mod ffi {
pub use crate::expr::*;
pub use crate::mem::*;
pub use crate::solver::*;
pub use crate::var::*;
}

0 comments on commit 27ac5df

Please sign in to comment.