diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoProgressMonitor.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoProgressMonitor.scala new file mode 100644 index 00000000000..5328f092ec9 --- /dev/null +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoProgressMonitor.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kyuubi.engine.trino + +import java.util + +import scala.collection.JavaConverters._ +import scala.collection.immutable.SortedMap + +import io.trino.client.{StageStats, StatementClient} + +import org.apache.kyuubi.engine.trino.TrinoProgressMonitor.{COLUMN_1_WIDTH, HEADERS} +import org.apache.kyuubi.engine.trino.operation.progress.{TrinoStage, TrinoStageProgress} +import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TJobExecutionStatus + +class TrinoProgressMonitor(trino: StatementClient) { + + private lazy val progressMap: Map[TrinoStage, TrinoStageProgress] = { + if (trino != null) { + val trinoStats = trino.getStats + val stageQueue = scala.collection.mutable.Queue[StageStats]() + val stages = scala.collection.mutable.ListBuffer[(TrinoStage, TrinoStageProgress)]() + val rootStage = trinoStats.getRootStage + if (rootStage != null) { + stageQueue.enqueue(rootStage) + } + while (stageQueue.nonEmpty) { + val stage = stageQueue.dequeue() + val stageId = stage.getStageId + val stageProgress = TrinoStageProgress( + stage.getState, + stage.getTotalSplits, + stage.getCompletedSplits, + stage.getRunningSplits, + stage.getFailedTasks) + stages.append((TrinoStage(stageId), stageProgress)) + val subStages = asScalaBuffer(stage.getSubStages) + stageQueue.enqueue(subStages: _*) + } + SortedMap(stages: _*) + } else { + SortedMap() + } + } + + def headers: util.List[String] = HEADERS + + def rows: util.List[util.List[String]] = { + val progressRows = progressMap.map { + case (stage, progress) => + val complete = progress.completedSplits + val total = progress.totalSplits + val running = progress.runningSplits + val failed = progress.failedTasks + val stageName = "Stage-" + stage.stageId + val nameWithProgress = getNameWithProgress(stageName, complete, total) + val pending = total - complete - running + util.Arrays.asList( + nameWithProgress, + progress.state, + String.valueOf(total), + String.valueOf(complete), + String.valueOf(running), + String.valueOf(pending), + String.valueOf(failed), + "") + }.toList.asJavaCollection + new util.ArrayList[util.List[String]](progressRows) + } + + def footerSummary: String = { + "STAGES: %02d/%02d".format(getCompletedStages, progressMap.keySet.size) + } + + def progressedPercentage: Double = { + if (trino != null && trino.getStats != null) { + val progressPercentage = trino.getStats.getProgressPercentage + progressPercentage.orElse(0.0d) + } else { + 0.0d + } + } + + def executionStatus: TJobExecutionStatus = + if (getCompletedStages == progressMap.keySet.size) { + TJobExecutionStatus.COMPLETE + } else { + TJobExecutionStatus.IN_PROGRESS + } + + private def getNameWithProgress(s: String, complete: Int, total: Int): String = { + if (s == null) return "" + val percent = + if (total == 0) 1.0f + else complete.toFloat / total.toFloat + // lets use the remaining space in column 1 as progress bar + val spaceRemaining = COLUMN_1_WIDTH - s.length - 1 + var trimmedVName = s + // if the vertex name is longer than column 1 width, trim it down + if (s.length > COLUMN_1_WIDTH) { + trimmedVName = s.substring(0, COLUMN_1_WIDTH - 2) + trimmedVName += ".." + } else trimmedVName += " " + val toFill = (spaceRemaining * percent).toInt + s"$trimmedVName${"." * toFill}" + } + + private def getCompletedStages: Int = { + var completed = 0 + progressMap.values.foreach { progress => + val complete = progress.completedSplits + val total = progress.totalSplits + if (total > 0 && complete == total) completed += 1 + } + completed + } + +} + +object TrinoProgressMonitor { + + private val HEADERS: util.List[String] = util.Arrays.asList( + "STAGES", + "STATUS", + "TOTAL", + "COMPLETED", + "RUNNING", + "PENDING", + "FAILED", + "") + + private val COLUMN_1_WIDTH = 16 + +} diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala index 250b8d64b1e..4f1b42e1d1b 100644 --- a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala @@ -41,6 +41,8 @@ class ExecuteStatement( private val operationLog: OperationLog = OperationLog.createOperationLog(session, getHandle) override def getOperationLog: Option[OperationLog] = Option(operationLog) + override protected def supportProgress: Boolean = true + override protected def beforeRun(): Unit = { OperationLog.setCurrentOperationLog(operationLog) setState(OperationState.PENDING) diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala index 822f1726a3b..6afd8c09841 100644 --- a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala @@ -24,16 +24,19 @@ import io.trino.client.StatementClient import org.apache.kyuubi.KyuubiSQLException import org.apache.kyuubi.Utils +import org.apache.kyuubi.config.KyuubiConf.SESSION_PROGRESS_ENABLE import org.apache.kyuubi.engine.trino.TrinoContext +import org.apache.kyuubi.engine.trino.TrinoProgressMonitor import org.apache.kyuubi.engine.trino.schema.{SchemaHelper, TrinoTRowSetGenerator} import org.apache.kyuubi.engine.trino.session.TrinoSessionImpl import org.apache.kyuubi.operation.AbstractOperation import org.apache.kyuubi.operation.FetchIterator import org.apache.kyuubi.operation.FetchOrientation.{FETCH_FIRST, FETCH_NEXT, FETCH_PRIOR, FetchOrientation} import org.apache.kyuubi.operation.OperationState +import org.apache.kyuubi.operation.OperationStatus import org.apache.kyuubi.operation.log.OperationLog import org.apache.kyuubi.session.Session -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.{TFetchResultsResp, TGetResultSetMetadataResp} +import org.apache.kyuubi.shaded.hive.service.rpc.thrift.{TFetchResultsResp, TGetResultSetMetadataResp, TProgressUpdateResp} abstract class TrinoOperation(session: Session) extends AbstractOperation(session) { @@ -45,6 +48,24 @@ abstract class TrinoOperation(session: Session) extends AbstractOperation(sessio protected var iter: FetchIterator[List[Any]] = _ + protected def supportProgress: Boolean = false + + private val progressEnable: Boolean = session.sessionManager.getConf.get(SESSION_PROGRESS_ENABLE) + + override def getStatus: OperationStatus = { + if (progressEnable && supportProgress) { + val progressMonitor = new TrinoProgressMonitor(trino) + setOperationJobProgress(new TProgressUpdateResp( + progressMonitor.headers, + progressMonitor.rows, + progressMonitor.progressedPercentage, + progressMonitor.executionStatus, + progressMonitor.footerSummary, + startTime)) + } + super.getStatus + } + override def getResultSetMetadata: TGetResultSetMetadataResp = { val tTableSchema = SchemaHelper.toTTableSchema(schema) val resp = new TGetResultSetMetadataResp diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/progress/TrinoStage.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/progress/TrinoStage.scala new file mode 100644 index 00000000000..ce1a89ea611 --- /dev/null +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/progress/TrinoStage.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kyuubi.engine.trino.operation.progress + +case class TrinoStage(stageId: String) extends Comparable[TrinoStage] { + override def compareTo(o: TrinoStage): Int = { + stageId.compareTo(o.stageId) + } +} + +case class TrinoStageProgress( + state: String, + totalSplits: Int, + completedSplits: Int, + runningSplits: Int, + failedTasks: Int) diff --git a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationProgressSuite.scala b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationProgressSuite.scala new file mode 100644 index 00000000000..0132735ff2f --- /dev/null +++ b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationProgressSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino.operation + +import scala.collection.JavaConverters._ + +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime + +import org.apache.kyuubi.config.KyuubiConf.{ENGINE_TRINO_CONNECTION_CATALOG, SESSION_PROGRESS_ENABLE} +import org.apache.kyuubi.shaded.hive.service.rpc.thrift.{TExecuteStatementReq, TGetOperationStatusReq, TJobExecutionStatus} + +class TrinoOperationProgressSuite extends TrinoOperationSuite { + override def withKyuubiConf: Map[String, String] = Map( + ENGINE_TRINO_CONNECTION_CATALOG.key -> "memory", + SESSION_PROGRESS_ENABLE.key -> "true") + + test("get operation progress") { + val sql = "select * from (select item from (SELECT sequence(0, 100, 1) as t) as a " + + "CROSS JOIN UNNEST(t) AS temTable (item)) WHERE random() < 0.1" + + withSessionHandle { (client, handle) => + val req = new TExecuteStatementReq() + req.setStatement(sql) + req.setRunAsync(true) + req.setSessionHandle(handle) + val resp = client.ExecuteStatement(req) + eventually(Timeout(25.seconds)) { + val statusReq = new TGetOperationStatusReq(resp.getOperationHandle) + val statusResp = client.GetOperationStatus(statusReq) + val headers = statusResp.getProgressUpdateResponse.getHeaderNames + val progress = statusResp.getProgressUpdateResponse.getProgressedPercentage + val rows = statusResp.getProgressUpdateResponse.getRows + val footerSummary = statusResp.getProgressUpdateResponse.getFooterSummary + val status = statusResp.getProgressUpdateResponse.getStatus + assertResult(Seq( + "STAGES", + "STATUS", + "TOTAL", + "COMPLETED", + "RUNNING", + "PENDING", + "FAILED", + ""))(headers.asScala) + assert(rows.size() == 1) + progress match { + case 100.0 => + assertResult(Seq( + s"Stage-0 ........", + "FINISHED", + "3", + "3", + "0", + "0", + "0", + ""))( + rows.get(0).asScala) + assert("STAGES: 01/01" === footerSummary) + assert(TJobExecutionStatus.COMPLETE === status) + } + } + } + } +}