From 59ae5f428eb0d3aa09b9aa26f72141db3df2544a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20R=C5=BCysko?= Date: Mon, 27 Nov 2023 19:13:26 +0100 Subject: [PATCH] Add support for 512-bit vectors in utf-8 validator (#32) --- .../org/simdjson/CharactersClassifier.java | 8 ++-- .../java/org/simdjson/JsonStringScanner.java | 4 +- src/main/java/org/simdjson/StringParser.java | 4 +- .../java/org/simdjson/StructuralIndexer.java | 40 +++++++++++++------ src/main/java/org/simdjson/Utf8Validator.java | 11 ++--- src/test/java/org/simdjson/TestUtils.java | 2 +- .../java/org/simdjson/Utf8ValidatorTest.java | 4 +- 7 files changed, 44 insertions(+), 29 deletions(-) diff --git a/src/main/java/org/simdjson/CharactersClassifier.java b/src/main/java/org/simdjson/CharactersClassifier.java index 09d6515..68b685c 100644 --- a/src/main/java/org/simdjson/CharactersClassifier.java +++ b/src/main/java/org/simdjson/CharactersClassifier.java @@ -9,14 +9,14 @@ class CharactersClassifier { private static final ByteVector WHITESPACE_TABLE = ByteVector.fromArray( - StructuralIndexer.SPECIES, - repeat(new byte[]{' ', 100, 100, 100, 17, 100, 113, 2, 100, '\t', '\n', 112, 100, '\r', 100, 100}, StructuralIndexer.SPECIES.vectorByteSize() / 4), + StructuralIndexer.BYTE_SPECIES, + repeat(new byte[]{' ', 100, 100, 100, 17, 100, 113, 2, 100, '\t', '\n', 112, 100, '\r', 100, 100}, StructuralIndexer.BYTE_SPECIES.vectorByteSize() / 4), 0); private static final ByteVector OP_TABLE = ByteVector.fromArray( - StructuralIndexer.SPECIES, - repeat(new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ':', '{', ',', '}', 0, 0}, StructuralIndexer.SPECIES.vectorByteSize() / 4), + StructuralIndexer.BYTE_SPECIES, + repeat(new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ':', '{', ',', '}', 0, 0}, StructuralIndexer.BYTE_SPECIES.vectorByteSize() / 4), 0); private static byte[] repeat(byte[] array, int n) { diff --git a/src/main/java/org/simdjson/JsonStringScanner.java b/src/main/java/org/simdjson/JsonStringScanner.java index f8e9a12..6d856ac 100644 --- a/src/main/java/org/simdjson/JsonStringScanner.java +++ b/src/main/java/org/simdjson/JsonStringScanner.java @@ -14,8 +14,8 @@ class JsonStringScanner { private long prevEscaped = 0; JsonStringScanner() { - this.backslashMask = ByteVector.broadcast(StructuralIndexer.SPECIES, (byte) '\\'); - this.quoteMask = ByteVector.broadcast(StructuralIndexer.SPECIES, (byte) '"'); + this.backslashMask = ByteVector.broadcast(StructuralIndexer.BYTE_SPECIES, (byte) '\\'); + this.quoteMask = ByteVector.broadcast(StructuralIndexer.BYTE_SPECIES, (byte) '"'); } JsonStringBlock next(ByteVector chunk0) { diff --git a/src/main/java/org/simdjson/StringParser.java b/src/main/java/org/simdjson/StringParser.java index 074a3db..11fb7fd 100644 --- a/src/main/java/org/simdjson/StringParser.java +++ b/src/main/java/org/simdjson/StringParser.java @@ -10,7 +10,7 @@ class StringParser { private static final byte BACKSLASH = '\\'; private static final byte QUOTE = '"'; - private static final int BYTES_PROCESSED = StructuralIndexer.SPECIES.vectorByteSize(); + private static final int BYTES_PROCESSED = StructuralIndexer.BYTE_SPECIES.vectorByteSize(); private static final int MIN_HIGH_SURROGATE = 0xD800; private static final int MAX_HIGH_SURROGATE = 0xDBFF; private static final int MIN_LOW_SURROGATE = 0xDC00; @@ -31,7 +31,7 @@ void parseString(byte[] buffer, int idx) { int src = idx + 1; int dst = stringBufferIdx + Integer.BYTES; while (true) { - ByteVector srcVec = ByteVector.fromArray(StructuralIndexer.SPECIES, buffer, src); + ByteVector srcVec = ByteVector.fromArray(StructuralIndexer.BYTE_SPECIES, buffer, src); srcVec.intoArray(stringBuffer, dst); long backslashBits = srcVec.eq(BACKSLASH).toLong(); long quoteBits = srcVec.eq(QUOTE).toLong(); diff --git a/src/main/java/org/simdjson/StructuralIndexer.java b/src/main/java/org/simdjson/StructuralIndexer.java index c0eb4b0..43ec952 100644 --- a/src/main/java/org/simdjson/StructuralIndexer.java +++ b/src/main/java/org/simdjson/StructuralIndexer.java @@ -1,27 +1,43 @@ package org.simdjson; import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.VectorShape; import jdk.incubator.vector.VectorSpecies; -import java.lang.invoke.MethodType; import static jdk.incubator.vector.VectorOperators.UNSIGNED_LE; class StructuralIndexer { - static final VectorSpecies SPECIES; + static final VectorSpecies INT_SPECIES; + static final VectorSpecies BYTE_SPECIES; static final int N_CHUNKS; static { String species = System.getProperty("org.simdjson.species", "preferred"); - SPECIES = switch(species) { - case "preferred" -> ByteVector.SPECIES_PREFERRED; - case "512" -> ByteVector.SPECIES_512; - case "256" -> ByteVector.SPECIES_256; + switch (species) { + case "preferred" -> { + BYTE_SPECIES = ByteVector.SPECIES_PREFERRED; + INT_SPECIES = IntVector.SPECIES_PREFERRED; + } + case "512" -> { + BYTE_SPECIES = ByteVector.SPECIES_512; + INT_SPECIES = IntVector.SPECIES_512; + } + case "256" -> { + BYTE_SPECIES = ByteVector.SPECIES_256; + INT_SPECIES = IntVector.SPECIES_256; + } default -> throw new IllegalArgumentException("Unsupported vector species: " + species); - }; - N_CHUNKS = 64 / SPECIES.vectorByteSize(); - if (SPECIES != ByteVector.SPECIES_256 && SPECIES != ByteVector.SPECIES_512) { - throw new IllegalArgumentException("Unsupported vector species: " + SPECIES); + } + N_CHUNKS = 64 / BYTE_SPECIES.vectorByteSize(); + assertSupportForSpecies(BYTE_SPECIES); + assertSupportForSpecies(INT_SPECIES); + } + + private static void assertSupportForSpecies(VectorSpecies species) { + if (species.vectorShape() != VectorShape.S_256_BIT && species.vectorShape() != VectorShape.S_512_BIT) { + throw new IllegalArgumentException("Unsupported vector species: " + species); } } @@ -48,7 +64,7 @@ void step(byte[] buffer, int offset, int blockIndex) { } private void step1(byte[] buffer, int offset, int blockIndex) { - ByteVector chunk0 = ByteVector.fromArray(ByteVector.SPECIES_512, buffer, offset); + ByteVector chunk0 = ByteVector.fromArray(ByteVector.SPECIES_512, buffer, offset); JsonStringBlock strings = stringScanner.next(chunk0); JsonCharacterBlock characters = classifier.classify(chunk0); long unescaped = lteq(chunk0, (byte) 0x1F); @@ -75,7 +91,7 @@ private void finishStep(JsonCharacterBlock characters, JsonStringBlock strings, bitIndexes.write(blockIndex, prevStructurals); prevStructurals = potentialStructuralStart & ~strings.stringTail(); unescapedCharsError |= strings.nonQuoteInsideString(unescaped); - } + } private long lteq(ByteVector chunk0, byte scalar) { long r = chunk0.compare(UNSIGNED_LE, scalar).toLong(); diff --git a/src/main/java/org/simdjson/Utf8Validator.java b/src/main/java/org/simdjson/Utf8Validator.java index 2838d76..e4d9c63 100644 --- a/src/main/java/org/simdjson/Utf8Validator.java +++ b/src/main/java/org/simdjson/Utf8Validator.java @@ -4,11 +4,12 @@ import java.util.Arrays; -public class Utf8Validator { - private static final VectorSpecies VECTOR_SPECIES = ByteVector.SPECIES_256; +class Utf8Validator { + + private static final VectorSpecies VECTOR_SPECIES = StructuralIndexer.BYTE_SPECIES; private static final ByteVector INCOMPLETE_CHECK = getIncompleteCheck(); - private static final VectorShuffle SHIFT_FOUR_BYTES_FORWARD = VectorShuffle.iota(IntVector.SPECIES_256, - IntVector.SPECIES_256.elementSize() - 1, 1, true); + private static final VectorShuffle SHIFT_FOUR_BYTES_FORWARD = VectorShuffle.iota(StructuralIndexer.INT_SPECIES, + StructuralIndexer.INT_SPECIES.elementSize() - 1, 1, true); private static final ByteVector LOW_NIBBLE_MASK = ByteVector.broadcast(VECTOR_SPECIES, 0b0000_1111); private static final ByteVector ALL_ASCII_MASK = ByteVector.broadcast(VECTOR_SPECIES, (byte) 0b1000_0000); @@ -39,7 +40,7 @@ static void validate(byte[] inputBytes) { errors |= secondCheck.compare(VectorOperators.NE, 0).toLong(); } - previousFourUtf8Bytes = utf8Vector.reinterpretAsInts().lane(IntVector.SPECIES_256.length() - 1); + previousFourUtf8Bytes = utf8Vector.reinterpretAsInts().lane(StructuralIndexer.INT_SPECIES.length() - 1); } // if the input file doesn't align with the vector width, pad the missing bytes with zero diff --git a/src/test/java/org/simdjson/TestUtils.java b/src/test/java/org/simdjson/TestUtils.java index 0fee084..8d63221 100644 --- a/src/test/java/org/simdjson/TestUtils.java +++ b/src/test/java/org/simdjson/TestUtils.java @@ -19,7 +19,7 @@ static String padWithSpaces(String str) { } static ByteVector chunk(String str, int n) { - return ByteVector.fromArray(StructuralIndexer.SPECIES, str.getBytes(UTF_8), n * StructuralIndexer.SPECIES.vectorByteSize()); + return ByteVector.fromArray(StructuralIndexer.BYTE_SPECIES, str.getBytes(UTF_8), n * StructuralIndexer.BYTE_SPECIES.vectorByteSize()); } static byte[] toUtf8(String str) { diff --git a/src/test/java/org/simdjson/Utf8ValidatorTest.java b/src/test/java/org/simdjson/Utf8ValidatorTest.java index b129e86..995323b 100644 --- a/src/test/java/org/simdjson/Utf8ValidatorTest.java +++ b/src/test/java/org/simdjson/Utf8ValidatorTest.java @@ -1,6 +1,5 @@ package org.simdjson; -import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.VectorSpecies; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -8,12 +7,11 @@ import java.io.IOException; import java.util.Arrays; -import java.util.Objects; import static org.assertj.core.api.Assertions.*; class Utf8ValidatorTest { - private static final VectorSpecies VECTOR_SPECIES = StructuralIndexer.SPECIES; + private static final VectorSpecies VECTOR_SPECIES = StructuralIndexer.BYTE_SPECIES; /* ASCII / 1 BYTE TESTS */