From dc0d80b00bf2eaa35c0acfdb1fafc95a4bd0e28e Mon Sep 17 00:00:00 2001 From: bensku Date: Sun, 2 Jun 2024 20:30:00 +0300 Subject: [PATCH] lua4jvm: Add support for negation and length unary operations --- .../code4jvm/lua/compiler/IrCompiler.java | 14 +- .../code4jvm/lua/ir/expr/LengthExpr.java | 56 ++++++++ .../code4jvm/lua/ir/expr/NegateExpr.java | 57 ++++++++ .../benjami/code4jvm/lua/linker/BinaryOp.java | 1 + .../benjami/code4jvm/lua/linker/UnaryOp.java | 134 ++++++++++++++++++ .../code4jvm/lua/test/BinaryOpTest.java | 1 - .../code4jvm/lua/test/UnaryOpTest.java | 76 ++++++++++ .../code4jvm/statement/Arithmetic.java | 8 ++ 8 files changed, 336 insertions(+), 11 deletions(-) create mode 100644 lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/LengthExpr.java create mode 100644 lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/NegateExpr.java create mode 100644 lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/UnaryOp.java create mode 100644 lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/UnaryOpTest.java diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/IrCompiler.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/IrCompiler.java index 00912bf..550cc05 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/IrCompiler.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/compiler/IrCompiler.java @@ -19,8 +19,10 @@ import fi.benjami.code4jvm.lua.ir.expr.CompareExpr; import fi.benjami.code4jvm.lua.ir.expr.FunctionCallExpr; import fi.benjami.code4jvm.lua.ir.expr.FunctionDeclExpr; +import fi.benjami.code4jvm.lua.ir.expr.LengthExpr; import fi.benjami.code4jvm.lua.ir.expr.LogicalExpr; import fi.benjami.code4jvm.lua.ir.expr.LuaConstant; +import fi.benjami.code4jvm.lua.ir.expr.NegateExpr; import fi.benjami.code4jvm.lua.ir.expr.StringConcatExpr; import fi.benjami.code4jvm.lua.ir.expr.TableInitExpr; import fi.benjami.code4jvm.lua.ir.expr.VariableExpr; @@ -334,17 +336,9 @@ public IrNode visitTrueLiteral(TrueLiteralContext ctx) { @Override public IrNode visitUnaryOp(UnaryOpContext ctx) { return switch (ctx.unop().getText()) { - case "-" -> { - // FIXME proper unary operation support (though this might still make sense as optimization) - var value = visit(ctx.exp()); - if (value instanceof LuaConstant constant && constant.type().equals(LuaType.NUMBER)) { - yield new LuaConstant(-(double) constant.value()); - } else { - throw new UnsupportedOperationException(); - } - } + case "-" -> new NegateExpr(visit(ctx.exp())); case "not" -> throw new UnsupportedOperationException(); - case "#" -> throw new UnsupportedOperationException(); + case "#" -> new LengthExpr(visit(ctx.exp())); case "~" -> throw new UnsupportedOperationException(); default -> throw new AssertionError(); }; diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/LengthExpr.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/LengthExpr.java new file mode 100644 index 0000000..27fcced --- /dev/null +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/LengthExpr.java @@ -0,0 +1,56 @@ +package fi.benjami.code4jvm.lua.ir.expr; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; + +import fi.benjami.code4jvm.Value; +import fi.benjami.code4jvm.block.Block; +import fi.benjami.code4jvm.lua.compiler.LuaContext; +import fi.benjami.code4jvm.lua.ir.IrNode; +import fi.benjami.code4jvm.lua.ir.LuaType; +import fi.benjami.code4jvm.lua.linker.CallSiteOptions; +import fi.benjami.code4jvm.lua.linker.DynamicTarget; +import fi.benjami.code4jvm.lua.linker.LuaLinker; +import fi.benjami.code4jvm.lua.linker.UnaryOp; +import fi.benjami.code4jvm.lua.runtime.LuaTable; +import fi.benjami.code4jvm.lua.stdlib.LuaException; +import fi.benjami.code4jvm.statement.Arithmetic; + +public record LengthExpr(IrNode expr) implements IrNode { + + private static final MethodHandle TABLE_LENGTH, STRING_LENGTH; + private static final DynamicTarget TARGET; + + static { + var lookup = MethodHandles.lookup(); + try { + TABLE_LENGTH = MethodHandles.dropArguments(lookup.findVirtual(LuaTable.class, "arraySize", MethodType.methodType(int.class)) + .asType(MethodType.methodType(double.class, LuaTable.class)), 0, Object.class); + STRING_LENGTH = MethodHandles.dropArguments(lookup.findVirtual(String.class, "length", MethodType.methodType(int.class)) + .asType(MethodType.methodType(double.class, String.class)), 0, Object.class); + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new AssertionError(e); + } + + TARGET = UnaryOp.newTarget(new UnaryOp.Path[] { + new UnaryOp.Path(String.class, STRING_LENGTH), + new UnaryOp.Path(LuaTable.class, TABLE_LENGTH) + }, "__len", + (val) -> new LuaException("attempted to get length of non-string or table value")); + } + + @Override + public Value emit(LuaContext ctx, Block block) { + // TODO setup direct calls if static analysis has enough information? + var value = expr.emit(ctx, block); + return block.add(LuaLinker.setupCall(ctx, CallSiteOptions.nonFunction(LuaType.UNKNOWN, LuaType.UNKNOWN), TARGET, value)); + } + + @Override + public LuaType outputType(LuaContext ctx) { + // We can't do type analysis through metatables (yet) + return expr.outputType(ctx).equals(LuaType.STRING) ? LuaType.NUMBER : LuaType.UNKNOWN; + } + +} diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/NegateExpr.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/NegateExpr.java new file mode 100644 index 0000000..fa66d69 --- /dev/null +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/expr/NegateExpr.java @@ -0,0 +1,57 @@ +package fi.benjami.code4jvm.lua.ir.expr; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; + +import fi.benjami.code4jvm.Value; +import fi.benjami.code4jvm.block.Block; +import fi.benjami.code4jvm.lua.compiler.LuaContext; +import fi.benjami.code4jvm.lua.ir.IrNode; +import fi.benjami.code4jvm.lua.ir.LuaType; +import fi.benjami.code4jvm.lua.linker.CallSiteOptions; +import fi.benjami.code4jvm.lua.linker.DynamicTarget; +import fi.benjami.code4jvm.lua.linker.LuaLinker; +import fi.benjami.code4jvm.lua.linker.UnaryOp; +import fi.benjami.code4jvm.lua.stdlib.LuaException; +import fi.benjami.code4jvm.statement.Arithmetic; + +public record NegateExpr(IrNode expr) implements IrNode { + + private static final MethodHandle NEGATE; + private static final DynamicTarget TARGET; + + static { + var lookup = MethodHandles.lookup(); + try { + NEGATE = lookup.findStatic(NegateExpr.class, "negate", MethodType.methodType(double.class, Object.class, double.class)); + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new AssertionError(e); + } + + TARGET = UnaryOp.newTarget(new UnaryOp.Path[] {new UnaryOp.Path(Double.class, NEGATE)}, "__unm", + (val) -> new LuaException("attempted to negate a non-number value")); + } + + @SuppressWarnings("unused") // MethodHandle + private static double negate(Object callable, double value) { + return -value; + } + + @Override + public Value emit(LuaContext ctx, Block block) { + var value = expr.emit(ctx, block); + if (outputType(ctx).equals(LuaType.NUMBER)) { + return block.add(Arithmetic.negate(value)); + } else { + return block.add(LuaLinker.setupCall(ctx, CallSiteOptions.nonFunction(LuaType.UNKNOWN, LuaType.UNKNOWN), TARGET, value)); + } + } + + @Override + public LuaType outputType(LuaContext ctx) { + // We can't do type analysis through metatables (yet) + return expr.outputType(ctx).equals(LuaType.NUMBER) ? LuaType.NUMBER : LuaType.UNKNOWN; + } + +} diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/BinaryOp.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/BinaryOp.java index 9c71926..64b5efb 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/BinaryOp.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/BinaryOp.java @@ -59,6 +59,7 @@ private static final boolean checkLhsMetamethod(String metamethod, Object callab public static DynamicTarget newTarget(Class expectedType, MethodHandle fastPath, String metamethod, BiFunction errorHandler) { assert !expectedType.isPrimitive(); // LHS and RHS will be in their boxed forms + assert !expectedType.equals(LuaType.class); // This is currently unnecessary for Lua return (meta, args) -> { assert args.length == 2; var lhs = args[0]; diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/UnaryOp.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/UnaryOp.java new file mode 100644 index 0000000..4051787 --- /dev/null +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/linker/UnaryOp.java @@ -0,0 +1,134 @@ +package fi.benjami.code4jvm.lua.linker; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.util.function.Function; + +import fi.benjami.code4jvm.lua.ir.LuaType; +import fi.benjami.code4jvm.lua.runtime.LuaTable; +import fi.benjami.code4jvm.lua.runtime.TableAccess; +import fi.benjami.code4jvm.lua.stdlib.LuaException; + +/** + * Linker support for Lua's unary operations such as negation. + * Supports metamethods for operator overloading. + * @see BinaryOp + */ +public class UnaryOp { + + private static final boolean checkType(Class expected, Object callable, Object arg) { + return arg != null && arg.getClass() == expected; + } + + private static final MethodHandle CHECK_TYPE; + + static { + var lookup = MethodHandles.lookup(); + try { + CHECK_TYPE = lookup.findStatic(UnaryOp.class, "checkType", + MethodType.methodType(boolean.class, Class.class, Object.class, Object.class)); + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new AssertionError(e); + } + } + + public record Path(Class type, MethodHandle target) {} + + /** + * Produces a dynamic call target for an unary operation call site. + * @param fastPaths Fast paths to check in order. + * @param metamethod Name of the metamethod to call. When present, this + * takes precedence over the fast path! + * @param errorHandler Called when the argument has invalid type and + * metamethod cannot be found. Returns a Lua exception that is thrown. + * @return Call target. + */ + public static DynamicTarget newTarget(Path[] fastPaths, String metamethod, + Function errorHandler) { + return (meta, args) -> { + var arg = args[0]; + // Try all paths in order + for (var path : fastPaths) { + if (path.type.equals(LuaTable.class)) { + if (arg instanceof LuaTable table) { + if (table.metatable() == null) { + // Fast path: no metatable + var guard = TableAccess.CHECK_TABLE_SHAPE.bindTo(table.shape()); + return new LuaCallTarget(path.target, guard); + } else if (table.metatable().get(metamethod) == null) { + // Metatable, but no relevant metamethod + var guard = MethodHandles.insertArguments(TableAccess.CHECK_TABLE_AND_META_SHAPES, 0, + table.shape(), table.metatable().shape()); + return new LuaCallTarget(path.target, guard); + } else { + // Metamethod found; call it! + return useMetamethod(meta, table, metamethod, arg); + } + } + } else { + if (checkType(path.type, null, arg)) { + // Expected type; take the fast path until this changes + var guard = CHECK_TYPE.bindTo(path.type); + return new LuaCallTarget(path.target, guard); + } else if (arg instanceof LuaTable table + && table.metatable() != null + && table.metatable().get(metamethod) != null) { + // Unexpected type, but we can call the metamethod + return useMetamethod(meta, table, metamethod, arg); + } + } + } + throw errorHandler.apply(arg); + }; +// if (expectedType.equals(LuaTable.class)) { +// // Special case: fast path accepts tables that don't have the metamethod +// return (meta, args) -> { +// var arg = args[0]; +// if (arg instanceof LuaTable table) { +// if (table.metatable() == null) { +// // Fast path: no metatable +// var guard = TableAccess.CHECK_TABLE_SHAPE.bindTo(table.shape()); +// return new LuaCallTarget(fastPath, guard); +// } else if (table.metatable().get(metamethod) != null) { +// // Metatable, but no relevant metamethod +// var guard = MethodHandles.insertArguments(TableAccess.CHECK_TABLE_AND_META_SHAPES, 0, +// table.shape(), table.metatable().shape()); +// return new LuaCallTarget(fastPath, guard); +// } else { +// // Metamethod found; call it! +// return useMetamethod(meta, table, metamethod, arg); +// } +// } else { +// throw errorHandler.apply(arg); +// } +// }; +// } else { +// // Expected type is not table; tables are accepted only if they have metamethods +// return (meta, args) -> { +// var arg = args[0]; +// if (checkType(expectedType, null, arg)) { +// // Expected type; take the fast path until this changes +// var guard = CHECK_TYPE.bindTo(expectedType); +// return new LuaCallTarget(fastPath, guard); +// } else if (arg instanceof LuaTable table +// && table.metatable() != null +// && table.metatable().get(metamethod) != null) { +// // Unexpected type, but we can call the metamethod +// return useMetamethod(meta, table, metamethod, arg); +// } else { +// throw errorHandler.apply(arg); +// } +// }; +// } + } + + private static LuaCallTarget useMetamethod(LuaCallSite meta, LuaTable table, String metamethod, Object arg) { + var target = LuaLinker.linkCall(new LuaCallSite(meta.site, CallSiteOptions.nonFunction(LuaType.UNKNOWN)), + table.metatable().get(metamethod), arg); + var guard = MethodHandles.insertArguments(TableAccess.CHECK_TABLE_AND_META_SHAPES, 0, + table.shape(), table.metatable().shape()); + return target.withGuards(guard); + } + +} diff --git a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BinaryOpTest.java b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BinaryOpTest.java index 7ac3227..2ce71bc 100644 --- a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BinaryOpTest.java +++ b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/BinaryOpTest.java @@ -1,6 +1,5 @@ package fi.benjami.code4jvm.lua.test; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; diff --git a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/UnaryOpTest.java b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/UnaryOpTest.java new file mode 100644 index 0000000..66c9758 --- /dev/null +++ b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/UnaryOpTest.java @@ -0,0 +1,76 @@ +package fi.benjami.code4jvm.lua.test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.junit.jupiter.api.Test; + +import fi.benjami.code4jvm.lua.LuaVm; +import fi.benjami.code4jvm.lua.runtime.LuaTable; +import fi.benjami.code4jvm.lua.stdlib.LuaException; + +public class UnaryOpTest { + + private final LuaVm vm = new LuaVm(); + + @Test + public void negateNumbers() throws Throwable { + assertEquals(-10d, vm.execute("return -10")); + assertEquals(10d, vm.execute("return -(-10)")); + assertEquals(-10d, vm.execute(""" + ten = 10 + return -ten + """)); + } + + @Test + public void negateMetatable() throws Throwable { + var metaTbl = new LuaTable(); + metaTbl.set("__unm", vm.execute(""" + return function (self) + return "nope!" + end + """)); + + var tbl = new LuaTable(); + tbl.metatable(metaTbl); + vm.globals().set("tbl", tbl); + + assertEquals("nope!", vm.execute("return -tbl")); + metaTbl.set("__unm", null); + assertThrows(LuaException.class, () -> vm.execute("return -tbl")); + } + + @Test + public void stringLength() throws Throwable { + assertEquals(5d, vm.execute("return #\"12345\"")); + assertEquals(5d, vm.execute(""" + str = "12345" + return #str + """)); + } + + @Test + public void tableLength() throws Throwable { + // Array length + assertEquals(0d, vm.execute("return #{}")); + assertEquals(5d, vm.execute("return #{1, 2, 3, false, true}")); + assertEquals(0d, vm.execute("return #{foo = 1}")); + + // Metatables + var metaTbl = new LuaTable(); + metaTbl.set("__len", vm.execute(""" + return function (self) + return "nope!" + end + """)); + + var tbl = new LuaTable(); + tbl.metatable(metaTbl); + vm.globals().set("tbl", tbl); + + assertEquals("nope!", vm.execute("return #tbl")); + metaTbl.set("__len", null); + assertEquals(0d, vm.execute("return #tbl")); + } +} diff --git a/src/main/java/fi/benjami/code4jvm/statement/Arithmetic.java b/src/main/java/fi/benjami/code4jvm/statement/Arithmetic.java index 65f1ad1..32d92f5 100644 --- a/src/main/java/fi/benjami/code4jvm/statement/Arithmetic.java +++ b/src/main/java/fi/benjami/code4jvm/statement/Arithmetic.java @@ -63,4 +63,12 @@ public static Expression remainder(Value lhs, Value rhs) { }, "remainder")); }; } + + public static Expression negate(Value value) { + return block -> { + return block.add(Bytecode.run(value.type(), new Value[] {value}, ctx -> { + ctx.asm().visitInsn(value.type().getOpcode(INEG, ctx)); + }, "negate")); + }; + } }