Skip to content

Commit

Permalink
fix review
Browse files Browse the repository at this point in the history
  • Loading branch information
loneylee committed Aug 20, 2024
1 parent f5ee78a commit 8afe4a1
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, GreaterThanOrEqual, IsNotNull, Literal}
import org.apache.spark.sql.delta._
import org.apache.spark.sql.execution.command.LeafRunnableCommand
import org.apache.spark.sql.execution.commands.GlutenCacheBase._
import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts
import org.apache.spark.sql.types.{BooleanType, StringType}

Expand All @@ -49,8 +50,7 @@ case class GlutenCHCacheDataCommand(
partitionColumn: Option[String],
partitionValue: Option[String],
tablePropertyOverrides: Map[String, String]
) extends LeafRunnableCommand
with GlutenCacheBase {
) extends LeafRunnableCommand {

override def output: Seq[Attribute] = Seq(
AttributeReference("result", BooleanType, nullable = false)(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Future
import scala.concurrent.duration.Duration

trait GlutenCacheBase {
object GlutenCacheBase {
def ALL_EXECUTORS: String = "allExecutors"

protected def toExecutorId(executorId: String): String =
def toExecutorId(executorId: String): String =
executorId.split("_").last

protected def waitRpcResults
Expand All @@ -46,7 +46,7 @@ trait GlutenCacheBase {
resultList
}

protected def checkExecutorId(executorId: String): Unit = {
def checkExecutorId(executorId: String): Unit = {
if (!GlutenDriverEndpoint.executorDataMap.containsKey(toExecutorId(executorId))) {
throw new GlutenException(
s"executor $executorId not found," +
Expand Down Expand Up @@ -87,7 +87,7 @@ trait GlutenCacheBase {
(status, messages.mkString(";"))
}

protected def collectJobTriggerResult(
def collectJobTriggerResult(
jobs: ArrayBuffer[(String, CacheJobInfo)]): (Boolean, ArrayBuffer[String]) = {
var status = true
val messages = ArrayBuffer[String]()
Expand All @@ -101,7 +101,7 @@ trait GlutenCacheBase {
(status, messages)
}

protected def getResult(
def getResult(
futureList: ArrayBuffer[(String, Future[CacheJobInfo])],
async: Boolean): Seq[Row] = {
val resultList = waitRpcResults(futureList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.rpc.GlutenRpcMessages.{CacheJobInfo, GlutenFilesCacheLoa
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.command.LeafRunnableCommand
import org.apache.spark.sql.execution.commands.GlutenCacheBase._
import org.apache.spark.sql.types.{BooleanType, StringType}

import org.apache.hadoop.conf.Configuration
Expand All @@ -42,8 +43,7 @@ case class GlutenCacheFilesCommand(
selectedColumn: Option[Seq[String]],
filePath: String,
propertyOverrides: Map[String, String]
) extends LeafRunnableCommand
with GlutenCacheBase {
) extends LeafRunnableCommand {

override def output: Seq[Attribute] = Seq(
AttributeReference("result", BooleanType, nullable = false)(),
Expand Down Expand Up @@ -135,8 +135,7 @@ case class GlutenCacheFilesCommand(
executorIdsToLocalFiles.foreach {
case (executorId, fileNode) =>
checkExecutorId(executorId)
val executor = GlutenDriverEndpoint.executorDataMap.get(
GlutenCacheFilesCommand.toExecutorId(executorId))
val executor = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(executorId))
futureList.append(
(
executorId,
Expand Down Expand Up @@ -187,10 +186,3 @@ case class GlutenCacheFilesCommand(
}
}
}

object GlutenCacheFilesCommand {
val ALL_EXECUTORS = "allExecutors"

private def toExecutorId(executorId: String): String =
executorId.split("_").last
}
13 changes: 9 additions & 4 deletions cpp-ch/local-engine/Common/QueryContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,19 @@ int64_t QueryContextManager::initializeQuery()

DB::ContextMutablePtr QueryContextManager::currentQueryContext()
{
if (!CurrentThread::getGroup())
{
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Thread group not found.");
}
auto thread_group = currentThreadGroup();
int64_t id = reinterpret_cast<int64_t>(CurrentThread::getGroup().get());
return query_map.get(id)->query_context;
}

std::shared_ptr<DB::ThreadGroup> QueryContextManager::currentThreadGroup()
{
if (auto thread_group = CurrentThread::getGroup())
return thread_group;

throw Exception(ErrorCodes::LOGICAL_ERROR, "Thread group not found.");
}

void QueryContextManager::logCurrentPerformanceCounters(ProfileEvents::Counters & counters)
{
if (!CurrentThread::getGroup())
Expand Down
1 change: 1 addition & 0 deletions cpp-ch/local-engine/Common/QueryContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class QueryContextManager
}
int64_t initializeQuery();
DB::ContextMutablePtr currentQueryContext();
static std::shared_ptr<DB::ThreadGroup> currentThreadGroup();
void logCurrentPerformanceCounters(ProfileEvents::Counters& counters);
size_t currentPeakMemory(int64_t id);
void finalizeQuery(int64_t id);
Expand Down
5 changes: 4 additions & 1 deletion cpp-ch/local-engine/Parser/RelMetric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
* limitations under the License.
*/
#include "RelMetric.h"

#include <Processors/IProcessor.h>
#include <Processors/QueryPlan/AggregatingStep.h>
#include <Processors/QueryPlan/ReadFromMergeTree.h>
#include <Storages/SubstraitSource/SubstraitFileSourceStep.h>
#include <Common/QueryContext.h>

using namespace rapidjson;

Expand Down Expand Up @@ -47,7 +49,8 @@ namespace local_engine

static void writeCacheHits(Writer<StringBuffer> & writer)
{
auto & counters = DB::CurrentThread::getProfileEvents();
const auto thread_group = QueryContextManager::currentThreadGroup();
auto & counters = thread_group->performance_counters;
auto read_cache_hits = counters[ProfileEvents::CachedReadBufferReadFromCacheHits].load();
auto miss_cache_hits = counters[ProfileEvents::CachedReadBufferReadFromCacheMisses].load();
auto read_cache_bytes = counters[ProfileEvents::CachedReadBufferReadFromCacheBytes].load();
Expand Down

0 comments on commit 8afe4a1

Please sign in to comment.