From 4585e7db8216d34efab6ec30e039deb20f41e011 Mon Sep 17 00:00:00 2001 From: bensku Date: Sat, 19 Oct 2024 21:09:31 +0300 Subject: [PATCH] lua4jvm: Implement VARIABLE_TRACING pass and re-enable upvalue typing This should make local functions efficient enough for numerical code! While testing said numerical code, it turned out that lua4jvm's support for using ints for numerical code was, especially mixed with doubles, quite buggy. The bugs are now fixed. --- .../java/fi/benjami/code4jvm/lua/LuaVm.java | 4 +- .../code4jvm/lua/compiler/CompilerPass.java | 10 +++-- .../lua/compiler/FunctionCompiler.java | 1 - .../code4jvm/lua/compiler/IrCompiler.java | 3 +- .../code4jvm/lua/compiler/LuaContext.java | 5 +++ .../code4jvm/lua/compiler/VariableFlag.java | 2 +- .../code4jvm/lua/ffi/JavaFunction.java | 21 ++++++++-- .../code4jvm/lua/ir/DebugInfoNode.java | 5 +++ .../fi/benjami/code4jvm/lua/ir/IrNode.java | 4 ++ .../fi/benjami/code4jvm/lua/ir/LuaBlock.java | 7 ++++ .../fi/benjami/code4jvm/lua/ir/LuaType.java | 6 ++- .../code4jvm/lua/ir/UpvalueTemplate.java | 8 +++- .../code4jvm/lua/ir/expr/ArithmeticExpr.java | 23 ++++++++--- .../lua/ir/expr/FunctionDeclExpr.java | 4 +- .../code4jvm/lua/ir/expr/NegateExpr.java | 16 ++++++-- .../code4jvm/lua/ir/stmt/IfBlockStmt.java | 7 ++++ .../code4jvm/lua/ir/stmt/IteratorForStmt.java | 12 ++++++ .../code4jvm/lua/ir/stmt/LoopStmt.java | 6 +++ .../lua/ir/stmt/SetVariablesStmt.java | 10 +++-- .../code4jvm/lua/linker/LuaLinker.java | 5 +-- .../code4jvm/lua/test/BasicLibTest.java | 30 ++++++++++---- .../code4jvm/lua/test/BinaryOpTest.java | 22 +++++++---- .../benjami/code4jvm/lua/test/BinderTest.java | 12 ++++-- .../code4jvm/lua/test/FunctionTest.java | 4 +- .../benjami/code4jvm/lua/test/LinkerTest.java | 39 ++++++++++++++++++- .../benjami/code4jvm/lua/test/LuaVmTest.java | 15 +++---- .../code4jvm/lua/test/MultiValTest.java | 4 +- .../code4jvm/lua/test/UnaryOpTest.java | 15 +++++-- 28 files changed, 236 insertions(+), 64 deletions(-) diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/LuaVm.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/LuaVm.java index d6945d1..a9c00d9 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/LuaVm.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/LuaVm.java @@ -101,8 +101,8 @@ public LuaModule compile(String chunk) { public LuaFunction load(LuaModule module, LuaTable env) { // Instantiate the module var type = LuaType.function( - // TODO _ENV mutability tracking - List.of(new UpvalueTemplate(module.env(), LuaType.TABLE)), + // TODO _ENV mutability tracking - we need LuaContext, which is a bit tricky here... + List.of(new UpvalueTemplate(module.env(), LuaType.UNKNOWN, true)), List.of(), module.root(), module.name(), diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/CompilerPass.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/CompilerPass.java index f166152..16e0ae5 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/CompilerPass.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/CompilerPass.java @@ -8,13 +8,17 @@ public enum CompilerPass { IR_GEN, /** - * In this phase, variables are traced to determine their mutability and - * other properties. TODO not yet implemented + * Return tracking, to determine if lua4jvm has to insert empty returns. + */ + RETURN_TRACKING, + + /** + * Variable flagging, based on e.g. their mutability. */ VARIABLE_TRACING, /** - * In analysis phase, types that can be statically inferred are inferred + * In analysis pass, types that can be statically inferred are inferred * to generate better code later. */ TYPE_ANALYSIS, diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/FunctionCompiler.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/FunctionCompiler.java index 8dc5c1b..3dded25 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/FunctionCompiler.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/FunctionCompiler.java @@ -63,7 +63,6 @@ public static MethodHandle callTarget(LuaType[] argTypes, LuaFunction function, // Compile and load the function code, or use something that is already cached var compiledFunc = function.type().specializations().computeIfAbsent(cacheKey, t -> { - CompilerPass.setCurrent(CompilerPass.TYPE_ANALYSIS); var ctx = LuaContext.forFunction(function.owner(), function.type(), truncateReturn, argTypes); CompilerPass.setCurrent(CompilerPass.CODEGEN); diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/IrCompiler.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/IrCompiler.java index 86d9f2d..9245e4d 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/IrCompiler.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/IrCompiler.java @@ -448,7 +448,8 @@ public IrNode visitStringConcat(StringConcatContext ctx) { @Override public IrNode visitNumberLiteral(NumberLiteralContext ctx) { var value = Double.valueOf(ctx.Numeral().getText()); - return new LuaConstant(value.intValue() == value ? value.intValue() : value); + // Use Math.rint() to handle very large doubles safely + return Math.rint(value) == value ? new LuaConstant(value.intValue()) : new LuaConstant(value); } @Override diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/LuaContext.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/LuaContext.java index 1855bd6..57e6705 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/LuaContext.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/LuaContext.java @@ -45,8 +45,13 @@ public static LuaContext forFunction(LuaVm vm, LuaType.Function type, boolean tr ctx.setFlag(arg, VariableFlag.ASSIGNED); // JVM assigns arguments to these } + // Do variable flagging BEFORE type analysis, we need that mutability information + type.body().flagVariables(ctx); + // Compute types of local variables and the return type + CompilerPass.setCurrent(CompilerPass.TYPE_ANALYSIS); type.body().outputType(ctx); + CompilerPass.setCurrent(null); return ctx; } diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/VariableFlag.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/VariableFlag.java index 9a174ab..df62ee7 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/VariableFlag.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/VariableFlag.java @@ -17,7 +17,7 @@ public enum VariableFlag { /** * Variable is mutable; that is, it is assigned to at least twice. */ - MUTABLE(CompilerPass.TYPE_ANALYSIS) + MUTABLE(CompilerPass.VARIABLE_TRACING) ; /** diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ffi/JavaFunction.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ffi/JavaFunction.java index b28984b..56117bd 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ffi/JavaFunction.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ffi/JavaFunction.java @@ -82,7 +82,9 @@ public Target matchToArgs(LuaType[] argTypes, String intrinsicId) { if (target.intrinsicId != null && !target.intrinsicId.equals(intrinsicId)) { continue; // Intrinsic not allowed by caller } - if (checkArgs(target, argTypes) == MatchResult.SUCCESS) { + var result = checkArgs(target, argTypes); + if (result == MatchResult.SUCCESS || result == MatchResult.INT_DOUBLE_CAST_NEEDED) { + // Linker calls MethodHandle#cast(...), which casts ints to doubles if needed return target; } } @@ -99,7 +101,8 @@ private enum MatchResult { SUCCESS, TOO_FEW_ARGS, ARG_TYPE_MISMATCH, - VARARGS_TYPE_MISMATCH + VARARGS_TYPE_MISMATCH, + INT_DOUBLE_CAST_NEEDED } private MatchResult checkArgs(Target target, LuaType[] argTypes) { @@ -113,12 +116,18 @@ private MatchResult checkArgs(Target target, LuaType[] argTypes) { } // Check types of arguments + var intDoubleCast = false; for (var i = 0; i < requiredArgs; i++) { var arg = target.arguments.get(i); if (!arg.type.isAssignableFrom(argTypes[i])) { // Allow nil instead of expected type if nullability is allowed - if (!arg.nullable ||!argTypes[i].equals(LuaType.NIL)) { - return MatchResult.ARG_TYPE_MISMATCH; + if (!arg.nullable || !argTypes[i].equals(LuaType.NIL)) { + if (argTypes[i].equals(LuaType.INTEGER) && arg.type.equals(LuaType.FLOAT)) { + // We'll need to cast ints to doubles using MethodHandle magic + intDoubleCast = true; + } else { + return MatchResult.ARG_TYPE_MISMATCH; + } } } } @@ -133,6 +142,10 @@ private MatchResult checkArgs(Target target, LuaType[] argTypes) { } } + if (intDoubleCast) { + return MatchResult.INT_DOUBLE_CAST_NEEDED; + } + return MatchResult.SUCCESS; } } diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/DebugInfoNode.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/DebugInfoNode.java index cf1e60e..83cb221 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/DebugInfoNode.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/DebugInfoNode.java @@ -31,6 +31,11 @@ public boolean hasReturn() { return node.hasReturn(); } + @Override + public void flagVariables(LuaContext ctx) { + node.flagVariables(ctx); + } + @Override public IrNode concreteNode() { return node.concreteNode(); diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/IrNode.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/IrNode.java index 95cd77e..49168bf 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/IrNode.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/IrNode.java @@ -14,6 +14,10 @@ default boolean hasReturn() { return false; } + default void flagVariables(LuaContext ctx) { + // No-op + } + default IrNode concreteNode() { return this; } diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/LuaBlock.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/LuaBlock.java index 7a3d7a5..c7eb466 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/LuaBlock.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/LuaBlock.java @@ -39,4 +39,11 @@ public boolean hasReturn() { return false; } + @Override + public void flagVariables(LuaContext ctx) { + for (var node : nodes) { + node.flagVariables(ctx); + } + } + } diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/LuaType.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/LuaType.java index c2581f4..26550db 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/LuaType.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/LuaType.java @@ -11,6 +11,7 @@ import fi.benjami.code4jvm.Value; import fi.benjami.code4jvm.lua.compiler.CompiledFunction; import fi.benjami.code4jvm.lua.compiler.CompiledShape; +import fi.benjami.code4jvm.lua.compiler.CompilerPass; import fi.benjami.code4jvm.lua.compiler.FunctionCompiler; import fi.benjami.code4jvm.lua.compiler.ShapeTypes; import fi.benjami.code4jvm.lua.ir.stmt.ReturnStmt; @@ -244,7 +245,10 @@ public static Tuple tuple(LuaType... types) { public static Function function(List upvalues, List args, LuaBlock body, String moduleName, String name) { - if (!body.hasReturn()) { + CompilerPass.setCurrent(CompilerPass.RETURN_TRACKING); + var hasReturn = body.hasReturn(); + CompilerPass.setCurrent(null); + if (!hasReturn) { // If the function doesn't always return, insert return nil at end var nodes = new ArrayList<>(body.nodes()); nodes.add(new ReturnStmt(List.of())); diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/UpvalueTemplate.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/UpvalueTemplate.java index c8ffc52..3317b40 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/UpvalueTemplate.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/UpvalueTemplate.java @@ -15,5 +15,11 @@ public record UpvalueTemplate( * {@link LuaFunction#upvalueTypes final types} that are known after * the function has been instantiated, this may be unknown. */ - LuaType type + LuaType type, + + /** + * Whether or not the upvalue variable is assigned to after its initial + * assignment. + */ + boolean mutable ) {} diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/ArithmeticExpr.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/ArithmeticExpr.java index 9c908de..92901b5 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/ArithmeticExpr.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/ArithmeticExpr.java @@ -32,22 +32,29 @@ public record ArithmeticExpr( ) implements IrNode { private static final CallTarget MATH_POW = CallTarget.staticMethod(Type.of(Math.class), Type.DOUBLE, "pow", Type.DOUBLE, Type.DOUBLE); - private static final CallTarget MATH_ABS = CallTarget.staticMethod(Type.of(Math.class), Type.DOUBLE, "abs", Type.DOUBLE); - private static final CallTarget FLOOR_DIV = CallTarget.staticMethod(Type.of(ArithmeticExpr.class), Type.DOUBLE, "floorDivide", Type.DOUBLE, Type.DOUBLE); + private static final CallTarget MATH_ABS_INT = CallTarget.staticMethod(Type.of(Math.class), Type.INT, "abs", Type.INT); + private static final CallTarget MATH_ABS_DOUBLE = CallTarget.staticMethod(Type.of(Math.class), Type.DOUBLE, "abs", Type.DOUBLE); + private static final CallTarget FLOOR_DIV_INTS = CallTarget.staticMethod(Type.of(ArithmeticExpr.class), Type.INT, "floorDivide", Type.INT, Type.INT); + private static final CallTarget FLOOR_DIV_DOUBLES = CallTarget.staticMethod(Type.of(ArithmeticExpr.class), Type.DOUBLE, "floorDivide", Type.DOUBLE, Type.DOUBLE); private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup(); public enum Kind { - POWER(MATH_POW::call, "power", "__pow"), + POWER((lhs, rhs) -> { + // Math.pow() does not have integer variant + return MATH_POW.call(lhs.cast(Type.DOUBLE), rhs.cast(Type.DOUBLE)); + }, "power", "__pow"), MULTIPLY(Arithmetic::multiply, "multiply", "__mul"), DIVIDE((lhs, rhs) -> { // Lua uses float division unless integer division is explicitly request (see below) return Arithmetic.divide(lhs.cast(Type.DOUBLE), rhs.cast(Type.DOUBLE)); }, "divide", "__div"), - FLOOR_DIVIDE(FLOOR_DIV::call, "floorDivide", "__idiv"), + FLOOR_DIVIDE((lhs, rhs) + -> lhs.type().equals(Type.INT) ? FLOOR_DIV_INTS.call(lhs, rhs) : FLOOR_DIV_DOUBLES.call(lhs, rhs), + "floorDivide", "__idiv"), MODULO((lhs, rhs) -> (block -> { // Lua expects modulo to be always positive; Java's remainder can return negative values var remainder = block.add(Arithmetic.remainder(lhs, rhs)); - return block.add(MATH_ABS.call(remainder)); + return block.add(remainder.type().equals(Type.INT) ? MATH_ABS_INT.call(remainder) : MATH_ABS_DOUBLE.call(remainder)); }), "modulo", "__mod"), ADD(Arithmetic::add, "add", "__add"), SUBTRACT(Arithmetic::subtract, "subtract", "__sub"); @@ -156,6 +163,12 @@ public Value emit(LuaContext ctx, Block block) { var rhsValue = rhs.emit(ctx, block); if (outputType(ctx).isNumber()) { // Both arguments are known to be numbers; emit arithmetic operation directly + // Just make sure that if either side is double, the other side is too + if (lhsValue.type().equals(Type.INT) && rhsValue.type().equals(Type.DOUBLE)) { + lhsValue = lhsValue.cast(Type.DOUBLE); + } else if (rhsValue.type().equals(Type.INT) && lhsValue.type().equals(Type.DOUBLE)) { + rhsValue = rhsValue.cast(Type.DOUBLE); + } return block.add(kind.directEmitter.apply(lhsValue, rhsValue)); } else { // Types are unknown compile-time; use invokedynamic diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/FunctionDeclExpr.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/FunctionDeclExpr.java index 6f4c42d..1f8e249 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/FunctionDeclExpr.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/FunctionDeclExpr.java @@ -7,6 +7,7 @@ import fi.benjami.code4jvm.Value; import fi.benjami.code4jvm.block.Block; import fi.benjami.code4jvm.lua.compiler.LuaContext; +import fi.benjami.code4jvm.lua.compiler.VariableFlag; import fi.benjami.code4jvm.lua.ir.IrNode; import fi.benjami.code4jvm.lua.ir.LuaBlock; import fi.benjami.code4jvm.lua.ir.LuaLocalVar; @@ -59,7 +60,8 @@ public Value emit(LuaContext ctx, Block block) { public LuaType.Function outputType(LuaContext ctx) { // Upvalue template has the variable INSIDE declared function, with type of OUTSIDE variable var upvalueTemplates = upvalues.stream() - .map(upvalue -> new UpvalueTemplate(upvalue, ctx.variableType(upvalue))) + .map(upvalue -> new UpvalueTemplate(upvalue, ctx.hasFlag(upvalue, VariableFlag.MUTABLE) + ? LuaType.UNKNOWN : ctx.variableType(upvalue), ctx.hasFlag(upvalue, VariableFlag.MUTABLE))) .toList(); return ctx.cached(this, LuaType.function(upvalueTemplates, arguments, body, moduleName, name)); } diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/NegateExpr.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/NegateExpr.java index 9d075ed..d3abec2 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/NegateExpr.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/NegateExpr.java @@ -18,19 +18,22 @@ public record NegateExpr(IrNode expr) implements IrNode { - private static final MethodHandle NEGATE; + private static final MethodHandle NEGATE_DOUBLE, NEGATE_INT; private static final DynamicTarget TARGET; static { var lookup = MethodHandles.lookup(); try { - NEGATE = lookup.findStatic(NegateExpr.class, "negate", MethodType.methodType(double.class, Object.class, double.class)); + NEGATE_DOUBLE = lookup.findStatic(NegateExpr.class, "negate", MethodType.methodType(double.class, Object.class, double.class)); + NEGATE_INT = lookup.findStatic(NegateExpr.class, "negate", MethodType.methodType(int.class, Object.class, int.class)); } catch (NoSuchMethodException | IllegalAccessException e) { throw new AssertionError(e); } - TARGET = UnaryOp.newTarget(new UnaryOp.Path[] {new UnaryOp.Path(Double.class, NEGATE)}, "__unm", - (val) -> new LuaException("attempted to negate a non-number value")); + TARGET = UnaryOp.newTarget(new UnaryOp.Path[] { + new UnaryOp.Path(Double.class, NEGATE_DOUBLE), + new UnaryOp.Path(Integer.class, NEGATE_INT) + }, "__unm", (val) -> new LuaException("attempted to negate a non-number value")); } @SuppressWarnings("unused") // MethodHandle @@ -38,6 +41,11 @@ private static double negate(Object callable, double value) { return -value; } + @SuppressWarnings("unused") // MethodHandle + private static int negate(Object callable, int value) { + return -value; + } + @Override public Value emit(LuaContext ctx, Block block) { var value = expr.emit(ctx, block); diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IfBlockStmt.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IfBlockStmt.java index 2975d88..3e51800 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IfBlockStmt.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IfBlockStmt.java @@ -64,5 +64,12 @@ public boolean hasReturn() { // If we don't have fallback, we might not return return fallback != null ? fallback.hasReturn() : false; } + + @Override + public void flagVariables(LuaContext ctx) { + for (var branch : branches) { + branch.body.flagVariables(ctx); + } + } } diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IteratorForStmt.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IteratorForStmt.java index 87b75db..58db3b8 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IteratorForStmt.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IteratorForStmt.java @@ -11,6 +11,7 @@ import fi.benjami.code4jvm.call.CallTarget; import fi.benjami.code4jvm.lua.compiler.LoopRef; import fi.benjami.code4jvm.lua.compiler.LuaContext; +import fi.benjami.code4jvm.lua.compiler.VariableFlag; import fi.benjami.code4jvm.lua.ir.IrNode; import fi.benjami.code4jvm.lua.ir.LuaBlock; import fi.benjami.code4jvm.lua.ir.LuaLocalVar; @@ -172,4 +173,15 @@ public LuaType outputType(LuaContext ctx) { public boolean hasReturn() { return false; // The loop might run for zero iterations } + + @Override + public void flagVariables(LuaContext ctx) { + for (var loopVar : loopVars) { + ctx.setFlag(loopVar, VariableFlag.MUTABLE); + } + for (var it : iterable) { + it.flagVariables(ctx); + } + body.flagVariables(ctx); + } } diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/LoopStmt.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/LoopStmt.java index ab2e801..d75b9ef 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/LoopStmt.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/LoopStmt.java @@ -61,4 +61,10 @@ public boolean hasReturn() { // If there is condition before first run, it might not return return kind != Kind.REPEAT_UNTIL ? false : body.hasReturn(); } + + @Override + public void flagVariables(LuaContext ctx) { + condition.flagVariables(ctx); + body.flagVariables(ctx); + } } diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/SetVariablesStmt.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/SetVariablesStmt.java index 95dd403..9306081 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/SetVariablesStmt.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/SetVariablesStmt.java @@ -130,7 +130,6 @@ public LuaType outputType(LuaContext ctx) { for (var i = 0; i < Math.min(normalSources, targets.size()); i++) { var target = targets.get(i); ctx.recordType(target, sources.get(i).outputType(ctx)); - markMutable(ctx, target); } if (spread) { @@ -145,19 +144,24 @@ public LuaType outputType(LuaContext ctx) { // anything else -> first multiValType, rest NIL var target = targets.get(i); ctx.recordType(target, LuaType.UNKNOWN); - markMutable(ctx, target); } } else { // If there are leftover targets, set them to nil for (var i = normalSources; i < targets.size(); i++) { var target = targets.get(i); ctx.recordType(target, LuaType.NIL); - markMutable(ctx, target); } } return LuaType.NIL; } + @Override + public void flagVariables(LuaContext ctx) { + for (var target : targets) { + markMutable(ctx, target); + } + } + private void markMutable(LuaContext ctx, LuaVariable target) { if (target instanceof LuaLocalVar localVar) { if (ctx.hasFlag(localVar, VariableFlag.ASSIGNED)) { diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/LuaLinker.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/LuaLinker.java index ff7cdb5..664c190 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/LuaLinker.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/LuaLinker.java @@ -116,12 +116,9 @@ public static LuaCallTarget linkCall(LuaCallSite meta, Object callable, Object.. specializedTypes = Arrays.copyOf(specializedTypes, function.type().acceptedArgs().size()); Arrays.fill(specializedTypes, compiledTypes.length, specializedTypes.length, LuaType.UNKNOWN); } - - // FIXME upvalue typing is incorrect for mutable upvalues until VARIABLE_TRACING pass is implemented - var useUpvalueTypes = false; // checkTarget // Truncate multival return if site doesn't want to spread - target = FunctionCompiler.callTarget(specializedTypes, function, useUpvalueTypes, + target = FunctionCompiler.callTarget(specializedTypes, function, checkTarget, !meta.options.spreadResults()); guard = checkTarget ? TARGET_HAS_CHANGED.bindTo(function) : PROTOTYPE_HAS_CHANGED.bindTo(function.type()); diff --git a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BasicLibTest.java b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BasicLibTest.java index 1e55bb5..6de31f1 100644 --- a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BasicLibTest.java +++ b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BasicLibTest.java @@ -51,7 +51,8 @@ public void print() throws Throwable { var vm = new LuaVm(VmOptions.builder().stdOut(out).build()); vm.execute("print(1, 2, 3)"); - assertEquals("1.0\t2.0\t3.0\n", new String(bas.toByteArray())); + vm.execute("print(1.1, 2.2, 3.3)"); + assertEquals("1\t2\t3\n1.1\t2.2\t3.3\n", new String(bas.toByteArray())); } @Test @@ -121,9 +122,9 @@ public void pairs() throws Throwable { end """); var out = (LuaTable) vm.globals().get("out"); - assertEquals(1d, out.get("foo")); - assertEquals(2d, out.get("bar")); - assertEquals(3d, out.get("baz")); + assertEquals(1, out.get("foo")); + assertEquals(2, out.get("bar")); + assertEquals(3, out.get("baz")); } @Test @@ -136,9 +137,24 @@ public void pairs2() throws Throwable { end """); var out = (LuaTable) vm.globals().get("out"); - assertEquals(1d, out.get("foo")); - assertEquals(2d, out.get("bar")); - assertEquals(3d, out.get("baz")); + assertEquals(1, out.get("foo")); + assertEquals(2, out.get("bar")); + assertEquals(3, out.get("baz")); + } + + @Test + public void pairs3() throws Throwable { + vm.execute(""" + tbl = {"foo", "bar", "baz", foo = 1.1, bar = 2.2, baz = 3.3} + out = {} + for k,v in pairs(tbl) do + out[k] = v + end + """); + var out = (LuaTable) vm.globals().get("out"); + assertEquals(1.1, out.get("foo")); + assertEquals(2.2, out.get("bar")); + assertEquals(3.3, out.get("baz")); } @Test diff --git a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BinaryOpTest.java b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BinaryOpTest.java index 2ce71bc..40b835f 100644 --- a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BinaryOpTest.java +++ b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BinaryOpTest.java @@ -7,6 +7,7 @@ import org.junit.jupiter.api.Test; +import fi.benjami.code4jvm.internal.DebugOptions; import fi.benjami.code4jvm.lua.LuaVm; import fi.benjami.code4jvm.lua.runtime.LuaFunction; import fi.benjami.code4jvm.lua.runtime.LuaTable; @@ -55,18 +56,25 @@ public void stringConcatMetatables() throws Throwable { @Test public void simpleMath() throws Throwable { assertEquals(125d, vm.execute("return 5 ^ 3")); - assertEquals(15d, vm.execute("return 5 * 3")); + assertEquals(15, vm.execute("return 5 * 3")); + assertEquals(17.5, vm.execute("return 5 * 3.5")); assertEquals(2.5, vm.execute("return 5 / 2")); - assertEquals(2d, vm.execute("return 5 // 2")); - assertEquals(2d, vm.execute("return 10 % 4")); - assertEquals(2d, vm.execute("return -10 % 4")); - assertEquals(8d, vm.execute("return 5 + 3")); - assertEquals(2d, vm.execute("return 5 - 3")); + assertEquals(2, vm.execute("return 5 // 2")); + assertEquals(2, vm.execute("return 10 % 4")); + assertEquals(2, vm.execute("return -10 % 4")); + assertEquals(8, vm.execute("return 5 + 3")); + assertEquals(8.4, vm.execute("return 5.4 + 3")); + assertEquals(2, vm.execute("return 5 - 3")); + assertEquals(1.1, vm.execute("return 5 - 3.9")); - assertEquals(2d, vm.execute(""" + assertEquals(2, vm.execute(""" lhs = 5 return lhs - 3 """)); + assertEquals(2.5, vm.execute(""" + lhs = 5.5 + return lhs - 3 + """)); } @Test diff --git a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BinderTest.java b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BinderTest.java index 4e9ebdc..e2aa65a 100644 --- a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BinderTest.java +++ b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BinderTest.java @@ -148,7 +148,7 @@ public void bindFunctions() throws Throwable { private static class LuaVarargs { @LuaExport("checkArgs") public static boolean checkArgs(Object... args) { - return Arrays.equals(args, new Object[] {"foo", "bar", 3d, "baz"}); + return Arrays.equals(args, new Object[] {"foo", "bar", 3, "baz"}); } @LuaExport("returnArgs") @@ -175,11 +175,17 @@ public void varargs() throws Throwable { assertTrue((boolean) vm.execute(""" return checkArgs("foo", "bar", 3, "baz") """)); - assertArrayEquals(new Object[] {"foo", "bar", 3d, "baz"}, (Object[]) vm.execute(""" + assertArrayEquals(new Object[] {"foo", "bar", 3, "baz"}, (Object[]) vm.execute(""" return returnArgs("foo", "bar", 3, "baz") """)); - assertArrayEquals(new Object[] {"foo", "bar", 3d, "baz"}, (Object[]) vm.execute(""" + assertArrayEquals(new Object[] {"foo", "bar", 3.1, "baz"}, (Object[]) vm.execute(""" + return returnArgs("foo", "bar", 3.1, "baz") + """)); + assertArrayEquals(new Object[] {"foo", "bar", 3, "baz"}, (Object[]) vm.execute(""" return returnArgs2(123, "foo", "bar", 3, "baz") """)); + assertArrayEquals(new Object[] {"foo", "bar", 3.1, "baz"}, (Object[]) vm.execute(""" + return returnArgs2(123, "foo", "bar", 3.1, "baz") + """)); } } diff --git a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/FunctionTest.java b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/FunctionTest.java index c67df85..b5d1199 100644 --- a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/FunctionTest.java +++ b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/FunctionTest.java @@ -59,7 +59,7 @@ public void upvalues() throws Throwable { var a = new LuaLocalVar("a"); var b = new LuaLocalVar("b"); var type = LuaType.function( - List.of(new UpvalueTemplate(a, LuaType.FLOAT)), + List.of(new UpvalueTemplate(a, LuaType.FLOAT, false)), List.of(b), new LuaBlock(List.of(new ReturnStmt(List.of( new ArithmeticExpr(new VariableExpr(a), ArithmeticExpr.Kind.ADD, new VariableExpr(b)) @@ -109,7 +109,7 @@ public void declareFunction() throws Throwable { var insideB = new LuaLocalVar("b"); var c = new LuaLocalVar("c"); var type = LuaType.function( - List.of(new UpvalueTemplate(a, LuaType.FLOAT)), + List.of(new UpvalueTemplate(a, LuaType.FLOAT, false)), List.of(b), new LuaBlock(List.of( new ReturnStmt(List.of(new FunctionDeclExpr( diff --git a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/LinkerTest.java b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/LinkerTest.java index b7711c8..ab04ca9 100644 --- a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/LinkerTest.java +++ b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/LinkerTest.java @@ -1,6 +1,7 @@ package fi.benjami.code4jvm.lua.test; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodType; @@ -42,9 +43,9 @@ public void knownTypes() throws Throwable { local function f(a, b) return a + b end - return f(1, 2) + return f(1.2, 2.1) """); - assertEquals(3d, result); + assertEquals(3.3, result); assertEquals(1, trace.metadata.linkageCount); assertEquals(MethodType.methodType(double.class, Object.class, double.class, double.class), getOnlySpecialization().bindTo(null).type()); @@ -121,6 +122,40 @@ local function g(a, b) assertEquals(1, trace.stableTargets); } + @Test + public void unboxedMath() throws Throwable { + var func = (LuaFunction) vm.execute(""" + local function sum(a, b) + return a + b + end + return function (a, b, c, d) + return sum(a, b) + sum(c, d) + end + """); + assertEquals(10, func.call(1, 2, 3, 4)); + assertEquals(2, trace.stableTargets); + assertFalse(trace.metadata.hasUnknownTypes); + } + + @Test + public void unboxedMath2() throws Throwable { + var func = (LuaFunction) vm.execute(""" + local function mul(a, b) + return a * b + end + + local function sum(a, b) + return mul(a, 3) + b + end + return function (a, b, c, d) + return sum(a, b) + sum(c, d) + end + """); + assertEquals(18, func.call(1, 2, 3, 4)); + assertEquals(3, trace.stableTargets); + assertFalse(trace.metadata.hasUnknownTypes); + } + @AfterEach public void cleanup() { LuaDebugOptions.linkerTrace = null; diff --git a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/LuaVmTest.java b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/LuaVmTest.java index 6640167..01d1afd 100644 --- a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/LuaVmTest.java +++ b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/LuaVmTest.java @@ -31,7 +31,8 @@ public void emptyModule() throws Throwable { public void constantReturns() throws Throwable { assertEquals("abc", vm.execute("return 'abc'")); assertEquals("abc", vm.execute("return \"abc\"")); - assertEquals(10d, vm.execute("return 10")); + assertEquals(10, vm.execute("return 10")); + assertEquals(10.5, vm.execute("return 10.5")); assertEquals(true, vm.execute("return true")); assertEquals(false, vm.execute("return false")); } @@ -64,9 +65,9 @@ public void createTable() throws Throwable { var list = (LuaTable) vm.execute(""" return {1, 2, 3} """); - assertEquals(1d, list.get(1d)); - assertEquals(2d, list.get(2d)); - assertEquals(3d, list.get(3d)); + assertEquals(1, list.get(1d)); + assertEquals(2, list.get(2d)); + assertEquals(3, list.get(3d)); var foo = (LuaTable) vm.execute(""" return {foo = "bar"} @@ -193,14 +194,14 @@ public void ifBlock() throws Throwable { @Test public void conditionalLoops() throws Throwable { - assertEquals(10d, vm.execute(""" + assertEquals(10, vm.execute(""" local a = 0 while a < 10 do a = a + 1 end return a """)); - assertEquals(10d, vm.execute(""" + assertEquals(10, vm.execute(""" local a = 0 repeat a = a + 1 @@ -298,6 +299,6 @@ local function count() end return count(), count(), count() """); - assertArrayEquals(new Object[] {1d, 2d, 3d}, result); + assertArrayEquals(new Object[] {1, 2, 3}, result); } } diff --git a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/MultiValTest.java b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/MultiValTest.java index 9f6af3c..f938f9b 100644 --- a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/MultiValTest.java +++ b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/MultiValTest.java @@ -13,7 +13,7 @@ public class MultiValTest { @Test public void returnToJava() throws Throwable { var vm = new LuaVm(); - assertArrayEquals(new Object[] {"foo", 3d, "bar", "baz"}, (Object[]) vm.execute(""" + assertArrayEquals(new Object[] {"foo", 3, "bar", "baz"}, (Object[]) vm.execute(""" return "foo", 3, "bar", "baz" """)); } @@ -28,7 +28,7 @@ local function stuff() a, b, c, d = stuff() -- assign to globals """); assertEquals("foo", vm.globals().get("a")); - assertEquals(3d, vm.globals().get("b")); + assertEquals(3, vm.globals().get("b")); assertEquals("bar", vm.globals().get("c")); assertEquals("baz", vm.globals().get("d")); } diff --git a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/UnaryOpTest.java b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/UnaryOpTest.java index eed7db2..f0b58ca 100644 --- a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/UnaryOpTest.java +++ b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/UnaryOpTest.java @@ -15,9 +15,18 @@ public class UnaryOpTest { @Test public void negateNumbers() throws Throwable { - assertEquals(-10d, vm.execute("return -10")); - assertEquals(10d, vm.execute("return -(-10)")); - assertEquals(-10d, vm.execute(""" + // Integers + assertEquals(-10.5, vm.execute("return -10.5")); + assertEquals(10.5, vm.execute("return -(-10.5)")); + assertEquals(-10.5, vm.execute(""" + ten = 10.5 + return -ten + """)); + + // Doubles + assertEquals(-10, vm.execute("return -10")); + assertEquals(10, vm.execute("return -(-10)")); + assertEquals(-10, vm.execute(""" ten = 10 return -ten """));