Skip to content

Commit

Permalink
[core] Introduce userDefineSeqComparator for MergeSorter (apache#2936)
Browse files Browse the repository at this point in the history
  • Loading branch information
JingsongLi authored Mar 5, 2024
1 parent 5e2fcfc commit 2e0e236
Show file tree
Hide file tree
Showing 21 changed files with 294 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,20 @@ public GeneratedClass<Projection> generateProjection(

@Override
public GeneratedClass<NormalizedKeyComputer> generateNormalizedKeyComputer(
List<DataType> fieldTypes, String name) {
List<DataType> inputTypes, int[] sortFields, String name) {
return new SortCodeGenerator(
RowType.builder().fields(fieldTypes).build(),
getAscendingSortSpec(fieldTypes.size()))
RowType.builder().fields(inputTypes).build(),
getAscendingSortSpec(sortFields))
.generateNormalizedKeyComputer(name);
}

@Override
public GeneratedClass<RecordComparator> generateRecordComparator(
List<DataType> fieldTypes, String name) {
List<DataType> inputTypes, int[] sortFields, String name) {
return ComparatorCodeGenerator.gen(
name,
RowType.builder().fields(fieldTypes).build(),
getAscendingSortSpec(fieldTypes.size()));
RowType.builder().fields(inputTypes).build(),
getAscendingSortSpec(sortFields));
}

/** Generate a {@link RecordEqualiser}. */
Expand All @@ -61,10 +61,10 @@ public GeneratedClass<RecordEqualiser> generateRecordEqualiser(
.generateRecordEqualiser(name);
}

private SortSpec getAscendingSortSpec(int numFields) {
private SortSpec getAscendingSortSpec(int[] sortFields) {
SortSpec.SortSpecBuilder builder = SortSpec.builder();
for (int i = 0; i < numFields; i++) {
builder.addField(i, true, false);
for (int sortField : sortFields) {
builder.addField(sortField, true, false);
}
return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,22 @@ GeneratedClass<Projection> generateProjection(
/**
* Generate a {@link NormalizedKeyComputer}.
*
* @param fieldTypes Both the input row field types and the sort key field types. Records are
* compared by the first field, then the second field, then the third field and so on. All
* fields are compared in ascending order.
* @param inputTypes input types.
* @param sortFields the sort key fields. Records are compared by the first field, then the
* second field, then the third field and so on. All fields are compared in ascending order.
*/
GeneratedClass<NormalizedKeyComputer> generateNormalizedKeyComputer(
List<DataType> fieldTypes, String name);
List<DataType> inputTypes, int[] sortFields, String name);

/**
* Generate a {@link RecordComparator}.
*
* @param fieldTypes Both the input row field types and the sort key field types. Records are *
* compared by the first field, then the second field, then the third field and so on. All *
* fields are compared in ascending order.
* @param inputTypes input types.
* @param sortFields the sort key fields. Records are compared by the first field, then the
* second field, then the third field and so on. All fields are compared in ascending order.
*/
GeneratedClass<RecordComparator> generateRecordComparator(
List<DataType> fieldTypes, String name);
List<DataType> inputTypes, int[] sortFields, String name);

/**
* Generate a {@link RecordEqualiser}.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.utils;

import org.apache.paimon.data.InternalRow;

import java.util.Comparator;

/** A {@link Comparator} to compare fields for {@link InternalRow}. */
public interface FieldsComparator extends Comparator<InternalRow> {

int[] compareFields();
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;

import java.io.UnsupportedEncodingException;
import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.apache.paimon.data.BinaryString.EMPTY_UTF8;
import static org.apache.paimon.data.BinaryString.blankString;
import static org.apache.paimon.data.BinaryString.fromBytes;
import static org.apache.paimon.utils.DecimalUtils.castFrom;
Expand All @@ -47,14 +47,13 @@
@ExtendWith(ParameterizedTestExtension.class)
public class BinaryStringTest {

private BinaryString empty = fromString("");

private final Mode mode;

public BinaryStringTest(Mode mode) {
this.mode = mode;
}

@SuppressWarnings("unused")
@Parameters(name = "{0}")
public static List<Mode> getVarSeg() {
return Arrays.asList(Mode.ONE_SEG, Mode.MULTI_SEGS, Mode.STRING, Mode.RANDOM);
Expand All @@ -78,7 +77,7 @@ private BinaryString fromString(String str) {
mode = Mode.ONE_SEG;
} else if (rnd == 1) {
mode = Mode.MULTI_SEGS;
} else if (rnd == 2) {
} else {
mode = Mode.STRING;
}
}
Expand Down Expand Up @@ -143,10 +142,10 @@ public void basicTest() {

@TestTemplate
public void emptyStringTest() {
assertThat(fromString("")).isEqualTo(empty);
assertThat(fromBytes(new byte[0])).isEqualTo(empty);
assertThat(empty.numChars()).isEqualTo(0);
assertThat(empty.getSizeInBytes()).isEqualTo(0);
assertThat(fromString("")).isEqualTo(EMPTY_UTF8);
assertThat(fromBytes(new byte[0])).isEqualTo(EMPTY_UTF8);
assertThat(EMPTY_UTF8.numChars()).isEqualTo(0);
assertThat(EMPTY_UTF8.getSizeInBytes()).isEqualTo(0);
}

@TestTemplate
Expand Down Expand Up @@ -224,7 +223,7 @@ public void testMultiSegments() {

@TestTemplate
public void contains() {
assertThat(empty.contains(empty)).isTrue();
assertThat(EMPTY_UTF8.contains(EMPTY_UTF8)).isTrue();
assertThat(fromString("hello").contains(fromString("ello"))).isTrue();
assertThat(fromString("hello").contains(fromString("vello"))).isFalse();
assertThat(fromString("hello").contains(fromString("hellooo"))).isFalse();
Expand All @@ -235,7 +234,7 @@ public void contains() {

@TestTemplate
public void startsWith() {
assertThat(empty.startsWith(empty)).isTrue();
assertThat(EMPTY_UTF8.startsWith(EMPTY_UTF8)).isTrue();
assertThat(fromString("hello").startsWith(fromString("hell"))).isTrue();
assertThat(fromString("hello").startsWith(fromString("ell"))).isFalse();
assertThat(fromString("hello").startsWith(fromString("hellooo"))).isFalse();
Expand All @@ -246,7 +245,7 @@ public void startsWith() {

@TestTemplate
public void endsWith() {
assertThat(empty.endsWith(empty)).isTrue();
assertThat(EMPTY_UTF8.endsWith(EMPTY_UTF8)).isTrue();
assertThat(fromString("hello").endsWith(fromString("ello"))).isTrue();
assertThat(fromString("hello").endsWith(fromString("ellov"))).isFalse();
assertThat(fromString("hello").endsWith(fromString("hhhello"))).isFalse();
Expand All @@ -257,7 +256,7 @@ public void endsWith() {

@TestTemplate
public void substring() {
assertThat(fromString("hello").substring(0, 0)).isEqualTo(empty);
assertThat(fromString("hello").substring(0, 0)).isEqualTo(EMPTY_UTF8);
assertThat(fromString("hello").substring(1, 3)).isEqualTo(fromString("el"));
assertThat(fromString("数据砖头").substring(0, 1)).isEqualTo(fromString("数"));
assertThat(fromString("数据砖头").substring(1, 3)).isEqualTo(fromString("据砖"));
Expand All @@ -267,9 +266,9 @@ public void substring() {

@TestTemplate
public void indexOf() {
assertThat(empty.indexOf(empty, 0)).isEqualTo(0);
assertThat(empty.indexOf(fromString("l"), 0)).isEqualTo(-1);
assertThat(fromString("hello").indexOf(empty, 0)).isEqualTo(0);
assertThat(EMPTY_UTF8.indexOf(EMPTY_UTF8, 0)).isEqualTo(0);
assertThat(EMPTY_UTF8.indexOf(fromString("l"), 0)).isEqualTo(-1);
assertThat(fromString("hello").indexOf(EMPTY_UTF8, 0)).isEqualTo(0);
assertThat(fromString("hello").indexOf(fromString("l"), 0)).isEqualTo(2);
assertThat(fromString("hello").indexOf(fromString("l"), 3)).isEqualTo(3);
assertThat(fromString("hello").indexOf(fromString("a"), 0)).isEqualTo(-1);
Expand Down Expand Up @@ -300,24 +299,21 @@ public void testToUpperLowerCase() {
writer.writeString(5, BinaryString.fromString("!@#$%^*"));
writer.complete();

assertThat(((BinaryString) row.getString(0)).toUpperCase()).isEqualTo(fromString("A"));
assertThat(((BinaryString) row.getString(1)).toUpperCase()).isEqualTo(fromString("我是中国人"));
assertThat(((BinaryString) row.getString(1)).toLowerCase()).isEqualTo(fromString("我是中国人"));
assertThat(((BinaryString) row.getString(3)).toUpperCase())
.isEqualTo(fromString("ABCDEFG"));
assertThat(((BinaryString) row.getString(3)).toLowerCase())
.isEqualTo(fromString("abcdefg"));
assertThat(((BinaryString) row.getString(5)).toUpperCase())
.isEqualTo(fromString("!@#$%^*"));
assertThat(((BinaryString) row.getString(5)).toLowerCase())
.isEqualTo(fromString("!@#$%^*"));
assertThat(row.getString(0).toUpperCase()).isEqualTo(fromString("A"));
assertThat(row.getString(1).toUpperCase()).isEqualTo(fromString("我是中国人"));
assertThat(row.getString(1).toLowerCase()).isEqualTo(fromString("我是中国人"));
assertThat(row.getString(3).toUpperCase()).isEqualTo(fromString("ABCDEFG"));
assertThat(row.getString(3).toLowerCase()).isEqualTo(fromString("abcdefg"));
assertThat(row.getString(5).toUpperCase()).isEqualTo(fromString("!@#$%^*"));
assertThat(row.getString(5).toLowerCase()).isEqualTo(fromString("!@#$%^*"));
}

@TestTemplate
public void testcastFrom() {
public void testCastFrom() {
class DecimalTestData {
private String str;
private int precision, scale;
private final String str;
private final int precision;
private final int scale;

private DecimalTestData(String str, int precision, int scale) {
this.str = str;
Expand Down Expand Up @@ -391,7 +387,7 @@ private DecimalTestData(String str, int precision, int scale) {
writer.complete();
for (int i = 0; i < data.length; i++) {
DecimalTestData d = data[i];
assertThat(castFrom((BinaryString) row.getString(i), d.precision, d.scale))
assertThat(castFrom(row.getString(i), d.precision, d.scale))
.isEqualTo(Decimal.fromBigDecimal(new BigDecimal(d.str), d.precision, d.scale));
}
}
Expand All @@ -407,21 +403,21 @@ public void testEmptyString() {
str3 = BinaryString.fromAddress(segments, 15, 0);
}

assertThat(BinaryString.EMPTY_UTF8.compareTo(str2)).isLessThan(0);
assertThat(str2.compareTo(BinaryString.EMPTY_UTF8)).isGreaterThan(0);
assertThat(EMPTY_UTF8.compareTo(str2)).isLessThan(0);
assertThat(str2.compareTo(EMPTY_UTF8)).isGreaterThan(0);

assertThat(BinaryString.EMPTY_UTF8.compareTo(str3)).isEqualTo(0);
assertThat(str3.compareTo(BinaryString.EMPTY_UTF8)).isEqualTo(0);
assertThat(EMPTY_UTF8.compareTo(str3)).isEqualTo(0);
assertThat(str3.compareTo(EMPTY_UTF8)).isEqualTo(0);

assertThat(str2).isNotEqualTo(BinaryString.EMPTY_UTF8);
assertThat(BinaryString.EMPTY_UTF8).isNotEqualTo(str2);
assertThat(str2).isNotEqualTo(EMPTY_UTF8);
assertThat(EMPTY_UTF8).isNotEqualTo(str2);

assertThat(str3).isEqualTo(BinaryString.EMPTY_UTF8);
assertThat(BinaryString.EMPTY_UTF8).isEqualTo(str3);
assertThat(str3).isEqualTo(EMPTY_UTF8);
assertThat(EMPTY_UTF8).isEqualTo(str3);
}

@TestTemplate
public void testEncodeWithIllegalCharacter() throws UnsupportedEncodingException {
public void testEncodeWithIllegalCharacter() {

// Tis char array has some illegal character, such as 55357
// the jdk ignores theses character and cast them to '?'
Expand All @@ -434,11 +430,11 @@ public void testEncodeWithIllegalCharacter() throws UnsupportedEncodingException

String str = new String(chars);

assertThat(BinaryString.encodeUTF8(str)).isEqualTo(str.getBytes("UTF-8"));
assertThat(BinaryString.encodeUTF8(str)).isEqualTo(str.getBytes(UTF_8));
}

@TestTemplate
public void testDecodeWithIllegalUtf8Bytes() throws UnsupportedEncodingException {
public void testDecodeWithIllegalUtf8Bytes() {

// illegal utf-8 bytes
byte[] bytes =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.paimon.types.RowType;

import java.util.List;
import java.util.stream.IntStream;

/** Utils for code generations. */
public class CodeGenUtils {
Expand All @@ -46,24 +47,32 @@ public static Projection newProjection(RowType inputType, int[] mapping) {
}

public static NormalizedKeyComputer newNormalizedKeyComputer(
List<DataType> fieldTypes, String name) {
List<DataType> inputTypes, int[] sortFields, String name) {
return CodeGenLoader.getCodeGenerator()
.generateNormalizedKeyComputer(fieldTypes, name)
.generateNormalizedKeyComputer(inputTypes, sortFields, name)
.newInstance(CodeGenUtils.class.getClassLoader());
}

public static GeneratedClass<RecordComparator> generateRecordComparator(
List<DataType> fieldTypes, String name) {
return CodeGenLoader.getCodeGenerator().generateRecordComparator(fieldTypes, name);
List<DataType> inputTypes, int[] sortFields, String name) {
return CodeGenLoader.getCodeGenerator()
.generateRecordComparator(inputTypes, sortFields, name);
}

public static GeneratedClass<RecordEqualiser> generateRecordEqualiser(
List<DataType> fieldTypes, String name) {
return CodeGenLoader.getCodeGenerator().generateRecordEqualiser(fieldTypes, name);
}

public static RecordComparator newRecordComparator(List<DataType> fieldTypes, String name) {
return generateRecordComparator(fieldTypes, name)
public static RecordComparator newRecordComparator(
List<DataType> inputTypes, int[] sortFields, String name) {
return generateRecordComparator(inputTypes, sortFields, name)
.newInstance(CodeGenUtils.class.getClassLoader());
}

public static RecordComparator newRecordComparator(List<DataType> inputTypes, String name) {
return generateRecordComparator(
inputTypes, IntStream.range(0, inputTypes.size()).toArray(), name)
.newInstance(CodeGenUtils.class.getClassLoader());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.IntStream;

import static org.apache.paimon.lookup.RocksDBOptions.BLOCK_CACHE_SIZE;
import static org.apache.paimon.utils.Preconditions.checkArgument;
Expand Down Expand Up @@ -285,8 +286,8 @@ private void bulkLoadBootstrapRecords() {
BinaryExternalSortBuffer keyIdBuffer =
BinaryExternalSortBuffer.create(
ioManager,
keyWithIdType,
keyWithRowType,
IntStream.range(0, keyWithIdType.getFieldCount()).toArray(),
coreOptions.writeBufferSize() / 2,
coreOptions.pageSize(),
coreOptions.localSortMaxNumFileHandles(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ public static BinaryExternalSortBuffer createBulkLoadSorter(
IOManager ioManager, CoreOptions options) {
return BinaryExternalSortBuffer.create(
ioManager,
RowType.of(DataTypes.BYTES()),
RowType.of(DataTypes.BYTES(), DataTypes.BYTES()),
new int[] {0},
options.writeBufferSize() / 2,
options.pageSize(),
options.localSortMaxNumFileHandles(),
Expand Down
Loading

0 comments on commit 2e0e236

Please sign in to comment.