Skip to content

Commit

Permalink
Add missing accumulate arm implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Nov 4, 2024
1 parent 272c6cc commit 9124935
Showing 1 changed file with 63 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2179,6 +2179,9 @@ public void accumulate(AbstractTensor aBatch, AbstractTensor bBatch, int offset,
case AVX_256:
accumulateF32BF16_256((FloatBufferTensor) a, (BFloat16BufferTensor) b, offset, limit);
break;
case ARM_128:
accumulateF32BF16_arm((FloatBufferTensor) a, (BFloat16BufferTensor) b, offset, limit);
break;
default:
throw new UnsupportedOperationException();
}
Expand All @@ -2197,6 +2200,9 @@ public void accumulate(AbstractTensor aBatch, AbstractTensor bBatch, int offset,
case AVX_256:
accumulateBF16_256((BFloat16BufferTensor) a, (BFloat16BufferTensor) b, offset, limit);
break;
case ARM_128:
accumulateBF16_arm((BFloat16BufferTensor) a, (BFloat16BufferTensor) b, offset, limit);
break;
default:
throw new UnsupportedOperationException();
}
Expand Down Expand Up @@ -2343,6 +2349,63 @@ void accumulateF32BF16_256(FloatBufferTensor a, BFloat16BufferTensor b, int offs
}
}

void accumulateF32BF16_arm(FloatBufferTensor a, BFloat16BufferTensor b, int offset, int limit) {
int upperBound = offset + FloatVector.SPECIES_128.loopBound(limit);

int i = offset;
for (; i < upperBound; i += FloatVector.SPECIES_128.length()) {

// F32
var af = a.getVector(FloatVector.SPECIES_128, 0, i);

// Convert BF16 to F32
var bf = b.getVector(ShortVector.SPECIES_64, 0, i)
.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0)
.lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128)
.reinterpretAsFloats();

var res = af.add(bf);
a.intoTensor(res, 0, i);
}

// tail
for (; i < offset + limit; i++) {
a.set(a.get(0, i) + b.get(0, i), 0, i);
}
}

void accumulateBF16_arm(BFloat16BufferTensor a, BFloat16BufferTensor b, int offset, int limit) {
int upperBound = offset + FloatVector.SPECIES_128.loopBound(limit);

int i = offset;
for (; i < upperBound; i += FloatVector.SPECIES_128.length()) {

// Convert BF16 to F32
var af = a.getVector(ShortVector.SPECIES_64, 0, i)
.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0)
.lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128)
.reinterpretAsFloats();

// Convert BF16 to F32
var bf = b.getVector(ShortVector.SPECIES_64, 0, i)
.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0)
.lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128)
.reinterpretAsFloats();

var res = af.add(bf)
.reinterpretAsInts()
.lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT_128)
.convertShape(VectorOperators.I2S, ShortVector.SPECIES_64, 0);

a.intoTensor((ShortVector) res, 0, i);
}

// tail
for (; i < offset + limit; i++) {
a.set(a.get(0, i) + b.get(0, i), 0, i);
}
}

void accumulateBF16_256(BFloat16BufferTensor a, BFloat16BufferTensor b, int offset, int limit) {
int upperBound = offset + FloatVector.SPECIES_256.loopBound(limit);

Expand Down

0 comments on commit 9124935

Please sign in to comment.