diff --git a/paimon-core/src/main/java/org/apache/paimon/index/HashBucketAssigner.java b/paimon-core/src/main/java/org/apache/paimon/index/HashBucketAssigner.java index aac1f03339bf..329d3b9712ed 100644 --- a/paimon-core/src/main/java/org/apache/paimon/index/HashBucketAssigner.java +++ b/paimon-core/src/main/java/org/apache/paimon/index/HashBucketAssigner.java @@ -157,6 +157,7 @@ private PartitionIndex loadIndex(BinaryRow partition) { indexFileHandler, partition, targetBucketRowNumber, - (hash) -> computeAssignId(hash) == assignId); + (hash) -> computeAssignId(hash) == assignId, + (bucket) -> computeAssignId(bucket) == assignId); } } diff --git a/paimon-core/src/main/java/org/apache/paimon/index/PartitionIndex.java b/paimon-core/src/main/java/org/apache/paimon/index/PartitionIndex.java index 6748fcdd843d..28428aa2d22f 100644 --- a/paimon-core/src/main/java/org/apache/paimon/index/PartitionIndex.java +++ b/paimon-core/src/main/java/org/apache/paimon/index/PartitionIndex.java @@ -27,8 +27,11 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.IntPredicate; import static org.apache.paimon.index.HashIndexFile.HASH_INDEX; @@ -38,7 +41,9 @@ public class PartitionIndex { public final Int2ShortHashMap hash2Bucket; - public final Map bucketInformation; + public final Map nonFullBucketInformation; + + public final Set totalBucket; private final long targetBucketRowNumber; @@ -51,7 +56,8 @@ public PartitionIndex( Map bucketInformation, long targetBucketRowNumber) { this.hash2Bucket = hash2Bucket; - this.bucketInformation = bucketInformation; + this.nonFullBucketInformation = bucketInformation; + this.totalBucket = new HashSet<>(bucketInformation.keySet()); this.targetBucketRowNumber = targetBucketRowNumber; this.lastAccessedCommitIdentifier = Long.MIN_VALUE; this.accessed = true; @@ -66,33 +72,36 @@ public int assign(int hash, IntPredicate bucketFilter) { } // 2. find bucket from existing buckets - for (Integer bucket : bucketInformation.keySet()) { - if (bucketFilter.test(bucket)) { - // it is my bucket - Long number = bucketInformation.get(bucket); - if (number < targetBucketRowNumber) { - bucketInformation.put(bucket, number + 1); - hash2Bucket.put(hash, bucket.shortValue()); - return bucket; - } + Iterator> iterator = + nonFullBucketInformation.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + Integer bucket = entry.getKey(); + Long number = entry.getValue(); + if (number < targetBucketRowNumber) { + entry.setValue(number + 1); + hash2Bucket.put(hash, bucket.shortValue()); + return bucket; + } else { + iterator.remove(); } } // 3. create a new bucket for (int i = 0; i < Short.MAX_VALUE; i++) { - if (bucketFilter.test(i) && !bucketInformation.containsKey(i)) { + if (bucketFilter.test(i) && !totalBucket.contains(i)) { hash2Bucket.put(hash, (short) i); - bucketInformation.put(i, 1L); + nonFullBucketInformation.put(i, 1L); + totalBucket.add(i); return i; } } @SuppressWarnings("OptionalGetWithoutIsPresent") - int maxBucket = - bucketInformation.keySet().stream().mapToInt(Integer::intValue).max().getAsInt(); + int maxBucket = totalBucket.stream().mapToInt(Integer::intValue).max().getAsInt(); throw new RuntimeException( String.format( - "To more bucket %s, you should increase target bucket row number %s.", + "Too more bucket %s, you should increase target bucket row number %s.", maxBucket, targetBucketRowNumber)); } @@ -100,7 +109,8 @@ public static PartitionIndex loadIndex( IndexFileHandler indexFileHandler, BinaryRow partition, long targetBucketRowNumber, - IntPredicate loadFilter) { + IntPredicate loadFilter, + IntPredicate bucketFilter) { Int2ShortHashMap map = new Int2ShortHashMap(); List files = indexFileHandler.scan(HASH_INDEX, partition); Map buckets = new HashMap<>(); @@ -112,8 +122,11 @@ public static PartitionIndex loadIndex( if (loadFilter.test(hash)) { map.put(hash, (short) file.bucket()); } - buckets.compute( - file.bucket(), (bucket, number) -> number == null ? 1 : number + 1); + if (bucketFilter.test(file.bucket())) { + buckets.compute( + file.bucket(), + (bucket, number) -> number == null ? 1 : number + 1); + } } catch (EOFException ignored) { break; } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala index 0954e2136588..d16ef746daec 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala @@ -245,7 +245,8 @@ object WriteIntoPaimonTable { indexFileHandler, partition, targetBucketRowNumber, - (_) => true)) + (_) => true, + buckFilter)) val bucket = index.assign(hash, buckFilter) val sparkInternalRow = toRow(row) sparkInternalRow.setInt(bucketColIndex, bucket)