Skip to content

Commit

Permalink
lua4jvm: Fix Lua upvalues
Browse files Browse the repository at this point in the history
Previously, upvalues were captured at function creation time. From language
design point of view, it is a sensible choice.
It also goes directly against the Lua spec. Oops...

The fix introduces mutability tracking for ALL local variables and boxes
mutable upvalues so that they work according to the spec. Upvalues that
do not change are unchanged, which is nice for performance reasons.

Next up: Try to get constant invokedynamic call sites for upvalue
functions that never change after their creation.
  • Loading branch information
bensku committed Oct 17, 2024
1 parent 3e5933e commit 805e0e9
Show file tree
Hide file tree
Showing 16 changed files with 200 additions and 42 deletions.
4 changes: 3 additions & 1 deletion lua4jvm/src/main/java/fi/benjami/code4jvm/lua/LuaVm.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ public LuaModule compile(String chunk) {

public LuaFunction load(LuaModule module, LuaTable env) {
// Instantiate the module
module.env().markMutable(); // Initial assignment by VM
var type = LuaType.function(
List.of(new UpvalueTemplate(module.env(), LuaType.TABLE)),
// TODO _ENV mutability tracking
List.of(new UpvalueTemplate(module.env(), module.env().mutable() ? LuaType.UNKNOWN : LuaType.TABLE, module.env().mutable())),
List.of(),
module.root(),
module.name(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ private static byte[] generateCode(LuaContext ctx, LuaType.Function type,
var template = type.upvalues().get(i);
var value = method.add(template.variable().name(), method.self()
.getField(upvalueTypes[i].backingType(), template.variable().name()));
ctx.addFunctionArg(template.variable(), value);
ctx.addUpvalue(template.variable(), value);
}

// Emit Lua code as JVM bytecode
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package fi.benjami.code4jvm.lua.compiler;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -15,6 +14,7 @@
import fi.benjami.code4jvm.lua.ir.LuaVariable;
import fi.benjami.code4jvm.lua.ir.TableField;
import fi.benjami.code4jvm.lua.ir.expr.LuaConstant;
import fi.benjami.code4jvm.lua.runtime.LuaBox;

public class LuaContext {

Expand Down Expand Up @@ -57,6 +57,11 @@ public static LuaContext forFunction(LuaVm vm, LuaType.Function type, boolean tr
*/
private final Map<LuaLocalVar, Variable> variables;

/**
* Local variables that are, in fact, upvalues.
*/
private final Map<LuaLocalVar, Variable> upvalues;

/**
* Data given to JVM when the function is loaded as a hidden class.
* This is used for creating constants of arbitrary kind, which can then
Expand Down Expand Up @@ -85,6 +90,7 @@ public LuaContext(LuaVm owner, boolean truncateReturn) {
assert owner != null;
this.typeTable = new IdentityHashMap<>();
this.variables = new IdentityHashMap<>();
this.upvalues = new IdentityHashMap<>();
this.classData = new ArrayList<>();
this.cache = new IdentityHashMap<>();
this.truncateReturn = truncateReturn;
Expand Down Expand Up @@ -118,6 +124,15 @@ public void addFunctionArg(LuaLocalVar arg, Variable variable) {
variables.put(arg, variable);
}

public void addUpvalue(LuaLocalVar arg, Variable variable) {
variables.put(arg, variable);
upvalues.put(arg, variable);
}

public boolean isUpvalue(LuaLocalVar localVar) {
return upvalues.containsKey(localVar);
}

public LuaType variableType(LuaVariable variable) {
if (variable instanceof LuaLocalVar) {
assert typeTable.containsKey(variable) : variable;
Expand All @@ -133,11 +148,16 @@ public LuaType variableType(LuaVariable variable) {
}
}

public boolean hasBeenAssigned(LuaLocalVar variable) {
return variables.containsKey(variable);
}

public Variable resolveLocalVar(LuaLocalVar variable) {
var backingVar = variables.get(variable);
if (backingVar == null) {
var type = typeTable.get(variable);
backingVar = Variable.create(type.backingType(), variable.name());
var useBox = variable.upvalue() && variable.mutable();
backingVar = Variable.create(useBox ? LuaBox.TYPE : type.backingType(), variable.name());
variables.put(variable, backingVar);
}
return backingVar;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,21 @@
import fi.benjami.code4jvm.lua.ir.TableField;
import fi.benjami.code4jvm.lua.ir.expr.LuaConstant;
import fi.benjami.code4jvm.lua.ir.expr.VariableExpr;
import fi.benjami.code4jvm.lua.ir.expr.FunctionDeclExpr.Upvalue;
import fi.benjami.code4jvm.lua.ir.stmt.LoopStmt;

public class LuaScope {

public static LuaScope chunkRoot() {
var scope = new LuaScope(null, true);
var env = scope.declare("_ENV");
scope.upvalues.put("_ENV", new Upvalue(env, null));
scope.upvalues.put("_ENV", env);
return scope;
}

private final LuaScope parent;
private final boolean functionRoot;

private final Map<String, LuaLocalVar> locals;
private final Map<String, Upvalue> upvalues;
private final Map<String, LuaLocalVar> upvalues;

/**
* Reference to current loop or null, used for break'ing out of loop.
Expand Down Expand Up @@ -66,15 +64,11 @@ public LuaVariable resolve(String name) {
if (result != null) {
// Local variable or upvalue
if (result.isUpvalue()) {
// Upvalue: record it and create a local variable
var inside = new LuaLocalVar(name);
locals.put(name, inside);
var outside = result.variable();
upvalues.put(name, new Upvalue(inside, outside));
return inside;
} else {
return result.variable(); // Local variable
locals.put(name, result.variable());
upvalues.put(name, result.variable());
result.variable().markUpvalue();
}
return result.variable(); // Local variable
} else {
// Neither local variable or upvalue; take a look at _ENV table
var env = resolve("_ENV");
Expand All @@ -99,7 +93,7 @@ private ResolveResult resolveLocal(String name) {
}
}

public List<Upvalue> upvalues() {
public List<LuaLocalVar> upvalues() {
return new ArrayList<>(upvalues.values());
}

Expand Down
41 changes: 38 additions & 3 deletions lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/LuaLocalVar.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,47 @@

import fi.benjami.code4jvm.Type;

public record LuaLocalVar(
String name
) implements LuaVariable {
public final class LuaLocalVar implements LuaVariable {

public static final Type TYPE = Type.of(LuaLocalVar.class);

public static final LuaLocalVar VARARGS = new LuaLocalVar("...");

private final String name;
private int mutationSites;
private boolean upvalue;

public LuaLocalVar(String name) {
this.name = name;
}

public String name() {
return name;
}

@Override
public void markMutable() {
mutationSites++;
}

/**
* Whether or not this local variable is ever assigned to after its initial
* assignment. This includes mutations by blocks that inherit it as upvalue
* (to be precise, Lua upvalues are essentially external local variables).
*/
public boolean mutable() {
return mutationSites > 1;
}

public void markUpvalue() {
upvalue = true;
}

/**
* Whether or not this local variable is an upvalue for some block.
*/
public boolean upvalue() {
return upvalue;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

public sealed interface LuaVariable permits LuaLocalVar, TableField {

void markMutable();
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,10 @@
public record TableField(
IrNode table,
IrNode field
) implements LuaVariable {}
) implements LuaVariable {

@Override
public void markMutable() {
// Do nothing, table fields are always mutable
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,11 @@ public record UpvalueTemplate(
* {@link LuaFunction#upvalueTypes final types} that are known after
* the function has been instantiated, this may be unknown.
*/
LuaType type
LuaType type,

/**
* Whether or not the upvalue variable is assigned to after its initial
* assignment.
*/
boolean mutable
) {}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import fi.benjami.code4jvm.Value;
import fi.benjami.code4jvm.block.Block;
import fi.benjami.code4jvm.call.CallTarget;
import fi.benjami.code4jvm.call.FixedCallTarget;
import fi.benjami.code4jvm.lua.compiler.LuaContext;
import fi.benjami.code4jvm.lua.ir.IrNode;
import fi.benjami.code4jvm.lua.ir.LuaLocalVar;
import fi.benjami.code4jvm.lua.ir.LuaType;
import fi.benjami.code4jvm.lua.linker.CallSiteOptions;
import fi.benjami.code4jvm.lua.linker.LuaLinker;
Expand All @@ -33,6 +35,16 @@ public Value emit(LuaContext ctx, Block block, String intrinsicId) {
var returnType = cache.returnType();

// TODO constant bootstrap is broken due to upvalues
// FixedCallTarget bootstrap;
// if (function instanceof VariableExpr variable // function is a variable read
// && variable.source() instanceof LuaLocalVar localVar // from local variable
// && localVar.upvalue() && !localVar.mutable() // that will be stable between calls to this function
// && TODO we also need to check that 1) linker has used LuaType.TARGET_HAS_CHANGED (to prove upvalue's block hasn't been re-executed)
// && TODO 2) cache key includes identity of the upvalue, not just its type! (this is quite tricky)
// ) {
//
// }
//
var bootstrap = LuaLinker.BOOTSTRAP_DYNAMIC;
var lastMultiVal = !args.isEmpty() && MultiVals.canReturnMultiVal(args.get(args.size() - 1));
var options = new CallSiteOptions(ctx.owner(), argTypes, ctx.allowSpread(), lastMultiVal, intrinsicId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public record FunctionDeclExpr(
*/
String name,

List<Upvalue> upvalues,
List<LuaLocalVar> upvalues,

/**
* Arguments inside the new function.
Expand All @@ -38,11 +38,6 @@ public record FunctionDeclExpr(
*/
LuaBlock body
) implements IrNode {

public record Upvalue(
LuaLocalVar inside,
LuaLocalVar outside
) {}

@Override
public Value emit(LuaContext ctx, Block block) {
Expand All @@ -52,7 +47,7 @@ public Value emit(LuaContext ctx, Block block) {
// Copy local variables to upvalues array
var upvalueValues = block.add(Type.OBJECT.array(1).newInstance(Constant.of(upvalues.size())));
for (var i = 0; i < upvalues.size(); i++) {
var value = ctx.resolveLocalVar(upvalues.get(i).outside());
var value = ctx.resolveLocalVar(upvalues.get(i));
block.add(ArrayAccess.set(upvalueValues, Constant.of(i), value.cast(Type.OBJECT)));
}

Expand All @@ -64,7 +59,7 @@ public Value emit(LuaContext ctx, Block block) {
public LuaType.Function outputType(LuaContext ctx) {
// Upvalue template has the variable INSIDE declared function, with type of OUTSIDE variable
var upvalueTemplates = upvalues.stream()
.map(upvalue -> new UpvalueTemplate(upvalue.inside(), ctx.variableType(upvalue.outside())))
.map(upvalue -> new UpvalueTemplate(upvalue, upvalue.mutable() ? LuaType.UNKNOWN : ctx.variableType(upvalue), upvalue.mutable()))
.toList();
return ctx.cached(this, LuaType.function(upvalueTemplates, arguments, body, moduleName, name));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import fi.benjami.code4jvm.lua.ir.TableField;
import fi.benjami.code4jvm.lua.linker.CallSiteOptions;
import fi.benjami.code4jvm.lua.linker.LuaLinker;
import fi.benjami.code4jvm.lua.runtime.LuaBox;
import fi.benjami.code4jvm.lua.runtime.TableAccess;

/**
Expand All @@ -25,7 +26,14 @@ public record VariableExpr(
@Override
public Value emit(LuaContext ctx, Block block) {
if (source instanceof LuaLocalVar localVar) {
return ctx.resolveLocalVar(localVar);
if (localVar.upvalue() && localVar.mutable()) {
// Mutable upvalues have to use LuaBoxes
// TODO cast should not be needed - potential code4jvm bug
return block.add(ctx.resolveLocalVar(localVar).cast(LuaBox.TYPE).getField(outputType(ctx).backingType(), "value"));
} else {
// Normal JVM local variable
return ctx.resolveLocalVar(localVar);
}
} else if (source instanceof TableField tableField) {
var table = tableField.table().emit(ctx, block);
var field = tableField.field().emit(ctx, block);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import fi.benjami.code4jvm.lua.ir.TableField;
import fi.benjami.code4jvm.lua.ir.expr.FunctionCallExpr;
import fi.benjami.code4jvm.lua.ir.expr.VariableExpr;
import fi.benjami.code4jvm.lua.runtime.LuaBox;
import fi.benjami.code4jvm.lua.runtime.LuaTable;
import fi.benjami.code4jvm.lua.runtime.MultiVals;

Expand Down Expand Up @@ -97,8 +98,19 @@ public Value emit(LuaContext ctx, Block block) {
private Statement setVariable(LuaContext ctx, LuaVariable variable, Value value) {
return block -> {
if (variable instanceof LuaLocalVar localVar) {
var jvmVar = ctx.resolveLocalVar(localVar);
block.add(jvmVar.set(value.cast(jvmVar.type())));
if (localVar.upvalue() && localVar.mutable()) {
// Mutable upvalues need to be put to LuaBoxes
if (!ctx.hasBeenAssigned(localVar)) {
// First assignment? Initialize box!
var box = block.add(LuaBox.TYPE.newInstance());
block.add(ctx.resolveLocalVar(localVar).set(box));
}
block.add(ctx.resolveLocalVar(localVar).putField("value", value.cast(Type.OBJECT)));
} else {
// Normal local variable assignment
var jvmVar = ctx.resolveLocalVar(localVar);
block.add(jvmVar.set(value.cast(jvmVar.type())));
}
} else if (variable instanceof TableField tableField) {
// Just call the setter
// TODO invokedynamic to TableAccess.CONSTANT_SET once it has some optimizations
Expand All @@ -115,7 +127,9 @@ private Statement setVariable(LuaContext ctx, LuaVariable variable, Value value)
public LuaType outputType(LuaContext ctx) {
var normalSources = spread ? sources.size() - 1 : sources.size();
for (var i = 0; i < Math.min(normalSources, targets.size()); i++) {
ctx.recordType(targets.get(i), sources.get(i).outputType(ctx));
var target = targets.get(i);
ctx.recordType(target, sources.get(i).outputType(ctx));
target.markMutable();
}

if (spread) {
Expand All @@ -128,12 +142,16 @@ public LuaType outputType(LuaContext ctx) {
// Tuple -> types for individual variables
// UNKNOWN -> current behavior
// anything else -> first multiValType, rest NIL
ctx.recordType(targets.get(i), LuaType.UNKNOWN);
var target = targets.get(i);
ctx.recordType(target, LuaType.UNKNOWN);
target.markMutable();
}
} else {
// If there are leftover targets, set them to nil
for (var i = normalSources; i < targets.size(); i++) {
ctx.recordType(targets.get(i), LuaType.NIL);
var target = targets.get(i);
ctx.recordType(target, LuaType.NIL);
target.markMutable();
}
}
return LuaType.NIL;
Expand Down
Loading

0 comments on commit 805e0e9

Please sign in to comment.