From 6ec178edc03818fe75e0aaead8bb12b3c25de120 Mon Sep 17 00:00:00 2001 From: "Mateusz \"Serafin\" Gajewski" Date: Sat, 11 May 2024 17:26:12 +0200 Subject: [PATCH] Move Slice to use MemorySegment --- .../airlift/slice/InputStreamSliceInput.java | 10 +- src/main/java/io/airlift/slice/JvmUtils.java | 82 ---- .../slice/OutputStreamSliceOutput.java | 8 +- src/main/java/io/airlift/slice/SizeOf.java | 22 + src/main/java/io/airlift/slice/Slice.java | 380 +++++------------- src/main/java/io/airlift/slice/SliceUtf8.java | 66 +-- src/main/java/io/airlift/slice/XxHash64.java | 113 +++--- .../io/airlift/slice/MemoryCopyBenchmark.java | 20 +- src/test/java/io/airlift/slice/TestSlice.java | 22 +- .../java/io/airlift/slice/TestSlices.java | 49 --- 10 files changed, 265 insertions(+), 507 deletions(-) delete mode 100644 src/main/java/io/airlift/slice/JvmUtils.java diff --git a/src/main/java/io/airlift/slice/InputStreamSliceInput.java b/src/main/java/io/airlift/slice/InputStreamSliceInput.java index 254a28fa..588efdfe 100644 --- a/src/main/java/io/airlift/slice/InputStreamSliceInput.java +++ b/src/main/java/io/airlift/slice/InputStreamSliceInput.java @@ -110,7 +110,7 @@ public boolean readBoolean() public byte readByte() { ensureAvailable(SIZE_OF_BYTE); - byte v = slice.getByteUnchecked(bufferPosition); + byte v = slice.getByte(bufferPosition); bufferPosition += SIZE_OF_BYTE; return v; } @@ -125,7 +125,7 @@ public int readUnsignedByte() public short readShort() { ensureAvailable(SIZE_OF_SHORT); - short v = slice.getShortUnchecked(bufferPosition); + short v = slice.getShort(bufferPosition); bufferPosition += SIZE_OF_SHORT; return v; } @@ -140,7 +140,7 @@ public int readUnsignedShort() public int readInt() { ensureAvailable(SIZE_OF_INT); - int v = slice.getIntUnchecked(bufferPosition); + int v = slice.getInt(bufferPosition); bufferPosition += SIZE_OF_INT; return v; } @@ -149,7 +149,7 @@ public int readInt() public long readLong() { ensureAvailable(SIZE_OF_LONG); - long v = slice.getLongUnchecked(bufferPosition); + long v = slice.getLong(bufferPosition); bufferPosition += SIZE_OF_LONG; return v; } @@ -174,7 +174,7 @@ public int read() } verify(availableBytes() > 0); - int v = slice.getByteUnchecked(bufferPosition) & 0xFF; + int v = slice.getByte(bufferPosition) & 0xFF; bufferPosition += SIZE_OF_BYTE; return v; } diff --git a/src/main/java/io/airlift/slice/JvmUtils.java b/src/main/java/io/airlift/slice/JvmUtils.java deleted file mode 100644 index 488e28fd..00000000 --- a/src/main/java/io/airlift/slice/JvmUtils.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.airlift.slice; - -import sun.misc.Unsafe; - -import java.lang.reflect.Field; -import java.nio.Buffer; -import java.nio.ByteOrder; - -import static io.airlift.slice.Preconditions.checkArgument; -import static sun.misc.Unsafe.ARRAY_BOOLEAN_INDEX_SCALE; -import static sun.misc.Unsafe.ARRAY_BYTE_INDEX_SCALE; -import static sun.misc.Unsafe.ARRAY_DOUBLE_INDEX_SCALE; -import static sun.misc.Unsafe.ARRAY_FLOAT_INDEX_SCALE; -import static sun.misc.Unsafe.ARRAY_INT_INDEX_SCALE; -import static sun.misc.Unsafe.ARRAY_LONG_INDEX_SCALE; -import static sun.misc.Unsafe.ARRAY_SHORT_INDEX_SCALE; - -final class JvmUtils -{ - static final Unsafe unsafe; - private static final long ADDRESS_OFFSET; - - static { - if (!ByteOrder.LITTLE_ENDIAN.equals(ByteOrder.nativeOrder())) { - throw new UnsupportedOperationException("Slice only supports little endian machines."); - } - - try { - // fetch theUnsafe object - Field field = Unsafe.class.getDeclaredField("theUnsafe"); - field.setAccessible(true); - unsafe = (Unsafe) field.get(null); - if (unsafe == null) { - throw new RuntimeException("Unsafe access not available"); - } - - // verify the stride of arrays matches the width of primitives - assertArrayIndexScale("Boolean", ARRAY_BOOLEAN_INDEX_SCALE, 1); - assertArrayIndexScale("Byte", ARRAY_BYTE_INDEX_SCALE, 1); - assertArrayIndexScale("Short", ARRAY_SHORT_INDEX_SCALE, 2); - assertArrayIndexScale("Int", ARRAY_INT_INDEX_SCALE, 4); - assertArrayIndexScale("Long", ARRAY_LONG_INDEX_SCALE, 8); - assertArrayIndexScale("Float", ARRAY_FLOAT_INDEX_SCALE, 4); - assertArrayIndexScale("Double", ARRAY_DOUBLE_INDEX_SCALE, 8); - - // fetch the address field for direct buffers - ADDRESS_OFFSET = unsafe.objectFieldOffset(Buffer.class.getDeclaredField("address")); - } - catch (ReflectiveOperationException e) { - throw new RuntimeException(e); - } - } - - private static void assertArrayIndexScale(String name, int actualIndexScale, int expectedIndexScale) - { - if (actualIndexScale != expectedIndexScale) { - throw new IllegalStateException(name + " array index scale must be " + expectedIndexScale + ", but is " + actualIndexScale); - } - } - - static long bufferAddress(Buffer buffer) - { - checkArgument(buffer.isDirect(), "buffer is not direct"); - - return unsafe.getLong(buffer, ADDRESS_OFFSET); - } - - private JvmUtils() {} -} diff --git a/src/main/java/io/airlift/slice/OutputStreamSliceOutput.java b/src/main/java/io/airlift/slice/OutputStreamSliceOutput.java index 90dadef1..0dd692bc 100644 --- a/src/main/java/io/airlift/slice/OutputStreamSliceOutput.java +++ b/src/main/java/io/airlift/slice/OutputStreamSliceOutput.java @@ -133,7 +133,7 @@ public boolean isWritable() public void writeByte(int value) { ensureWritableBytes(SIZE_OF_BYTE); - slice.setByteUnchecked(bufferPosition, value); + slice.setByte(bufferPosition, value); bufferPosition += SIZE_OF_BYTE; } @@ -141,7 +141,7 @@ public void writeByte(int value) public void writeShort(int value) { ensureWritableBytes(SIZE_OF_SHORT); - slice.setShortUnchecked(bufferPosition, value); + slice.setShort(bufferPosition, value); bufferPosition += SIZE_OF_SHORT; } @@ -149,7 +149,7 @@ public void writeShort(int value) public void writeInt(int value) { ensureWritableBytes(SIZE_OF_INT); - slice.setIntUnchecked(bufferPosition, value); + slice.setInt(bufferPosition, value); bufferPosition += SIZE_OF_INT; } @@ -157,7 +157,7 @@ public void writeInt(int value) public void writeLong(long value) { ensureWritableBytes(SIZE_OF_LONG); - slice.setLongUnchecked(bufferPosition, value); + slice.setLong(bufferPosition, value); bufferPosition += SIZE_OF_LONG; } diff --git a/src/main/java/io/airlift/slice/SizeOf.java b/src/main/java/io/airlift/slice/SizeOf.java index f32f7885..3affd58b 100644 --- a/src/main/java/io/airlift/slice/SizeOf.java +++ b/src/main/java/io/airlift/slice/SizeOf.java @@ -20,6 +20,7 @@ import org.openjdk.jol.vm.VM; import org.openjdk.jol.vm.VirtualMachine; +import java.lang.foreign.MemorySegment; import java.util.AbstractMap; import java.util.List; import java.util.Map; @@ -79,6 +80,27 @@ public final class SizeOf private static final int SIMPLE_ENTRY_INSTANCE_SIZE = instanceSize(AbstractMap.SimpleEntry.class); + private static final int MEMORY_SEGMENT_INSTANCE_SIZE = instanceSize(MemorySegment.ofArray(new byte[0]).getClass()); + + public static long sizeOf(MemorySegment segment) + { + if (segment.isNative()) { + return MEMORY_SEGMENT_INSTANCE_SIZE; + } + + return MEMORY_SEGMENT_INSTANCE_SIZE + segment.heapBase() // base + .map(value -> switch (value) { + case byte[] byteArray -> sizeOf(byteArray); + case short[] shortArray -> sizeOf(shortArray); + case int[] intArray -> sizeOf(intArray); + case long[] longArray -> sizeOf(longArray); + case float[] floatArray -> sizeOf(floatArray); + case double[] doubleArray -> sizeOf(doubleArray); + default -> throw new UnsupportedOperationException("Unsupported heap type: " + value.getClass()); + }) + .orElseThrow(); + } + public static long sizeOf(boolean[] array) { return (array == null) ? 0 : sizeOfBooleanArray(array.length); diff --git a/src/main/java/io/airlift/slice/Slice.java b/src/main/java/io/airlift/slice/Slice.java index fe7c3889..6fdbaa40 100644 --- a/src/main/java/io/airlift/slice/Slice.java +++ b/src/main/java/io/airlift/slice/Slice.java @@ -13,68 +13,44 @@ */ package io.airlift.slice; -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; - import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.lang.invoke.VarHandle; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import static io.airlift.slice.JvmUtils.unsafe; import static io.airlift.slice.Preconditions.checkArgument; -import static io.airlift.slice.SizeOf.SIZE_OF_BYTE; -import static io.airlift.slice.SizeOf.SIZE_OF_DOUBLE; -import static io.airlift.slice.SizeOf.SIZE_OF_FLOAT; import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; -import static io.airlift.slice.SizeOf.SIZE_OF_SHORT; import static io.airlift.slice.SizeOf.instanceSize; -import static io.airlift.slice.SizeOf.sizeOf; -import static java.lang.invoke.MethodHandles.byteArrayViewVarHandle; +import static java.lang.Math.toIntExact; import static java.nio.ByteOrder.LITTLE_ENDIAN; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.checkFromIndexSize; import static java.util.Objects.requireNonNull; -import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; -import static sun.misc.Unsafe.ARRAY_DOUBLE_BASE_OFFSET; -import static sun.misc.Unsafe.ARRAY_FLOAT_BASE_OFFSET; -import static sun.misc.Unsafe.ARRAY_INT_BASE_OFFSET; -import static sun.misc.Unsafe.ARRAY_LONG_BASE_OFFSET; -import static sun.misc.Unsafe.ARRAY_SHORT_BASE_OFFSET; public final class Slice implements Comparable { private static final int INSTANCE_SIZE = instanceSize(Slice.class); - private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocate(0); - private static final VarHandle SHORT_HANDLE = byteArrayViewVarHandle(short[].class, LITTLE_ENDIAN); - private static final VarHandle INT_HANDLE = byteArrayViewVarHandle(int[].class, LITTLE_ENDIAN); - private static final VarHandle LONG_HANDLE = byteArrayViewVarHandle(long[].class, LITTLE_ENDIAN); - private static final VarHandle FLOAT_HANDLE = byteArrayViewVarHandle(float[].class, LITTLE_ENDIAN); - private static final VarHandle DOUBLE_HANDLE = byteArrayViewVarHandle(double[].class, LITTLE_ENDIAN); + private static final ValueLayout.OfByte BYTE = ValueLayout.JAVA_BYTE.withOrder(LITTLE_ENDIAN); + private static final ValueLayout.OfShort SHORT = ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(LITTLE_ENDIAN); + private static final ValueLayout.OfInt INT = ValueLayout.JAVA_INT_UNALIGNED.withOrder(LITTLE_ENDIAN); + private static final ValueLayout.OfLong LONG = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(LITTLE_ENDIAN); + private static final ValueLayout.OfFloat FLOAT = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(LITTLE_ENDIAN); + private static final ValueLayout.OfDouble DOUBLE = ValueLayout.JAVA_DOUBLE_UNALIGNED.withOrder(LITTLE_ENDIAN); + private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocate(0); // Do not move this field above the constants used in the empty constructor static final Slice EMPTY_SLICE = new Slice(); - private final byte[] base; - + private final MemorySegment segment; private final int baseOffset; - /** - * Size of the slice - */ - private final int size; - - /** - * Bytes retained by the slice - */ - private final long retainedSize; - private int hash; /** @@ -84,10 +60,8 @@ private Slice() { // Since this is used to create a constant in this class, be careful to not use // other uninitialized constants. - this.base = new byte[0]; + this.segment = MemorySegment.ofArray(new byte[0]); this.baseOffset = 0; - this.size = 0; - this.retainedSize = INSTANCE_SIZE; } /** @@ -95,14 +69,12 @@ private Slice() */ Slice(byte[] base) { - requireNonNull(base, "base is null"); - if (base.length == 0) { - throw new IllegalArgumentException("Empty array"); - } - this.base = base; - this.baseOffset = 0; - this.size = base.length; - this.retainedSize = INSTANCE_SIZE + sizeOf(base); + this(MemorySegment.ofArray(base)); + } + + Slice(MemorySegment segment) + { + this(segment, 0, toIntExact(segment.byteSize())); } /** @@ -113,34 +85,15 @@ private Slice() */ Slice(byte[] base, int offset, int length) { - requireNonNull(base, "base is null"); - if (base.length == 0) { - throw new IllegalArgumentException("Empty array"); - } - checkFromIndexSize(offset, length, base.length); - - this.base = base; - this.baseOffset = offset; - this.size = length; - this.retainedSize = INSTANCE_SIZE + sizeOf(base); + this(MemorySegment.ofArray(base), offset, length); } - /** - * Creates a slice for directly accessing the base object. - */ - Slice(byte[] base, int baseOffset, int size, long retainedSize) + Slice(MemorySegment segment, int baseOffset, int length) { - requireNonNull(base, "base is null"); - if (base.length == 0) { - throw new IllegalArgumentException("Empty array"); - } - checkFromIndexSize(baseOffset, size, base.length); - - this.base = requireNonNull(base, "base is null"); + requireNonNull(segment, "segment is null"); + checkArgument(segment.byteSize() > 0, "Empty memory segment: " + segment); + this.segment = segment.asSlice(baseOffset, length); this.baseOffset = baseOffset; - this.size = size; - // INSTANCE_SIZE is not included, as the caller is responsible for including it. - this.retainedSize = retainedSize; } /** @@ -148,7 +101,7 @@ private Slice() */ public int length() { - return size; + return toIntExact(segment.byteSize()); } /** @@ -156,7 +109,7 @@ public int length() */ public long getRetainedSize() { - return retainedSize; + return INSTANCE_SIZE + SizeOf.sizeOf(segment); } /** @@ -165,17 +118,21 @@ public long getRetainedSize() */ public boolean isCompact() { - return baseOffset == 0 && size == base.length; + return byteArray().length == segment.byteSize(); } /** * Returns the byte array wrapped by this Slice. Callers should also take care to use {@link Slice#byteArrayOffset()} * since the contents of this Slice may not start at array index 0. */ - @SuppressFBWarnings("EI_EXPOSE_REP") public byte[] byteArray() { - return base; + return (byte[]) segment.heapBase().orElseThrow(); + } + + public MemorySegment toSegment() + { + return segment; } /** @@ -191,7 +148,7 @@ public int byteArrayOffset() */ public void fill(byte value) { - Arrays.fill(base, baseOffset, baseOffset + size, value); + segment.fill(value); } /** @@ -199,12 +156,12 @@ public void fill(byte value) */ public void clear() { - clear(0, size); + clear(0, length()); } public void clear(int offset, int length) { - Arrays.fill(base, baseOffset, baseOffset + size, (byte) 0); + segment.asSlice(offset, length).fill((byte) 0); } /** @@ -215,13 +172,7 @@ public void clear(int offset, int length) */ public byte getByte(int index) { - checkFromIndexSize(index, SIZE_OF_BYTE, length()); - return getByteUnchecked(index); - } - - public byte getByteUnchecked(int index) - { - return base[baseOffset + index]; + return segment.get(BYTE, index); } /** @@ -245,13 +196,7 @@ public short getUnsignedByte(int index) */ public short getShort(int index) { - checkFromIndexSize(index, SIZE_OF_SHORT, length()); - return getShortUnchecked(index); - } - - public short getShortUnchecked(int index) - { - return (short) SHORT_HANDLE.get(base, baseOffset + index); + return segment.get(SHORT, index); } /** @@ -275,13 +220,7 @@ public int getUnsignedShort(int index) */ public int getInt(int index) { - checkFromIndexSize(index, SIZE_OF_INT, length()); - return getIntUnchecked(index); - } - - public int getIntUnchecked(int index) - { - return (int) INT_HANDLE.get(base, baseOffset + index); + return segment.get(INT, index); } /** @@ -305,13 +244,7 @@ public long getUnsignedInt(int index) */ public long getLong(int index) { - checkFromIndexSize(index, SIZE_OF_LONG, length()); - return getLongUnchecked(index); - } - - public long getLongUnchecked(int index) - { - return (long) LONG_HANDLE.get(base, baseOffset + index); + return segment.get(LONG, index); } /** @@ -323,13 +256,7 @@ public long getLongUnchecked(int index) */ public float getFloat(int index) { - checkFromIndexSize(index, SIZE_OF_FLOAT, length()); - return getFloatUnchecked(index); - } - - public float getFloatUnchecked(int index) - { - return (float) FLOAT_HANDLE.get(base, baseOffset + index); + return segment.get(FLOAT, index); } /** @@ -341,13 +268,7 @@ public float getFloatUnchecked(int index) */ public double getDouble(int index) { - checkFromIndexSize(index, SIZE_OF_DOUBLE, length()); - return getDoubleUnchecked(index); - } - - public double getDoubleUnchecked(int index) - { - return (double) DOUBLE_HANDLE.get(base, baseOffset + index); + return segment.get(DOUBLE, index); } /** @@ -377,10 +298,7 @@ public void getBytes(int index, Slice destination) */ public void getBytes(int index, Slice destination, int destinationIndex, int length) { - checkFromIndexSize(destinationIndex, length, destination.length()); - checkFromIndexSize(index, length, length()); - - System.arraycopy(base, baseOffset + index, destination.base, destination.baseOffset + destinationIndex, length); + MemorySegment.copy(segment, index, destination.segment, destinationIndex, length); } /** @@ -410,10 +328,7 @@ public void getBytes(int index, byte[] destination) */ public void getBytes(int index, byte[] destination, int destinationIndex, int length) { - checkFromIndexSize(index, length, length()); - checkFromIndexSize(destinationIndex, length, destination.length); - - System.arraycopy(base, baseOffset + index, destination, destinationIndex, length); + MemorySegment.copy(segment, BYTE, index, destination, destinationIndex, length); } /** @@ -498,10 +413,7 @@ public void getShorts(int index, short[] destination) */ public void getShorts(int index, short[] destination, int destinationIndex, int length) { - checkFromIndexSize(index, length * Short.BYTES, length()); - checkFromIndexSize(destinationIndex, length, destination.length); - - copyFromBase(index, destination, ARRAY_SHORT_BASE_OFFSET + ((long) destinationIndex * Short.BYTES), length * Short.BYTES); + MemorySegment.copy(segment, SHORT, index, destination, destinationIndex, length); } /** @@ -546,10 +458,7 @@ public void getInts(int index, int[] destination) */ public void getInts(int index, int[] destination, int destinationIndex, int length) { - checkFromIndexSize(index, length * Integer.BYTES, length()); - checkFromIndexSize(destinationIndex, length, destination.length); - - copyFromBase(index, destination, ARRAY_INT_BASE_OFFSET + ((long) destinationIndex * Integer.BYTES), length * Integer.BYTES); + MemorySegment.copy(segment, INT, index, destination, destinationIndex, length); } /** @@ -594,10 +503,7 @@ public void getLongs(int index, long[] destination) */ public void getLongs(int index, long[] destination, int destinationIndex, int length) { - checkFromIndexSize(index, length * Long.BYTES, length()); - checkFromIndexSize(destinationIndex, length, destination.length); - - copyFromBase(index, destination, ARRAY_LONG_BASE_OFFSET + ((long) destinationIndex * Long.BYTES), length * Long.BYTES); + MemorySegment.copy(segment, LONG, index, destination, destinationIndex, length); } /** @@ -642,10 +548,7 @@ public void getFloats(int index, float[] destination) */ public void getFloats(int index, float[] destination, int destinationIndex, int length) { - checkFromIndexSize(index, length * Float.BYTES, length()); - checkFromIndexSize(destinationIndex, length, destination.length); - - copyFromBase(index, destination, ARRAY_FLOAT_BASE_OFFSET + ((long) destinationIndex * Float.BYTES), length * Float.BYTES); + MemorySegment.copy(segment, FLOAT, index, destination, destinationIndex, length); } /** @@ -690,10 +593,7 @@ public void getDoubles(int index, double[] destination) */ public void getDoubles(int index, double[] destination, int destinationIndex, int length) { - checkFromIndexSize(index, length * Double.BYTES, length()); - checkFromIndexSize(destinationIndex, length, destination.length); - - copyFromBase(index, destination, ARRAY_DOUBLE_BASE_OFFSET + ((long) destinationIndex * Double.BYTES), length * Double.BYTES); + MemorySegment.copy(segment, DOUBLE, index, destination, destinationIndex, length); } /** @@ -705,13 +605,7 @@ public void getDoubles(int index, double[] destination, int destinationIndex, in */ public void setByte(int index, int value) { - checkFromIndexSize(index, SIZE_OF_BYTE, length()); - setByteUnchecked(index, value); - } - - void setByteUnchecked(int index, int value) - { - base[baseOffset + index] = (byte) (value & 0xFF); + segment.set(BYTE, index, (byte) (value & 0xFF)); } /** @@ -724,13 +618,7 @@ void setByteUnchecked(int index, int value) */ public void setShort(int index, int value) { - checkFromIndexSize(index, SIZE_OF_SHORT, length()); - setShortUnchecked(index, value); - } - - void setShortUnchecked(int index, int value) - { - SHORT_HANDLE.set(base, baseOffset + index, (short) (value & 0xFFFF)); + segment.set(SHORT, index, (short) (value & 0xFFFF)); } /** @@ -742,13 +630,7 @@ void setShortUnchecked(int index, int value) */ public void setInt(int index, int value) { - checkFromIndexSize(index, SIZE_OF_INT, length()); - setIntUnchecked(index, value); - } - - void setIntUnchecked(int index, int value) - { - INT_HANDLE.set(base, baseOffset + index, value); + segment.set(INT, index, value); } /** @@ -760,13 +642,7 @@ void setIntUnchecked(int index, int value) */ public void setLong(int index, long value) { - checkFromIndexSize(index, SIZE_OF_LONG, length()); - setLongUnchecked(index, value); - } - - void setLongUnchecked(int index, long value) - { - LONG_HANDLE.set(base, baseOffset + index, value); + segment.set(LONG, index, value); } /** @@ -778,8 +654,7 @@ void setLongUnchecked(int index, long value) */ public void setFloat(int index, float value) { - checkFromIndexSize(index, SIZE_OF_FLOAT, length()); - FLOAT_HANDLE.set(base, baseOffset + index, value); + segment.set(FLOAT, index, value); } /** @@ -791,8 +666,7 @@ public void setFloat(int index, float value) */ public void setDouble(int index, double value) { - checkFromIndexSize(index, SIZE_OF_DOUBLE, length()); - DOUBLE_HANDLE.set(base, baseOffset + index, value); + segment.set(DOUBLE, index, value); } /** @@ -822,10 +696,7 @@ public void setBytes(int index, Slice source) */ public void setBytes(int index, Slice source, int sourceIndex, int length) { - checkFromIndexSize(index, length, length()); - checkFromIndexSize(sourceIndex, length, source.length()); - - System.arraycopy(source.base, source.baseOffset + sourceIndex, base, baseOffset + index, length); + MemorySegment.copy(source.segment, sourceIndex, segment, index, length); } /** @@ -852,9 +723,7 @@ public void setBytes(int index, byte[] source) */ public void setBytes(int index, byte[] source, int sourceIndex, int length) { - checkFromIndexSize(index, length, length()); - checkFromIndexSize(sourceIndex, length, source.length); - System.arraycopy(source, sourceIndex, base, baseOffset + index, length); + MemorySegment.copy(source, sourceIndex, segment, BYTE, index, length); } /** @@ -904,9 +773,7 @@ public void setShorts(int index, short[] source) */ public void setShorts(int index, short[] source, int sourceIndex, int length) { - checkFromIndexSize(index, length, length()); - checkFromIndexSize(sourceIndex, length, source.length); - copyToBase(index, source, ARRAY_SHORT_BASE_OFFSET + ((long) sourceIndex * Short.BYTES), length * Short.BYTES); + MemorySegment.copy(source, sourceIndex, segment, SHORT, index, length); } /** @@ -933,9 +800,7 @@ public void setInts(int index, int[] source) */ public void setInts(int index, int[] source, int sourceIndex, int length) { - checkFromIndexSize(index, length, length()); - checkFromIndexSize(sourceIndex, length, source.length); - copyToBase(index, source, ARRAY_INT_BASE_OFFSET + ((long) sourceIndex * Integer.BYTES), length * Integer.BYTES); + MemorySegment.copy(source, sourceIndex, segment, INT, index, length); } /** @@ -962,9 +827,7 @@ public void setLongs(int index, long[] source) */ public void setLongs(int index, long[] source, int sourceIndex, int length) { - checkFromIndexSize(index, length, length()); - checkFromIndexSize(sourceIndex, length, source.length); - copyToBase(index, source, ARRAY_LONG_BASE_OFFSET + ((long) sourceIndex * Long.BYTES), length * Long.BYTES); + MemorySegment.copy(source, sourceIndex, segment, LONG, index, length); } /** @@ -991,9 +854,7 @@ public void setFloats(int index, float[] source) */ public void setFloats(int index, float[] source, int sourceIndex, int length) { - checkFromIndexSize(index, length, length()); - checkFromIndexSize(sourceIndex, length, source.length); - copyToBase(index, source, ARRAY_FLOAT_BASE_OFFSET + ((long) sourceIndex * Float.BYTES), length * Float.BYTES); + MemorySegment.copy(source, sourceIndex, segment, FLOAT, index, length); } /** @@ -1020,9 +881,7 @@ public void setDoubles(int index, double[] source) */ public void setDoubles(int index, double[] source, int sourceIndex, int length) { - checkFromIndexSize(index, length, length()); - checkFromIndexSize(sourceIndex, length, source.length); - copyToBase(index, source, ARRAY_DOUBLE_BASE_OFFSET + ((long) sourceIndex * Double.BYTES), length * Double.BYTES); + MemorySegment.copy(source, sourceIndex, segment, DOUBLE, index, length); } /** @@ -1039,7 +898,7 @@ public Slice slice(int index, int length) return Slices.EMPTY_SLICE; } - return new Slice(base, baseOffset + index, length, retainedSize); + return new Slice(byteArray(), baseOffset + index, length); } /** @@ -1048,10 +907,12 @@ public Slice slice(int index, int length) */ public Slice copy() { - if (size == 0) { + if (length() == 0) { return Slices.EMPTY_SLICE; } - return new Slice(Arrays.copyOfRange(base, baseOffset, baseOffset + size)); + MemorySegment copy = MemorySegment.ofArray(new byte[length()]); + MemorySegment.copy(segment, 0, copy, 0, length()); + return new Slice(copy); } /** @@ -1060,11 +921,13 @@ public Slice copy() */ public Slice copy(int index, int length) { - checkFromIndexSize(index, length, size); + checkFromIndexSize(index, length, length()); if (length == 0) { return Slices.EMPTY_SLICE; } - return new Slice(Arrays.copyOfRange(base, baseOffset + index, baseOffset + index + length)); + MemorySegment copy = MemorySegment.ofArray(new byte[length]); + MemorySegment.copy(segment, index, copy, 0, length); + return new Slice(copy); } public int indexOfByte(int b) @@ -1075,8 +938,8 @@ public int indexOfByte(int b) public int indexOfByte(byte b) { - for (int i = 0; i < size; i++) { - if (getByteUnchecked(i) == b) { + for (int i = 0; i < length(); i++) { + if (getByte(i) == b) { return i; } } @@ -1100,7 +963,7 @@ public int indexOf(Slice slice) */ public int indexOf(Slice pattern, int offset) { - if (size == 0 || offset >= size || offset < 0) { + if (length() == 0 || offset >= length() || offset < 0) { return -1; } @@ -1109,24 +972,24 @@ public int indexOf(Slice pattern, int offset) } // Do we have enough characters? - if (pattern.length() < SIZE_OF_INT || size < SIZE_OF_LONG) { + if (pattern.length() < SIZE_OF_INT || length() < SIZE_OF_LONG) { return indexOfBruteForce(pattern, offset); } // Using the first four bytes for faster search. We are not using eight bytes for long // because we want more strings to get use of fast search. - int head = pattern.getIntUnchecked(0); + int head = pattern.getInt(0); // Take the first byte of head for faster skipping int firstByteMask = head & 0xff; firstByteMask |= firstByteMask << 8; firstByteMask |= firstByteMask << 16; - int lastValidIndex = size - pattern.length(); + int lastValidIndex = length() - pattern.length(); int index = offset; while (index <= lastValidIndex) { // Read four bytes in sequence - int value = getIntUnchecked(index); + int value = getInt(index); // Compare all bytes of value with the first byte of search data // see https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord @@ -1152,7 +1015,7 @@ public int indexOf(Slice pattern, int offset) int indexOfBruteForce(Slice pattern, int offset) { - if (size == 0 || offset >= size || offset < 0) { + if (length() == 0 || offset >= length() || offset < 0) { return -1; } @@ -1160,12 +1023,12 @@ int indexOfBruteForce(Slice pattern, int offset) return offset; } - byte firstByte = pattern.getByteUnchecked(0); - int lastValidIndex = size - pattern.length(); + byte firstByte = pattern.getByte(0); + int lastValidIndex = length() - pattern.length(); int index = offset; while (true) { // seek to first byte match - while (index < lastValidIndex && getByteUnchecked(index) != firstByte) { + while (index < lastValidIndex && getByte(index) != firstByte) { index++; } if (index > lastValidIndex) { @@ -1194,7 +1057,7 @@ public int compareTo(Slice that) if (this == that) { return 0; } - return compareTo(0, size, that, 0, that.size); + return compareTo(0, length(), that, 0, that.length()); } /** @@ -1208,16 +1071,18 @@ public int compareTo(int offset, int length, Slice that, int otherOffset, int ot return 0; } - checkFromIndexSize(offset, length, length()); - checkFromIndexSize(otherOffset, otherLength, that.length()); - - return Arrays.compareUnsigned( - base, - baseOffset + offset, - baseOffset + offset + length, - that.base, - that.baseOffset + otherOffset, - that.baseOffset + otherOffset + otherLength); + // Find index of the first mismatched byte + long mismatch = MemorySegment.mismatch(segment, offset, offset + length, that.segment, otherOffset, otherOffset + otherLength); + if (mismatch == -1) { + return 0; + } + if (mismatch >= length) { + return -1; + } + if (mismatch >= otherLength) { + return 1; + } + return Byte.compareUnsigned(segment.get(BYTE, offset + mismatch), that.segment.get(BYTE, otherOffset + mismatch)); } /** @@ -1253,7 +1118,7 @@ public int hashCode() return hash; } - hash = hashCode(0, size); + hash = hashCode(0, length()); return hash; } @@ -1280,21 +1145,12 @@ public boolean equals(int offset, int length, Slice that, int otherOffset, int o return true; } - checkFromIndexSize(offset, length, length()); - checkFromIndexSize(otherOffset, otherLength, that.length()); - return equalsUnchecked(offset, that, otherOffset, length); } boolean equalsUnchecked(int offset, Slice that, int otherOffset, int length) { - return Arrays.equals( - base, - baseOffset + offset, - baseOffset + offset + length, - that.base, - that.baseOffset + otherOffset, - that.baseOffset + otherOffset + length); + return MemorySegment.mismatch(segment, offset, offset + length, that.segment, otherOffset, otherOffset + length) == -1; } /** @@ -1340,7 +1196,7 @@ public String toStringUtf8() */ public String toStringAscii() { - return toStringAscii(0, size); + return toStringAscii(0, length()); } public String toStringAscii(int index, int length) @@ -1350,7 +1206,7 @@ public String toStringAscii(int index, int length) return ""; } - return new String(byteArray(), byteArrayOffset() + index, length, StandardCharsets.US_ASCII); + return new String(byteArray(), baseOffset + index, length, StandardCharsets.US_ASCII); } /** @@ -1362,12 +1218,12 @@ public String toString(int index, int length, Charset charset) if (length == 0) { return ""; } - return new String(byteArray(), byteArrayOffset() + index, length, charset); + return new String(byteArray(), baseOffset + index, length, charset); } public ByteBuffer toByteBuffer() { - return toByteBuffer(0, size); + return toByteBuffer(0, length()); } public ByteBuffer toByteBuffer(int index, int length) @@ -1378,7 +1234,7 @@ public ByteBuffer toByteBuffer(int index, int length) return EMPTY_BYTE_BUFFER; } - return ByteBuffer.wrap(byteArray(), byteArrayOffset() + index, length).slice(); + return segment.asByteBuffer().slice(index, length); } /** @@ -1388,40 +1244,12 @@ public ByteBuffer toByteBuffer(int index, int length) public String toString() { StringBuilder builder = new StringBuilder("Slice{"); - builder.append("base=").append(identityToString(base)).append(", "); + builder.append("memorySegment=") + .append("{@heap:").append(segment.heapBase().orElseThrow()).append("}") + .append(", "); builder.append("baseOffset=").append(baseOffset); builder.append(", length=").append(length()); builder.append('}'); return builder.toString(); } - - private static String identityToString(Object o) - { - if (o == null) { - return null; - } - return o.getClass().getName() + "@" + Integer.toHexString(System.identityHashCode(o)); - } - - private void copyFromBase(int index, Object dest, long destAddress, int length) - { - int baseAddress = ARRAY_BYTE_BASE_OFFSET + baseOffset + index; - // The Unsafe Javadoc specifies that the transfer size is 8 iff length % 8 == 0 - // so ensure that we copy big chunks whenever possible, even at the expense of two separate copy operations - // todo the optimization only works if the baseOffset is is a multiple of 8 for both src and dest - int bytesToCopy = length - (length % 8); - unsafe.copyMemory(base, baseAddress, dest, destAddress, bytesToCopy); - unsafe.copyMemory(base, baseAddress + bytesToCopy, dest, destAddress + bytesToCopy, length - bytesToCopy); - } - - private void copyToBase(int index, Object src, long srcAddress, int length) - { - int baseAddress = ARRAY_BYTE_BASE_OFFSET + baseOffset + index; - // The Unsafe Javadoc specifies that the transfer size is 8 iff length % 8 == 0 - // so ensure that we copy big chunks whenever possible, even at the expense of two separate copy operations - // todo the optimization only works if the baseOffset is is a multiple of 8 for both src and dest - int bytesToCopy = length - (length % 8); - unsafe.copyMemory(src, srcAddress, base, baseAddress, bytesToCopy); - unsafe.copyMemory(src, srcAddress + bytesToCopy, base, baseAddress + bytesToCopy, length - bytesToCopy); - } } diff --git a/src/main/java/io/airlift/slice/SliceUtf8.java b/src/main/java/io/airlift/slice/SliceUtf8.java index 8ffe408b..9c38eed7 100644 --- a/src/main/java/io/airlift/slice/SliceUtf8.java +++ b/src/main/java/io/airlift/slice/SliceUtf8.java @@ -72,13 +72,13 @@ public static boolean isAscii(Slice utf8) // Length rounded to 8 bytes int length8 = length & 0x7FFF_FFF8; for (; offset < length8; offset += 8) { - if ((utf8.getLongUnchecked(offset) & TOP_MASK64) != 0) { + if ((utf8.getLong(offset) & TOP_MASK64) != 0) { return false; } } // Enough bytes left for 32 bits? if (offset + 4 < length) { - if ((utf8.getIntUnchecked(offset) & TOP_MASK32) != 0) { + if ((utf8.getInt(offset) & TOP_MASK32) != 0) { return false; } @@ -86,7 +86,7 @@ public static boolean isAscii(Slice utf8) } // Do the rest one by one for (; offset < length; offset++) { - if ((utf8.getByteUnchecked(offset) & 0x80) != 0) { + if ((utf8.getByte(offset) & 0x80) != 0) { return false; } } @@ -125,19 +125,19 @@ public static int countCodePoints(Slice utf8, int offset, int length) int length8 = length & 0x7FFF_FFF8; for (; offset < length8; offset += 8) { // Count bytes which are NOT the start of a code point - continuationBytesCount += countContinuationBytes(utf8.getLongUnchecked(offset)); + continuationBytesCount += countContinuationBytes(utf8.getLong(offset)); } // Enough bytes left for 32 bits? if (offset + 4 < length) { // Count bytes which are NOT the start of a code point - continuationBytesCount += countContinuationBytes(utf8.getIntUnchecked(offset)); + continuationBytesCount += countContinuationBytes(utf8.getInt(offset)); offset += 4; } // Do the rest one by one for (; offset < length; offset++) { // Count bytes which are NOT the start of a code point - continuationBytesCount += countContinuationBytes(utf8.getByteUnchecked(offset)); + continuationBytesCount += countContinuationBytes(utf8.getByte(offset)); } verify(continuationBytesCount <= length); @@ -343,20 +343,20 @@ private static Slice translateCodePoints(Slice utf8, int[] codePointTranslationM private static void copyUtf8SequenceUnsafe(Slice source, int sourcePosition, Slice destination, int destinationPosition, int length) { switch (length) { - case 1 -> destination.setByteUnchecked(destinationPosition, source.getByteUnchecked(sourcePosition)); - case 2 -> destination.setShortUnchecked(destinationPosition, source.getShortUnchecked(sourcePosition)); + case 1 -> destination.setByte(destinationPosition, source.getByte(sourcePosition)); + case 2 -> destination.setShort(destinationPosition, source.getShort(sourcePosition)); case 3 -> { - destination.setShortUnchecked(destinationPosition, source.getShortUnchecked(sourcePosition)); - destination.setByteUnchecked(destinationPosition + 2, source.getByteUnchecked(sourcePosition + 2)); + destination.setShort(destinationPosition, source.getShort(sourcePosition)); + destination.setByte(destinationPosition + 2, source.getByte(sourcePosition + 2)); } - case 4 -> destination.setIntUnchecked(destinationPosition, source.getIntUnchecked(sourcePosition)); + case 4 -> destination.setInt(destinationPosition, source.getInt(sourcePosition)); case 5 -> { - destination.setIntUnchecked(destinationPosition, source.getIntUnchecked(sourcePosition)); - destination.setByteUnchecked(destinationPosition + 4, source.getByteUnchecked(sourcePosition + 4)); + destination.setInt(destinationPosition, source.getInt(sourcePosition)); + destination.setByte(destinationPosition + 4, source.getByte(sourcePosition + 4)); } case 6 -> { - destination.setIntUnchecked(destinationPosition, source.getIntUnchecked(sourcePosition)); - destination.setShortUnchecked(destinationPosition + 4, source.getShortUnchecked(sourcePosition + 4)); + destination.setInt(destinationPosition, source.getInt(sourcePosition)); + destination.setShort(destinationPosition + 4, source.getShort(sourcePosition + 4)); } default -> throw new IllegalStateException("Invalid code point length " + length); } @@ -635,7 +635,7 @@ public static int tryGetCodePointAt(Slice utf8, int position) return -1; } - byte secondByte = utf8.getByteUnchecked(position + 1); + byte secondByte = utf8.getByte(position + 1); if (!isContinuationByte(secondByte)) { return -1; } @@ -654,7 +654,7 @@ public static int tryGetCodePointAt(Slice utf8, int position) return -2; } - byte thirdByte = utf8.getByteUnchecked(position + 2); + byte thirdByte = utf8.getByte(position + 2); if (!isContinuationByte(thirdByte)) { return -2; } @@ -679,7 +679,7 @@ public static int tryGetCodePointAt(Slice utf8, int position) return -3; } - byte forthByte = utf8.getByteUnchecked(position + 3); + byte forthByte = utf8.getByte(position + 3); if (!isContinuationByte(forthByte)) { return -3; } @@ -703,7 +703,7 @@ public static int tryGetCodePointAt(Slice utf8, int position) return -4; } - byte fifthByte = utf8.getByteUnchecked(position + 4); + byte fifthByte = utf8.getByte(position + 4); if (!isContinuationByte(fifthByte)) { return -4; } @@ -719,7 +719,7 @@ public static int tryGetCodePointAt(Slice utf8, int position) return -5; } - byte sixthByte = utf8.getByteUnchecked(position + 5); + byte sixthByte = utf8.getByte(position + 5); if (!isContinuationByte(sixthByte)) { return -5; } @@ -812,7 +812,7 @@ public static int offsetOfCodePoint(Slice utf8, int position, int codePointCount // is only called if there are at least 8 more code points needed while (position < length8 && correctIndex >= position + 8) { // Count bytes which are NOT the start of a code point - correctIndex += countContinuationBytes(utf8.getLongUnchecked(position)); + correctIndex += countContinuationBytes(utf8.getLong(position)); position += 8; } @@ -821,14 +821,14 @@ public static int offsetOfCodePoint(Slice utf8, int position, int codePointCount // While we have enough bytes left and we need at least 4 characters process 4 bytes at once while (position < length4 && correctIndex >= position + 4) { // Count bytes which are NOT the start of a code point - correctIndex += countContinuationBytes(utf8.getIntUnchecked(position)); + correctIndex += countContinuationBytes(utf8.getInt(position)); position += 4; } // Do the rest one by one, always check the last byte to find the end of the code point while (position < utf8.length()) { // Count bytes which are NOT the start of a code point - correctIndex += countContinuationBytes(utf8.getByteUnchecked(position)); + correctIndex += countContinuationBytes(utf8.getByte(position)); if (position == correctIndex) { break; } @@ -866,23 +866,23 @@ public static int lengthOfCodePointSafe(Slice utf8, int position) return -length; } - if (length == 1 || position + 1 >= utf8.length() || !isContinuationByte(utf8.getByteUnchecked(position + 1))) { + if (length == 1 || position + 1 >= utf8.length() || !isContinuationByte(utf8.getByte(position + 1))) { return 1; } - if (length == 2 || position + 2 >= utf8.length() || !isContinuationByte(utf8.getByteUnchecked(position + 2))) { + if (length == 2 || position + 2 >= utf8.length() || !isContinuationByte(utf8.getByte(position + 2))) { return 2; } - if (length == 3 || position + 3 >= utf8.length() || !isContinuationByte(utf8.getByteUnchecked(position + 3))) { + if (length == 3 || position + 3 >= utf8.length() || !isContinuationByte(utf8.getByte(position + 3))) { return 3; } - if (length == 4 || position + 4 >= utf8.length() || !isContinuationByte(utf8.getByteUnchecked(position + 4))) { + if (length == 4 || position + 4 >= utf8.length() || !isContinuationByte(utf8.getByte(position + 4))) { return 4; } - if (length == 5 || position + 5 >= utf8.length() || !isContinuationByte(utf8.getByteUnchecked(position + 5))) { + if (length == 5 || position + 5 >= utf8.length() || !isContinuationByte(utf8.getByte(position + 5))) { return 5; } @@ -989,8 +989,8 @@ public static int getCodePointAt(Slice utf8, int position) throw new InvalidUtf8Exception("UTF-8 sequence truncated"); } return ((unsignedStartByte & 0b0000_1111) << 12) | - ((utf8.getByteUnchecked(position + 1) & 0b0011_1111) << 6) | - (utf8.getByteUnchecked(position + 2) & 0b0011_1111); + ((utf8.getByte(position + 1) & 0b0011_1111) << 6) | + (utf8.getByte(position + 2) & 0b0011_1111); } if (unsignedStartByte < 0xf8) { // 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx @@ -998,9 +998,9 @@ public static int getCodePointAt(Slice utf8, int position) throw new InvalidUtf8Exception("UTF-8 sequence truncated"); } return ((unsignedStartByte & 0b0000_0111) << 18) | - ((utf8.getByteUnchecked(position + 1) & 0b0011_1111) << 12) | - ((utf8.getByteUnchecked(position + 2) & 0b0011_1111) << 6) | - (utf8.getByteUnchecked(position + 3) & 0b0011_1111); + ((utf8.getByte(position + 1) & 0b0011_1111) << 12) | + ((utf8.getByte(position + 2) & 0b0011_1111) << 6) | + (utf8.getByte(position + 3) & 0b0011_1111); } // Per RFC3629, UTF-8 is limited to 4 bytes, so more bytes are illegal throw new InvalidUtf8Exception("Illegal start 0x" + toHexString(unsignedStartByte).toUpperCase() + " of code point"); diff --git a/src/main/java/io/airlift/slice/XxHash64.java b/src/main/java/io/airlift/slice/XxHash64.java index 3d0b4eb5..be554bba 100644 --- a/src/main/java/io/airlift/slice/XxHash64.java +++ b/src/main/java/io/airlift/slice/XxHash64.java @@ -15,12 +15,13 @@ import java.io.IOException; import java.io.InputStream; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; -import static io.airlift.slice.JvmUtils.unsafe; import static java.lang.Long.rotateLeft; import static java.lang.Math.min; -import static java.util.Objects.checkFromIndexSize; -import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; +import static java.lang.Math.toIntExact; +import static java.nio.ByteOrder.LITTLE_ENDIAN; public final class XxHash64 { @@ -34,8 +35,12 @@ public final class XxHash64 private final long seed; - private static final long BUFFER_ADDRESS = ARRAY_BYTE_BASE_OFFSET; - private final byte[] buffer = new byte[32]; + private static final ValueLayout.OfByte BYTE = ValueLayout.JAVA_BYTE.withOrder(LITTLE_ENDIAN); + private static final ValueLayout.OfInt INT = ValueLayout.JAVA_INT_UNALIGNED.withOrder(LITTLE_ENDIAN); + private static final ValueLayout.OfLong LONG = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(LITTLE_ENDIAN); + + private final MemorySegment buffer = MemorySegment.ofArray(new byte[32]); + private int bufferSize; private long bodyLength; @@ -66,8 +71,7 @@ public XxHash64 update(byte[] data) public XxHash64 update(byte[] data, int offset, int length) { - checkFromIndexSize(offset, length, data.length); - updateHash(data, ARRAY_BYTE_BASE_OFFSET + offset, length); + updateHash(MemorySegment.ofArray(data), offset, length); return this; } @@ -76,10 +80,20 @@ public XxHash64 update(Slice data) return update(data, 0, data.length()); } + public XxHash64 update(MemorySegment segment) + { + return update(segment, 0, toIntExact(segment.byteSize())); + } + public XxHash64 update(Slice data, int offset, int length) { - checkFromIndexSize(offset, length, data.length()); - updateHash(data.byteArray(), (long) data.byteArrayOffset() + ARRAY_BYTE_BASE_OFFSET + offset, length); + updateHash(data.toSegment(), offset, length); + return this; + } + + public XxHash64 update(MemorySegment data, int offset, int length) + { + updateHash(data, offset, length); return this; } @@ -95,60 +109,58 @@ public long hash() hash += bodyLength + bufferSize; - return updateTail(hash, buffer, BUFFER_ADDRESS, 0, bufferSize); + return updateTail(hash, buffer, 0, 0, bufferSize); } private long computeBody() { long hash = rotateLeft(v1, 1) + rotateLeft(v2, 7) + rotateLeft(v3, 12) + rotateLeft(v4, 18); - hash = update(hash, v1); - hash = update(hash, v2); - hash = update(hash, v3); - hash = update(hash, v4); + hash = (hash ^ mix(0, v1)) * PRIME64_1 + PRIME64_4; + hash = (hash ^ mix(0, v2)) * PRIME64_1 + PRIME64_4; + hash = (hash ^ mix(0, v3)) * PRIME64_1 + PRIME64_4; + hash = (hash ^ mix(0, v4)) * PRIME64_1 + PRIME64_4; return hash; } - private void updateHash(byte[] base, long address, int length) + private void updateHash(MemorySegment base, int offset, int length) { if (bufferSize > 0) { int available = min(32 - bufferSize, length); - - unsafe.copyMemory(base, address, buffer, BUFFER_ADDRESS + bufferSize, available); - + MemorySegment.copy(base, offset, buffer, bufferSize, available); bufferSize += available; - address += available; + offset += available; length -= available; if (bufferSize == 32) { - updateBody(buffer, BUFFER_ADDRESS, bufferSize); + updateBody(buffer, 0, bufferSize); bufferSize = 0; } } if (length >= 32) { - int index = updateBody(base, address, length); - address += index; + int index = updateBody(base, offset, length); + offset += index; length -= index; } if (length > 0) { - unsafe.copyMemory(base, address, buffer, BUFFER_ADDRESS, length); + MemorySegment.copy(base, offset, buffer, 0, length); bufferSize = length; } } - private int updateBody(byte[] base, long address, int length) + private int updateBody(MemorySegment base, long offset, int length) { int remaining = length; while (remaining >= 32) { - v1 = mix(v1, unsafe.getLong(base, address)); - v2 = mix(v2, unsafe.getLong(base, address + 8)); - v3 = mix(v3, unsafe.getLong(base, address + 16)); - v4 = mix(v4, unsafe.getLong(base, address + 24)); + v1 = mix(v1, base.get(LONG, offset)); + v2 = mix(v2, base.get(LONG, offset + 8)); + v3 = mix(v3, base.get(LONG, offset + 16)); + v4 = mix(v4, base.get(LONG, offset + 24)); - address += 32; + offset += 32; remaining -= 32; } @@ -209,14 +221,9 @@ public static long hash(Slice data, int offset, int length) public static long hash(long seed, Slice data, int offset, int length) { - checkFromIndexSize(offset, length, data.length()); - - byte[] base = data.byteArray(); - final long address = (long) data.byteArrayOffset() + ARRAY_BYTE_BASE_OFFSET + offset; - long hash; if (length >= 32) { - hash = updateBody(seed, base, address, length); + hash = updateBody(seed, data.toSegment(), offset, length); } else { hash = seed + PRIME64_5; @@ -228,23 +235,23 @@ public static long hash(long seed, Slice data, int offset, int length) // this is the point up to which updateBody() processed int index = length & 0xFFFFFFE0; - return updateTail(hash, base, address, index, length); + return updateTail(hash, data.toSegment(), offset, index, length); } - private static long updateTail(long hash, byte[] base, long address, int index, int length) + private static long updateTail(long hash, MemorySegment base, int offset, int index, int length) { while (index <= length - 8) { - hash = updateTail(hash, unsafe.getLong(base, address + index)); + hash = updateTail(hash, base.get(LONG, offset + index)); index += 8; } if (index <= length - 4) { - hash = updateTail(hash, unsafe.getInt(base, address + index)); + hash = updateTail(hash, base.get(INT, offset + index)); index += 4; } while (index < length) { - hash = updateTail(hash, unsafe.getByte(base, address + index)); + hash = updateTail(hash, base.get(BYTE, offset + index)); index++; } @@ -253,7 +260,7 @@ private static long updateTail(long hash, byte[] base, long address, int index, return hash; } - private static long updateBody(long seed, byte[] base, long address, int length) + private static long updateBody(long seed, MemorySegment base, long offset, int length) { long v1 = seed + PRIME64_1 + PRIME64_2; long v2 = seed + PRIME64_2; @@ -262,21 +269,21 @@ private static long updateBody(long seed, byte[] base, long address, int length) int remaining = length; while (remaining >= 32) { - v1 = mix(v1, unsafe.getLong(base, address)); - v2 = mix(v2, unsafe.getLong(base, address + 8)); - v3 = mix(v3, unsafe.getLong(base, address + 16)); - v4 = mix(v4, unsafe.getLong(base, address + 24)); + v1 = mix(v1, base.get(LONG, offset)); + v2 = mix(v2, base.get(LONG, offset + 8)); + v3 = mix(v3, base.get(LONG, offset + 16)); + v4 = mix(v4, base.get(LONG, offset + 24)); - address += 32; + offset += 32; remaining -= 32; } long hash = rotateLeft(v1, 1) + rotateLeft(v2, 7) + rotateLeft(v3, 12) + rotateLeft(v4, 18); - hash = update(hash, v1); - hash = update(hash, v2); - hash = update(hash, v3); - hash = update(hash, v4); + hash = (hash ^ mix(0, v1)) * PRIME64_1 + PRIME64_4; + hash = (hash ^ mix(0, v2)) * PRIME64_1 + PRIME64_4; + hash = (hash ^ mix(0, v3)) * PRIME64_1 + PRIME64_4; + hash = (hash ^ mix(0, v4)) * PRIME64_1 + PRIME64_4; return hash; } @@ -286,12 +293,6 @@ private static long mix(long current, long value) return rotateLeft(current + value * PRIME64_2, 31) * PRIME64_1; } - private static long update(long hash, long value) - { - long temp = hash ^ mix(0, value); - return temp * PRIME64_1 + PRIME64_4; - } - private static long updateTail(long hash, long value) { long temp = hash ^ mix(0, value); diff --git a/src/test/java/io/airlift/slice/MemoryCopyBenchmark.java b/src/test/java/io/airlift/slice/MemoryCopyBenchmark.java index 72fa1451..d559e558 100644 --- a/src/test/java/io/airlift/slice/MemoryCopyBenchmark.java +++ b/src/test/java/io/airlift/slice/MemoryCopyBenchmark.java @@ -28,11 +28,12 @@ import org.openjdk.jmh.runner.options.Options; import org.openjdk.jmh.runner.options.OptionsBuilder; import org.openjdk.jmh.runner.options.VerboseMode; +import sun.misc.Unsafe; +import java.lang.reflect.Field; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; -import static io.airlift.slice.JvmUtils.unsafe; import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; @SuppressWarnings("restriction") @@ -42,10 +43,27 @@ @Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) public class MemoryCopyBenchmark { + private static final Unsafe unsafe; + private static final int PAGE_SIZE = 4 * 1024; private static final int N_PAGES = 256 * 1024; private static final int ALLOC_SIZE = PAGE_SIZE * N_PAGES; + static { + try { + // fetch theUnsafe object + Field field = Unsafe.class.getDeclaredField("theUnsafe"); + field.setAccessible(true); + unsafe = (Unsafe) field.get(null); + if (unsafe == null) { + throw new RuntimeException("Unsafe access not available"); + } + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } + @State(Scope.Thread) public static class Buffers { diff --git a/src/test/java/io/airlift/slice/TestSlice.java b/src/test/java/io/airlift/slice/TestSlice.java index e53ee939..8aea81d9 100644 --- a/src/test/java/io/airlift/slice/TestSlice.java +++ b/src/test/java/io/airlift/slice/TestSlice.java @@ -17,6 +17,7 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.lang.foreign.MemorySegment; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; @@ -61,6 +62,24 @@ public void testFillAndClear() } } + @Test + public void testSlicing() + { + Slice slice = Slices.utf8Slice("ala ma kota"); + + Slice subSlice = slice.slice(4, slice.length() - 4); + assertThat(subSlice).isEqualTo(utf8Slice("ma kota")); + assertThat(subSlice.byteArray()).isEqualTo(slice.byteArray()); + + Slice subSubSlice = subSlice.slice(3, subSlice.length() - 3); + assertThat(subSubSlice).isEqualTo(utf8Slice("kota")); + assertThat(subSubSlice.byteArray()).isEqualTo(subSlice.byteArray()); + + Slice subSubSubSlice = subSubSlice.slice(3, subSubSlice.length() - 3); + assertThat(subSubSubSlice).isEqualTo(utf8Slice("a")); + assertThat(subSubSubSlice.byteArray()).isEqualTo(subSubSlice.byteArray()); + } + @Test public void testEqualsHashCodeCompare() { @@ -763,7 +782,8 @@ private static void assertBytesStreams(Slice slice, int index) public void testRetainedSize() throws Exception { - int sliceInstanceSize = instanceSize(Slice.class); + MemorySegment heapAllocatedSegment = MemorySegment.ofArray(new byte[0]); + int sliceInstanceSize = instanceSize(Slice.class) + instanceSize(heapAllocatedSegment.getClass()); Slice slice = Slices.allocate(10); assertThat(slice.getRetainedSize()).isEqualTo(sizeOfByteArray(10) + sliceInstanceSize); assertThat(slice.length()).isEqualTo(10); diff --git a/src/test/java/io/airlift/slice/TestSlices.java b/src/test/java/io/airlift/slice/TestSlices.java index d1e978ee..30b3bbdd 100644 --- a/src/test/java/io/airlift/slice/TestSlices.java +++ b/src/test/java/io/airlift/slice/TestSlices.java @@ -15,12 +15,9 @@ import org.junit.jupiter.api.Test; -import java.nio.ByteBuffer; import java.util.Random; import static io.airlift.slice.SizeOf.SIZE_OF_BYTE; -import static io.airlift.slice.SizeOf.instanceSize; -import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.slice.Slices.EMPTY_SLICE; import static io.airlift.slice.Slices.MAX_ARRAY_SIZE; import static io.airlift.slice.Slices.SLICE_ALLOC_THRESHOLD; @@ -28,7 +25,6 @@ import static io.airlift.slice.Slices.allocate; import static io.airlift.slice.Slices.ensureSize; import static io.airlift.slice.Slices.wrappedBuffer; -import static io.airlift.slice.Slices.wrappedHeapBuffer; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -42,51 +38,6 @@ public void testEmptySlice() assertThat(EMPTY_SLICE.byteArrayOffset()).isEqualTo(0); } - @Test - public void testWrapHeapBuffer() - { - ByteBuffer buffer = ByteBuffer.allocate(50); - // initialize buffer - for (int i = 0; i < 50; i++) { - buffer.put((byte) i); - } - - // test empty buffer - assertThat(wrappedHeapBuffer(buffer)).isSameAs(EMPTY_SLICE); - - // test full buffer - buffer.rewind(); - Slice slice = wrappedHeapBuffer(buffer); - assertThat(slice.length()).isEqualTo(50); - for (int i = 0; i < 50; i++) { - assertThat(slice.getByte(i)).isEqualTo((byte) i); - } - - // test limited buffer - buffer.position(10).limit(30); - slice = wrappedHeapBuffer(buffer); - assertThat(slice.length()).isEqualTo(20); - for (int i = 0; i < 20; i++) { - assertThat(slice.getByte(i)).isEqualTo((byte) (i + 10)); - } - - // test limited buffer after slicing - buffer = buffer.slice(); - slice = wrappedHeapBuffer(buffer); - assertThat(slice.length()).isEqualTo(20); - for (int i = 0; i < 20; i++) { - assertThat(slice.getByte(i)).isEqualTo((byte) (i + 10)); - } - } - - @Test - public void testWrapHeapBufferRetainedSize() - { - ByteBuffer heapByteBuffer = ByteBuffer.allocate(50); - Slice slice = wrappedHeapBuffer(heapByteBuffer); - assertThat(slice.getRetainedSize()).isEqualTo(instanceSize(Slice.class) + sizeOf(heapByteBuffer.array())); - } - @Test public void testWrapByteArray() {