diff --git a/Common/src/main/java/at/petrak/hexcasting/api/casting/eval/vm/FrameForEach.kt b/Common/src/main/java/at/petrak/hexcasting/api/casting/eval/vm/FrameForEach.kt index 0221c1f36d..d16c5d8845 100644 --- a/Common/src/main/java/at/petrak/hexcasting/api/casting/eval/vm/FrameForEach.kt +++ b/Common/src/main/java/at/petrak/hexcasting/api/casting/eval/vm/FrameForEach.kt @@ -9,6 +9,7 @@ import at.petrak.hexcasting.api.utils.NBTBuilder import at.petrak.hexcasting.api.utils.getList import at.petrak.hexcasting.api.utils.hasList import at.petrak.hexcasting.api.utils.serializeToNBT +import at.petrak.hexcasting.api.utils.Vec import at.petrak.hexcasting.common.lib.hex.HexEvalSounds import at.petrak.hexcasting.common.lib.hex.HexIotaTypes import net.minecraft.nbt.CompoundTag @@ -28,14 +29,13 @@ data class FrameForEach( val data: SpellList, val code: SpellList, val baseStack: List?, - val acc: MutableList + val acc: Vec ) : ContinuationFrame { /** When halting, we add the stack state at halt to the stack accumulator, then return the original pre-Thoth stack, plus the accumulator. */ override fun breakDownwards(stack: List): Pair> { val newStack = baseStack?.toMutableList() ?: mutableListOf() - acc.addAll(stack) - newStack.add(ListIota(acc)) + newStack.add(ListIota(acc.appendAll(stack).toList())) return true to newStack } @@ -46,13 +46,12 @@ data class FrameForEach( harness: CastingVM ): CastResult { // If this isn't the very first Thoth step (i.e. no Thoth computations run yet)... - val stack = if (baseStack == null) { + val (stack, nextAcc) = if (baseStack == null) { // init stack to the harness stack... - harness.image.stack.toList() + harness.image.stack.toList() to acc } else { // else save the stack to the accumulator and reuse the saved base stack. - acc.addAll(harness.image.stack) - baseStack + baseStack to acc.appendAll(harness.image.stack) } // If we still have data to process... @@ -60,13 +59,13 @@ data class FrameForEach( // push the next datum to the top of the stack, val cont2 = continuation // put the next Thoth object back on the stack for the next Thoth cycle, - .pushFrame(FrameForEach(data.cdr, code, stack, acc)) + .pushFrame(FrameForEach(data.cdr, code, stack, nextAcc)) // and prep the Thoth'd code block for evaluation. .pushFrame(FrameEvaluate(code, true)) Triple(data.car, harness.image.withUsedOp(), cont2) } else { // Else, dump our final list onto the stack. - Triple(ListIota(acc), harness.image, continuation) + Triple(ListIota(acc.toList()), harness.image, continuation) } val tStack = stack.toMutableList() tStack.add(stackTop) @@ -86,10 +85,10 @@ data class FrameForEach( "code" %= code.serializeToNBT() if (baseStack != null) "base" %= baseStack.serializeToNBT() - "accumulator" %= acc.serializeToNBT() + "accumulator" %= acc.toList().serializeToNBT() } - override fun size() = data.size() + code.size() + acc.size + (baseStack?.size ?: 0) + override fun size() = data.size() + code.size() + acc.length + (baseStack?.size ?: 0) override val type: ContinuationFrame.Type<*> = TYPE @@ -104,10 +103,10 @@ data class FrameForEach( HexIotaTypes.LIST.deserialize(tag.getList("base", Tag.TAG_COMPOUND), world)!!.list.toList() else null, - HexIotaTypes.LIST.deserialize( + Vec.ofIterable(HexIotaTypes.LIST.deserialize( tag.getList("accumulator", Tag.TAG_COMPOUND), world - )!!.list.toMutableList() + )!!.list) ) } diff --git a/Common/src/main/java/at/petrak/hexcasting/api/utils/Vec.java b/Common/src/main/java/at/petrak/hexcasting/api/utils/Vec.java new file mode 100644 index 0000000000..abcb168022 --- /dev/null +++ b/Common/src/main/java/at/petrak/hexcasting/api/utils/Vec.java @@ -0,0 +1,319 @@ +package at.petrak.hexcasting.api.utils; + +import java.util.Arrays; +import java.util.Optional; +import java.util.Map; +import java.util.Iterator; +import java.util.RandomAccess; + +public class Vec implements Iterable, RandomAccess { + + static int hash(int key) { + return Integer.hashCode(key); + } + + // This is a persistent vector backed by a HAMT, see https://github.com/python/cpython/blob/main/Python/hamt.c. + // Currently just supports push/pop/random access. Will need to update in future if this becomes more user-facing. + sealed interface HamtNode { + HamtNode assoc(int hash, V val); + Optional get(int hash); + HamtNode dissoc(int hash); + int size(); + } + + // Array node: store children "densely" (when there are >16 children); size is the number of nonnull children + static record ArrayNode(int size, HamtNode[] children) implements HamtNode { + @Override + public HamtNode assoc(int hash, V val) { + int next = hash >>> 5; + hash &= 0x1f; + var child = children[hash]; + if (child != null) { + var newChild = child.assoc(next, val); + if (newChild == child) { + return this; + } + var newChildren = Arrays.copyOf(children, children.length); + newChildren[hash] = newChild; + return new ArrayNode<>(size, newChildren); + } + var newChildren = Arrays.copyOf(children, children.length); + newChildren[hash] = new SingleNode<>(next, val); + return new ArrayNode<>(size + 1, newChildren); + } + @Override + public Optional get(int hash) { + int next = hash >>> 5; + var child = children[hash & 0x1f]; + return child == null ? Optional.empty() : child.get(next); + } + @Override + public HamtNode dissoc(int hash) { + int next = hash >>> 5; + hash &= 0x1f; + var child = children[hash]; + if (child == null) { + return this; + } + var newChild = child.dissoc(next); + if (newChild == child) { + return this; + } + // TODO: if nchildren = 16 && newChild == null, downgrade? + if (size <= 16 && newChild == null) { + int pop = 0, index = 0; + @SuppressWarnings("unchecked") + var newChildren = (HamtNode[]) new HamtNode[size - 1]; + + for (int i = 0; i < children.length; i++) { + if (i != hash && children[i] != null) { + pop |= 1 << i; + newChildren[index++] = children[i]; + } + } + assert (size - 1 == index); + return new HamNode<>(pop, newChildren); + } + var newChildren = Arrays.copyOf(children, children.length); + newChildren[hash] = newChild; + return new ArrayNode<>(size - (newChild == null ? 1 : 0), newChildren); + } + + @Override + public int size() { + int count = 0; + for (int i = 0; i < children.length; i++) { + if (children[i] != null) { + count += children[i].size(); + } + } + return count; + } + + @Override public String toString() { return "A[" + Arrays.toString(children) + "]"; } + } + + // Array node: store children "sparsely" (<16 children); pop is a bitmap of the 32 children this can have + static record HamNode(int pop, HamtNode[] children) implements HamtNode { + @Override + public HamtNode assoc(int hash, V val) { + int next = hash >>> 5; + hash &= 0x1f; + int index = indexOf(pop, hash); + if (hasHash(pop, hash)) { + var child = children[index]; + var newChild = child.assoc(next, val); + if (child == newChild) { + return this; + } + var newChildren = Arrays.copyOf(children, children.length); + newChildren[index] = newChild; + return new HamNode<>(pop, newChildren); + } + if (children.length >= 15) { + @SuppressWarnings("unchecked") + var arrayEnts = (HamtNode[]) new HamtNode[32]; + int work = pop, inputPos = 0; + while (work != 0) { + int outputPos = Integer.numberOfTrailingZeros(work); + work &= work - 1; // remove lowest 1 + arrayEnts[outputPos] = children[inputPos++]; + } + arrayEnts[hash] = new SingleNode<>(next, val); + return new ArrayNode<>(1 + inputPos, arrayEnts); + } + @SuppressWarnings("unchecked") + var newChildren = (HamtNode[]) new HamtNode[children.length + 1]; + System.arraycopy(children, 0, newChildren, 0, index); + System.arraycopy(children, index, newChildren, index + 1, children.length - index); + newChildren[index] = new SingleNode<>(next, val); + return new HamNode<>(pop | 1 << hash, newChildren); + } + @Override + public Optional get(int hash) { + return hasHash(pop, hash & 0x1f) ? children[indexOf(pop, hash & 0x1f)].get(hash >>> 5) : Optional.empty(); + } + @Override + public HamtNode dissoc(int hash) { + int next = hash >>> 5; + hash &= 0x1f; + if (!hasHash(pop, hash)) { + return this; + } + int index = indexOf(pop, hash); + var child = children[index]; + var newChild = child.dissoc(next); + if (child == newChild) { + return this; + } + if (newChild != null) { + var newChildren = Arrays.copyOf(children, children.length); + newChildren[index] = newChild; + return new HamNode<>(pop, newChildren); + } + if (children.length == 1) { + return null; + } + int newPop = pop & ~(1 << hash); + if (children.length == 2) { + int remainingHash = Integer.numberOfTrailingZeros(newPop); + var childNode = children[indexOf(pop, remainingHash)]; + if (childNode instanceof SingleNode ln) { + return ln.withNewHash(ln.tailHash() << 5 | remainingHash); + } + } + @SuppressWarnings("unchecked") + var newChildren = (HamtNode[]) new HamtNode[children.length - 1]; + System.arraycopy(children, 0, newChildren, 0, index); + System.arraycopy(children, index + 1, newChildren, index, children.length - index - 1); + return new HamNode<>(newPop, newChildren); + } + + @Override + public int size() { + int count = 0; + for (int i = 0; i < children.length; i++) { + count += children[i].size(); + } + return count; + } + + static boolean hasHash(int pop, int hash) { + int offset = 1 << hash; + return (pop & offset) != 0; + } + + static int indexOf(int pop, int hash) { + int offset = 1 << hash; + return Integer.bitCount(pop & (offset - 1)); + } + + @Override public String toString() { return "H[" + Integer.toString(pop, 2) + ", " + Arrays.toString(children) + "]"; } + } + + static record SingleNode(int tailHash, V value) implements HamtNode { + public SingleNode withNewHash(int newHash) { + return new SingleNode<>(newHash, value); + } + @Override + public HamtNode assoc(int hash, V val) { + if (hash == tailHash) { + return new SingleNode<>(tailHash, val); + } + return assocRecursive(hash, tailHash, val); + } + @Override + public Optional get(int hash) { + if (tailHash == hash) { + return Optional.of(value); + } + return Optional.empty(); + } + @Override + public HamtNode dissoc(int hash) { + if (tailHash == hash) { + return null; + } + return this; + } + + @Override public int size() { return 1; } + + private HamtNode assocRecursive(int hash, int tailHash, V val) { + int nextHash = hash >>> 5; + int nextTailHash = tailHash >>> 5; + hash &= 0x1f; + tailHash &= 0x1f; + if (hash == tailHash) { + @SuppressWarnings("unchecked") + var child = (HamtNode[]) new HamtNode[] {assocRecursive(nextHash, nextTailHash, val)}; + return new HamNode<>(1 << hash, child); + } + var existingNode = withNewHash(nextTailHash); + var newNode = new SingleNode<>(nextHash, val); + var left = hash < tailHash ? newNode : existingNode; + var right = hash < tailHash ? existingNode : newNode; + @SuppressWarnings("unchecked") + var child = (HamtNode[]) new HamtNode[] {left, right}; + return new HamNode<>(1 << hash | 1 << tailHash, child); + } + } + + private HamtNode root; + public final int length; + + private Vec(HamtNode root, int length) { + this.root = root; + this.length = length; + } + + @SuppressWarnings("unchecked") + public static Vec empty() { + return new Vec<>(null, 0); + } + + public Vec append(V value) { + int key = length; + return new Vec<>(root != null ? root.assoc(Vec.hash(key), value) : new SingleNode<>(Vec.hash(key), value), length + 1); + } + + public Vec assoc(int pos, V value) { + if (0 <= pos && pos < length) { + return new Vec<>(root.assoc(Vec.hash(pos), value), length); + } + throw new IllegalArgumentException("Index " + pos + " out of bounds for vec of length " + length); + } + + public Vec pop() { + if (isEmpty()) { + throw new IllegalArgumentException("Can't pop from empty vec!"); + } + return new Vec<>(root.dissoc(Vec.hash(length - 1)), length - 1); + } + + public boolean isEmpty() { + return length == 0; + } + + public V get(int pos) { + if (0 <= pos && pos < length) { + return root.get(Vec.hash(pos)).orElseThrow(IllegalStateException::new); + } + throw new IllegalArgumentException("Index " + pos + " out of bounds for vec of length " + length); + } + + public int size() { + return length; // invariant: length == root.size() + // return root == null ? 0 : root.size(); + } + + public Vec appendAll(Iterable values) { + var out = this; + for (var entry : values) { + out = out.append(entry); + } + return out; + } + + public static Vec ofIterable(Iterable values) { + return Vec.empty().appendAll(values); + } + + HamtNode root() { return root; } + + // iterator over a HAMT + @Override + public Iterator iterator() { + return new Iterator() { + private int next = 0; + @Override + public boolean hasNext() { + return next < length; + } + + public V next() { + return get(next++); + } + }; + } +} diff --git a/Common/src/main/java/at/petrak/hexcasting/common/casting/actions/eval/OpForEach.kt b/Common/src/main/java/at/petrak/hexcasting/common/casting/actions/eval/OpForEach.kt index bdd4f76f7a..605ef260cd 100644 --- a/Common/src/main/java/at/petrak/hexcasting/common/casting/actions/eval/OpForEach.kt +++ b/Common/src/main/java/at/petrak/hexcasting/common/casting/actions/eval/OpForEach.kt @@ -8,6 +8,7 @@ import at.petrak.hexcasting.api.casting.eval.vm.FrameForEach import at.petrak.hexcasting.api.casting.eval.vm.SpellContinuation import at.petrak.hexcasting.api.casting.getList import at.petrak.hexcasting.api.casting.mishaps.MishapNotEnoughArgs +import at.petrak.hexcasting.api.utils.Vec import at.petrak.hexcasting.common.lib.hex.HexEvalSounds object OpForEach : Action { @@ -22,7 +23,7 @@ object OpForEach : Action { stack.removeLastOrNull() stack.removeLastOrNull() - val frame = FrameForEach(datums, instrs, null, mutableListOf()) + val frame = FrameForEach(datums, instrs, null, Vec.empty()) val image2 = image.withUsedOp().copy(stack = stack) return OperationResult(image2, listOf(), continuation.pushFrame(frame), HexEvalSounds.THOTH)