Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CH] A simple job scheduler for merge tree cache sync load #6842

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@
import java.util.Set;

public class CHNativeCacheManager {
public static void cacheParts(String table, Set<String> columns, boolean async) {
nativeCacheParts(table, String.join(",", columns), async);
public static String cacheParts(String table, Set<String> columns) {
return nativeCacheParts(table, String.join(",", columns));
}

private static native void nativeCacheParts(String table, String columns, boolean async);
private static native String nativeCacheParts(String table, String columns);

public static CacheResult getCacheStatus(String jobId) {
return nativeGetCacheStatus(jobId);
}

private static native CacheResult nativeGetCacheStatus(String jobId);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution;

public class CacheResult {
public enum Status {
RUNNING(0),
SUCCESS(1),
ERROR(2);

private final int value;

Status(int value) {
this.value = value;
}

public int getValue() {
return value;
}

public static Status fromInt(int value) {
for (Status myEnum : Status.values()) {
if (myEnum.getValue() == value) {
return myEnum;
}
}
throw new IllegalArgumentException("No enum constant for value: " + value);
}
}

private final Status status;
private final String message;

public CacheResult(int status, String message) {
this.status = Status.fromInt(status);
this.message = message;
}

public Status getStatus() {
return status;
}

public String getMessage() {
return message;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf)
hashIds.forEach(
resource_id => CHBroadcastBuildSideCache.invalidateBroadcastHashtable(resource_id))
}
case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) =>
CHNativeCacheManager.cacheParts(mergeTreeTable, columns, true)

case e =>
logError(s"Received unexpected message. $e")
Expand All @@ -74,12 +72,16 @@ class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf)
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) =>
try {
CHNativeCacheManager.cacheParts(mergeTreeTable, columns, false)
context.reply(CacheLoadResult(true))
val jobId = CHNativeCacheManager.cacheParts(mergeTreeTable, columns)
context.reply(CacheJobInfo(status = true, jobId))
} catch {
case _: Exception =>
context.reply(CacheLoadResult(false, s"executor: $executorId cache data failed."))
context.reply(
CacheJobInfo(status = false, "", s"executor: $executorId cache data failed."))
}
case GlutenMergeTreeCacheLoadStatus(jobId) =>
val status = CHNativeCacheManager.getCacheStatus(jobId)
context.reply(status)
case e =>
logError(s"Received unexpected message. $e")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ object GlutenRpcMessages {
case class GlutenCleanExecutionResource(executionId: String, broadcastHashIds: util.Set[String])
extends GlutenRpcMessage

// for mergetree cache
case class GlutenMergeTreeCacheLoad(mergeTreeTable: String, columns: util.Set[String])
extends GlutenRpcMessage

case class CacheLoadResult(success: Boolean, reason: String = "") extends GlutenRpcMessage
case class GlutenMergeTreeCacheLoadStatus(jobId: String)

case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "")
extends GlutenRpcMessage
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
package org.apache.spark.sql.execution.commands

import org.apache.gluten.exception.GlutenException
import org.apache.gluten.execution.CacheResult
import org.apache.gluten.execution.CacheResult.Status
import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.substrait.rel.ExtensionTableBuilder

import org.apache.spark.affinity.CHAffinity
import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.rpc.GlutenRpcMessages.{CacheLoadResult, GlutenMergeTreeCacheLoad}
import org.apache.spark.rpc.GlutenRpcMessages.{CacheJobInfo, GlutenMergeTreeCacheLoad, GlutenMergeTreeCacheLoadStatus}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, GreaterThanOrEqual, IsNotNull, Literal}
import org.apache.spark.sql.delta._
import org.apache.spark.sql.execution.command.LeafRunnableCommand
import org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.toExecutorId
import org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.{checkExecutorId, collectJobTriggerResult, toExecutorId, waitAllJobFinish, waitRpcResults}
import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts
import org.apache.spark.sql.types.{BooleanType, StringType}
import org.apache.spark.util.ThreadUtils
Expand Down Expand Up @@ -106,7 +108,8 @@ case class GlutenCHCacheDataCommand(
}

val selectedAddFiles = if (tsfilter.isDefined) {
val allParts = DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, false)
val allParts =
DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, keepNumRecords = false)
allParts.files.filter(_.modificationTime >= tsfilter.get.toLong).toSeq
} else if (partitionColumn.isDefined && partitionValue.isDefined) {
val partitionColumns = snapshot.metadata.partitionSchema.fieldNames
Expand All @@ -126,10 +129,12 @@ case class GlutenCHCacheDataCommand(
snapshot,
Seq(partitionColumnAttr),
Seq(isNotNullExpr, greaterThanOrEqual),
false)
keepNumRecords = false)
.files
} else {
DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, false).files
DeltaAdapter
.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, keepNumRecords = false)
.files
}

val executorIdsToAddFiles =
Expand All @@ -151,17 +156,15 @@ case class GlutenCHCacheDataCommand(

if (locations.isEmpty) {
// non soft affinity
executorIdsToAddFiles
.get(GlutenCHCacheDataCommand.ALL_EXECUTORS)
.get
executorIdsToAddFiles(GlutenCHCacheDataCommand.ALL_EXECUTORS)
.append(mergeTreePart)
} else {
locations.foreach(
executor => {
if (!executorIdsToAddFiles.contains(executor)) {
executorIdsToAddFiles.put(executor, new ArrayBuffer[AddMergeTreeParts]())
}
executorIdsToAddFiles.get(executor).get.append(mergeTreePart)
executorIdsToAddFiles(executor).append(mergeTreePart)
})
}
})
Expand Down Expand Up @@ -201,87 +204,112 @@ case class GlutenCHCacheDataCommand(
executorIdsToParts.put(executorId, extensionTableNode.getExtensionTableStr)
}
})

// send rpc call
val futureList = ArrayBuffer[(String, Future[CacheJobInfo])]()
if (executorIdsToParts.contains(GlutenCHCacheDataCommand.ALL_EXECUTORS)) {
// send all parts to all executors
val tableMessage = executorIdsToParts.get(GlutenCHCacheDataCommand.ALL_EXECUTORS).get
if (asynExecute) {
GlutenDriverEndpoint.executorDataMap.forEach(
(executorId, executor) => {
executor.executorEndpointRef.send(
GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava))
})
Seq(Row(true, ""))
} else {
val futureList = ArrayBuffer[Future[CacheLoadResult]]()
val resultList = ArrayBuffer[CacheLoadResult]()
GlutenDriverEndpoint.executorDataMap.forEach(
(executorId, executor) => {
futureList.append(
executor.executorEndpointRef.ask[CacheLoadResult](
val tableMessage = executorIdsToParts(GlutenCHCacheDataCommand.ALL_EXECUTORS)
GlutenDriverEndpoint.executorDataMap.forEach(
(executorId, executor) => {
futureList.append(
(
executorId,
executor.executorEndpointRef.ask[CacheJobInfo](
GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava)
))
})
futureList.foreach(
f => {
resultList.append(ThreadUtils.awaitResult(f, Duration.Inf))
})
if (resultList.exists(!_.success)) {
Seq(Row(false, resultList.filter(!_.success).map(_.reason).mkString(";")))
} else {
Seq(Row(true, ""))
}
}
)))
})
} else {
if (asynExecute) {
executorIdsToParts.foreach(
value => {
val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1))
if (executorData != null) {
executorData.executorEndpointRef.send(
GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava))
} else {
throw new GlutenException(
s"executor ${value._1} not found," +
s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}")
}
})
Seq(Row(true, ""))
} else {
val futureList = ArrayBuffer[Future[CacheLoadResult]]()
val resultList = ArrayBuffer[CacheLoadResult]()
executorIdsToParts.foreach(
value => {
val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1))
if (executorData != null) {
futureList.append(
executorData.executorEndpointRef.ask[CacheLoadResult](
GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava)
))
} else {
throw new GlutenException(
s"executor ${value._1} not found," +
s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}")
}
})
futureList.foreach(
f => {
resultList.append(ThreadUtils.awaitResult(f, Duration.Inf))
})
if (resultList.exists(!_.success)) {
Seq(Row(false, resultList.filter(!_.success).map(_.reason).mkString(";")))
} else {
Seq(Row(true, ""))
}
}
executorIdsToParts.foreach(
value => {
checkExecutorId(value._1)
val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1))
futureList.append(
(
value._1,
executorData.executorEndpointRef.ask[CacheJobInfo](
GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava)
)))
})
}
val resultList = waitRpcResults(futureList)
if (asynExecute) {
val res = collectJobTriggerResult(resultList)
Seq(Row(res._1, res._2.mkString(";")))
} else {
val res = waitAllJobFinish(resultList)
Seq(Row(res._1, res._2))
}
}

}

object GlutenCHCacheDataCommand {
val ALL_EXECUTORS = "allExecutors"
private val ALL_EXECUTORS = "allExecutors"

private def toExecutorId(executorId: String): String =
executorId.split("_").last

def waitAllJobFinish(jobs: ArrayBuffer[(String, CacheJobInfo)]): (Boolean, String) = {
val res = collectJobTriggerResult(jobs)
var status = res._1
val messages = res._2
jobs.foreach(
job => {
if (status) {
var complete = false
while (!complete) {
Thread.sleep(5000)
liuneng1994 marked this conversation as resolved.
Show resolved Hide resolved
val future_result = GlutenDriverEndpoint.executorDataMap
.get(toExecutorId(job._1))
.executorEndpointRef
.ask[CacheResult](GlutenMergeTreeCacheLoadStatus(job._2.jobId))
val result = ThreadUtils.awaitResult(future_result, Duration.Inf)
result.getStatus match {
case Status.ERROR =>
status = false
messages.append(
s"executor : {}, failed with message: {};",
job._1,
result.getMessage)
complete = true
case Status.SUCCESS =>
complete = true
case _ =>
// still running
}
}
}
})
(status, messages.mkString(";"))
}

private def collectJobTriggerResult(jobs: ArrayBuffer[(String, CacheJobInfo)]) = {
var status = true
val messages = ArrayBuffer[String]()
jobs.foreach(
job => {
if (!job._2.status) {
messages.append(job._2.reason)
status = false
}
})
(status, messages)
}

private def waitRpcResults = (futureList: ArrayBuffer[(String, Future[CacheJobInfo])]) => {
val resultList = ArrayBuffer[(String, CacheJobInfo)]()
futureList.foreach(
f => {
resultList.append((f._1, ThreadUtils.awaitResult(f._2, Duration.Inf)))
})
resultList
}

private def checkExecutorId(executorId: String): Unit = {
if (!GlutenDriverEndpoint.executorDataMap.containsKey(toExecutorId(executorId))) {
throw new GlutenException(
s"executor $executorId not found," +
s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}")
}
}

}
1 change: 1 addition & 0 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,7 @@ void BackendInitializerUtil::init(const std::string_view plan)
// Init the table metadata cache map
StorageMergeTreeFactory::init_cache_map();

JobScheduler::initialize(SerializedPlanParser::global_context);
CacheManager::initialize(SerializedPlanParser::global_context);

std::call_once(
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Common/ConcurrentMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
#pragma once

#include <mutex>
#include <shared_mutex>
#include <unordered_map>

namespace local_engine
Expand Down
Loading
Loading