Skip to content

Commit

Permalink
adjust param type
Browse files Browse the repository at this point in the history
  • Loading branch information
askwang committed Sep 13, 2024
1 parent 5564403 commit 6400f12
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

import org.apache.paimon.options.ExpireConfig;
import org.apache.paimon.table.ExpireSnapshots;
import org.apache.paimon.utils.DateTimeUtils;
import org.apache.paimon.utils.StringUtils;
import org.apache.paimon.utils.TimeUtils;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.Identifier;
Expand All @@ -29,10 +32,10 @@
import org.apache.spark.sql.types.StructType;

import java.time.Duration;
import java.util.TimeZone;

import static org.apache.spark.sql.types.DataTypes.IntegerType;
import static org.apache.spark.sql.types.DataTypes.StringType;
import static org.apache.spark.sql.types.DataTypes.TimestampType;

/** A procedure to expire snapshots. */
public class ExpireSnapshotsProcedure extends BaseProcedure {
Expand All @@ -42,8 +45,9 @@ public class ExpireSnapshotsProcedure extends BaseProcedure {
ProcedureParameter.required("table", StringType),
ProcedureParameter.optional("retain_max", IntegerType),
ProcedureParameter.optional("retain_min", IntegerType),
ProcedureParameter.optional("older_than", TimestampType),
ProcedureParameter.optional("max_deletes", IntegerType)
ProcedureParameter.optional("older_than", StringType),
ProcedureParameter.optional("max_deletes", IntegerType),
ProcedureParameter.optional("time_retained", StringType)
};

private static final StructType OUTPUT_TYPE =
Expand Down Expand Up @@ -72,8 +76,11 @@ public InternalRow[] call(InternalRow args) {
Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name());
Integer retainMax = args.isNullAt(1) ? null : args.getInt(1);
Integer retainMin = args.isNullAt(2) ? null : args.getInt(2);
Long olderThanMills = args.isNullAt(3) ? null : args.getLong(3) / 1000;
String olderThanStr = args.isNullAt(3) ? null : args.getString(3);
Integer maxDeletes = args.isNullAt(4) ? null : args.getInt(4);
Duration timeRetained =
args.isNullAt(5) ? null : TimeUtils.parseDuration(args.getString(5));

return modifyPaimonTable(
tableIdent,
table -> {
Expand All @@ -85,13 +92,32 @@ public InternalRow[] call(InternalRow args) {
if (retainMin != null) {
builder.snapshotRetainMin(retainMin);
}
if (olderThanMills != null) {
builder.snapshotTimeRetain(
Duration.ofMillis(System.currentTimeMillis() - olderThanMills));
if (!StringUtils.isBlank(olderThanStr) && timeRetained != null) {
throw new IllegalArgumentException(
"older_than and time_retained cannot be used together.");
}
if (!StringUtils.isBlank(olderThanStr)) {
long olderThanMills;
// forward compatibility for timestamp type
if (StringUtils.isNumeric(olderThanStr)) {
olderThanMills = Long.parseLong(olderThanStr) / 1000;
builder.snapshotTimeRetain(
Duration.ofMillis(System.currentTimeMillis() - olderThanMills));
} else {
olderThanMills =
DateTimeUtils.parseTimestampData(
olderThanStr, 3, TimeZone.getDefault())
.getMillisecond();
builder.snapshotTimeRetain(
Duration.ofMillis(System.currentTimeMillis() - olderThanMills));
}
}
if (maxDeletes != null) {
builder.snapshotMaxDeletes(maxDeletes);
}
if (timeRetained != null) {
builder.snapshotTimeRetain(timeRetained);
}
int deleted = expireSnapshots.config(builder.build()).expire();
return new InternalRow[] {newInternalRow(deleted)};
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@
package org.apache.paimon.spark.procedure

import org.apache.paimon.spark.PaimonSparkTestBase
import org.apache.paimon.utils.SnapshotManager

import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.StreamTest
import org.assertj.core.api.Assertions.{assertThat, assertThatIllegalArgumentException}

import java.sql.Timestamp

class ExpireSnapshotsProcedureTest extends PaimonSparkTestBase with StreamTest {

Expand Down Expand Up @@ -136,4 +140,66 @@ class ExpireSnapshotsProcedureTest extends PaimonSparkTestBase with StreamTest {
}
}
}

test("test parameter order_than with old timestamp type and new string type") {
sql(
"CREATE TABLE T (a INT, b STRING) " +
"TBLPROPERTIES ( 'num-sorted-run.compaction-trigger' = '999' )")
val table = loadTable("T")
val snapshotManager = table.snapshotManager

// generate 5 snapshot
for (i <- 1 to 5) {
sql(s"INSERT INTO T VALUES ($i, '$i')")
}
checkSnapshots(snapshotManager, 1, 5)

// older_than with old timestamp type, like "1724664000000"
val nanosecond: Long = snapshotManager.latestSnapshot().timeMillis * 1000
spark.sql(
s"CALL paimon.sys.expire_snapshots(table => 'test.T', older_than => $nanosecond, max_deletes => 2)")
checkSnapshots(snapshotManager, 3, 5)

// older_than with new string type, like "2024-08-26 17:20:00"
val timestamp = new Timestamp(snapshotManager.latestSnapshot().timeMillis)
spark.sql(
s"CALL paimon.sys.expire_snapshots(table => 'test.T', older_than => '${timestamp.toString}', max_deletes => 2)")
checkSnapshots(snapshotManager, 5, 5)
}

test("test new parameter time_retained") {
sql(
"CREATE TABLE T (a INT, b STRING) " +
"TBLPROPERTIES ( 'num-sorted-run.compaction-trigger' = '999' )")
val table = loadTable("T")
val snapshotManager = table.snapshotManager

// generate 5 snapshot
for (i <- 1 to 5) {
sql(s"INSERT INTO T VALUES ($i, '$i')")
}
checkSnapshots(snapshotManager, 1, 5)

// no snapshots expired
spark.sql(s"CALL paimon.sys.expire_snapshots(table => 'test.T', time_retained => '1h')")
checkSnapshots(snapshotManager, 1, 5)

// expire assert throw exception
val timestamp = snapshotManager.latestSnapshot().timeMillis
assertThrows[IllegalArgumentException] {
spark.sql(
s"CALL paimon.sys.expire_snapshots(table => 'test.T', older_than => '${timestamp.toString}', time_retained => '1h')")
}

// all snapshot are expired, keep latest snapshot
Thread.sleep(1000)
spark.sql(s"CALL paimon.sys.expire_snapshots(table => 'test.T', time_retained => '1s')")
checkSnapshots(snapshotManager, 5, 5)
}

def checkSnapshots(sm: SnapshotManager, earliest: Int, latest: Int): Unit = {
assertThat(sm.snapshotCount).isEqualTo(latest - earliest + 1)
assertThat(sm.earliestSnapshotId).isEqualTo(earliest)
assertThat(sm.latestSnapshotId).isEqualTo(latest)
}
}

0 comments on commit 6400f12

Please sign in to comment.