diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java index f5f75dc1dca6..7b765924fa0d 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java @@ -19,9 +19,15 @@ import java.util.Set; public class CHNativeCacheManager { - public static void cacheParts(String table, Set columns, boolean async) { - nativeCacheParts(table, String.join(",", columns), async); + public static String cacheParts(String table, Set columns) { + return nativeCacheParts(table, String.join(",", columns)); } - private static native void nativeCacheParts(String table, String columns, boolean async); + private static native String nativeCacheParts(String table, String columns); + + public static CacheResult getCacheStatus(String jobId) { + return nativeGetCacheStatus(jobId); + } + + private static native CacheResult nativeGetCacheStatus(String jobId); } diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CacheResult.java b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CacheResult.java new file mode 100644 index 000000000000..0fa69e0d0b1f --- /dev/null +++ b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CacheResult.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution; + +public class CacheResult { + public enum Status { + RUNNING(0), + SUCCESS(1), + ERROR(2); + + private final int value; + + Status(int value) { + this.value = value; + } + + public int getValue() { + return value; + } + + public static Status fromInt(int value) { + for (Status myEnum : Status.values()) { + if (myEnum.getValue() == value) { + return myEnum; + } + } + throw new IllegalArgumentException("No enum constant for value: " + value); + } + } + + private final Status status; + private final String message; + + public CacheResult(int status, String message) { + this.status = Status.fromInt(status); + this.message = message; + } + + public Status getStatus() { + return status; + } + + public String getMessage() { + return message; + } +} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala index 4d90ab6533ba..8a3bde235887 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala @@ -64,8 +64,6 @@ class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf) hashIds.forEach( resource_id => CHBroadcastBuildSideCache.invalidateBroadcastHashtable(resource_id)) } - case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) => - CHNativeCacheManager.cacheParts(mergeTreeTable, columns, true) case e => logError(s"Received unexpected message. $e") @@ -74,12 +72,16 @@ class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) => try { - CHNativeCacheManager.cacheParts(mergeTreeTable, columns, false) - context.reply(CacheLoadResult(true)) + val jobId = CHNativeCacheManager.cacheParts(mergeTreeTable, columns) + context.reply(CacheJobInfo(status = true, jobId)) } catch { case _: Exception => - context.reply(CacheLoadResult(false, s"executor: $executorId cache data failed.")) + context.reply( + CacheJobInfo(status = false, "", s"executor: $executorId cache data failed.")) } + case GlutenMergeTreeCacheLoadStatus(jobId) => + val status = CHNativeCacheManager.getCacheStatus(jobId) + context.reply(status) case e => logError(s"Received unexpected message. $e") } diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala index d675d705f10a..800b15b9949b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala @@ -35,8 +35,12 @@ object GlutenRpcMessages { case class GlutenCleanExecutionResource(executionId: String, broadcastHashIds: util.Set[String]) extends GlutenRpcMessage + // for mergetree cache case class GlutenMergeTreeCacheLoad(mergeTreeTable: String, columns: util.Set[String]) extends GlutenRpcMessage - case class CacheLoadResult(success: Boolean, reason: String = "") extends GlutenRpcMessage + case class GlutenMergeTreeCacheLoadStatus(jobId: String) + + case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "") + extends GlutenRpcMessage } diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala index 1e6b024063b6..f32d22d5eac0 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql.execution.commands import org.apache.gluten.exception.GlutenException +import org.apache.gluten.execution.CacheResult +import org.apache.gluten.execution.CacheResult.Status import org.apache.gluten.expression.ConverterUtils import org.apache.gluten.substrait.rel.ExtensionTableBuilder import org.apache.spark.affinity.CHAffinity import org.apache.spark.rpc.GlutenDriverEndpoint -import org.apache.spark.rpc.GlutenRpcMessages.{CacheLoadResult, GlutenMergeTreeCacheLoad} +import org.apache.spark.rpc.GlutenRpcMessages.{CacheJobInfo, GlutenMergeTreeCacheLoad, GlutenMergeTreeCacheLoadStatus} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, GreaterThanOrEqual, IsNotNull, Literal} import org.apache.spark.sql.delta._ import org.apache.spark.sql.execution.command.LeafRunnableCommand -import org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.toExecutorId +import org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.{checkExecutorId, collectJobTriggerResult, toExecutorId, waitAllJobFinish, waitRpcResults} import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts import org.apache.spark.sql.types.{BooleanType, StringType} import org.apache.spark.util.ThreadUtils @@ -106,7 +108,8 @@ case class GlutenCHCacheDataCommand( } val selectedAddFiles = if (tsfilter.isDefined) { - val allParts = DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, false) + val allParts = + DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, keepNumRecords = false) allParts.files.filter(_.modificationTime >= tsfilter.get.toLong).toSeq } else if (partitionColumn.isDefined && partitionValue.isDefined) { val partitionColumns = snapshot.metadata.partitionSchema.fieldNames @@ -126,10 +129,12 @@ case class GlutenCHCacheDataCommand( snapshot, Seq(partitionColumnAttr), Seq(isNotNullExpr, greaterThanOrEqual), - false) + keepNumRecords = false) .files } else { - DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, false).files + DeltaAdapter + .snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, keepNumRecords = false) + .files } val executorIdsToAddFiles = @@ -151,9 +156,7 @@ case class GlutenCHCacheDataCommand( if (locations.isEmpty) { // non soft affinity - executorIdsToAddFiles - .get(GlutenCHCacheDataCommand.ALL_EXECUTORS) - .get + executorIdsToAddFiles(GlutenCHCacheDataCommand.ALL_EXECUTORS) .append(mergeTreePart) } else { locations.foreach( @@ -161,7 +164,7 @@ case class GlutenCHCacheDataCommand( if (!executorIdsToAddFiles.contains(executor)) { executorIdsToAddFiles.put(executor, new ArrayBuffer[AddMergeTreeParts]()) } - executorIdsToAddFiles.get(executor).get.append(mergeTreePart) + executorIdsToAddFiles(executor).append(mergeTreePart) }) } }) @@ -201,87 +204,112 @@ case class GlutenCHCacheDataCommand( executorIdsToParts.put(executorId, extensionTableNode.getExtensionTableStr) } }) - - // send rpc call + val futureList = ArrayBuffer[(String, Future[CacheJobInfo])]() if (executorIdsToParts.contains(GlutenCHCacheDataCommand.ALL_EXECUTORS)) { // send all parts to all executors - val tableMessage = executorIdsToParts.get(GlutenCHCacheDataCommand.ALL_EXECUTORS).get - if (asynExecute) { - GlutenDriverEndpoint.executorDataMap.forEach( - (executorId, executor) => { - executor.executorEndpointRef.send( - GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava)) - }) - Seq(Row(true, "")) - } else { - val futureList = ArrayBuffer[Future[CacheLoadResult]]() - val resultList = ArrayBuffer[CacheLoadResult]() - GlutenDriverEndpoint.executorDataMap.forEach( - (executorId, executor) => { - futureList.append( - executor.executorEndpointRef.ask[CacheLoadResult]( + val tableMessage = executorIdsToParts(GlutenCHCacheDataCommand.ALL_EXECUTORS) + GlutenDriverEndpoint.executorDataMap.forEach( + (executorId, executor) => { + futureList.append( + ( + executorId, + executor.executorEndpointRef.ask[CacheJobInfo]( GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava) - )) - }) - futureList.foreach( - f => { - resultList.append(ThreadUtils.awaitResult(f, Duration.Inf)) - }) - if (resultList.exists(!_.success)) { - Seq(Row(false, resultList.filter(!_.success).map(_.reason).mkString(";"))) - } else { - Seq(Row(true, "")) - } - } + ))) + }) } else { - if (asynExecute) { - executorIdsToParts.foreach( - value => { - val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1)) - if (executorData != null) { - executorData.executorEndpointRef.send( - GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava)) - } else { - throw new GlutenException( - s"executor ${value._1} not found," + - s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}") - } - }) - Seq(Row(true, "")) - } else { - val futureList = ArrayBuffer[Future[CacheLoadResult]]() - val resultList = ArrayBuffer[CacheLoadResult]() - executorIdsToParts.foreach( - value => { - val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1)) - if (executorData != null) { - futureList.append( - executorData.executorEndpointRef.ask[CacheLoadResult]( - GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava) - )) - } else { - throw new GlutenException( - s"executor ${value._1} not found," + - s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}") - } - }) - futureList.foreach( - f => { - resultList.append(ThreadUtils.awaitResult(f, Duration.Inf)) - }) - if (resultList.exists(!_.success)) { - Seq(Row(false, resultList.filter(!_.success).map(_.reason).mkString(";"))) - } else { - Seq(Row(true, "")) - } - } + executorIdsToParts.foreach( + value => { + checkExecutorId(value._1) + val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1)) + futureList.append( + ( + value._1, + executorData.executorEndpointRef.ask[CacheJobInfo]( + GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava) + ))) + }) + } + val resultList = waitRpcResults(futureList) + if (asynExecute) { + val res = collectJobTriggerResult(resultList) + Seq(Row(res._1, res._2.mkString(";"))) + } else { + val res = waitAllJobFinish(resultList) + Seq(Row(res._1, res._2)) } } + } object GlutenCHCacheDataCommand { - val ALL_EXECUTORS = "allExecutors" + private val ALL_EXECUTORS = "allExecutors" private def toExecutorId(executorId: String): String = executorId.split("_").last + + def waitAllJobFinish(jobs: ArrayBuffer[(String, CacheJobInfo)]): (Boolean, String) = { + val res = collectJobTriggerResult(jobs) + var status = res._1 + val messages = res._2 + jobs.foreach( + job => { + if (status) { + var complete = false + while (!complete) { + Thread.sleep(5000) + val future_result = GlutenDriverEndpoint.executorDataMap + .get(toExecutorId(job._1)) + .executorEndpointRef + .ask[CacheResult](GlutenMergeTreeCacheLoadStatus(job._2.jobId)) + val result = ThreadUtils.awaitResult(future_result, Duration.Inf) + result.getStatus match { + case Status.ERROR => + status = false + messages.append( + s"executor : {}, failed with message: {};", + job._1, + result.getMessage) + complete = true + case Status.SUCCESS => + complete = true + case _ => + // still running + } + } + } + }) + (status, messages.mkString(";")) + } + + private def collectJobTriggerResult(jobs: ArrayBuffer[(String, CacheJobInfo)]) = { + var status = true + val messages = ArrayBuffer[String]() + jobs.foreach( + job => { + if (!job._2.status) { + messages.append(job._2.reason) + status = false + } + }) + (status, messages) + } + + private def waitRpcResults = (futureList: ArrayBuffer[(String, Future[CacheJobInfo])]) => { + val resultList = ArrayBuffer[(String, CacheJobInfo)]() + futureList.foreach( + f => { + resultList.append((f._1, ThreadUtils.awaitResult(f._2, Duration.Inf))) + }) + resultList + } + + private def checkExecutorId(executorId: String): Unit = { + if (!GlutenDriverEndpoint.executorDataMap.containsKey(toExecutorId(executorId))) { + throw new GlutenException( + s"executor $executorId not found," + + s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}") + } + } + } diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 0409b66bd920..8e07eea011b8 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -979,6 +979,7 @@ void BackendInitializerUtil::init(const std::string_view plan) // Init the table metadata cache map StorageMergeTreeFactory::init_cache_map(); + JobScheduler::initialize(SerializedPlanParser::global_context); CacheManager::initialize(SerializedPlanParser::global_context); std::call_once( diff --git a/cpp-ch/local-engine/Common/ConcurrentMap.h b/cpp-ch/local-engine/Common/ConcurrentMap.h index 1719d9b255ea..2db35102215a 100644 --- a/cpp-ch/local-engine/Common/ConcurrentMap.h +++ b/cpp-ch/local-engine/Common/ConcurrentMap.h @@ -16,7 +16,7 @@ */ #pragma once -#include +#include #include namespace local_engine diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h b/cpp-ch/local-engine/Common/GlutenConfig.h index 84744dab21b8..ac82b0fff03a 100644 --- a/cpp-ch/local-engine/Common/GlutenConfig.h +++ b/cpp-ch/local-engine/Common/GlutenConfig.h @@ -183,5 +183,19 @@ struct MergeTreeConfig return config; } }; + +struct GlutenJobSchedulerConfig +{ + inline static const String JOB_SCHEDULER_MAX_THREADS = "job_scheduler_max_threads"; + + size_t job_scheduler_max_threads = 10; + + static GlutenJobSchedulerConfig loadFromContext(DB::ContextPtr context) + { + GlutenJobSchedulerConfig config; + config.job_scheduler_max_threads = context->getConfigRef().getUInt64(JOB_SCHEDULER_MAX_THREADS, 10); + return config; + } +}; } diff --git a/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp b/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp index d2c7b06810db..a97f0c72ada4 100644 --- a/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp +++ b/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp @@ -26,12 +26,13 @@ #include #include #include -#include #include #include #include #include +#include + namespace DB { namespace ErrorCodes @@ -49,6 +50,16 @@ extern const Metric LocalThreadScheduled; namespace local_engine { + +jclass CacheManager::cache_result_class = nullptr; +jmethodID CacheManager::cache_result_constructor = nullptr; + +void CacheManager::initJNI(JNIEnv * env) +{ + cache_result_class = CreateGlobalClassReference(env, "Lorg/apache/gluten/execution/CacheResult;"); + cache_result_constructor = GetMethodID(env, cache_result_class, "", "(ILjava/lang/String;)V"); +} + CacheManager & CacheManager::instance() { static CacheManager cache_manager; @@ -59,13 +70,6 @@ void CacheManager::initialize(DB::ContextMutablePtr context_) { auto & manager = instance(); manager.context = context_; - manager.thread_pool = std::make_unique( - CurrentMetrics::LocalThread, - CurrentMetrics::LocalThreadActive, - CurrentMetrics::LocalThreadScheduled, - manager.context->getConfigRef().getInt("cache_sync_max_threads", 10), - 0, - 0); } struct CacheJobContext @@ -73,17 +77,16 @@ struct CacheJobContext MergeTreeTable table; }; -void CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set & columns, std::shared_ptr latch) +Task CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set & columns) { CacheJobContext job_context{table}; job_context.table.parts.clear(); job_context.table.parts.push_back(part); job_context.table.snapshot_id = ""; - auto job = [job_detail = job_context, context = this->context, read_columns = columns, latch = latch]() + Task task = [job_detail = job_context, context = this->context, read_columns = columns]() { try { - SCOPE_EXIT({ if (latch) latch->count_down();}); auto storage = MergeTreeRelParser::parseStorage(job_detail.table, context, true); auto storage_snapshot = std::make_shared(*storage, storage->getInMemoryMetadataPtr()); NamesAndTypesList names_and_types_list; @@ -113,8 +116,7 @@ void CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& p PullingPipelineExecutor executor(pipeline); while (true) { - Chunk chunk; - if (!executor.pull(chunk)) + if (Chunk chunk; !executor.pull(chunk)) break; } LOG_INFO(getLogger("CacheManager"), "Load cache of table {}.{} part {} success.", job_detail.table.database, job_detail.table.table, job_detail.table.parts.front().name); @@ -122,22 +124,58 @@ void CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& p catch (std::exception& e) { LOG_ERROR(getLogger("CacheManager"), "Load cache of table {}.{} part {} failed.\n {}", job_detail.table.database, job_detail.table.table, job_detail.table.parts.front().name, e.what()); + std::rethrow_exception(std::current_exception()); } }; LOG_INFO(getLogger("CacheManager"), "Loading cache of table {}.{} part {}", job_context.table.database, job_context.table.table, job_context.table.parts.front().name); - thread_pool->scheduleOrThrowOnError(std::move(job)); + return std::move(task); } -void CacheManager::cacheParts(const String& table_def, const std::unordered_set& columns, bool async) +JobId CacheManager::cacheParts(const String& table_def, const std::unordered_set& columns) { auto table = parseMergeTreeTableString(table_def); - std::shared_ptr latch = nullptr; - if (!async) latch = std::make_shared(table.parts.size()); + JobId id = toString(UUIDHelpers::generateV4()); + Job job(id); for (const auto & part : table.parts) { - cachePart(table, part, columns, latch); + job.addTask(cachePart(table, part, columns)); + } + auto& scheduler = JobScheduler::instance(); + scheduler.scheduleJob(std::move(job)); + return id; +} + +jobject CacheManager::getCacheStatus(JNIEnv * env, const String & jobId) +{ + auto& scheduler = JobScheduler::instance(); + auto job_status = scheduler.getJobSatus(jobId); + int status = 0; + String message; + if (job_status.has_value()) + { + switch (job_status.value().status) + { + case JobSatus::RUNNING: + status = 0; + break; + case JobSatus::FINISHED: + status = 1; + break; + case JobSatus::FAILED: + status = 2; + for (const auto & msg : job_status->messages) + { + message.append(msg); + message.append(";"); + } + break; + } + } + else + { + status = 2; + message = fmt::format("job {} not found", jobId); } - if (latch) - latch->wait(); + return env->NewObject(cache_result_class, cache_result_constructor, status, charTojstring(env, message.c_str())); } } \ No newline at end of file diff --git a/cpp-ch/local-engine/Storages/Cache/CacheManager.h b/cpp-ch/local-engine/Storages/Cache/CacheManager.h index a303b7b7fc63..b88a3ea03e4e 100644 --- a/cpp-ch/local-engine/Storages/Cache/CacheManager.h +++ b/cpp-ch/local-engine/Storages/Cache/CacheManager.h @@ -16,29 +16,32 @@ */ #pragma once #include -#include - +#include +#include namespace local_engine { struct MergeTreePart; struct MergeTreeTable; + + + /*** * Manage the cache of the MergeTree, mainly including meta.bin, data.bin, metadata.gluten */ class CacheManager { public: + static jclass cache_result_class; + static jmethodID cache_result_constructor; + static void initJNI(JNIEnv* env); + static CacheManager & instance(); static void initialize(DB::ContextMutablePtr context); - void cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set& columns, std::shared_ptr latch = nullptr); - void cacheParts(const String& table_def, const std::unordered_set& columns, bool async = true); + Task cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set& columns); + JobId cacheParts(const String& table_def, const std::unordered_set& columns); + static jobject getCacheStatus(JNIEnv * env, const String& jobId); private: CacheManager() = default; - - std::unique_ptr thread_pool; DB::ContextMutablePtr context; - std::unordered_map policy_to_disk; - std::unordered_map disk_to_metadisk; - std::unordered_map policy_to_cache; }; } \ No newline at end of file diff --git a/cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp b/cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp new file mode 100644 index 000000000000..6a43ad644433 --- /dev/null +++ b/cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "JobScheduler.h" + +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} +} + +namespace CurrentMetrics +{ +extern const Metric LocalThread; +extern const Metric LocalThreadActive; +extern const Metric LocalThreadScheduled; +} + +namespace local_engine +{ +std::shared_ptr global_job_scheduler = nullptr; + +void JobScheduler::initialize(DB::ContextPtr context) +{ + auto config = GlutenJobSchedulerConfig::loadFromContext(context); + instance().thread_pool = std::make_unique( + CurrentMetrics::LocalThread, + CurrentMetrics::LocalThreadActive, + CurrentMetrics::LocalThreadScheduled, + config.job_scheduler_max_threads, + 0, + 0); + +} + +JobId JobScheduler::scheduleJob(Job&& job) +{ + cleanFinishedJobs(); + if (job_details.contains(job.id)) + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "job {} exists.", job.id); + } + size_t task_num = job.tasks.size(); + auto job_id = job.id; + std::vector task_results; + task_results.reserve(task_num); + JobContext job_context = {std::move(job), std::make_unique(task_num), std::move(task_results)}; + { + std::lock_guard lock(job_details_mutex); + job_details.emplace(job_id, std::move(job_context)); + } + LOG_INFO(logger, "schedule job {}", job_id); + + auto & job_detail = job_details.at(job_id); + + for (auto & task : job_detail.job.tasks) + { + job_detail.task_results.emplace_back(TaskResult()); + auto & task_result = job_detail.task_results.back(); + thread_pool->scheduleOrThrow( + [&]() + { + SCOPE_EXIT({ + job_detail.remain_tasks->fetch_sub(1, std::memory_order::acquire); + if (job_detail.isFinished()) + { + addFinishedJob(job_detail.job.id); + } + }); + try + { + task(); + task_result.status = TaskResult::Status::SUCCESS; + } + catch (std::exception & e) + { + task_result.status = TaskResult::Status::FAILED; + task_result.message = e.what(); + } + }); + } + return job_id; +} + +std::optional JobScheduler::getJobSatus(const JobId & job_id) +{ + if (!job_details.contains(job_id)) + { + return std::nullopt; + } + std::optional res; + auto & job_context = job_details.at(job_id); + if (job_context.isFinished()) + { + std::vector messages; + for (auto & task_result : job_context.task_results) + { + if (task_result.status == TaskResult::Status::FAILED) + { + messages.push_back(task_result.message); + } + } + if (messages.empty()) + res = JobSatus::success(); + else + res= JobSatus::failed(messages); + } + else + res = JobSatus::running(); + return res; +} + +void JobScheduler::cleanupJob(const JobId & job_id) +{ + LOG_INFO(logger, "clean job {}", job_id); + job_details.erase(job_id); +} + +void JobScheduler::addFinishedJob(const JobId & job_id) +{ + std::lock_guard lock(finished_job_mutex); + auto job = std::make_pair(job_id, Stopwatch()); + finished_job.emplace_back(job); +} + +void JobScheduler::cleanFinishedJobs() +{ + std::lock_guard lock(finished_job_mutex); + for (auto it = finished_job.begin(); it != finished_job.end();) + { + // clean finished job after 5 minutes + if (it->second.elapsedSeconds() > 60 * 5) + { + cleanupJob(it->first); + it = finished_job.erase(it); + } + else + ++it; + } +} +} \ No newline at end of file diff --git a/cpp-ch/local-engine/Storages/Cache/JobScheduler.h b/cpp-ch/local-engine/Storages/Cache/JobScheduler.h new file mode 100644 index 000000000000..b5c2f601a92b --- /dev/null +++ b/cpp-ch/local-engine/Storages/Cache/JobScheduler.h @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include + +namespace local_engine +{ + +using JobId = String; +using Task = std::function; + +class Job +{ + friend class JobScheduler; +public: + explicit Job(const JobId& id) + : id(id) + { + } + + void addTask(Task&& task) + { + tasks.emplace_back(task); + } + +private: + JobId id; + std::vector tasks; +}; + + + +struct JobSatus +{ + enum Status + { + RUNNING, + FINISHED, + FAILED + }; + Status status; + std::vector messages; + + static JobSatus success() + { + return JobSatus{FINISHED}; + } + + static JobSatus running() + { + return JobSatus{RUNNING}; + } + + static JobSatus failed(const std::vector & messages) + { + return JobSatus{FAILED, messages}; + } +}; + +struct TaskResult +{ + enum Status + { + SUCCESS, + FAILED, + RUNNING + }; + Status status = RUNNING; + String message; +}; + +class JobContext +{ +public: + Job job; + std::unique_ptr remain_tasks = std::make_unique(); + std::vector task_results; + + bool isFinished() + { + return remain_tasks->load(std::memory_order::relaxed) == 0; + } +}; + +class JobScheduler +{ +public: + static JobScheduler& instance() + { + static JobScheduler global_job_scheduler; + return global_job_scheduler; + } + + static void initialize(DB::ContextPtr context); + + JobId scheduleJob(Job&& job); + + std::optional getJobSatus(const JobId& job_id); + + void cleanupJob(const JobId& job_id); + + void addFinishedJob(const JobId& job_id); + + void cleanFinishedJobs(); +private: + JobScheduler() = default; + std::unique_ptr thread_pool; + std::unordered_map job_details; + std::mutex job_details_mutex; + + std::vector> finished_job; + std::mutex finished_job_mutex; + LoggerPtr logger = getLogger("JobScheduler"); +}; +} diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 828556b4abf6..3c3d6d4f89c2 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -163,6 +163,7 @@ JNIEXPORT jint JNI_OnLoad(JavaVM * vm, void * /*reserved*/) env, local_engine::SparkRowToCHColumn::spark_row_interator_class, "nextBatch", "()Ljava/nio/ByteBuffer;"); local_engine::BroadCastJoinBuilder::init(env); + local_engine::CacheManager::initJNI(env); local_engine::JNIUtils::vm = vm; return JNI_VERSION_1_8; @@ -1269,7 +1270,7 @@ JNIEXPORT void Java_org_apache_gluten_utils_TestExceptionUtils_generateNativeExc -JNIEXPORT void Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCacheParts(JNIEnv * env, jobject, jstring table_, jstring columns_, jboolean async_) +JNIEXPORT jstring Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCacheParts(JNIEnv * env, jobject, jstring table_, jstring columns_) { LOCAL_ENGINE_JNI_METHOD_START auto table_def = jstring2string(env, table_); @@ -1280,10 +1281,17 @@ JNIEXPORT void Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCache { column_set.insert(col); } - local_engine::CacheManager::instance().cacheParts(table_def, column_set, async_); - LOCAL_ENGINE_JNI_METHOD_END(env, ); + auto id = local_engine::CacheManager::instance().cacheParts(table_def, column_set); + return local_engine::charTojstring(env, id.c_str()); + LOCAL_ENGINE_JNI_METHOD_END(env, nullptr); } +JNIEXPORT jobject Java_org_apache_gluten_execution_CHNativeCacheManager_nativeGetCacheStatus(JNIEnv * env, jobject, jstring id) +{ + LOCAL_ENGINE_JNI_METHOD_START + return local_engine::CacheManager::instance().getCacheStatus(env, jstring2string(env, id)); + LOCAL_ENGINE_JNI_METHOD_END(env, nullptr); +} #ifdef __cplusplus }