Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Refactor metrics #49629

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions core/src/main/scala/org/apache/spark/util/MetricUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package org.apache.spark.util

import java.text.NumberFormat
import java.util.{Arrays, Locale}

import scala.concurrent.duration._

import org.apache.spark.SparkException
import org.apache.spark.util.Utils

object MetricUtils {

val SUM_METRIC: String = "sum"
val SIZE_METRIC: String = "size"
val TIMING_METRIC: String = "timing"
val NS_TIMING_METRIC: String = "nsTiming"
val AVERAGE_METRIC: String = "average"
private val baseForAvgMetric: Int = 1000
private val METRICS_NAME_SUFFIX = "(min, med, max (stageId: taskId))"

private def toNumberFormat(value: Long): String = {
val numberFormat = NumberFormat.getNumberInstance(Locale.US)
numberFormat.format(value.toDouble / baseForAvgMetric)
}

def metricNeedsMax(metricsType: String): Boolean = {
metricsType != SUM_METRIC
}

/**
* A function that defines how we aggregate the final accumulator results among all tasks,
* and represent it in string for a SQL physical operator.
*/
def stringValue(metricsType: String, values: Array[Long], maxMetrics: Array[Long]): String = {
// taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId)
val taskInfo = if (maxMetrics.isEmpty) {
"(driver)"
} else {
s"(stage ${maxMetrics(1)}.${maxMetrics(2)}: task ${maxMetrics(3)})"
}
if (metricsType == SUM_METRIC) {
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
numberFormat.format(values.sum)
} else if (metricsType == AVERAGE_METRIC) {
val validValues = values.filter(_ > 0)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
toNumberFormat(validValues.headOption.getOrElse(0))
} else {
val Seq(min, med, max) = {
Arrays.sort(validValues)
Seq(
toNumberFormat(validValues(0)),
toNumberFormat(validValues(validValues.length / 2)),
toNumberFormat(validValues(validValues.length - 1)))
}
s"$METRICS_NAME_SUFFIX:\n($min, $med, $max $taskInfo)"
}
} else {
val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
Utils.bytesToString
} else if (metricsType == TIMING_METRIC) {
Utils.msDurationToString
} else if (metricsType == NS_TIMING_METRIC) {
duration => Utils.msDurationToString(duration.nanos.toMillis)
} else {
throw SparkException.internalError(s"unexpected metrics type: $metricsType")
}

val validValues = values.filter(_ >= 0)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
strFormat(validValues.headOption.getOrElse(0))
} else {
val Seq(sum, min, med, max) = {
Arrays.sort(validValues)
Seq(
strFormat(validValues.sum),
strFormat(validValues(0)),
strFormat(validValues(validValues.length / 2)),
strFormat(validValues(validValues.length - 1)))
}
s"total $METRICS_NAME_SUFFIX\n$sum ($min, $med, $max $taskInfo)"
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
package org.apache.spark.sql.execution.datasources.v2.python

import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.PythonSQLMetrics
import org.apache.spark.util.MetricUtils


class PythonCustomMetric(
Expand All @@ -28,7 +29,7 @@ class PythonCustomMetric(
def this() = this(null, null)

override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
MetricUtils.stringValue("size", taskMetrics, Array.empty[Long])
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,16 @@

package org.apache.spark.sql.execution.metric

import java.text.NumberFormat
import java.util.{Arrays, Locale}

import scala.concurrent.duration._
// import scala.concurrent.duration._

import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}

import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.SparkContext
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2}
import org.apache.spark.util.AccumulatorContext.internOption

/**
Expand Down Expand Up @@ -72,7 +69,7 @@ class SQLMetric(

// This is used to filter out metrics. Metrics with value equal to initValue should
// be filtered out, since they are either invalid or safe to filter without changing
// the aggregation defined in [[SQLMetrics.stringValue]].
// the aggregation defined in [[MetricUtils.stringValue]].
// Note that we don't use 0 here since we may want to collect 0 metrics for
// calculating min, max, etc. See SPARK-11013.
override def isZero: Boolean = _value == initValue
Expand Down Expand Up @@ -106,8 +103,8 @@ class SQLMetric(
SQLMetrics.cachedSQLAccumIdentifier)
}

// We should provide the raw value which can be -1, so that `SQLMetrics.stringValue` can correctly
// filter out the invalid -1 values.
// We should provide the raw value which can be -1, so that `MetricUtils.stringValue` can
// correctly filter out the invalid -1 values.
override def toInfoUpdate: AccumulableInfo = {
AccumulableInfo(id, name, internOption(Some(_value)), None, true, true,
SQLMetrics.cachedSQLAccumIdentifier)
Expand Down Expand Up @@ -203,77 +200,6 @@ object SQLMetrics {
acc
}

private def toNumberFormat(value: Long): String = {
val numberFormat = NumberFormat.getNumberInstance(Locale.US)
numberFormat.format(value.toDouble / baseForAvgMetric)
}

def metricNeedsMax(metricsType: String): Boolean = {
metricsType != SUM_METRIC
}

private val METRICS_NAME_SUFFIX = "(min, med, max (stageId: taskId))"

/**
* A function that defines how we aggregate the final accumulator results among all tasks,
* and represent it in string for a SQL physical operator.
*/
def stringValue(metricsType: String, values: Array[Long], maxMetrics: Array[Long]): String = {
// taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId)
val taskInfo = if (maxMetrics.isEmpty) {
"(driver)"
} else {
s"(stage ${maxMetrics(1)}.${maxMetrics(2)}: task ${maxMetrics(3)})"
}
if (metricsType == SUM_METRIC) {
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
numberFormat.format(values.sum)
} else if (metricsType == AVERAGE_METRIC) {
val validValues = values.filter(_ > 0)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
toNumberFormat(validValues.headOption.getOrElse(0))
} else {
val Seq(min, med, max) = {
Arrays.sort(validValues)
Seq(
toNumberFormat(validValues(0)),
toNumberFormat(validValues(validValues.length / 2)),
toNumberFormat(validValues(validValues.length - 1)))
}
s"$METRICS_NAME_SUFFIX:\n($min, $med, $max $taskInfo)"
}
} else {
val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
Utils.bytesToString
} else if (metricsType == TIMING_METRIC) {
Utils.msDurationToString
} else if (metricsType == NS_TIMING_METRIC) {
duration => Utils.msDurationToString(duration.nanos.toMillis)
} else {
throw SparkException.internalError(s"unexpected metrics type: $metricsType")
}

val validValues = values.filter(_ >= 0)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
strFormat(validValues.headOption.getOrElse(0))
} else {
val Seq(sum, min, med, max) = {
Arrays.sort(validValues)
Seq(
strFormat(validValues.sum),
strFormat(validValues(0)),
strFormat(validValues(validValues.length / 2)),
strFormat(validValues(validValues.length - 1)))
}
s"total $METRICS_NAME_SUFFIX\n$sum ($min, $med, $max $taskInfo)"
}
}
}

def postDriverMetricsUpdatedByValue(
sc: SparkContext,
executionId: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
import org.apache.spark.util.Utils
import org.apache.spark.util.{MetricUtils, Utils}
import org.apache.spark.util.collection.OpenHashMap

class SQLAppStatusListener(
Expand Down Expand Up @@ -235,7 +235,7 @@ class SQLAppStatusListener(
}
}.getOrElse(
// Built-in SQLMetric
SQLMetrics.stringValue(m.metricType, _, _)
MetricUtils.stringValue(m.metricType, _, _)
)
(m.accumulatorId, metricAggMethod)
}.toMap
Expand Down Expand Up @@ -554,7 +554,7 @@ private class LiveStageMetrics(
/**
* Task metrics values for the stage. Maps the metric ID to the metric values for each
* index. For each metric ID, there will be the same number of values as the number
* of indices. This relies on `SQLMetrics.stringValue` treating 0 as a neutral value,
* of indices. This relies on `MetricUtils.stringValue` treating 0 as a neutral value,
* independent of the actual metric type.
*/
private val taskMetrics = new ConcurrentHashMap[Long, Array[Long]]()
Expand Down Expand Up @@ -601,7 +601,7 @@ private class LiveStageMetrics(
val metricValues = taskMetrics.computeIfAbsent(acc.id, _ => new Array(numTasks))
metricValues(taskIdx) = value

if (SQLMetrics.metricNeedsMax(accumIdsToMetricType(acc.id))) {
if (MetricUtils.metricNeedsMax(accumIdsToMetricType(acc.id))) {
val maxMetricsTaskId = metricsIdToMaxTaskValue.computeIfAbsent(acc.id, _ => Array(value,
taskId))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.status.{AppStatusStore, ElementTrackingStore}
import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator, SerializableConfiguration, Utils}
import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator, MetricUtils, SerializableConfiguration, Utils}
import org.apache.spark.util.kvstore.InMemoryStore


Expand Down Expand Up @@ -597,9 +597,9 @@ abstract class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTes
val metrics = statusStore.executionMetrics(execId)
val driverMetric = physicalPlan.metrics("dummy")
val driverMetric2 = physicalPlan.metrics("dummy2")
val expectedValue = SQLMetrics.stringValue(driverMetric.metricType,
val expectedValue = MetricUtils.stringValue(driverMetric.metricType,
Array(expectedAccumValue), Array.empty[Long])
val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType,
val expectedValue2 = MetricUtils.stringValue(driverMetric2.metricType,
Array(expectedAccumValue2), Array.empty[Long])

assert(metrics.contains(driverMetric.id))
Expand Down
Loading