From 2e0e2366bb49ccb7f6d11a5a8036d600c09c2546 Mon Sep 17 00:00:00 2001 From: Jingsong Lee Date: Tue, 5 Mar 2024 10:40:35 +0800 Subject: [PATCH] [core] Introduce userDefineSeqComparator for MergeSorter (#2936) --- .../paimon/codegen/CodeGeneratorImpl.java | 18 +-- .../apache/paimon/codegen/CodeGenerator.java | 16 +- .../apache/paimon/utils/FieldsComparator.java | 29 ++++ .../apache/paimon/data/BinaryStringTest.java | 78 +++++----- .../apache/paimon/codegen/CodeGenUtils.java | 21 ++- .../crosspartition/GlobalIndexAssigner.java | 3 +- .../apache/paimon/lookup/RocksDBState.java | 2 +- .../apache/paimon/mergetree/MergeSorter.java | 33 ++-- .../paimon/mergetree/MergeTreeReaders.java | 8 +- .../mergetree/SortBufferWriteBuffer.java | 7 +- .../compact/ChangelogMergeTreeRewriter.java | 1 + .../mergetree/compact/SortMergeReader.java | 8 +- .../compact/SortMergeReaderWithLoserTree.java | 19 ++- .../compact/SortMergeReaderWithMinHeap.java | 10 ++ .../apache/paimon/operation/DiffReader.java | 3 + .../operation/KeyValueFileStoreRead.java | 2 + .../paimon/sort/BinaryExternalSortBuffer.java | 10 +- .../paimon/utils/KeyComparatorSupplier.java | 6 +- .../paimon/mergetree/MergeSorterTest.java | 144 +++++++++++++----- .../compact/SortMergeReaderTestBase.java | 1 + .../paimon/flink/sorter/SortOperator.java | 4 +- 21 files changed, 294 insertions(+), 129 deletions(-) create mode 100644 paimon-common/src/main/java/org/apache/paimon/utils/FieldsComparator.java diff --git a/paimon-codegen/src/main/java/org/apache/paimon/codegen/CodeGeneratorImpl.java b/paimon-codegen/src/main/java/org/apache/paimon/codegen/CodeGeneratorImpl.java index 8cb44666434e..c4189542219e 100644 --- a/paimon-codegen/src/main/java/org/apache/paimon/codegen/CodeGeneratorImpl.java +++ b/paimon-codegen/src/main/java/org/apache/paimon/codegen/CodeGeneratorImpl.java @@ -37,20 +37,20 @@ public GeneratedClass generateProjection( @Override public GeneratedClass generateNormalizedKeyComputer( - List fieldTypes, String name) { + List inputTypes, int[] sortFields, String name) { return new SortCodeGenerator( - RowType.builder().fields(fieldTypes).build(), - getAscendingSortSpec(fieldTypes.size())) + RowType.builder().fields(inputTypes).build(), + getAscendingSortSpec(sortFields)) .generateNormalizedKeyComputer(name); } @Override public GeneratedClass generateRecordComparator( - List fieldTypes, String name) { + List inputTypes, int[] sortFields, String name) { return ComparatorCodeGenerator.gen( name, - RowType.builder().fields(fieldTypes).build(), - getAscendingSortSpec(fieldTypes.size())); + RowType.builder().fields(inputTypes).build(), + getAscendingSortSpec(sortFields)); } /** Generate a {@link RecordEqualiser}. */ @@ -61,10 +61,10 @@ public GeneratedClass generateRecordEqualiser( .generateRecordEqualiser(name); } - private SortSpec getAscendingSortSpec(int numFields) { + private SortSpec getAscendingSortSpec(int[] sortFields) { SortSpec.SortSpecBuilder builder = SortSpec.builder(); - for (int i = 0; i < numFields; i++) { - builder.addField(i, true, false); + for (int sortField : sortFields) { + builder.addField(sortField, true, false); } return builder.build(); } diff --git a/paimon-common/src/main/java/org/apache/paimon/codegen/CodeGenerator.java b/paimon-common/src/main/java/org/apache/paimon/codegen/CodeGenerator.java index 5ee64a0eccd6..842098dbcbed 100644 --- a/paimon-common/src/main/java/org/apache/paimon/codegen/CodeGenerator.java +++ b/paimon-common/src/main/java/org/apache/paimon/codegen/CodeGenerator.java @@ -32,22 +32,22 @@ GeneratedClass generateProjection( /** * Generate a {@link NormalizedKeyComputer}. * - * @param fieldTypes Both the input row field types and the sort key field types. Records are - * compared by the first field, then the second field, then the third field and so on. All - * fields are compared in ascending order. + * @param inputTypes input types. + * @param sortFields the sort key fields. Records are compared by the first field, then the + * second field, then the third field and so on. All fields are compared in ascending order. */ GeneratedClass generateNormalizedKeyComputer( - List fieldTypes, String name); + List inputTypes, int[] sortFields, String name); /** * Generate a {@link RecordComparator}. * - * @param fieldTypes Both the input row field types and the sort key field types. Records are * - * compared by the first field, then the second field, then the third field and so on. All * - * fields are compared in ascending order. + * @param inputTypes input types. + * @param sortFields the sort key fields. Records are compared by the first field, then the + * second field, then the third field and so on. All fields are compared in ascending order. */ GeneratedClass generateRecordComparator( - List fieldTypes, String name); + List inputTypes, int[] sortFields, String name); /** * Generate a {@link RecordEqualiser}. diff --git a/paimon-common/src/main/java/org/apache/paimon/utils/FieldsComparator.java b/paimon-common/src/main/java/org/apache/paimon/utils/FieldsComparator.java new file mode 100644 index 000000000000..85140d26b457 --- /dev/null +++ b/paimon-common/src/main/java/org/apache/paimon/utils/FieldsComparator.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.utils; + +import org.apache.paimon.data.InternalRow; + +import java.util.Comparator; + +/** A {@link Comparator} to compare fields for {@link InternalRow}. */ +public interface FieldsComparator extends Comparator { + + int[] compareFields(); +} diff --git a/paimon-common/src/test/java/org/apache/paimon/data/BinaryStringTest.java b/paimon-common/src/test/java/org/apache/paimon/data/BinaryStringTest.java index a5052c7303b4..c574ae8beae7 100644 --- a/paimon-common/src/test/java/org/apache/paimon/data/BinaryStringTest.java +++ b/paimon-common/src/test/java/org/apache/paimon/data/BinaryStringTest.java @@ -26,7 +26,6 @@ import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.ExtendWith; -import java.io.UnsupportedEncodingException; import java.math.BigDecimal; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -34,6 +33,7 @@ import java.util.Random; import static java.nio.charset.StandardCharsets.UTF_8; +import static org.apache.paimon.data.BinaryString.EMPTY_UTF8; import static org.apache.paimon.data.BinaryString.blankString; import static org.apache.paimon.data.BinaryString.fromBytes; import static org.apache.paimon.utils.DecimalUtils.castFrom; @@ -47,14 +47,13 @@ @ExtendWith(ParameterizedTestExtension.class) public class BinaryStringTest { - private BinaryString empty = fromString(""); - private final Mode mode; public BinaryStringTest(Mode mode) { this.mode = mode; } + @SuppressWarnings("unused") @Parameters(name = "{0}") public static List getVarSeg() { return Arrays.asList(Mode.ONE_SEG, Mode.MULTI_SEGS, Mode.STRING, Mode.RANDOM); @@ -78,7 +77,7 @@ private BinaryString fromString(String str) { mode = Mode.ONE_SEG; } else if (rnd == 1) { mode = Mode.MULTI_SEGS; - } else if (rnd == 2) { + } else { mode = Mode.STRING; } } @@ -143,10 +142,10 @@ public void basicTest() { @TestTemplate public void emptyStringTest() { - assertThat(fromString("")).isEqualTo(empty); - assertThat(fromBytes(new byte[0])).isEqualTo(empty); - assertThat(empty.numChars()).isEqualTo(0); - assertThat(empty.getSizeInBytes()).isEqualTo(0); + assertThat(fromString("")).isEqualTo(EMPTY_UTF8); + assertThat(fromBytes(new byte[0])).isEqualTo(EMPTY_UTF8); + assertThat(EMPTY_UTF8.numChars()).isEqualTo(0); + assertThat(EMPTY_UTF8.getSizeInBytes()).isEqualTo(0); } @TestTemplate @@ -224,7 +223,7 @@ public void testMultiSegments() { @TestTemplate public void contains() { - assertThat(empty.contains(empty)).isTrue(); + assertThat(EMPTY_UTF8.contains(EMPTY_UTF8)).isTrue(); assertThat(fromString("hello").contains(fromString("ello"))).isTrue(); assertThat(fromString("hello").contains(fromString("vello"))).isFalse(); assertThat(fromString("hello").contains(fromString("hellooo"))).isFalse(); @@ -235,7 +234,7 @@ public void contains() { @TestTemplate public void startsWith() { - assertThat(empty.startsWith(empty)).isTrue(); + assertThat(EMPTY_UTF8.startsWith(EMPTY_UTF8)).isTrue(); assertThat(fromString("hello").startsWith(fromString("hell"))).isTrue(); assertThat(fromString("hello").startsWith(fromString("ell"))).isFalse(); assertThat(fromString("hello").startsWith(fromString("hellooo"))).isFalse(); @@ -246,7 +245,7 @@ public void startsWith() { @TestTemplate public void endsWith() { - assertThat(empty.endsWith(empty)).isTrue(); + assertThat(EMPTY_UTF8.endsWith(EMPTY_UTF8)).isTrue(); assertThat(fromString("hello").endsWith(fromString("ello"))).isTrue(); assertThat(fromString("hello").endsWith(fromString("ellov"))).isFalse(); assertThat(fromString("hello").endsWith(fromString("hhhello"))).isFalse(); @@ -257,7 +256,7 @@ public void endsWith() { @TestTemplate public void substring() { - assertThat(fromString("hello").substring(0, 0)).isEqualTo(empty); + assertThat(fromString("hello").substring(0, 0)).isEqualTo(EMPTY_UTF8); assertThat(fromString("hello").substring(1, 3)).isEqualTo(fromString("el")); assertThat(fromString("数据砖头").substring(0, 1)).isEqualTo(fromString("数")); assertThat(fromString("数据砖头").substring(1, 3)).isEqualTo(fromString("据砖")); @@ -267,9 +266,9 @@ public void substring() { @TestTemplate public void indexOf() { - assertThat(empty.indexOf(empty, 0)).isEqualTo(0); - assertThat(empty.indexOf(fromString("l"), 0)).isEqualTo(-1); - assertThat(fromString("hello").indexOf(empty, 0)).isEqualTo(0); + assertThat(EMPTY_UTF8.indexOf(EMPTY_UTF8, 0)).isEqualTo(0); + assertThat(EMPTY_UTF8.indexOf(fromString("l"), 0)).isEqualTo(-1); + assertThat(fromString("hello").indexOf(EMPTY_UTF8, 0)).isEqualTo(0); assertThat(fromString("hello").indexOf(fromString("l"), 0)).isEqualTo(2); assertThat(fromString("hello").indexOf(fromString("l"), 3)).isEqualTo(3); assertThat(fromString("hello").indexOf(fromString("a"), 0)).isEqualTo(-1); @@ -300,24 +299,21 @@ public void testToUpperLowerCase() { writer.writeString(5, BinaryString.fromString("!@#$%^*")); writer.complete(); - assertThat(((BinaryString) row.getString(0)).toUpperCase()).isEqualTo(fromString("A")); - assertThat(((BinaryString) row.getString(1)).toUpperCase()).isEqualTo(fromString("我是中国人")); - assertThat(((BinaryString) row.getString(1)).toLowerCase()).isEqualTo(fromString("我是中国人")); - assertThat(((BinaryString) row.getString(3)).toUpperCase()) - .isEqualTo(fromString("ABCDEFG")); - assertThat(((BinaryString) row.getString(3)).toLowerCase()) - .isEqualTo(fromString("abcdefg")); - assertThat(((BinaryString) row.getString(5)).toUpperCase()) - .isEqualTo(fromString("!@#$%^*")); - assertThat(((BinaryString) row.getString(5)).toLowerCase()) - .isEqualTo(fromString("!@#$%^*")); + assertThat(row.getString(0).toUpperCase()).isEqualTo(fromString("A")); + assertThat(row.getString(1).toUpperCase()).isEqualTo(fromString("我是中国人")); + assertThat(row.getString(1).toLowerCase()).isEqualTo(fromString("我是中国人")); + assertThat(row.getString(3).toUpperCase()).isEqualTo(fromString("ABCDEFG")); + assertThat(row.getString(3).toLowerCase()).isEqualTo(fromString("abcdefg")); + assertThat(row.getString(5).toUpperCase()).isEqualTo(fromString("!@#$%^*")); + assertThat(row.getString(5).toLowerCase()).isEqualTo(fromString("!@#$%^*")); } @TestTemplate - public void testcastFrom() { + public void testCastFrom() { class DecimalTestData { - private String str; - private int precision, scale; + private final String str; + private final int precision; + private final int scale; private DecimalTestData(String str, int precision, int scale) { this.str = str; @@ -391,7 +387,7 @@ private DecimalTestData(String str, int precision, int scale) { writer.complete(); for (int i = 0; i < data.length; i++) { DecimalTestData d = data[i]; - assertThat(castFrom((BinaryString) row.getString(i), d.precision, d.scale)) + assertThat(castFrom(row.getString(i), d.precision, d.scale)) .isEqualTo(Decimal.fromBigDecimal(new BigDecimal(d.str), d.precision, d.scale)); } } @@ -407,21 +403,21 @@ public void testEmptyString() { str3 = BinaryString.fromAddress(segments, 15, 0); } - assertThat(BinaryString.EMPTY_UTF8.compareTo(str2)).isLessThan(0); - assertThat(str2.compareTo(BinaryString.EMPTY_UTF8)).isGreaterThan(0); + assertThat(EMPTY_UTF8.compareTo(str2)).isLessThan(0); + assertThat(str2.compareTo(EMPTY_UTF8)).isGreaterThan(0); - assertThat(BinaryString.EMPTY_UTF8.compareTo(str3)).isEqualTo(0); - assertThat(str3.compareTo(BinaryString.EMPTY_UTF8)).isEqualTo(0); + assertThat(EMPTY_UTF8.compareTo(str3)).isEqualTo(0); + assertThat(str3.compareTo(EMPTY_UTF8)).isEqualTo(0); - assertThat(str2).isNotEqualTo(BinaryString.EMPTY_UTF8); - assertThat(BinaryString.EMPTY_UTF8).isNotEqualTo(str2); + assertThat(str2).isNotEqualTo(EMPTY_UTF8); + assertThat(EMPTY_UTF8).isNotEqualTo(str2); - assertThat(str3).isEqualTo(BinaryString.EMPTY_UTF8); - assertThat(BinaryString.EMPTY_UTF8).isEqualTo(str3); + assertThat(str3).isEqualTo(EMPTY_UTF8); + assertThat(EMPTY_UTF8).isEqualTo(str3); } @TestTemplate - public void testEncodeWithIllegalCharacter() throws UnsupportedEncodingException { + public void testEncodeWithIllegalCharacter() { // Tis char array has some illegal character, such as 55357 // the jdk ignores theses character and cast them to '?' @@ -434,11 +430,11 @@ public void testEncodeWithIllegalCharacter() throws UnsupportedEncodingException String str = new String(chars); - assertThat(BinaryString.encodeUTF8(str)).isEqualTo(str.getBytes("UTF-8")); + assertThat(BinaryString.encodeUTF8(str)).isEqualTo(str.getBytes(UTF_8)); } @TestTemplate - public void testDecodeWithIllegalUtf8Bytes() throws UnsupportedEncodingException { + public void testDecodeWithIllegalUtf8Bytes() { // illegal utf-8 bytes byte[] bytes = diff --git a/paimon-core/src/main/java/org/apache/paimon/codegen/CodeGenUtils.java b/paimon-core/src/main/java/org/apache/paimon/codegen/CodeGenUtils.java index 868734b0f804..57abd915a96d 100644 --- a/paimon-core/src/main/java/org/apache/paimon/codegen/CodeGenUtils.java +++ b/paimon-core/src/main/java/org/apache/paimon/codegen/CodeGenUtils.java @@ -23,6 +23,7 @@ import org.apache.paimon.types.RowType; import java.util.List; +import java.util.stream.IntStream; /** Utils for code generations. */ public class CodeGenUtils { @@ -46,15 +47,16 @@ public static Projection newProjection(RowType inputType, int[] mapping) { } public static NormalizedKeyComputer newNormalizedKeyComputer( - List fieldTypes, String name) { + List inputTypes, int[] sortFields, String name) { return CodeGenLoader.getCodeGenerator() - .generateNormalizedKeyComputer(fieldTypes, name) + .generateNormalizedKeyComputer(inputTypes, sortFields, name) .newInstance(CodeGenUtils.class.getClassLoader()); } public static GeneratedClass generateRecordComparator( - List fieldTypes, String name) { - return CodeGenLoader.getCodeGenerator().generateRecordComparator(fieldTypes, name); + List inputTypes, int[] sortFields, String name) { + return CodeGenLoader.getCodeGenerator() + .generateRecordComparator(inputTypes, sortFields, name); } public static GeneratedClass generateRecordEqualiser( @@ -62,8 +64,15 @@ public static GeneratedClass generateRecordEqualiser( return CodeGenLoader.getCodeGenerator().generateRecordEqualiser(fieldTypes, name); } - public static RecordComparator newRecordComparator(List fieldTypes, String name) { - return generateRecordComparator(fieldTypes, name) + public static RecordComparator newRecordComparator( + List inputTypes, int[] sortFields, String name) { + return generateRecordComparator(inputTypes, sortFields, name) + .newInstance(CodeGenUtils.class.getClassLoader()); + } + + public static RecordComparator newRecordComparator(List inputTypes, String name) { + return generateRecordComparator( + inputTypes, IntStream.range(0, inputTypes.size()).toArray(), name) .newInstance(CodeGenUtils.class.getClassLoader()); } } diff --git a/paimon-core/src/main/java/org/apache/paimon/crosspartition/GlobalIndexAssigner.java b/paimon-core/src/main/java/org/apache/paimon/crosspartition/GlobalIndexAssigner.java index d4383d4c2c99..787fc9e0f95c 100644 --- a/paimon-core/src/main/java/org/apache/paimon/crosspartition/GlobalIndexAssigner.java +++ b/paimon-core/src/main/java/org/apache/paimon/crosspartition/GlobalIndexAssigner.java @@ -67,6 +67,7 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.function.BiConsumer; import java.util.function.Function; +import java.util.stream.IntStream; import static org.apache.paimon.lookup.RocksDBOptions.BLOCK_CACHE_SIZE; import static org.apache.paimon.utils.Preconditions.checkArgument; @@ -285,8 +286,8 @@ private void bulkLoadBootstrapRecords() { BinaryExternalSortBuffer keyIdBuffer = BinaryExternalSortBuffer.create( ioManager, - keyWithIdType, keyWithRowType, + IntStream.range(0, keyWithIdType.getFieldCount()).toArray(), coreOptions.writeBufferSize() / 2, coreOptions.pageSize(), coreOptions.localSortMaxNumFileHandles(), diff --git a/paimon-core/src/main/java/org/apache/paimon/lookup/RocksDBState.java b/paimon-core/src/main/java/org/apache/paimon/lookup/RocksDBState.java index 5ffcaf1463cc..2e10acb0e146 100644 --- a/paimon-core/src/main/java/org/apache/paimon/lookup/RocksDBState.java +++ b/paimon-core/src/main/java/org/apache/paimon/lookup/RocksDBState.java @@ -107,8 +107,8 @@ public static BinaryExternalSortBuffer createBulkLoadSorter( IOManager ioManager, CoreOptions options) { return BinaryExternalSortBuffer.create( ioManager, - RowType.of(DataTypes.BYTES()), RowType.of(DataTypes.BYTES(), DataTypes.BYTES()), + new int[] {0}, options.writeBufferSize() / 2, options.pageSize(), options.localSortMaxNumFileHandles(), diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeSorter.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeSorter.java index 26a6d7fc6a21..5e30d16aa440 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeSorter.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeSorter.java @@ -37,12 +37,11 @@ import org.apache.paimon.sort.SortBuffer; import org.apache.paimon.types.BigIntType; import org.apache.paimon.types.DataField; -import org.apache.paimon.types.DataType; -import org.apache.paimon.types.DataTypes; import org.apache.paimon.types.IntType; import org.apache.paimon.types.RowKind; import org.apache.paimon.types.RowType; import org.apache.paimon.types.TinyIntType; +import org.apache.paimon.utils.FieldsComparator; import org.apache.paimon.utils.IOUtils; import org.apache.paimon.utils.MutableObjectIterator; import org.apache.paimon.utils.OffsetRow; @@ -53,6 +52,7 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.List; +import java.util.stream.IntStream; import static org.apache.paimon.schema.SystemColumns.SEQUENCE_NUMBER; import static org.apache.paimon.schema.SystemColumns.VALUE_KIND; @@ -103,10 +103,12 @@ public void setProjectedValueType(RowType projectedType) { public RecordReader mergeSort( List> lazyReaders, Comparator keyComparator, + @Nullable FieldsComparator userDefinedSeqComparator, MergeFunctionWrapper mergeFunction) throws IOException { if (ioManager != null && lazyReaders.size() > spillThreshold) { - return spillMergeSort(lazyReaders, keyComparator, mergeFunction); + return spillMergeSort( + lazyReaders, keyComparator, userDefinedSeqComparator, mergeFunction); } List> readers = new ArrayList<>(lazyReaders.size()); @@ -121,15 +123,16 @@ public RecordReader mergeSort( } return SortMergeReader.createSortMergeReader( - readers, keyComparator, mergeFunction, sortEngine); + readers, keyComparator, userDefinedSeqComparator, mergeFunction, sortEngine); } private RecordReader spillMergeSort( List> readers, Comparator keyComparator, + @Nullable FieldsComparator userDefinedSeqComparator, MergeFunctionWrapper mergeFunction) throws IOException { - ExternalSorterWithLevel sorter = new ExternalSorterWithLevel(); + ExternalSorterWithLevel sorter = new ExternalSorterWithLevel(userDefinedSeqComparator); ConcatRecordReader.create(readers).forIOEachRemaining(sorter::put); sorter.flushMemory(); @@ -176,15 +179,25 @@ private class ExternalSorterWithLevel { private final SortBuffer buffer; - public ExternalSorterWithLevel() { + public ExternalSorterWithLevel(@Nullable FieldsComparator userDefinedSeqComparator) { if (memoryPool.freePages() < 3) { throw new IllegalArgumentException( "Write buffer requires a minimum of 3 page memory, please increase write buffer memory size."); } - // user key + sequenceNumber - List sortKeyTypes = new ArrayList<>(keyType.getFieldTypes()); - sortKeyTypes.add(new BigIntType(false)); + // key fields + IntStream sortFields = IntStream.range(0, keyType.getFieldCount()); + + // user define sequence fields + if (userDefinedSeqComparator != null) { + IntStream udsFields = + IntStream.of(userDefinedSeqComparator.compareFields()) + .map(operand -> operand + keyType.getFieldCount() + 3); + sortFields = IntStream.concat(sortFields, udsFields); + } + + // sequence field + sortFields = IntStream.concat(sortFields, IntStream.of(keyType.getFieldCount())); // row type List fields = new ArrayList<>(keyType.getFields()); @@ -196,8 +209,8 @@ public ExternalSorterWithLevel() { this.buffer = BinaryExternalSortBuffer.create( ioManager, - DataTypes.ROW(sortKeyTypes.toArray(new DataType[0])), new RowType(fields), + sortFields.toArray(), memoryPool, spillSortMaxNumFiles, compression); diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeTreeReaders.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeTreeReaders.java index 3fe28c073686..186371b81742 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeTreeReaders.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeTreeReaders.java @@ -28,6 +28,9 @@ import org.apache.paimon.mergetree.compact.MergeFunctionWrapper; import org.apache.paimon.mergetree.compact.ReducerMergeFunctionWrapper; import org.apache.paimon.reader.RecordReader; +import org.apache.paimon.utils.FieldsComparator; + +import javax.annotation.Nullable; import java.io.IOException; import java.util.ArrayList; @@ -55,6 +58,7 @@ public static RecordReader readerForMergeTree( section, readerFactory, userKeyComparator, + null, new ReducerMergeFunctionWrapper(mergeFunction), mergeSorter)); } @@ -69,6 +73,7 @@ public static RecordReader readerForSection( List section, KeyValueFileReaderFactory readerFactory, Comparator userKeyComparator, + @Nullable FieldsComparator userDefinedSeqComparator, MergeFunctionWrapper mergeFunctionWrapper, MergeSorter mergeSorter) throws IOException { @@ -76,7 +81,8 @@ public static RecordReader readerForSection( for (SortedRun run : section) { readers.add(() -> readerForRun(run, readerFactory)); } - return mergeSorter.mergeSort(readers, userKeyComparator, mergeFunctionWrapper); + return mergeSorter.mergeSort( + readers, userKeyComparator, userDefinedSeqComparator, mergeFunctionWrapper); } public static RecordReader readerForRun( diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/SortBufferWriteBuffer.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/SortBufferWriteBuffer.java index 744c4f31af03..ae877f5cb826 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/SortBufferWriteBuffer.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/SortBufferWriteBuffer.java @@ -48,6 +48,7 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.List; +import java.util.stream.IntStream; /** A {@link WriteBuffer} which stores records in {@link BinaryInMemorySortBuffer}. */ public class SortBufferWriteBuffer implements WriteBuffer { @@ -74,10 +75,12 @@ public SortBufferWriteBuffer( sortKeyTypes.add(new BigIntType(false)); // for sort binary buffer + int[] sortFields = IntStream.range(0, sortKeyTypes.size()).toArray(); NormalizedKeyComputer normalizedKeyComputer = - CodeGenUtils.newNormalizedKeyComputer(sortKeyTypes, "MemTableKeyComputer"); + CodeGenUtils.newNormalizedKeyComputer( + sortKeyTypes, sortFields, "MemTableKeyComputer"); RecordComparator keyComparator = - CodeGenUtils.newRecordComparator(sortKeyTypes, "MemTableComparator"); + CodeGenUtils.newRecordComparator(sortKeyTypes, sortFields, "MemTableComparator"); if (memoryPool.freePages() < 3) { throw new IllegalArgumentException( diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/ChangelogMergeTreeRewriter.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/ChangelogMergeTreeRewriter.java index f056fafaae95..99f5f052d13b 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/ChangelogMergeTreeRewriter.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/ChangelogMergeTreeRewriter.java @@ -112,6 +112,7 @@ private CompactResult rewriteChangelogCompaction( section, readerFactory, keyComparator, + null, createMergeWrapper(outputLevel), mergeSorter)); } diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReader.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReader.java index d545efd35bdf..598ca2aa692c 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReader.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReader.java @@ -22,6 +22,9 @@ import org.apache.paimon.KeyValue; import org.apache.paimon.data.InternalRow; import org.apache.paimon.reader.RecordReader; +import org.apache.paimon.utils.FieldsComparator; + +import javax.annotation.Nullable; import java.util.Comparator; import java.util.List; @@ -38,15 +41,16 @@ public interface SortMergeReader extends RecordReader { static SortMergeReader createSortMergeReader( List> readers, Comparator userKeyComparator, + @Nullable FieldsComparator userDefinedSeqComparator, MergeFunctionWrapper mergeFunctionWrapper, SortEngine sortEngine) { switch (sortEngine) { case MIN_HEAP: return new SortMergeReaderWithMinHeap<>( - readers, userKeyComparator, mergeFunctionWrapper); + readers, userKeyComparator, userDefinedSeqComparator, mergeFunctionWrapper); case LOSER_TREE: return new SortMergeReaderWithLoserTree<>( - readers, userKeyComparator, mergeFunctionWrapper); + readers, userKeyComparator, userDefinedSeqComparator, mergeFunctionWrapper); default: throw new UnsupportedOperationException("Unsupported sort engine: " + sortEngine); } diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithLoserTree.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithLoserTree.java index f7a20ee671dc..3ca3d288e05b 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithLoserTree.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithLoserTree.java @@ -21,6 +21,7 @@ import org.apache.paimon.KeyValue; import org.apache.paimon.data.InternalRow; import org.apache.paimon.reader.RecordReader; +import org.apache.paimon.utils.FieldsComparator; import org.apache.paimon.utils.Preconditions; import javax.annotation.Nullable; @@ -38,13 +39,29 @@ public class SortMergeReaderWithLoserTree implements SortMergeReader { public SortMergeReaderWithLoserTree( List> readers, Comparator userKeyComparator, + @Nullable FieldsComparator userDefinedSeqComparator, MergeFunctionWrapper mergeFunctionWrapper) { this.mergeFunctionWrapper = mergeFunctionWrapper; this.loserTree = new LoserTree<>( readers, (e1, e2) -> userKeyComparator.compare(e2.key(), e1.key()), - (e1, e2) -> Long.compare(e2.sequenceNumber(), e1.sequenceNumber())); + createSequenceComparator(userDefinedSeqComparator)); + } + + private Comparator createSequenceComparator( + @Nullable FieldsComparator userDefinedSeqComparator) { + if (userDefinedSeqComparator == null) { + return (e1, e2) -> Long.compare(e2.sequenceNumber(), e1.sequenceNumber()); + } + + return (o1, o2) -> { + int result = userDefinedSeqComparator.compare(o2.value(), o1.value()); + if (result != 0) { + return result; + } + return Long.compare(o2.sequenceNumber(), o1.sequenceNumber()); + }; } /** Compared with heapsort, {@link LoserTree} will only produce one batch. */ diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithMinHeap.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithMinHeap.java index adca53fdbb68..a78ef334f071 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithMinHeap.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithMinHeap.java @@ -21,6 +21,7 @@ import org.apache.paimon.KeyValue; import org.apache.paimon.data.InternalRow; import org.apache.paimon.reader.RecordReader; +import org.apache.paimon.utils.FieldsComparator; import org.apache.paimon.utils.Preconditions; import javax.annotation.Nullable; @@ -44,6 +45,7 @@ public class SortMergeReaderWithMinHeap implements SortMergeReader { public SortMergeReaderWithMinHeap( List> readers, Comparator userKeyComparator, + @Nullable FieldsComparator userDefinedSeqComparator, MergeFunctionWrapper mergeFunctionWrapper) { this.nextBatchReaders = new ArrayList<>(readers); this.userKeyComparator = userKeyComparator; @@ -56,6 +58,14 @@ public SortMergeReaderWithMinHeap( if (result != 0) { return result; } + if (userDefinedSeqComparator != null) { + result = + userDefinedSeqComparator.compare( + e1.kv.value(), e2.kv.value()); + if (result != 0) { + return result; + } + } return Long.compare(e1.kv.sequenceNumber(), e2.kv.sequenceNumber()); }); this.polled = new ArrayList<>(); diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/DiffReader.java b/paimon-core/src/main/java/org/apache/paimon/operation/DiffReader.java index 324bff9351d1..bc5153600cde 100644 --- a/paimon-core/src/main/java/org/apache/paimon/operation/DiffReader.java +++ b/paimon-core/src/main/java/org/apache/paimon/operation/DiffReader.java @@ -24,6 +24,7 @@ import org.apache.paimon.mergetree.compact.MergeFunctionWrapper; import org.apache.paimon.reader.RecordReader; import org.apache.paimon.types.RowKind; +import org.apache.paimon.utils.FieldsComparator; import javax.annotation.Nullable; @@ -43,6 +44,7 @@ public static RecordReader readDiff( RecordReader beforeReader, RecordReader afterReader, Comparator keyComparator, + @Nullable FieldsComparator userDefinedSeqComparator, MergeSorter sorter, boolean keepDelete) throws IOException { @@ -51,6 +53,7 @@ public static RecordReader readDiff( () -> wrapLevelToReader(beforeReader, BEFORE_LEVEL), () -> wrapLevelToReader(afterReader, AFTER_LEVEL)), keyComparator, + userDefinedSeqComparator, new DiffMerger(keepDelete)); } diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/KeyValueFileStoreRead.java b/paimon-core/src/main/java/org/apache/paimon/operation/KeyValueFileStoreRead.java index 09b9cb81b633..e8accda5b62f 100644 --- a/paimon-core/src/main/java/org/apache/paimon/operation/KeyValueFileStoreRead.java +++ b/paimon-core/src/main/java/org/apache/paimon/operation/KeyValueFileStoreRead.java @@ -227,6 +227,7 @@ private RecordReader createReaderWithoutOuterProjection(DataSplit spli batchMergeRead( split.partition(), split.bucket(), split.dataFiles(), false), keyComparator, + null, mergeSorter, forceKeepDelete); } @@ -255,6 +256,7 @@ private RecordReader batchMergeRead( ? overlappedSectionFactory : nonOverlappedSectionFactory, keyComparator, + null, mergeFuncWrapper, mergeSorter)); } diff --git a/paimon-core/src/main/java/org/apache/paimon/sort/BinaryExternalSortBuffer.java b/paimon-core/src/main/java/org/apache/paimon/sort/BinaryExternalSortBuffer.java index 7d9522f612c1..4de25a4ea9cd 100644 --- a/paimon-core/src/main/java/org/apache/paimon/sort/BinaryExternalSortBuffer.java +++ b/paimon-core/src/main/java/org/apache/paimon/sort/BinaryExternalSortBuffer.java @@ -91,16 +91,16 @@ public BinaryExternalSortBuffer( public static BinaryExternalSortBuffer create( IOManager ioManager, - RowType keyType, RowType rowType, + int[] keyFields, long bufferSize, int pageSize, int maxNumFileHandles, String compression) { return create( ioManager, - keyType, rowType, + keyFields, new HeapMemorySegmentPool(bufferSize, pageSize), maxNumFileHandles, compression); @@ -108,17 +108,17 @@ public static BinaryExternalSortBuffer create( public static BinaryExternalSortBuffer create( IOManager ioManager, - RowType keyType, RowType rowType, + int[] keyFields, MemorySegmentPool pool, int maxNumFileHandles, String compression) { RecordComparator comparator = - newRecordComparator(keyType.getFieldTypes(), "ExternalSort_comparator"); + newRecordComparator(rowType.getFieldTypes(), keyFields, "ExternalSort_comparator"); BinaryInMemorySortBuffer sortBuffer = BinaryInMemorySortBuffer.createBuffer( newNormalizedKeyComputer( - keyType.getFieldTypes(), "ExternalSort_normalized_key"), + rowType.getFieldTypes(), keyFields, "ExternalSort_normalized_key"), new InternalRowSerializer(rowType), comparator, pool); diff --git a/paimon-core/src/main/java/org/apache/paimon/utils/KeyComparatorSupplier.java b/paimon-core/src/main/java/org/apache/paimon/utils/KeyComparatorSupplier.java index 88589cc8137d..30d3ab32874e 100644 --- a/paimon-core/src/main/java/org/apache/paimon/utils/KeyComparatorSupplier.java +++ b/paimon-core/src/main/java/org/apache/paimon/utils/KeyComparatorSupplier.java @@ -26,6 +26,7 @@ import java.util.Comparator; import java.util.function.Supplier; +import java.util.stream.IntStream; /** A {@link Supplier} that returns the comparator for the file store key. */ public class KeyComparatorSupplier implements SerializableSupplier> { @@ -36,7 +37,10 @@ public class KeyComparatorSupplier implements SerializableSupplier getVarSeg() { + return Arrays.asList(SortEngine.LOSER_TREE, SortEngine.MIN_HEAP); + } + @BeforeEach public void beforeTest() { ioManager = IOManager.create(tempDir.toString()); Options options = new Options(); options.set(CoreOptions.SORT_SPILL_BUFFER_SIZE, new MemorySize(MEMORY_SIZE)); + options.set(CoreOptions.SORT_ENGINE, sortEngine); sorter = new MergeSorter(new CoreOptions(options), keyType, valueType, ioManager); totalPages = sorter.memoryPool().freePages(); } @@ -86,60 +107,82 @@ public void afterTest() throws Exception { this.ioManager.close(); } - @Test + @TestTemplate public void testSortAndMerge() throws Exception { + innerTest(null); + } + + @TestTemplate + public void testWithUserDefineSequence() throws Exception { + innerTest( + new FieldsComparator() { + @Override + public int[] compareFields() { + return new int[] {0}; + } + + @Override + public int compare(InternalRow o1, InternalRow o2) { + return Integer.compare(o1.getInt(0), o2.getInt(0)); + } + }); + } + + private void innerTest(FieldsComparator userDefinedSeqComparator) throws Exception { + Comparator comparator = + Comparator.comparingInt((KeyValue o) -> o.key().getInt(0)); + if (userDefinedSeqComparator != null) { + comparator = + comparator.thenComparing( + (o1, o2) -> userDefinedSeqComparator.compare(o1.value(), o2.value())); + } + comparator = comparator.thenComparingLong(KeyValue::sequenceNumber); + List> readers = new ArrayList<>(); Random rnd = new Random(); List expectedKvs = new ArrayList<>(); Set distinctSeq = new HashSet<>(); - for (int i = 0; i < 10; i++) { + for (int i = 0; i < rnd.nextInt(10) + 3; i++) { List kvs = new ArrayList<>(); + Set distinctKeys = new HashSet<>(); for (int j = 0; j < 100; j++) { - long seq = rnd.nextLong(); - while (distinctSeq.contains(seq)) { - rnd.nextLong(); + while (true) { + int key = rnd.nextInt(1000); + if (distinctKeys.contains(key)) { + continue; + } + + long seq = rnd.nextLong(); + while (distinctSeq.contains(seq)) { + seq = rnd.nextLong(); + } + distinctSeq.add(seq); + kvs.add( + new KeyValue() + .replace( + GenericRow.of(key), + seq, + RowKind.fromByteValue((byte) rnd.nextInt(4)), + GenericRow.of(rnd.nextInt(1000))) + .setLevel(rnd.nextInt(100))); + distinctKeys.add(key); + break; } - distinctSeq.add(seq); - kvs.add( - new KeyValue() - .replace( - GenericRow.of(rnd.nextInt(100)), - seq, - RowKind.fromByteValue((byte) rnd.nextInt(4)), - GenericRow.of(rnd.nextInt())) - .setLevel(rnd.nextInt(100))); } expectedKvs.addAll(kvs); + kvs.sort(comparator); readers.add(() -> new IteratorRecordReader<>(kvs.iterator())); } - expectedKvs.sort( - Comparator.comparingInt((KeyValue o) -> o.key().getInt(0)) - .thenComparingLong(KeyValue::sequenceNumber)); + expectedKvs.sort(comparator); - MergeFunctionWrapper> collectFunc = - new MergeFunctionWrapper>() { - - private List result; - - @Override - public void reset() { - result = new ArrayList<>(); - } - - @Override - public void add(KeyValue kv) { - result.add(kv); - } - - @Nullable - @Override - public List getResult() { - return result; - } - }; + TestMergeFunctionWrapper collectFunc = new TestMergeFunctionWrapper(); List all = new ArrayList<>(); - sorter.mergeSort(readers, Comparator.comparingInt(o -> o.getInt(0)), collectFunc) + sorter.mergeSort( + readers, + Comparator.comparingInt(o -> o.getInt(0)), + userDefinedSeqComparator, + collectFunc) .forEachRemaining(all::addAll); assertThat(toString(all)).containsExactlyElementsOf(toString(expectedKvs)); @@ -148,4 +191,25 @@ public List getResult() { private List toString(List kvs) { return kvs.stream().map(kv -> kv.toString(keyType, valueType)).collect(Collectors.toList()); } + + private static class TestMergeFunctionWrapper implements MergeFunctionWrapper> { + + private List result; + + @Override + public void reset() { + result = new ArrayList<>(); + } + + @Override + public void add(KeyValue kv) { + result.add(kv); + } + + @Nullable + @Override + public List getResult() { + return result; + } + } } diff --git a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/SortMergeReaderTestBase.java b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/SortMergeReaderTestBase.java index 6895815cdcdc..81e49b81d2a6 100644 --- a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/SortMergeReaderTestBase.java +++ b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/SortMergeReaderTestBase.java @@ -48,6 +48,7 @@ protected RecordReader createRecordReader( return SortMergeReader.createSortMergeReader( new ArrayList<>(readers), KEY_COMPARATOR, + null, new ReducerMergeFunctionWrapper(createMergeFunction()), sortEngine); } diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sorter/SortOperator.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sorter/SortOperator.java index 4272dc53e34a..1292c7ce31ce 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sorter/SortOperator.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sorter/SortOperator.java @@ -31,6 +31,8 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.runtime.operators.TableStreamOperator; +import java.util.stream.IntStream; + /** SortOperator to sort the `InternalRow`s by the `KeyType`. */ public class SortOperator extends TableStreamOperator implements OneInputStreamOperator, BoundedOneInput { @@ -87,8 +89,8 @@ void initBuffer() { buffer = BinaryExternalSortBuffer.create( ioManager, - keyType, rowType, + IntStream.range(0, keyType.getFieldCount()).toArray(), maxMemory, pageSize, spillSortMaxNumFiles,