Skip to content

Commit

Permalink
[core] Fix thread safety bug in Partial-update with agg (#3777)
Browse files Browse the repository at this point in the history
  • Loading branch information
JingsongLi authored Jul 18, 2024
1 parent 36d0d92 commit e1f6528
Showing 1 changed file with 36 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static org.apache.paimon.CoreOptions.FIELDS_PREFIX;
import static org.apache.paimon.CoreOptions.FIELDS_SEPARATOR;
import static org.apache.paimon.CoreOptions.PARTIAL_UPDATE_REMOVE_RECORD_ON_DELETE;
import static org.apache.paimon.mergetree.compact.aggregate.FieldAggregator.createFieldAggregator;
import static org.apache.paimon.utils.InternalRowUtils.createFieldGetters;

/**
Expand Down Expand Up @@ -268,9 +270,9 @@ private static class Factory implements MergeFunctionFactory<KeyValue> {

private final List<DataType> tableTypes;

private final Map<Integer, FieldsComparator> fieldSeqComparators;
private final Map<Integer, Supplier<FieldsComparator>> fieldSeqComparators;

private final Map<Integer, FieldAggregator> fieldAggregators;
private final Map<Integer, Supplier<FieldAggregator>> fieldAggregators;

private final boolean removeRecordOnDelete;

Expand All @@ -296,8 +298,8 @@ private Factory(Options options, RowType rowType, List<String> primaryKeys) {
.map(fieldName -> validateFieldName(fieldName, fieldNames))
.collect(Collectors.toList());

UserDefinedSeqComparator userDefinedSeqComparator =
UserDefinedSeqComparator.create(rowType, sequenceFields);
Supplier<FieldsComparator> userDefinedSeqComparator =
() -> UserDefinedSeqComparator.create(rowType, sequenceFields);
Arrays.stream(v.split(FIELDS_SEPARATOR))
.map(
fieldName ->
Expand Down Expand Up @@ -360,7 +362,8 @@ public MergeFunction<KeyValue> create(@Nullable int[][] projection) {
RowType newRowType = RowType.builder().fields(newDataTypes).build();

fieldSeqComparators.forEach(
(field, comparator) -> {
(field, comparatorSupplier) -> {
FieldsComparator comparator = comparatorSupplier.get();
int newField = indexMap.getOrDefault(field, -1);
if (newField != -1) {
int[] newSequenceFields =
Expand Down Expand Up @@ -390,7 +393,7 @@ public MergeFunction<KeyValue> create(@Nullable int[][] projection) {
});
for (int i = 0; i < projects.length; i++) {
if (fieldAggregators.containsKey(projects[i])) {
projectedAggregators.put(i, fieldAggregators.get(projects[i]));
projectedAggregators.put(i, fieldAggregators.get(projects[i]).get());
}
}

Expand All @@ -402,6 +405,12 @@ public MergeFunction<KeyValue> create(@Nullable int[][] projection) {
!fieldSeqComparators.isEmpty(),
removeRecordOnDelete);
} else {
Map<Integer, FieldsComparator> fieldSeqComparators = new HashMap<>();
this.fieldSeqComparators.forEach(
(f, supplier) -> fieldSeqComparators.put(f, supplier.get()));
Map<Integer, FieldAggregator> fieldAggregators = new HashMap<>();
this.fieldAggregators.forEach(
(f, supplier) -> fieldAggregators.put(f, supplier.get()));
return new PartialUpdateMergeFunction(
createFieldGetters(tableTypes),
ignoreDelete,
Expand All @@ -425,11 +434,12 @@ public AdjustedProjection adjustProjection(@Nullable int[][] projection) {
int[] topProjects = Projection.of(projection).toTopLevelIndexes();
Set<Integer> indexSet = Arrays.stream(topProjects).boxed().collect(Collectors.toSet());
for (int index : topProjects) {
FieldsComparator comparator = fieldSeqComparators.get(index);
if (comparator == null) {
Supplier<FieldsComparator> comparatorSupplier = fieldSeqComparators.get(index);
if (comparatorSupplier == null) {
continue;
}

FieldsComparator comparator = comparatorSupplier.get();
for (int field : comparator.compareFields()) {
if (!indexSet.contains(field)) {
extraFields.add(field);
Expand Down Expand Up @@ -464,12 +474,12 @@ private String validateFieldName(String fieldName, List<String> fieldNames) {
*
* @return The aggregators for each column.
*/
private Map<Integer, FieldAggregator> createFieldAggregators(
private Map<Integer, Supplier<FieldAggregator>> createFieldAggregators(
RowType rowType, List<String> primaryKeys, CoreOptions options) {

List<String> fieldNames = rowType.getFieldNames();
List<DataType> fieldTypes = rowType.getFieldTypes();
Map<Integer, FieldAggregator> fieldAggregators = new HashMap<>();
Map<Integer, Supplier<FieldAggregator>> fieldAggregators = new HashMap<>();
String defaultAggFunc = options.fieldsDefaultFunc();
for (int i = 0; i < fieldNames.size(); i++) {
String fieldName = fieldNames.get(i);
Expand All @@ -482,23 +492,25 @@ private Map<Integer, FieldAggregator> createFieldAggregators(
if (strAggFunc != null) {
fieldAggregators.put(
i,
FieldAggregator.createFieldAggregator(
fieldType,
strAggFunc,
ignoreRetract,
isPrimaryKey,
options,
fieldName));
() ->
createFieldAggregator(
fieldType,
strAggFunc,
ignoreRetract,
isPrimaryKey,
options,
fieldName));
} else if (defaultAggFunc != null) {
fieldAggregators.put(
i,
FieldAggregator.createFieldAggregator(
fieldType,
defaultAggFunc,
ignoreRetract,
isPrimaryKey,
options,
fieldName));
() ->
createFieldAggregator(
fieldType,
defaultAggFunc,
ignoreRetract,
isPrimaryKey,
options,
fieldName));
}
}
return fieldAggregators;
Expand Down

0 comments on commit e1f6528

Please sign in to comment.