diff --git a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IteratorForStmt.java b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IteratorForStmt.java index 2bdc2a3..d975e4b 100644 --- a/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IteratorForStmt.java +++ b/lua4jvm/src/main/java/fi/benjami/code4jvm/lua/ir/stmt/IteratorForStmt.java @@ -1,12 +1,12 @@ package fi.benjami.code4jvm.lua.ir.stmt; -import java.util.Arrays; import java.util.List; import fi.benjami.code4jvm.Condition; import fi.benjami.code4jvm.Constant; import fi.benjami.code4jvm.Type; import fi.benjami.code4jvm.Value; +import fi.benjami.code4jvm.Variable; import fi.benjami.code4jvm.block.Block; import fi.benjami.code4jvm.call.CallTarget; import fi.benjami.code4jvm.lua.compiler.LoopRef; @@ -17,11 +17,10 @@ import fi.benjami.code4jvm.lua.ir.LuaType; import fi.benjami.code4jvm.lua.linker.CallSiteOptions; import fi.benjami.code4jvm.lua.linker.LuaLinker; -import fi.benjami.code4jvm.lua.runtime.MultiVals; import fi.benjami.code4jvm.lua.stdlib.LuaException; import fi.benjami.code4jvm.statement.ArrayAccess; -import fi.benjami.code4jvm.statement.Jump; import fi.benjami.code4jvm.statement.Instanceof; +import fi.benjami.code4jvm.statement.Jump; import fi.benjami.code4jvm.structure.IfBlock; public record IteratorForStmt( @@ -42,7 +41,7 @@ public Value emit(LuaContext ctx, Block block) { .toList(); var init = Block.create("iterator for init"); - Value next, state; + Variable next = Variable.create(Type.OBJECT), state = Variable.create(Type.OBJECT); var control = loopJvmVars.get(0); if (iterable.size() == 1) { // Before loop body, call the iterable to (hopefully) produce an array of: @@ -51,21 +50,42 @@ public Value emit(LuaContext ctx, Block block) { ctx.setAllowSpread(true); var iterator = iterable.get(0).emit(ctx, init); ctx.setAllowSpread(false); - // FIXME guard against too short array! - next = init.add(ArrayAccess.get(iterator, Constant.of(0))); - state = init.add(ArrayAccess.get(iterator, Constant.of(1))); - init.add(control, ArrayAccess.get(iterator, Constant.of(2))); + + // We might've gotten a multival of next, state, control or only some of those + // Set state, control to null as they are technically optional + init.add(state.set(Constant.nullValue(Type.OBJECT))); + init.add(control.set(Constant.nullValue(Type.OBJECT))); + + // Extract the values + var innerInit = new IfBlock(); + innerInit.branch(inner -> { + var isArray = inner.add(Instanceof.isInstance(iterator, Type.OBJECT.array(1))); + return Condition.isTrue(isArray); + }, inner -> { + // Multival, extract array elements + var array = iterator.cast(Type.OBJECT.array(1)); + var length = inner.add(ArrayAccess.length(array)); + inner.add(next, ArrayAccess.get(array, Constant.of(0))); // This must exist, but TODO improve error messages + inner.add(Jump.to(init, Jump.Target.END, Condition.equal(length, Constant.of(1)))); + inner.add(ArrayAccess.get(array, Constant.of(1))); + inner.add(control, ArrayAccess.get(array, Constant.of(2))); + }); + innerInit.fallback(inner -> { + // Only one value, the iterator function + inner.add(next.set(iterator)); + }); + init.add(innerInit); } else { // Of course, it doesn't strictly NEED to be one function... if (iterable.size() > 0) { - next = iterable.get(0).emit(ctx, init); + init.add(next.set(iterable.get(0).emit(ctx, init))); } else { throw new LuaException("for iterator is nil"); } if (iterable.size() > 1) { - state = iterable.get(1).emit(ctx, init); + init.add(state.set(iterable.get(1).emit(ctx, init))); } else { - state = Constant.nullValue(Type.OBJECT); + init.add(state.set(Constant.nullValue(Type.OBJECT))); } if (iterable.size() > 2) { var value = iterable.get(2).emit(ctx, init); @@ -93,8 +113,8 @@ public Value emit(LuaContext ctx, Block block) { // Assign whatever was returned to loop variables var assignVars = new IfBlock(); assignVars.branch(inner -> { - var result = inner.add(Instanceof.isInstance(results, Type.OBJECT.array(1))); - return Condition.isTrue(result); + var isArray = inner.add(Instanceof.isInstance(results, Type.OBJECT.array(1))); + return Condition.isTrue(isArray); }, inner -> { // Got multival for loop variables var array = results.cast(Type.OBJECT.array(1)); diff --git a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/LoopTest.java b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/LoopTest.java index b415b33..36e11ec 100644 --- a/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/LoopTest.java +++ b/lua4jvm/src/test/java/fi/benjami/code4jvm/lua/test/LoopTest.java @@ -54,4 +54,85 @@ local function testNext(state, ctrl) assertEquals(true, tbl.get(i)); } } + + @Test + public void iteratorFromFunction() throws Throwable { + vm.execute(""" + local function testNext(state, ctrl) + if ctrl == 10 then + return nil + else + return ctrl + 1 + end + end + + local function testIterator() + return testNext, nil, 0 + end + + iTbl = {} + for i in testIterator() do + iTbl[i] = true + end + """); + var tbl = (LuaTable) vm.globals().get("iTbl"); + for (double i = 1; i < 10; i++) { + assertEquals(true, tbl.get(i)); + } + } + + @Test + public void iteratorFromFunction2() throws Throwable { + // Almost same as above, but testIterator() does not return multival + vm.execute(""" + local function testNext(state, ctrl) + if ctrl == nil then + return 1 + elseif ctrl == 10 then + return nil + else + return ctrl + 1 + end + end + + local function testIterator() + return testNext + end + + iTbl = {} + for i in testIterator() do + iTbl[i] = true + end + """); + var tbl = (LuaTable) vm.globals().get("iTbl"); + for (double i = 1; i < 10; i++) { + assertEquals(true, tbl.get(i)); + } + } + + @Test + public void breakFor() throws Throwable { + vm.execute(""" + local function testNext(state, ctrl) + if ctrl == 10 then + return nil + else + return ctrl + 1 + end + end + + local function testIterator() + return testNext, nil, 0 + end + + iTbl = {} + for i in testIterator() do + iTbl[i] = true + break + end + """); + var tbl = (LuaTable) vm.globals().get("iTbl"); + assertEquals(true, tbl.get(1d)); + assertEquals(null, tbl.get(2d)); + } }