Skip to content

Commit

Permalink
[followup]
Browse files Browse the repository at this point in the history
  • Loading branch information
YannByron committed Jun 4, 2024
1 parent 2d4ec0a commit d4d12a9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;

/** Utils for spark {@link DataType}. */
public class SparkTypeUtils {
Expand Down Expand Up @@ -206,34 +207,42 @@ protected DataType defaultMethod(org.apache.paimon.types.DataType dataType) {
private static class SparkToPaimonTypeVisitor {

static org.apache.paimon.types.DataType visit(DataType type) {
return visit(type, new SparkToPaimonTypeVisitor());
AtomicInteger atomicInteger = new AtomicInteger(-1);
return visit(type, new SparkToPaimonTypeVisitor(), atomicInteger);
}

static org.apache.paimon.types.DataType visit(
DataType type, SparkToPaimonTypeVisitor visitor) {
DataType type, SparkToPaimonTypeVisitor visitor, AtomicInteger atomicInteger) {
if (type instanceof StructType) {
StructField[] fields = ((StructType) type).fields();
List<org.apache.paimon.types.DataType> fieldResults =
new ArrayList<>(fields.length);

for (StructField field : fields) {
fieldResults.add(visit(field.dataType(), visitor));
fieldResults.add(visit(field.dataType(), visitor, atomicInteger));
}

return visitor.struct((StructType) type, fieldResults);
return visitor.struct((StructType) type, fieldResults, atomicInteger);

} else if (type instanceof org.apache.spark.sql.types.MapType) {
return visitor.map(
(org.apache.spark.sql.types.MapType) type,
visit(((org.apache.spark.sql.types.MapType) type).keyType(), visitor),
visit(((org.apache.spark.sql.types.MapType) type).valueType(), visitor));
visit(
((org.apache.spark.sql.types.MapType) type).keyType(),
visitor,
atomicInteger),
visit(
((org.apache.spark.sql.types.MapType) type).valueType(),
visitor,
atomicInteger));

} else if (type instanceof org.apache.spark.sql.types.ArrayType) {
return visitor.array(
(org.apache.spark.sql.types.ArrayType) type,
visit(
((org.apache.spark.sql.types.ArrayType) type).elementType(),
visitor));
visitor,
atomicInteger));

} else if (type instanceof UserDefinedType) {
throw new UnsupportedOperationException("User-defined types are not supported");
Expand All @@ -244,15 +253,19 @@ static org.apache.paimon.types.DataType visit(
}

public org.apache.paimon.types.DataType struct(
StructType struct, List<org.apache.paimon.types.DataType> fieldResults) {
StructType struct,
List<org.apache.paimon.types.DataType> fieldResults,
AtomicInteger atomicInteger) {
StructField[] fields = struct.fields();
List<DataField> newFields = new ArrayList<>(fields.length);
for (int i = 0; i < fields.length; i += 1) {
for (int i = 0; i < fields.length; i++) {
StructField field = fields[i];
org.apache.paimon.types.DataType fieldType =
fieldResults.get(i).copy(field.nullable());
String comment = field.getComment().getOrElse(() -> null);
newFields.add(new DataField(i, field.name(), fieldType, comment));
newFields.add(
new DataField(
atomicInteger.incrementAndGet(), field.name(), fieldType, comment));
}

return new RowType(newFields);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.concurrent.atomic.AtomicInteger;

import static org.apache.paimon.spark.SparkTypeUtils.fromPaimonRowType;
import static org.apache.paimon.spark.SparkTypeUtils.toPaimonType;
import static org.assertj.core.api.Assertions.assertThat;

/** Test for {@link SparkTypeUtils}. */
Expand Down Expand Up @@ -76,14 +77,14 @@ public void testAllTypes() {
String nestedRowMapType =
"StructField(locations,MapType("
+ "StringType,"
+ "StructType(StructField(posX,DoubleType,true),StructField(posY,DoubleType,true)),true),true)";
+ "StructType(StructField(posX,DoubleType,false),StructField(posY,DoubleType,false)),true),true)";
String expected =
"StructType("
+ "StructField(id,IntegerType,true),"
+ "StructField(id,IntegerType,false),"
+ "StructField(name,StringType,true),"
+ "StructField(char,CharType(10),true),"
+ "StructField(varchar,VarcharType(10),true),"
+ "StructField(salary,DoubleType,true),"
+ "StructField(salary,DoubleType,false),"
+ nestedRowMapType
+ ","
+ "StructField(strArray,ArrayType(StringType,true),true),"
Expand All @@ -102,7 +103,6 @@ public void testAllTypes() {
StructType sparkType = fromPaimonRowType(ALL_TYPES);
assertThat(sparkType.toString().replace(", ", ",")).isEqualTo(expected);

// Ignore the assertion below, since we force to make all the fields nullable.
// assertThat(toPaimonType(sparkType)).isEqualTo(ALL_TYPES);
assertThat(toPaimonType(sparkType)).isEqualTo(ALL_TYPES);
}
}

0 comments on commit d4d12a9

Please sign in to comment.