Skip to content

Commit

Permalink
[core] Support default aggregate function for partial update and aggr…
Browse files Browse the repository at this point in the history
…egate merge function (apache#3374)
  • Loading branch information
FangYongs authored and joyCurry30 committed May 30, 2024
1 parent abc38dd commit bef29ae
Show file tree
Hide file tree
Showing 10 changed files with 329 additions and 1 deletion.
28 changes: 28 additions & 0 deletions docs/content/primary-key-table/merge-engine.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,34 @@ INSERT INTO t VALUES (1, CAST(NULL AS INT), CAST(NULL AS INT), 2, 2);
SELECT * FROM t; -- output 1, 2, 1, 2, 3
```

You can specify a default aggregation function for all the input fields with `fields.default-aggregate-function`, see example:

```sql
CREATE TABLE t (
k INT,
a INT,
b INT,
c INT,
d INT,
PRIMARY KEY (k) NOT ENFORCED
) WITH (
'merge-engine'='partial-update',
'fields.a.sequence-group' = 'b',
'fields.c.sequence-group' = 'd',
'fields.default-aggregate-function' = 'last_non_null_value',
'fields.d.aggregate-function' = 'sum'
);

INSERT INTO t VALUES (1, 1, 1, CAST(NULL AS INT), CAST(NULL AS INT));
INSERT INTO t VALUES (1, CAST(NULL AS INT), CAST(NULL AS INT), 1, 1);
INSERT INTO t VALUES (1, 2, 2, CAST(NULL AS INT), CAST(NULL AS INT));
INSERT INTO t VALUES (1, CAST(NULL AS INT), CAST(NULL AS INT), 2, 2);


SELECT * FROM t; -- output 1, 2, 2, 2, 3

```

## Aggregation

{{< hint info >}}
Expand Down
6 changes: 6 additions & 0 deletions docs/layouts/shortcodes/generated/core_configuration.html
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@
<td>Boolean</td>
<td>Whether only overwrite dynamic partition when overwriting a partitioned table with dynamic partition columns. Works only when the table has partition keys.</td>
</tr>
<tr>
<td><h5>fields.default-aggregate-function</h5></td>
<td style="word-wrap: break-word;">(none)</td>
<td>String</td>
<td>Default aggregate function of all fields for partial-update and aggregate merge function</td>
</tr>
<tr>
<td><h5>file-index.in-manifest-threshold</h5></td>
<td style="word-wrap: break-word;">500 bytes</td>
Expand Down
16 changes: 16 additions & 0 deletions paimon-common/src/main/java/org/apache/paimon/CoreOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public class CoreOptions implements Serializable {
public static final String FIELDS_PREFIX = "fields";

public static final String AGG_FUNCTION = "aggregate-function";
public static final String DEFAULT_AGG_FUNCTION = "default-aggregate-function";

public static final String IGNORE_RETRACT = "ignore-retract";

Expand Down Expand Up @@ -1151,6 +1152,13 @@ public class CoreOptions implements Serializable {
.withDescription(
"Time field for record level expire, it should be a seconds INT.");

public static final ConfigOption<String> FIELDS_DEFAULT_AGG_FUNC =
key(FIELDS_PREFIX + "." + DEFAULT_AGG_FUNCTION)
.stringType()
.noDefaultValue()
.withDescription(
"Default aggregate function of all fields for partial-update and aggregate merge function");

private final Options options;

public CoreOptions(Map<String, String> options) {
Expand Down Expand Up @@ -1260,7 +1268,15 @@ private static String normalizeFileFormat(String fileFormat) {
return fileFormat.toLowerCase();
}

public String fieldsDefaultFunc() {
return options.get(FIELDS_DEFAULT_AGG_FUNC);
}

public boolean definedAggFunc() {
if (options.contains(FIELDS_DEFAULT_AGG_FUNC)) {
return true;
}

for (String key : options.toMap().keySet()) {
if (key.startsWith(FIELDS_PREFIX) && key.endsWith(AGG_FUNCTION)) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ private Map<Integer, FieldAggregator> createFieldAggregators(
List<String> fieldNames = rowType.getFieldNames();
List<DataType> fieldTypes = rowType.getFieldTypes();
Map<Integer, FieldAggregator> fieldAggregators = new HashMap<>();
String defaultAggFunc = options.fieldsDefaultFunc();
for (int i = 0; i < fieldNames.size(); i++) {
String fieldName = fieldNames.get(i);
DataType fieldType = fieldTypes.get(i);
Expand All @@ -379,6 +380,16 @@ private Map<Integer, FieldAggregator> createFieldAggregators(
isPrimaryKey,
options,
fieldName));
} else if (defaultAggFunc != null) {
fieldAggregators.put(
i,
FieldAggregator.createFieldAggregator(
fieldType,
defaultAggFunc,
ignoreRetract,
isPrimaryKey,
options,
fieldName));
}
}
return fieldAggregators;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,15 @@ public MergeFunction<KeyValue> create(@Nullable int[][] projection) {
}

FieldAggregator[] fieldAggregators = new FieldAggregator[fieldNames.size()];
String defaultAggFunc = options.fieldsDefaultFunc();
for (int i = 0; i < fieldNames.size(); i++) {
String fieldName = fieldNames.get(i);
DataType fieldType = fieldTypes.get(i);
// aggregate by primary keys, so they do not aggregate
boolean isPrimaryKey = primaryKeys.contains(fieldName);
String strAggFunc = options.fieldAggFunc(fieldName);
strAggFunc = strAggFunc == null ? defaultAggFunc : strAggFunc;

boolean ignoreRetract = options.fieldAggIgnoreRetract(fieldName);
fieldAggregators[i] =
FieldAggregator.createFieldAggregator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import static org.apache.paimon.CoreOptions.CHANGELOG_NUM_RETAINED_MAX;
import static org.apache.paimon.CoreOptions.CHANGELOG_NUM_RETAINED_MIN;
import static org.apache.paimon.CoreOptions.CHANGELOG_PRODUCER;
import static org.apache.paimon.CoreOptions.DEFAULT_AGG_FUNCTION;
import static org.apache.paimon.CoreOptions.FIELDS_PREFIX;
import static org.apache.paimon.CoreOptions.FULL_COMPACTION_DELTA_COMMITS;
import static org.apache.paimon.CoreOptions.INCREMENTAL_BETWEEN;
Expand Down Expand Up @@ -349,7 +350,8 @@ private static void validateFieldsPrefix(TableSchema schema, CoreOptions options
if (k.startsWith(FIELDS_PREFIX)) {
String fieldName = k.split("\\.")[1];
checkArgument(
fieldNames.contains(fieldName),
DEFAULT_AGG_FUNCTION.equals(fieldName)
|| fieldNames.contains(fieldName),
String.format(
"Field %s can not be found in table schema.",
fieldName));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import org.junit.jupiter.api.Test;

import static org.apache.paimon.CoreOptions.FIELDS_DEFAULT_AGG_FUNC;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

Expand Down Expand Up @@ -94,6 +95,34 @@ public void testSequenceGroup() {
validate(func, 1, null, null, 6, null, null, 6);
}

@Test
public void testSequenceGroupDefaultAggFunc() {
Options options = new Options();
options.set("fields.f3.sequence-group", "f1,f2");
options.set("fields.f6.sequence-group", "f4,f5");
options.set(FIELDS_DEFAULT_AGG_FUNC, "last_non_null_value");
RowType rowType =
RowType.of(
DataTypes.INT(),
DataTypes.INT(),
DataTypes.INT(),
DataTypes.INT(),
DataTypes.INT(),
DataTypes.INT(),
DataTypes.INT());
MergeFunction<KeyValue> func =
PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0"))
.create();
func.reset();
add(func, 1, 1, 1, 1, 1, 1, 1);
add(func, 1, 2, 2, 2, 2, 2, null);
validate(func, 1, 2, 2, 2, 1, 1, 1);
add(func, 1, 3, 3, 1, 3, 3, 3);
validate(func, 1, 2, 2, 2, 3, 3, 3);
add(func, 1, 4, null, 4, 5, null, 5);
validate(func, 1, 4, 2, 4, 5, 3, 5);
}

@Test
public void testSequenceGroupDefinedNoField() {
Options options = new Options();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.mergetree.compact.aggregate;

import org.apache.paimon.KeyValue;
import org.apache.paimon.data.GenericRow;
import org.apache.paimon.mergetree.compact.MergeFunction;
import org.apache.paimon.options.Options;
import org.apache.paimon.types.DataTypes;
import org.apache.paimon.types.RowKind;

import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.Collections;

import static org.apache.paimon.CoreOptions.FIELDS_DEFAULT_AGG_FUNC;
import static org.assertj.core.api.Assertions.assertThat;

/** Test for aggregate merge function. */
class AggregateMergeFunctionTest {
@Test
void testDefaultAggFunc() {
Options options = new Options();
options.set(FIELDS_DEFAULT_AGG_FUNC, "first_non_null_value");
options.set("fields.b.aggregate-function", "sum");
MergeFunction<KeyValue> aggregateFunction =
AggregateMergeFunction.factory(
options,
Arrays.asList("k", "a", "b", "c", "d"),
Arrays.asList(
DataTypes.INT(),
DataTypes.INT(),
DataTypes.INT(),
DataTypes.INT(),
DataTypes.INT()),
Collections.singletonList("k"))
.create();
aggregateFunction.reset();

aggregateFunction.add(value(1, null, 1, 1, 1));
aggregateFunction.add(value(1, 2, null, 2, 2));
aggregateFunction.add(value(1, 3, 3, null, 3));
aggregateFunction.add(value(1, 4, 4, 4, null));
aggregateFunction.add(value(1, 5, 5, 5, 5));
assertThat(aggregateFunction.getResult().value()).isEqualTo(GenericRow.of(1, 2, 13, 1, 1));
}

private KeyValue value(Integer... values) {
return new KeyValue()
.replace(GenericRow.of(values[0]), RowKind.INSERT, GenericRow.of(values));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,42 @@ public void testSequenceGroup() {
assertThat(sql("SELECT c, d FROM SG")).containsExactlyInAnyOrder(Row.of(5, null));
}

@Test
public void testSequenceGroupWithDefaultAggFunc() {
sql(
"CREATE TABLE SG ("
+ "k INT, a INT, b INT, g_1 INT, c INT, d INT, g_2 INT, PRIMARY KEY (k) NOT ENFORCED)"
+ " WITH ("
+ "'merge-engine'='partial-update', "
+ "'fields.g_1.sequence-group'='a,b', "
+ "'fields.g_2.sequence-group'='c,d', "
+ "'fields.default-aggregate-function'='last_non_null_value');");

sql("INSERT INTO SG VALUES (1, 1, 1, 1, 1, 1, 1)");

// g_2 should not be updated
sql("INSERT INTO SG VALUES (1, 2, 2, 2, 2, 2, CAST(NULL AS INT))");

// select *
assertThat(sql("SELECT * FROM SG")).containsExactlyInAnyOrder(Row.of(1, 2, 2, 2, 1, 1, 1));

// projection
assertThat(sql("SELECT c, d FROM SG")).containsExactlyInAnyOrder(Row.of(1, 1));

// g_1 should not be updated
sql("INSERT INTO SG VALUES (1, 3, 3, 1, 3, 3, 3)");

assertThat(sql("SELECT * FROM SG")).containsExactlyInAnyOrder(Row.of(1, 2, 2, 2, 3, 3, 3));

// d should not be updated by null
sql("INSERT INTO SG VALUES (1, 3, 3, 3, 2, 2, CAST(NULL AS INT))");
sql("INSERT INTO SG VALUES (1, 4, 4, 4, 2, 2, CAST(NULL AS INT))");
sql("INSERT INTO SG VALUES (1, 5, 5, 3, 5, CAST(NULL AS INT), 4)");

assertThat(sql("SELECT a, b FROM SG")).containsExactlyInAnyOrder(Row.of(4, 4));
assertThat(sql("SELECT c, d FROM SG")).containsExactlyInAnyOrder(Row.of(5, 3));
}

@Test
public void testInvalidSequenceGroup() {
Assertions.assertThatThrownBy(
Expand Down Expand Up @@ -319,6 +355,46 @@ public void testPartialUpdateWithAggregation() {
assertThat(sql("SELECT a, b, c FROM AGG")).containsExactlyInAnyOrder(Row.of(6, 3, null));
}

@Test
public void testPartialUpdateWithDefaultAndFieldAggregation() {
sql(
"CREATE TABLE AGG ("
+ "k INT, a INT, b INT, g_1 INT, c VARCHAR, g_2 INT, PRIMARY KEY (k) NOT ENFORCED)"
+ " WITH ("
+ "'merge-engine'='partial-update', "
+ "'fields.a.aggregate-function'='sum', "
+ "'fields.g_1.sequence-group'='a', "
+ "'fields.g_2.sequence-group'='c', "
+ "'fields.default-aggregate-function'='last_non_null_value');");
// a in group g_1 with sum agg
// b not in group
// c in group g_2 without agg

sql("INSERT INTO AGG VALUES (1, 1, 1, 1, '1', 1)");

// g_2 should not be updated
sql("INSERT INTO AGG VALUES (1, 2, 2, 2, '2', CAST(NULL AS INT))");

// select *
assertThat(sql("SELECT * FROM AGG")).containsExactlyInAnyOrder(Row.of(1, 3, 2, 2, "1", 1));

// projection
assertThat(sql("SELECT a, c FROM AGG")).containsExactlyInAnyOrder(Row.of(3, "1"));

// g_1 should not be updated
sql("INSERT INTO AGG VALUES (1, 3, 3, 1, '3', 3)");

assertThat(sql("SELECT * FROM AGG")).containsExactlyInAnyOrder(Row.of(1, 6, 3, 2, "3", 3));

sql(
"INSERT INTO AGG VALUES (1, CAST(NULL AS INT), CAST(NULL AS INT), 2, CAST(NULL AS VARCHAR), 4)");

// a keep the last accumulator
// b is not updated to null
// c is updated to "3" for default agg func last_non_null_value
assertThat(sql("SELECT a, b, c FROM AGG")).containsExactlyInAnyOrder(Row.of(6, 3, "3"));
}

@Test
public void testFirstValuePartialUpdate() {
sql(
Expand Down
Loading

0 comments on commit bef29ae

Please sign in to comment.