Skip to content

Commit

Permalink
[core] Refactor the Field Aggregator factory to validate data types w…
Browse files Browse the repository at this point in the history
…hen creating function (#4446)
  • Loading branch information
zhuangchong authored Nov 5, 2024
1 parent 992a406 commit 9339ee6
Show file tree
Hide file tree
Showing 49 changed files with 591 additions and 603 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.paimon.data.GenericRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.mergetree.compact.aggregate.FieldAggregator;
import org.apache.paimon.mergetree.compact.aggregate.factory.FieldAggregatorFactory;
import org.apache.paimon.options.Options;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.DataType;
Expand Down Expand Up @@ -51,7 +52,6 @@
import static org.apache.paimon.CoreOptions.FIELDS_PREFIX;
import static org.apache.paimon.CoreOptions.FIELDS_SEPARATOR;
import static org.apache.paimon.CoreOptions.PARTIAL_UPDATE_REMOVE_RECORD_ON_DELETE;
import static org.apache.paimon.mergetree.compact.aggregate.FieldAggregator.createFieldAggregator;
import static org.apache.paimon.utils.InternalRowUtils.createFieldGetters;

/**
Expand Down Expand Up @@ -495,7 +495,7 @@ private Map<Integer, Supplier<FieldAggregator>> createFieldAggregators(
fieldAggregators.put(
i,
() ->
createFieldAggregator(
FieldAggregatorFactory.create(
fieldType,
strAggFunc,
ignoreRetract,
Expand All @@ -506,7 +506,7 @@ private Map<Integer, Supplier<FieldAggregator>> createFieldAggregators(
fieldAggregators.put(
i,
() ->
createFieldAggregator(
FieldAggregatorFactory.create(
fieldType,
defaultAggFunc,
ignoreRetract,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.mergetree.compact.MergeFunction;
import org.apache.paimon.mergetree.compact.MergeFunctionFactory;
import org.apache.paimon.mergetree.compact.aggregate.factory.FieldAggregatorFactory;
import org.apache.paimon.options.Options;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.RowKind;
Expand Down Expand Up @@ -142,7 +143,7 @@ public MergeFunction<KeyValue> create(@Nullable int[][] projection) {

boolean ignoreRetract = options.fieldAggIgnoreRetract(fieldName);
fieldAggregators[i] =
FieldAggregator.createFieldAggregator(
FieldAggregatorFactory.create(
fieldType,
strAggFunc,
ignoreRetract,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,62 +18,23 @@

package org.apache.paimon.mergetree.compact.aggregate;

import org.apache.paimon.CoreOptions;
import org.apache.paimon.factories.FactoryUtil;
import org.apache.paimon.mergetree.compact.aggregate.factory.FieldAggregatorFactory;
import org.apache.paimon.types.DataType;

import javax.annotation.Nullable;

import java.io.Serializable;

/** abstract class of aggregating a field of a row. */
public abstract class FieldAggregator implements Serializable {
protected DataType fieldType;

private static final long serialVersionUID = 1L;

public FieldAggregator(DataType dataType) {
this.fieldType = dataType;
}

public static FieldAggregator createFieldAggregator(
DataType fieldType,
@Nullable String strAgg,
boolean ignoreRetract,
boolean isPrimaryKey,
CoreOptions options,
String field) {
FieldAggregator fieldAggregator;
if (isPrimaryKey) {
strAgg = FieldPrimaryKeyAgg.NAME;
} else if (strAgg == null) {
strAgg = FieldLastNonNullValueAgg.NAME;
}

FieldAggregatorFactory fieldAggregatorFactory =
FactoryUtil.discoverFactory(
FieldAggregator.class.getClassLoader(),
FieldAggregatorFactory.class,
strAgg);
if (fieldAggregatorFactory == null) {
throw new RuntimeException(
String.format(
"Use unsupported aggregation: %s or spell aggregate function incorrectly!",
strAgg));
}
protected final DataType fieldType;
protected final String name;

fieldAggregator = fieldAggregatorFactory.create(fieldType, options, field);

if (ignoreRetract) {
fieldAggregator = new FieldIgnoreRetractAgg(fieldAggregator);
}

return fieldAggregator;
public FieldAggregator(String name, DataType dataType) {
this.name = name;
this.fieldType = dataType;
}

public abstract String name();

public abstract Object agg(Object accumulator, Object inputField);

public Object aggReversed(Object accumulator, Object inputField) {
Expand All @@ -89,6 +50,6 @@ public Object retract(Object accumulator, Object retractField) {
"Aggregate function '%s' does not support retraction,"
+ " If you allow this function to ignore retraction messages,"
+ " you can configure 'fields.${field_name}.ignore-retract'='true'.",
name()));
name));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,22 @@

package org.apache.paimon.mergetree.compact.aggregate;

import org.apache.paimon.types.DataType;
import org.apache.paimon.types.BooleanType;

/** bool_and aggregate a field of a row. */
public class FieldBoolAndAgg extends FieldAggregator {

public static final String NAME = "bool_and";

private static final long serialVersionUID = 1L;

public FieldBoolAndAgg(DataType dataType) {
super(dataType);
}

@Override
public String name() {
return NAME;
public FieldBoolAndAgg(String name, BooleanType dataType) {
super(name, dataType);
}

@Override
public Object agg(Object accumulator, Object inputField) {
Object boolAnd;

if (accumulator == null || inputField == null) {
boolAnd = (inputField == null) ? accumulator : inputField;
} else {
switch (fieldType.getTypeRoot()) {
case BOOLEAN:
boolAnd = (boolean) accumulator && (boolean) inputField;
break;
default:
String msg =
String.format(
"type %s not support in %s",
fieldType.getTypeRoot().toString(), this.getClass().getName());
throw new IllegalArgumentException(msg);
}
return accumulator == null ? inputField : accumulator;
}
return boolAnd;
return (boolean) accumulator && (boolean) inputField;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,22 @@

package org.apache.paimon.mergetree.compact.aggregate;

import org.apache.paimon.types.DataType;
import org.apache.paimon.types.BooleanType;

/** bool_or aggregate a field of a row. */
public class FieldBoolOrAgg extends FieldAggregator {

public static final String NAME = "bool_or";

private static final long serialVersionUID = 1L;

public FieldBoolOrAgg(DataType dataType) {
super(dataType);
}

@Override
public String name() {
return NAME;
public FieldBoolOrAgg(String name, BooleanType dataType) {
super(name, dataType);
}

@Override
public Object agg(Object accumulator, Object inputField) {
Object boolOr;

if (accumulator == null || inputField == null) {
boolOr = (inputField == null) ? accumulator : inputField;
} else {
switch (fieldType.getTypeRoot()) {
case BOOLEAN:
boolOr = (boolean) accumulator || (boolean) inputField;
break;
default:
String msg =
String.format(
"type %s not support in %s",
fieldType.getTypeRoot().toString(), this.getClass().getName());
throw new IllegalArgumentException(msg);
}
return accumulator == null ? inputField : accumulator;
}
return boolOr;
return (boolean) accumulator || (boolean) inputField;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,14 @@
/** Collect elements into an ARRAY. */
public class FieldCollectAgg extends FieldAggregator {

public static final String NAME = "collect";

private static final long serialVersionUID = 1L;

private final boolean distinct;
private final InternalArray.ElementGetter elementGetter;
@Nullable private final BiFunction<Object, Object, Boolean> equaliser;

public FieldCollectAgg(ArrayType dataType, boolean distinct) {
super(dataType);
public FieldCollectAgg(String name, ArrayType dataType, boolean distinct) {
super(name, dataType);
this.distinct = distinct;
this.elementGetter = InternalArray.createElementGetter(dataType.getElementType());

Expand Down Expand Up @@ -84,11 +82,6 @@ public FieldCollectAgg(ArrayType dataType, boolean distinct) {
}
}

@Override
public String name() {
return NAME;
}

@Override
public Object aggReversed(Object accumulator, Object inputField) {
// we don't need to actually do the reverse here for this agg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,12 @@
/** first non-null value aggregate a field of a row. */
public class FieldFirstNonNullValueAgg extends FieldAggregator {

public static final String NAME = "first_non_null_value";
public static final String LEGACY_NAME = "first_not_null_value";

private static final long serialVersionUID = 1L;

private boolean initialized;

public FieldFirstNonNullValueAgg(DataType dataType) {
super(dataType);
}

@Override
public String name() {
return NAME;
public FieldFirstNonNullValueAgg(String name, DataType dataType) {
super(name, dataType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,12 @@
/** first value aggregate a field of a row. */
public class FieldFirstValueAgg extends FieldAggregator {

public static final String NAME = "first_value";

private static final long serialVersionUID = 1L;

private boolean initialized;

public FieldFirstValueAgg(DataType dataType) {
super(dataType);
}

@Override
public String name() {
return NAME;
public FieldFirstValueAgg(String name, DataType dataType) {
super(name, dataType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,14 @@
/** HllSketch aggregate a field of a row. */
public class FieldHllSketchAgg extends FieldAggregator {

public static final String NAME = "hll_sketch";

private static final long serialVersionUID = 1L;

public FieldHllSketchAgg(VarBinaryType dataType) {
super(dataType);
}

@Override
public String name() {
return NAME;
public FieldHllSketchAgg(String name, VarBinaryType dataType) {
super(name, dataType);
}

@Override
public Object agg(Object accumulator, Object inputField) {
if (accumulator == null && inputField == null) {
return null;
}

if (accumulator == null || inputField == null) {
return accumulator == null ? inputField : accumulator;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,15 @@
/** An aggregator which ignores retraction messages. */
public class FieldIgnoreRetractAgg extends FieldAggregator {

private final FieldAggregator aggregator;

private static final long serialVersionUID = 1L;

private final FieldAggregator aggregator;

public FieldIgnoreRetractAgg(FieldAggregator aggregator) {
super(aggregator.fieldType);
super(aggregator.name, aggregator.fieldType);
this.aggregator = aggregator;
}

@Override
public String name() {
return aggregator.name();
}

@Override
public Object agg(Object accumulator, Object inputField) {
return aggregator.agg(accumulator, inputField);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,10 @@
/** last non-null value aggregate a field of a row. */
public class FieldLastNonNullValueAgg extends FieldAggregator {

public static final String NAME = "last_non_null_value";

private static final long serialVersionUID = 1L;

public FieldLastNonNullValueAgg(DataType dataType) {
super(dataType);
}

@Override
public String name() {
return NAME;
public FieldLastNonNullValueAgg(String name, DataType dataType) {
super(name, dataType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,10 @@
/** last value aggregate a field of a row. */
public class FieldLastValueAgg extends FieldAggregator {

public static final String NAME = "last_value";

private static final long serialVersionUID = 1L;

public FieldLastValueAgg(DataType dataType) {
super(dataType);
}

@Override
public String name() {
return NAME;
public FieldLastValueAgg(String name, DataType dataType) {
super(name, dataType);
}

@Override
Expand Down
Loading

0 comments on commit 9339ee6

Please sign in to comment.