From 9124935a0a33aa318295b0f28d5a874f5f29d842 Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Sun, 3 Nov 2024 21:58:26 -0500 Subject: [PATCH] Add missing accumulate arm implementations --- .../operations/PanamaTensorOperations.java | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java index aef4139..f03ca84 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java @@ -2179,6 +2179,9 @@ public void accumulate(AbstractTensor aBatch, AbstractTensor bBatch, int offset, case AVX_256: accumulateF32BF16_256((FloatBufferTensor) a, (BFloat16BufferTensor) b, offset, limit); break; + case ARM_128: + accumulateF32BF16_arm((FloatBufferTensor) a, (BFloat16BufferTensor) b, offset, limit); + break; default: throw new UnsupportedOperationException(); } @@ -2197,6 +2200,9 @@ public void accumulate(AbstractTensor aBatch, AbstractTensor bBatch, int offset, case AVX_256: accumulateBF16_256((BFloat16BufferTensor) a, (BFloat16BufferTensor) b, offset, limit); break; + case ARM_128: + accumulateBF16_arm((BFloat16BufferTensor) a, (BFloat16BufferTensor) b, offset, limit); + break; default: throw new UnsupportedOperationException(); } @@ -2343,6 +2349,63 @@ void accumulateF32BF16_256(FloatBufferTensor a, BFloat16BufferTensor b, int offs } } + void accumulateF32BF16_arm(FloatBufferTensor a, BFloat16BufferTensor b, int offset, int limit) { + int upperBound = offset + FloatVector.SPECIES_128.loopBound(limit); + + int i = offset; + for (; i < upperBound; i += FloatVector.SPECIES_128.length()) { + + // F32 + var af = a.getVector(FloatVector.SPECIES_128, 0, i); + + // Convert BF16 to F32 + var bf = b.getVector(ShortVector.SPECIES_64, 0, i) + .convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0) + .lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128) + .reinterpretAsFloats(); + + var res = af.add(bf); + a.intoTensor(res, 0, i); + } + + // tail + for (; i < offset + limit; i++) { + a.set(a.get(0, i) + b.get(0, i), 0, i); + } + } + + void accumulateBF16_arm(BFloat16BufferTensor a, BFloat16BufferTensor b, int offset, int limit) { + int upperBound = offset + FloatVector.SPECIES_128.loopBound(limit); + + int i = offset; + for (; i < upperBound; i += FloatVector.SPECIES_128.length()) { + + // Convert BF16 to F32 + var af = a.getVector(ShortVector.SPECIES_64, 0, i) + .convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0) + .lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128) + .reinterpretAsFloats(); + + // Convert BF16 to F32 + var bf = b.getVector(ShortVector.SPECIES_64, 0, i) + .convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0) + .lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128) + .reinterpretAsFloats(); + + var res = af.add(bf) + .reinterpretAsInts() + .lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT_128) + .convertShape(VectorOperators.I2S, ShortVector.SPECIES_64, 0); + + a.intoTensor((ShortVector) res, 0, i); + } + + // tail + for (; i < offset + limit; i++) { + a.set(a.get(0, i) + b.get(0, i), 0, i); + } + } + void accumulateBF16_256(BFloat16BufferTensor a, BFloat16BufferTensor b, int offset, int limit) { int upperBound = offset + FloatVector.SPECIES_256.loopBound(limit);