Skip to content

Commit

Permalink
lua4jvm: Implement VARIABLE_TRACING pass and re-enable upvalue typing
Browse files Browse the repository at this point in the history
This should make local functions efficient enough for numerical code!

While testing said numerical code, it turned out that lua4jvm's support for
using ints for numerical code was, especially mixed with doubles, quite
buggy. The bugs are now fixed.
  • Loading branch information
bensku committed Oct 19, 2024
1 parent 88033d3 commit 4585e7d
Show file tree
Hide file tree
Showing 28 changed files with 236 additions and 64 deletions.
4 changes: 2 additions & 2 deletions lua4jvm/src/main/java/fi/benjami/code4jvm/lua/LuaVm.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ public LuaModule compile(String chunk) {
public LuaFunction load(LuaModule module, LuaTable env) {
// Instantiate the module
var type = LuaType.function(
// TODO _ENV mutability tracking
List.of(new UpvalueTemplate(module.env(), LuaType.TABLE)),
// TODO _ENV mutability tracking - we need LuaContext, which is a bit tricky here...
List.of(new UpvalueTemplate(module.env(), LuaType.UNKNOWN, true)),
List.of(),
module.root(),
module.name(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ public enum CompilerPass {
IR_GEN,

/**
* In this phase, variables are traced to determine their mutability and
* other properties. TODO not yet implemented
* Return tracking, to determine if lua4jvm has to insert empty returns.
*/
RETURN_TRACKING,

/**
* Variable flagging, based on e.g. their mutability.
*/
VARIABLE_TRACING,

/**
* In analysis phase, types that can be statically inferred are inferred
* In analysis pass, types that can be statically inferred are inferred
* to generate better code later.
*/
TYPE_ANALYSIS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public static MethodHandle callTarget(LuaType[] argTypes, LuaFunction function,

// Compile and load the function code, or use something that is already cached
var compiledFunc = function.type().specializations().computeIfAbsent(cacheKey, t -> {
CompilerPass.setCurrent(CompilerPass.TYPE_ANALYSIS);
var ctx = LuaContext.forFunction(function.owner(), function.type(), truncateReturn, argTypes);

CompilerPass.setCurrent(CompilerPass.CODEGEN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ public IrNode visitStringConcat(StringConcatContext ctx) {
@Override
public IrNode visitNumberLiteral(NumberLiteralContext ctx) {
var value = Double.valueOf(ctx.Numeral().getText());
return new LuaConstant(value.intValue() == value ? value.intValue() : value);
// Use Math.rint() to handle very large doubles safely
return Math.rint(value) == value ? new LuaConstant(value.intValue()) : new LuaConstant(value);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@ public static LuaContext forFunction(LuaVm vm, LuaType.Function type, boolean tr
ctx.setFlag(arg, VariableFlag.ASSIGNED); // JVM assigns arguments to these
}

// Do variable flagging BEFORE type analysis, we need that mutability information
type.body().flagVariables(ctx);

// Compute types of local variables and the return type
CompilerPass.setCurrent(CompilerPass.TYPE_ANALYSIS);
type.body().outputType(ctx);
CompilerPass.setCurrent(null);
return ctx;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public enum VariableFlag {
/**
* Variable is mutable; that is, it is assigned to at least twice.
*/
MUTABLE(CompilerPass.TYPE_ANALYSIS)
MUTABLE(CompilerPass.VARIABLE_TRACING)
;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ public Target matchToArgs(LuaType[] argTypes, String intrinsicId) {
if (target.intrinsicId != null && !target.intrinsicId.equals(intrinsicId)) {
continue; // Intrinsic not allowed by caller
}
if (checkArgs(target, argTypes) == MatchResult.SUCCESS) {
var result = checkArgs(target, argTypes);
if (result == MatchResult.SUCCESS || result == MatchResult.INT_DOUBLE_CAST_NEEDED) {
// Linker calls MethodHandle#cast(...), which casts ints to doubles if needed
return target;
}
}
Expand All @@ -99,7 +101,8 @@ private enum MatchResult {
SUCCESS,
TOO_FEW_ARGS,
ARG_TYPE_MISMATCH,
VARARGS_TYPE_MISMATCH
VARARGS_TYPE_MISMATCH,
INT_DOUBLE_CAST_NEEDED
}

private MatchResult checkArgs(Target target, LuaType[] argTypes) {
Expand All @@ -113,12 +116,18 @@ private MatchResult checkArgs(Target target, LuaType[] argTypes) {
}

// Check types of arguments
var intDoubleCast = false;
for (var i = 0; i < requiredArgs; i++) {
var arg = target.arguments.get(i);
if (!arg.type.isAssignableFrom(argTypes[i])) {
// Allow nil instead of expected type if nullability is allowed
if (!arg.nullable ||!argTypes[i].equals(LuaType.NIL)) {
return MatchResult.ARG_TYPE_MISMATCH;
if (!arg.nullable || !argTypes[i].equals(LuaType.NIL)) {
if (argTypes[i].equals(LuaType.INTEGER) && arg.type.equals(LuaType.FLOAT)) {
// We'll need to cast ints to doubles using MethodHandle magic
intDoubleCast = true;
} else {
return MatchResult.ARG_TYPE_MISMATCH;
}
}
}
}
Expand All @@ -133,6 +142,10 @@ private MatchResult checkArgs(Target target, LuaType[] argTypes) {
}
}

if (intDoubleCast) {
return MatchResult.INT_DOUBLE_CAST_NEEDED;
}

return MatchResult.SUCCESS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public boolean hasReturn() {
return node.hasReturn();
}

@Override
public void flagVariables(LuaContext ctx) {
node.flagVariables(ctx);
}

@Override
public IrNode concreteNode() {
return node.concreteNode();
Expand Down
4 changes: 4 additions & 0 deletions lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/IrNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ default boolean hasReturn() {
return false;
}

default void flagVariables(LuaContext ctx) {
// No-op
}

default IrNode concreteNode() {
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,11 @@ public boolean hasReturn() {
return false;
}

@Override
public void flagVariables(LuaContext ctx) {
for (var node : nodes) {
node.flagVariables(ctx);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import fi.benjami.code4jvm.Value;
import fi.benjami.code4jvm.lua.compiler.CompiledFunction;
import fi.benjami.code4jvm.lua.compiler.CompiledShape;
import fi.benjami.code4jvm.lua.compiler.CompilerPass;
import fi.benjami.code4jvm.lua.compiler.FunctionCompiler;
import fi.benjami.code4jvm.lua.compiler.ShapeTypes;
import fi.benjami.code4jvm.lua.ir.stmt.ReturnStmt;
Expand Down Expand Up @@ -244,7 +245,10 @@ public static Tuple tuple(LuaType... types) {

public static Function function(List<UpvalueTemplate> upvalues, List<LuaLocalVar> args, LuaBlock body,
String moduleName, String name) {
if (!body.hasReturn()) {
CompilerPass.setCurrent(CompilerPass.RETURN_TRACKING);
var hasReturn = body.hasReturn();
CompilerPass.setCurrent(null);
if (!hasReturn) {
// If the function doesn't always return, insert return nil at end
var nodes = new ArrayList<>(body.nodes());
nodes.add(new ReturnStmt(List.of()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,11 @@ public record UpvalueTemplate(
* {@link LuaFunction#upvalueTypes final types} that are known after
* the function has been instantiated, this may be unknown.
*/
LuaType type
LuaType type,

/**
* Whether or not the upvalue variable is assigned to after its initial
* assignment.
*/
boolean mutable
) {}
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,29 @@ public record ArithmeticExpr(
) implements IrNode {

private static final CallTarget MATH_POW = CallTarget.staticMethod(Type.of(Math.class), Type.DOUBLE, "pow", Type.DOUBLE, Type.DOUBLE);
private static final CallTarget MATH_ABS = CallTarget.staticMethod(Type.of(Math.class), Type.DOUBLE, "abs", Type.DOUBLE);
private static final CallTarget FLOOR_DIV = CallTarget.staticMethod(Type.of(ArithmeticExpr.class), Type.DOUBLE, "floorDivide", Type.DOUBLE, Type.DOUBLE);
private static final CallTarget MATH_ABS_INT = CallTarget.staticMethod(Type.of(Math.class), Type.INT, "abs", Type.INT);
private static final CallTarget MATH_ABS_DOUBLE = CallTarget.staticMethod(Type.of(Math.class), Type.DOUBLE, "abs", Type.DOUBLE);
private static final CallTarget FLOOR_DIV_INTS = CallTarget.staticMethod(Type.of(ArithmeticExpr.class), Type.INT, "floorDivide", Type.INT, Type.INT);
private static final CallTarget FLOOR_DIV_DOUBLES = CallTarget.staticMethod(Type.of(ArithmeticExpr.class), Type.DOUBLE, "floorDivide", Type.DOUBLE, Type.DOUBLE);
private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();

public enum Kind {
POWER(MATH_POW::call, "power", "__pow"),
POWER((lhs, rhs) -> {
// Math.pow() does not have integer variant
return MATH_POW.call(lhs.cast(Type.DOUBLE), rhs.cast(Type.DOUBLE));
}, "power", "__pow"),
MULTIPLY(Arithmetic::multiply, "multiply", "__mul"),
DIVIDE((lhs, rhs) -> {
// Lua uses float division unless integer division is explicitly request (see below)
return Arithmetic.divide(lhs.cast(Type.DOUBLE), rhs.cast(Type.DOUBLE));
}, "divide", "__div"),
FLOOR_DIVIDE(FLOOR_DIV::call, "floorDivide", "__idiv"),
FLOOR_DIVIDE((lhs, rhs)
-> lhs.type().equals(Type.INT) ? FLOOR_DIV_INTS.call(lhs, rhs) : FLOOR_DIV_DOUBLES.call(lhs, rhs),
"floorDivide", "__idiv"),
MODULO((lhs, rhs) -> (block -> {
// Lua expects modulo to be always positive; Java's remainder can return negative values
var remainder = block.add(Arithmetic.remainder(lhs, rhs));
return block.add(MATH_ABS.call(remainder));
return block.add(remainder.type().equals(Type.INT) ? MATH_ABS_INT.call(remainder) : MATH_ABS_DOUBLE.call(remainder));
}), "modulo", "__mod"),
ADD(Arithmetic::add, "add", "__add"),
SUBTRACT(Arithmetic::subtract, "subtract", "__sub");
Expand Down Expand Up @@ -156,6 +163,12 @@ public Value emit(LuaContext ctx, Block block) {
var rhsValue = rhs.emit(ctx, block);
if (outputType(ctx).isNumber()) {
// Both arguments are known to be numbers; emit arithmetic operation directly
// Just make sure that if either side is double, the other side is too
if (lhsValue.type().equals(Type.INT) && rhsValue.type().equals(Type.DOUBLE)) {
lhsValue = lhsValue.cast(Type.DOUBLE);
} else if (rhsValue.type().equals(Type.INT) && lhsValue.type().equals(Type.DOUBLE)) {
rhsValue = rhsValue.cast(Type.DOUBLE);
}
return block.add(kind.directEmitter.apply(lhsValue, rhsValue));
} else {
// Types are unknown compile-time; use invokedynamic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import fi.benjami.code4jvm.Value;
import fi.benjami.code4jvm.block.Block;
import fi.benjami.code4jvm.lua.compiler.LuaContext;
import fi.benjami.code4jvm.lua.compiler.VariableFlag;
import fi.benjami.code4jvm.lua.ir.IrNode;
import fi.benjami.code4jvm.lua.ir.LuaBlock;
import fi.benjami.code4jvm.lua.ir.LuaLocalVar;
Expand Down Expand Up @@ -59,7 +60,8 @@ public Value emit(LuaContext ctx, Block block) {
public LuaType.Function outputType(LuaContext ctx) {
// Upvalue template has the variable INSIDE declared function, with type of OUTSIDE variable
var upvalueTemplates = upvalues.stream()
.map(upvalue -> new UpvalueTemplate(upvalue, ctx.variableType(upvalue)))
.map(upvalue -> new UpvalueTemplate(upvalue, ctx.hasFlag(upvalue, VariableFlag.MUTABLE)
? LuaType.UNKNOWN : ctx.variableType(upvalue), ctx.hasFlag(upvalue, VariableFlag.MUTABLE)))
.toList();
return ctx.cached(this, LuaType.function(upvalueTemplates, arguments, body, moduleName, name));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,34 @@

public record NegateExpr(IrNode expr) implements IrNode {

private static final MethodHandle NEGATE;
private static final MethodHandle NEGATE_DOUBLE, NEGATE_INT;
private static final DynamicTarget TARGET;

static {
var lookup = MethodHandles.lookup();
try {
NEGATE = lookup.findStatic(NegateExpr.class, "negate", MethodType.methodType(double.class, Object.class, double.class));
NEGATE_DOUBLE = lookup.findStatic(NegateExpr.class, "negate", MethodType.methodType(double.class, Object.class, double.class));
NEGATE_INT = lookup.findStatic(NegateExpr.class, "negate", MethodType.methodType(int.class, Object.class, int.class));
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new AssertionError(e);
}

TARGET = UnaryOp.newTarget(new UnaryOp.Path[] {new UnaryOp.Path(Double.class, NEGATE)}, "__unm",
(val) -> new LuaException("attempted to negate a non-number value"));
TARGET = UnaryOp.newTarget(new UnaryOp.Path[] {
new UnaryOp.Path(Double.class, NEGATE_DOUBLE),
new UnaryOp.Path(Integer.class, NEGATE_INT)
}, "__unm", (val) -> new LuaException("attempted to negate a non-number value"));
}

@SuppressWarnings("unused") // MethodHandle
private static double negate(Object callable, double value) {
return -value;
}

@SuppressWarnings("unused") // MethodHandle
private static int negate(Object callable, int value) {
return -value;
}

@Override
public Value emit(LuaContext ctx, Block block) {
var value = expr.emit(ctx, block);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,12 @@ public boolean hasReturn() {
// If we don't have fallback, we might not return
return fallback != null ? fallback.hasReturn() : false;
}

@Override
public void flagVariables(LuaContext ctx) {
for (var branch : branches) {
branch.body.flagVariables(ctx);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import fi.benjami.code4jvm.call.CallTarget;
import fi.benjami.code4jvm.lua.compiler.LoopRef;
import fi.benjami.code4jvm.lua.compiler.LuaContext;
import fi.benjami.code4jvm.lua.compiler.VariableFlag;
import fi.benjami.code4jvm.lua.ir.IrNode;
import fi.benjami.code4jvm.lua.ir.LuaBlock;
import fi.benjami.code4jvm.lua.ir.LuaLocalVar;
Expand Down Expand Up @@ -172,4 +173,15 @@ public LuaType outputType(LuaContext ctx) {
public boolean hasReturn() {
return false; // The loop might run for zero iterations
}

@Override
public void flagVariables(LuaContext ctx) {
for (var loopVar : loopVars) {
ctx.setFlag(loopVar, VariableFlag.MUTABLE);
}
for (var it : iterable) {
it.flagVariables(ctx);
}
body.flagVariables(ctx);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,10 @@ public boolean hasReturn() {
// If there is condition before first run, it might not return
return kind != Kind.REPEAT_UNTIL ? false : body.hasReturn();
}

@Override
public void flagVariables(LuaContext ctx) {
condition.flagVariables(ctx);
body.flagVariables(ctx);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ public LuaType outputType(LuaContext ctx) {
for (var i = 0; i < Math.min(normalSources, targets.size()); i++) {
var target = targets.get(i);
ctx.recordType(target, sources.get(i).outputType(ctx));
markMutable(ctx, target);
}

if (spread) {
Expand All @@ -145,19 +144,24 @@ public LuaType outputType(LuaContext ctx) {
// anything else -> first multiValType, rest NIL
var target = targets.get(i);
ctx.recordType(target, LuaType.UNKNOWN);
markMutable(ctx, target);
}
} else {
// If there are leftover targets, set them to nil
for (var i = normalSources; i < targets.size(); i++) {
var target = targets.get(i);
ctx.recordType(target, LuaType.NIL);
markMutable(ctx, target);
}
}
return LuaType.NIL;
}

@Override
public void flagVariables(LuaContext ctx) {
for (var target : targets) {
markMutable(ctx, target);
}
}

private void markMutable(LuaContext ctx, LuaVariable target) {
if (target instanceof LuaLocalVar localVar) {
if (ctx.hasFlag(localVar, VariableFlag.ASSIGNED)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,9 @@ public static LuaCallTarget linkCall(LuaCallSite meta, Object callable, Object..
specializedTypes = Arrays.copyOf(specializedTypes, function.type().acceptedArgs().size());
Arrays.fill(specializedTypes, compiledTypes.length, specializedTypes.length, LuaType.UNKNOWN);
}

// FIXME upvalue typing is incorrect for mutable upvalues until VARIABLE_TRACING pass is implemented
var useUpvalueTypes = false; // checkTarget

// Truncate multival return if site doesn't want to spread
target = FunctionCompiler.callTarget(specializedTypes, function, useUpvalueTypes,
target = FunctionCompiler.callTarget(specializedTypes, function, checkTarget,
!meta.options.spreadResults());
guard = checkTarget ? TARGET_HAS_CHANGED.bindTo(function)
: PROTOTYPE_HAS_CHANGED.bindTo(function.type());
Expand Down
Loading

0 comments on commit 4585e7d

Please sign in to comment.