Skip to content

Commit

Permalink
refactor: enable Spotless for Scala code (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-coleman authored Oct 16, 2024
1 parent 357cc01 commit 06c9c36
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 56 deletions.
70 changes: 70 additions & 0 deletions spark/.scalafmt.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
runner.dialect = scala212

# Version is required to make sure IntelliJ picks the right version
version = 3.7.3
preset = default

# Max column
maxColumn = 100

# This parameter simply says the .stripMargin method was not redefined by the user to assign
# special meaning to indentation preceding the | character. Hence, that indentation can be modified.
assumeStandardLibraryStripMargin = true
align.stripMargin = true

# Align settings
align.preset = none
align.closeParenSite = false
align.openParenCallSite = false
danglingParentheses.defnSite = false
danglingParentheses.callSite = false
danglingParentheses.ctrlSite = true
danglingParentheses.tupleSite = false
align.openParenCallSite = false
align.openParenDefnSite = false
align.openParenTupleSite = false

# Newlines
newlines.alwaysBeforeElseAfterCurlyIf = false
newlines.beforeCurlyLambdaParams = multiline # Newline before lambda params
newlines.afterCurlyLambdaParams = squash # No newline after lambda params
newlines.inInterpolation = "avoid"
newlines.avoidInResultType = true
optIn.annotationNewlines = true

# Scaladoc
docstrings.style = Asterisk # Javadoc style
docstrings.removeEmpty = true
docstrings.oneline = fold
docstrings.forceBlankLineBefore = true

# Indentation
indent.extendSite = 2 # This makes sure extend is not indented as the ctor parameters

# Rewrites
rewrite.rules = [AvoidInfix, Imports, RedundantBraces, SortModifiers]

# Imports
rewrite.imports.sort = scalastyle
rewrite.imports.groups = [
["io.substrait.spark\\..*"],
["org.apache.spark\\..*"],
[".*"],
["javax\\..*"],
["java\\..*"],
["scala\\..*"]
]
rewrite.imports.contiguousGroups = no
importSelectors = singleline # Imports in a single line, like IntelliJ

# Remove redundant braces in string interpolation.
rewrite.redundantBraces.stringInterpolation = true
rewrite.redundantBraces.defnBodies = false
rewrite.redundantBraces.generalExpressions = false
rewrite.redundantBraces.ifElseExpressions = false
rewrite.redundantBraces.methodBodies = false
rewrite.redundantBraces.includeUnitMethods = false
rewrite.redundantBraces.maxBreaks = 1

# Remove trailing commas
rewrite.trailingCommas.style = "never"
7 changes: 7 additions & 0 deletions spark/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ dependencies {
testImplementation("org.apache.spark:spark-catalyst_2.12:${SPARK_VERSION}:tests")
}

spotless {
scala {
scalafmt().configFile(".scalafmt.conf")
toggleOffOn()
}
}

tasks {
test {
dependsOn(":core:shadowJar")
Expand Down
5 changes: 3 additions & 2 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
*/
package io.substrait.spark

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types._

import io.substrait.`type`.{NamedStruct, Type, TypeVisitor}
import io.substrait.function.TypeExpression
import io.substrait.utils.Util
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types._

import scala.collection.JavaConverters
import scala.collection.JavaConverters.asScalaBufferConverter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package io.substrait.spark.expression

import io.substrait.spark.ToSubstraitType

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion}
Expand All @@ -29,7 +31,6 @@ import io.substrait.expression.{Expression => SExpression, ExpressionCreator, Fu
import io.substrait.expression.Expression.FailureBehavior
import io.substrait.extension.SimpleExtension
import io.substrait.function.{ParameterizedType, ToTypeString}
import io.substrait.spark.ToSubstraitType
import io.substrait.utils.Util

import java.{util => ju}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType)
typeToMatch.isInstanceOf[Type.IntervalYear]

override def visit(`type`: Type.IntervalDay): Boolean =
typeToMatch.isInstanceOf[Type.IntervalDay] || typeToMatch.isInstanceOf[ParameterizedType.IntervalDay]
typeToMatch.isInstanceOf[Type.IntervalDay] || typeToMatch
.isInstanceOf[ParameterizedType.IntervalDay]

override def visit(`type`: Type.IntervalCompound): Boolean =
typeToMatch.isInstanceOf[Type.IntervalCompound] || typeToMatch.isInstanceOf[ParameterizedType.IntervalCompound]
typeToMatch.isInstanceOf[Type.IntervalCompound] || typeToMatch
.isInstanceOf[ParameterizedType.IntervalCompound]

override def visit(`type`: Type.UUID): Boolean = typeToMatch.isInstanceOf[Type.UUID]

Expand Down Expand Up @@ -109,11 +111,13 @@ class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType)

@throws[RuntimeException]
override def visit(expr: ParameterizedType.IntervalDay): Boolean =
typeToMatch.isInstanceOf[Type.IntervalDay] || typeToMatch.isInstanceOf[ParameterizedType.IntervalDay]
typeToMatch.isInstanceOf[Type.IntervalDay] || typeToMatch
.isInstanceOf[ParameterizedType.IntervalDay]

@throws[RuntimeException]
override def visit(expr: ParameterizedType.IntervalCompound): Boolean =
typeToMatch.isInstanceOf[Type.IntervalCompound] || typeToMatch.isInstanceOf[ParameterizedType.IntervalCompound]
typeToMatch.isInstanceOf[Type.IntervalCompound] || typeToMatch
.isInstanceOf[ParameterizedType.IntervalCompound]

@throws[RuntimeException]
override def visit(expr: ParameterizedType.Struct): Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ package io.substrait.spark.expression

import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType}
import io.substrait.spark.logical.ToLogicalPlan

import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.substrait.SparkTypeUtil
import org.apache.spark.unsafe.types.UTF8String

import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
import io.substrait.util.DecimalUtil
import org.apache.spark.substrait.SparkTypeUtil

import scala.collection.JavaConverters.asScalaBufferConverter

Expand Down Expand Up @@ -132,31 +134,34 @@ class ToSparkExpression(
}

expr.declaration.name match {
case "make_decimal" if expr.declaration.uri == SparkExtension.uri => expr.outputType match {
// Need special case handing of this internal function.
// Because the precision and scale arguments are extracted from the output type,
// we can't use the generic scalar function conversion mechanism here.
case d: Type.Decimal => MakeDecimal(args.head, d.precision, d.scale)
case _ => throw new IllegalArgumentException("Output type of MakeDecimal must be a decimal type")
}
case _ => scalarFunctionConverter
.getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
.flatMap(sig => Option(sig.makeCall(args)))
.getOrElse({
val msg = String.format(
"Unable to convert scalar function %s(%s).",
expr.declaration.name,
expr.arguments.asScala
.map {
case ea: exp.EnumArg => ea.value.toString
case e: SExpression => e.getType.accept(new StringTypeVisitor)
case t: Type => t.accept(new StringTypeVisitor)
case a => throw new IllegalStateException("Unexpected value: " + a)
}
.mkString(", ")
)
throw new IllegalArgumentException(msg)
})
case "make_decimal" if expr.declaration.uri == SparkExtension.uri =>
expr.outputType match {
// Need special case handing of this internal function.
// Because the precision and scale arguments are extracted from the output type,
// we can't use the generic scalar function conversion mechanism here.
case d: Type.Decimal => MakeDecimal(args.head, d.precision, d.scale)
case _ =>
throw new IllegalArgumentException("Output type of MakeDecimal must be a decimal type")
}
case _ =>
scalarFunctionConverter
.getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
.flatMap(sig => Option(sig.makeCall(args)))
.getOrElse({
val msg = String.format(
"Unable to convert scalar function %s(%s).",
expr.declaration.name,
expr.arguments.asScala
.map {
case ea: exp.EnumArg => ea.value.toString
case e: SExpression => e.getType.accept(new StringTypeVisitor)
case t: Type => t.accept(new StringTypeVisitor)
case a => throw new IllegalStateException("Unexpected value: " + a)
}
.mkString(", ")
)
throw new IllegalArgumentException(msg)
})
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ package io.substrait.spark.expression
import io.substrait.spark.{HasOutputStack, ToSubstraitType}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.substrait.SparkTypeUtil

import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FieldReference, ImmutableExpression}
import io.substrait.expression.Expression.FailureBehavior
import io.substrait.utils.Util
import org.apache.spark.substrait.SparkTypeUtil

import scala.collection.JavaConverters.asJavaIterableConverter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
*/
package io.substrait.spark.expression

import io.substrait.spark.ToSubstraitType

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import io.substrait.expression.{Expression => SExpression}
import io.substrait.expression.ExpressionCreator._
import io.substrait.spark.ToSubstraitType

class ToSubstraitLiteral {

Expand Down
26 changes: 16 additions & 10 deletions spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package io.substrait.spark.logical

import io.substrait.spark.{DefaultRelVisitor, SparkExtension, ToSubstraitType}
import io.substrait.spark.expression._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.types.{DataTypes, IntegerType, StructField, StructType}

import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
Expand Down Expand Up @@ -167,14 +169,14 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
if (limit >= 0) {
val limitExpr = toLiteral(limit)
if (offset > 0) {
GlobalLimit(limitExpr,
Offset(toLiteral(offset),
LocalLimit(toLiteral(offset + limit), child)))
GlobalLimit(
limitExpr,
Offset(toLiteral(offset), LocalLimit(toLiteral(offset + limit), child)))
} else {
GlobalLimit(limitExpr, LocalLimit(limitExpr, child))
}
} else {
Offset(toLiteral(offset), child)
Offset(toLiteral(offset), child)
}
}
override def visit(sort: relation.Sort): LogicalPlan = {
Expand Down Expand Up @@ -213,13 +215,16 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
withChild(child) {
val projections = expand.getFields.asScala
.map {
case sf: SwitchingField => sf.getDuplicates.asScala
.map(expr => expr.accept(expressionConverter))
.map(toNamedExpression)
case _: ConsistentField => throw new UnsupportedOperationException("ConsistentField not currently supported")
case sf: SwitchingField =>
sf.getDuplicates.asScala
.map(expr => expr.accept(expressionConverter))
.map(toNamedExpression)
case _: ConsistentField =>
throw new UnsupportedOperationException("ConsistentField not currently supported")
}

val output = projections.head.zip(names)
val output = projections.head
.zip(names)
.map { case (t, name) => StructField(name, t.dataType, t.nullable) }
.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())

Expand All @@ -240,7 +245,8 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
withOutput(children.flatMap(_.output)) {
set.getSetOp match {
case SetOp.UNION_ALL => Union(children, byName = false, allowMissingCol = false)
case op => throw new UnsupportedOperationException(s"Operation not currently supported: $op")
case op =>
throw new UnsupportedOperationException(s"Operation not currently supported: $op")
}
}
}
Expand Down
Loading

0 comments on commit 06c9c36

Please sign in to comment.