From b194e192f23e761f91f56281665b679039592313 Mon Sep 17 00:00:00 2001 From: askwang <135721692+askwang@users.noreply.github.com> Date: Tue, 6 Aug 2024 17:42:56 +0800 Subject: [PATCH] [spark] Support lower case callback param if use dataframe write (#3839) --- .../java/org/apache/paimon/CoreOptions.java | 7 +- .../paimon/spark/PaimonCommitTest.scala | 77 +++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonCommitTest.scala diff --git a/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java b/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java index 93ada725fddb..660580284f64 100644 --- a/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java +++ b/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java @@ -49,6 +49,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -1959,7 +1960,11 @@ private Map callbacks( continue; } - String param = options.get(callbackParam.key().replace("#", className)); + String originParamKey = callbackParam.key().replace("#", className); + String param = options.get(originParamKey); + if (param == null) { + param = options.get(originParamKey.toLowerCase(Locale.ROOT)); + } result.put(className, param); } return result; diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonCommitTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonCommitTest.scala new file mode 100644 index 000000000000..0095e1024a86 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonCommitTest.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark + +import org.apache.paimon.CoreOptions +import org.apache.paimon.manifest.{ManifestCommittable, ManifestEntry} +import org.apache.paimon.table.sink.CommitCallback + +import org.junit.jupiter.api.Assertions + +import java.lang +import java.util.List + +class PaimonCommitTest extends PaimonSparkTestBase { + + test("test commit callback parameter compatibility") { + withTable("tb") { + spark.sql(""" + |CREATE TABLE tb (id int, dt string) using paimon + |TBLPROPERTIES ('file.format'='parquet', 'primary-key'='id', 'bucket'='1') + |""".stripMargin) + + val table = loadTable("tb") + val location = table.location().toString + + val _spark = spark + import _spark.implicits._ + val df = Seq((1, "a"), (2, "b")).toDF("a", "b") + df.write + .format("paimon") + .option(CoreOptions.COMMIT_CALLBACKS.key(), classOf[CustomCommitCallback].getName) + .option( + CoreOptions.COMMIT_CALLBACK_PARAM + .key() + .replace("#", classOf[CustomCommitCallback].getName), + "testid-100") + .mode("append") + .save(location) + + Assertions.assertEquals(PaimonCommitTest.id, "testid-100") + } + } +} + +object PaimonCommitTest { + var id = "" +} + +case class CustomCommitCallback(testId: String) extends CommitCallback { + + override def call( + committedEntries: List[ManifestEntry], + identifier: Long, + watermark: lang.Long): Unit = { + PaimonCommitTest.id = testId + } + + override def retry(committable: ManifestCommittable): Unit = {} + + override def close(): Unit = {} +}