Skip to content

Commit

Permalink
Make FixLogicalTypeSupplier more permissive (#5233)
Browse files Browse the repository at this point in the history
  • Loading branch information
clairemcginty authored Feb 9, 2024
1 parent 5e9fac0 commit f5cdf9e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,22 @@ object FixLogicalTypeSuppliers {
conf.setClass(AvroReadSupport.AVRO_DATA_SUPPLIER, classOf[AvroLogicalTypeSupplier], classOf[AvroDataSupplier])
conf.setClass(AvroWriteSupport.AVRO_DATA_SUPPLIER, classOf[AvroLogicalTypeSupplier], classOf[LogicalTypeSupplier])
conf.setClass("someClass", classOf[String], classOf[CharSequence])

implicit class WrappedSCollection(val sc: ScioContext) extends AnyVal {
def customMethod[T](input: String, conf: Option[Configuration] = None): SCollection[T] = ???
}

sc.customMethod[String](
"input",
conf = Some(ParquetConfiguration.of(
AvroReadSupport.AVRO_DATA_SUPPLIER -> (classOf[LogicalTypeSupplier])
))
)

sc.customMethod[String](
"input",
Some(ParquetConfiguration.of(
AvroReadSupport.AVRO_DATA_SUPPLIER -> (classOf[LogicalTypeSupplier])
))
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,12 @@ object FixLogicalTypeSuppliers {


conf.setClass("someClass", classOf[String], classOf[CharSequence])

implicit class WrappedSCollection(val sc: ScioContext) extends AnyVal {
def customMethod[T](input: String, conf: Option[Configuration] = None): SCollection[T] = ???
}

sc.customMethod[String]("input")

sc.customMethod[String]("input")
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@ object FixLogicalTypeSupplier {

val JavaClassMatcher: SymbolMatcher = SymbolMatcher.normalized("java/lang/Class")

val OptionMatcher: SymbolMatcher = SymbolMatcher.normalized("scala/Some", "scala/Option")

private val ParquetAvroPrefix = "com/spotify/scio/parquet/avro"
val LogicalTypeSupplierMatcher: SymbolMatcher = SymbolMatcher.normalized(
s"$ParquetAvroPrefix/LogicalTypeSupplier",
"org/apache/beam/sdk/extensions/smb/AvroLogicalTypeSupplier"
)

private val ParquetAvroMatcher = SymbolMatcher.normalized(
s"$ParquetAvroPrefix/syntax/ScioContextOps#parquetAvroFile",
s"$ParquetAvroPrefix/syntax/SCollectionOps#saveAsParquetAvroFile"
)
}

class FixLogicalTypeSupplier extends SemanticRule("FixLogicalTypeSupplier") {
Expand All @@ -49,23 +46,51 @@ class FixLogicalTypeSupplier extends SemanticRule("FixLogicalTypeSupplier") {
}

private def updateIOArgs(fnArgs: List[Term])(implicit doc: SemanticDocument): List[Term] = {
def filterArgs(lhsOpt: Option[Term], rhsOption: Boolean, confArgs: List[Term]): Option[Term] = {
val filtered = parquetConfigurationArgs(confArgs)
(lhsOpt, rhsOption, filtered.isEmpty) match {
case (_, _, true) => None
case (Some(lhs), true, false) => Some(q"$lhs = Some(ParquetConfiguration.of(..$filtered))")
case (Some(lhs), false, false) => Some(q"$lhs = ParquetConfiguration.of(..$filtered)")
case (None, true, false) => Some(q"Some(ParquetConfiguration.of(..$filtered))")
case (None, false, false) => Some(q"ParquetConfiguration.of(..$filtered)")
}
}

fnArgs.flatMap {
case q"$lhs = $fn(..$confArgs)" if ParquetConfigurationMatcher.matches(fn.symbol) =>
val filtered = parquetConfigurationArgs(confArgs)
if (filtered.isEmpty) None else Some(q"$lhs = ParquetConfiguration.of(..$filtered)")
filterArgs(Some(lhs), false, confArgs)
case q"$fn(..$confArgs)" if ParquetConfigurationMatcher.matches(fn.symbol) =>
val filtered = parquetConfigurationArgs(confArgs)
if (filtered.isEmpty) None else Some(q"ParquetConfiguration.of(..$filtered)")
filterArgs(None, false, confArgs)
case q"$lhs = $maybeOpt($fn(..$confArgs))" if ParquetConfigurationMatcher.matches(fn.symbol) && OptionMatcher.matches(maybeOpt) =>
filterArgs(Some(lhs), true, confArgs)
case q"$maybeOpt($fn(..$confArgs))" if ParquetConfigurationMatcher.matches(fn.symbol) && OptionMatcher.matches(maybeOpt) =>
filterArgs(None, true, confArgs)
case a =>
Some(a)
}
}

private def containsConfArg(args: Seq[Term])(implicit doc: SemanticDocument): Boolean = {
def isParquetConf(term: Term): Boolean = ParquetConfigurationMatcher.matches(term.symbol)

args.exists {
case q"$_ = $fn(..$args)" if isParquetConf(fn) => true
case q"$fn(..$args)" if isParquetConf(fn) => true
case q"$_ = Some($fn(..$args))" if isParquetConf(fn) => true
case q"Some($fn(..$args))" if isParquetConf(fn) => true
case _ => false
}
}

override def fix(implicit doc: SemanticDocument): Patch = {
doc.tree.collect {
case method @ q"$fn(..$args)" if ParquetAvroMatcher.matches(fn.symbol) =>
case method @ q"$coll.$fn(..$args)" if containsConfArg(args) =>
val newArgs = updateIOArgs(args)
Patch.replaceTree(method, q"$coll.$fn(..$newArgs)".syntax)
case method @ q"$coll.$fn[$exprs](..$args)" if containsConfArg(args) =>
val newArgs = updateIOArgs(args)
Patch.replaceTree(method, q"$fn(..$newArgs)".syntax)
Patch.replaceTree(method, q"$coll.$fn[$exprs](..$newArgs)".syntax)
case method @ q"$_.$fn($_, $theClass, $xface)" if SetClassMatcher.matches(fn.symbol) =>
if (isLogicalTypeSupplier(theClass) || isLogicalTypeSupplier(xface)) {
Patch.removeTokens(method.tokens)
Expand Down

0 comments on commit f5cdf9e

Please sign in to comment.