Skip to content

Commit

Permalink
Do not set partition count if it's close to the number of workers
Browse files Browse the repository at this point in the history
  • Loading branch information
sopel39 committed Sep 13, 2023
1 parent a12541d commit b0fe9a6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ public PlanOptimizers(
builder.add(new UnaliasSymbolReferences(metadata));
builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(plannerContext, typeAnalyzer, statsCalculator, taskCountEstimator)));
// It can only run after AddExchanges since it estimates the hash partition count for all remote exchanges
builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new DeterminePartitionCount(statsCalculator)));
builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new DeterminePartitionCount(statsCalculator, taskCountEstimator)));
}

// use cost calculator without estimated exchanges after AddExchanges
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsProvider;
import io.trino.cost.TableStatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.execution.querystats.PlanOptimizersStatsCollector;
import io.trino.execution.warnings.WarningCollector;
import io.trino.operator.RetryPolicy;
Expand Down Expand Up @@ -94,10 +95,12 @@ public class DeterminePartitionCount
private static final List<Class<? extends PlanNode>> INSERT_NODES = ImmutableList.of(TableExecuteNode.class, TableWriterNode.class, MergeWriterNode.class);

private final StatsCalculator statsCalculator;
private final TaskCountEstimator taskCountEstimator;

public DeterminePartitionCount(StatsCalculator statsCalculator)
public DeterminePartitionCount(StatsCalculator statsCalculator, TaskCountEstimator taskCountEstimator)
{
this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null");
this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null");
}

@Override
Expand All @@ -115,6 +118,7 @@ public PlanNode optimize(
requireNonNull(session, "session is null");
requireNonNull(types, "types is null");
requireNonNull(tableStatsProvider, "tableStatsProvider is null");
requireNonNull(taskCountEstimator, "taskCountEstimator is null");

// Skip partition count determination if no partitioned remote exchanges exist in the plan anyway
if (!isEligibleRemoteExchangePresent(plan)) {
Expand Down Expand Up @@ -208,6 +212,11 @@ private Optional<Integer> determinePartitionCount(
return Optional.empty();
}

if (partitionCount * 2 >= taskCountEstimator.estimateHashedTaskCount(session) && !getRetryPolicy(session).equals(RetryPolicy.TASK)) {
// Do not cap partition count if it's already close to the possible number of tasks.
return Optional.empty();
}

log.debug("Estimated remote exchange partition count for query %s is %s", session.getQueryId(), partitionCount);
return Optional.of(partitionCount);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ protected LocalQueryRunner createLocalQueryRunner()
.setCatalog(catalogName)
.setSchema("default")
.build();
LocalQueryRunner queryRunner = LocalQueryRunner.create(session);
LocalQueryRunner queryRunner = LocalQueryRunner.builder(session)
.withNodeCountForStats(100)
.build();
queryRunner.createCatalog(
catalogName,
connectorFactory,
Expand Down Expand Up @@ -175,7 +177,7 @@ SELECT count(column_a) FROM table_with_stats_a group by column_b
assertDistributedPlan(
query,
Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(MAX_HASH_PARTITION_COUNT, "20")
.setSystemProperty(MAX_HASH_PARTITION_COUNT, "21")
.setSystemProperty(MIN_HASH_PARTITION_COUNT, "4")
.setSystemProperty(MIN_INPUT_SIZE_PER_TASK, "20MB")
.setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400")
Expand All @@ -190,6 +192,33 @@ SELECT count(column_a) FROM table_with_stats_a group by column_b
node(TableScanNode.class)))))))));
}

@Test
public void testDoesNotSetPartitionCountWhenNodeCountIsSmall()
{
@Language("SQL") String query = """
SELECT count(column_a) FROM table_with_stats_a group by column_b
""";

// DeterminePartitionCount shouldn't put partition count when 2 * "determined partition count"
// is greater or equal to number of workers.
assertDistributedPlan(
query,
Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(MAX_HASH_PARTITION_COUNT, "20")
.setSystemProperty(MIN_HASH_PARTITION_COUNT, "4")
.setSystemProperty(MIN_INPUT_SIZE_PER_TASK, "20MB")
.setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400")
.build(),
output(
project(
node(AggregationNode.class,
exchange(LOCAL,
exchange(REMOTE, REPARTITION, Optional.empty(),
node(AggregationNode.class,
project(
node(TableScanNode.class)))))))));
}

@Test
public void testPlanWhenTableStatisticsAreAbsent()
{
Expand Down Expand Up @@ -371,7 +400,7 @@ public void testEstimatedPartitionCountShouldNotBeLessThanMinLimit()
assertDistributedPlan(
query,
Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(MAX_HASH_PARTITION_COUNT, "20")
.setSystemProperty(MAX_HASH_PARTITION_COUNT, "40")
.setSystemProperty(MIN_HASH_PARTITION_COUNT, "15")
.setSystemProperty(MIN_INPUT_SIZE_PER_TASK, "20MB")
.setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400")
Expand Down

0 comments on commit b0fe9a6

Please sign in to comment.