Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable TableRow converted from BQ types #5536

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -270,77 +270,164 @@ private[types] object ConverterProvider {
// Converter helpers
// =======================================================================
def cast(tree: Tree, tpe: Type): Tree = {
val msg = Constant(s"Cannot convert to ${tpe.typeSymbol.name}: ")
val fail = q"""throw new _root_.java.lang.IllegalArgumentException($msg + $tree)"""

def readBase64(term: TermName) =
q"_root_.com.google.common.io.BaseEncoding.base64().decode($term)"

val provider: OverrideTypeProvider =
OverrideTypeProviderFinder.getProvider
val s = q"$tree.toString"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we were forcing toString before converting back to desired type

tpe match {
case t if provider.shouldOverrideType(c)(t) =>
provider.createInstance(c)(t, q"$tree")
case t if t =:= typeOf[Boolean] => q"$s.toBoolean"
case t if t =:= typeOf[Int] => q"$s.toInt"
case t if t =:= typeOf[Long] => q"$s.toLong"
case t if t =:= typeOf[Float] => q"$s.toFloat"
case t if t =:= typeOf[Double] => q"$s.toDouble"
case t if t =:= typeOf[String] => q"$s"
case t if t =:= typeOf[Boolean] =>
q"""$tree match {
case b: ${typeOf[java.lang.Boolean]} => _root_.scala.Boolean.unbox(b)
case _ => $fail
}"""
case t if t =:= typeOf[Int] =>
q"""$tree match {
case i: ${typeOf[java.lang.Integer]} => _root_.scala.Int.unbox(i)
case s: ${typeOf[String]} => s.toInt
case _ => $fail
}"""
case t if t =:= typeOf[Long] =>
q"""$tree match {
case l: ${typeOf[java.lang.Long]} => _root_.scala.Long.unbox(l)
case i: ${typeOf[java.lang.Integer]} => _root_.scala.Long.unbox(i).toLong
case s: ${typeOf[String]} => s.toLong
case _ => $fail
}"""
case t if t =:= typeOf[Float] =>
q"""$tree match {
case f: ${typeOf[java.lang.Float]} => _root_.scala.Float.unbox(f)
case d: ${typeOf[java.lang.Double]} => _root_.scala.Double.unbox(d).toFloat
case s: ${typeOf[String]} => s.toFloat
case _ => $fail
}"""
case t if t =:= typeOf[Double] =>
q"""$tree match {
case d: ${typeOf[java.lang.Double]} => _root_.scala.Double.unbox(d)
case s: ${typeOf[String]} => s.toDouble
case _ => $fail
}"""
case t if t =:= typeOf[String] =>
q"""$tree match {
case s: ${typeOf[String]} => s
case _ => $fail
}"""
case t if t =:= typeOf[BigDecimal] =>
q"_root_.com.spotify.scio.bigquery.Numeric($s)"

q"""$tree match {
case bd: ${typeOf[BigDecimal]} =>
_root_.com.spotify.scio.bigquery.Numeric(bd)
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.Numeric(s)
case _ => $fail
}"""
case t if t =:= typeOf[ByteString] =>
val b =
q"_root_.com.google.common.io.BaseEncoding.base64().decode($s)"
q"_root_.com.google.protobuf.ByteString.copyFrom($b)"
val s = TermName("s")
q"""$tree match {
case bs: ${typeOf[ByteString]} => bs
case $s: ${typeOf[String]} =>
_root_.com.google.protobuf.ByteString.copyFrom(${readBase64(s)})
case _ => $fail
}"""
case t if t =:= typeOf[Array[Byte]] =>
q"_root_.com.google.common.io.BaseEncoding.base64().decode($s)"

val s = TermName("s")
q"""$tree match {
case bs: ${typeOf[Array[Byte]]} => bs
case $s: ${typeOf[String]} =>
${readBase64(s)}
case _ => $fail
}"""
case t if t =:= typeOf[Instant] =>
q"_root_.com.spotify.scio.bigquery.Timestamp.parse($s)"
q"""$tree match {
case i: ${typeOf[Instant]} => i
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.Timestamp.parse(s)
case _ => $fail
}"""
case t if t =:= typeOf[LocalDate] =>
q"_root_.com.spotify.scio.bigquery.Date.parse($s)"
q"""$tree match {
case ld: ${typeOf[LocalDate]} => ld
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.Date.parse(s)
case _ => $fail
}"""
case t if t =:= typeOf[LocalTime] =>
q"_root_.com.spotify.scio.bigquery.Time.parse($s)"
q"""$tree match {
case lt: ${typeOf[LocalTime]} => lt
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.Time.parse(s)
case _ => $fail
}"""
case t if t =:= typeOf[LocalDateTime] =>
q"_root_.com.spotify.scio.bigquery.DateTime.parse($s)"

q"""$tree match {
case ldt: ${typeOf[LocalDateTime]} => ldt
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.DateTime.parse(s)
case _ => $fail
}"""
// different than nested record match below, even though those are case classes
case t if t =:= typeOf[Geography] =>
q"_root_.com.spotify.scio.bigquery.types.Geography($s)"
q"""$tree match {
case g: ${typeOf[Geography]} => g
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.types.Geography(s)
case _ => $fail
}"""
case t if t =:= typeOf[Json] =>
q"_root_.com.spotify.scio.bigquery.types.Json($s)"
q"""$tree match {
case j: ${typeOf[Json]} => j
case tr: ${typeOf[TableRow]} =>
_root_.com.spotify.scio.bigquery.types.Json(tr)
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.types.Json(s)
case _ => $fail
}"""
case t if t =:= typeOf[BigNumeric] =>
q"_root_.com.spotify.scio.bigquery.types.BigNumeric($s)"

q"""$tree match {
case bn: ${typeOf[BigNumeric]} => bn
case bd: ${typeOf[BigDecimal]} =>
_root_.com.spotify.scio.bigquery.types.BigNumeric(bd)
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.types.BigNumeric(s)
case _ => $fail
}"""
case t if isCaseClass(c)(t) => // nested records
val fn = TermName("r" + t.typeSymbol.name)
q"""{
val $fn = $tree.asInstanceOf[_root_.java.util.Map[String, AnyRef]]
${constructor(t, fn)}
val r = TermName("record")
q"""$tree match {
case $r: ${typeOf[java.util.Map[String, AnyRef]]} => ${constructor(t, r)}
case _ => $fail
}
"""
case _ => c.abort(c.enclosingPosition, s"Unsupported type: $tpe")
}
}

def option(tree: Tree, tpe: Type): Tree =
q"if ($tree == null) None else Some(${cast(tree, tpe)})"
q"_root_.scala.Option($tree).map(x => ${cast(q"x", tpe)})"

def list(tree: Tree, tpe: Type): Tree = {
val jl = tq"_root_.java.util.List[AnyRef]"
q"asScala($tree.asInstanceOf[$jl].iterator).map(x => ${cast(q"x", tpe)}).toList"
}
def list(tree: Tree, tpe: Type): Tree =
q"asScala($tree.asInstanceOf[${typeOf[java.util.List[AnyRef]]}].iterator).map(x => ${cast(q"x", tpe)}).toList"

def field(symbol: Symbol, fn: TermName): Tree = {
val name = symbol.name.toString
val tpe = symbol.asMethod.returnType

val tree = q"$fn.get($name)"
def nonNullTree(fType: String) =
def nonNullTree(fType: String) = {
val msg = Constant(s"$fType field '$name' is null")
q"""{
val v = $fn.get($name)
if (v == null) {
throw new NullPointerException($fType + " field \"" + $name + "\" is null")
throw new NullPointerException($msg)
}
v
}"""
}

if (tpe.erasure =:= typeOf[Option[_]].erasure) {
option(tree, tpe.typeArgs.head)
} else if (tpe.erasure =:= typeOf[List[_]].erasure) {
Expand Down Expand Up @@ -387,10 +474,10 @@ private[types] object ConverterProvider {
case t if provider.shouldOverrideType(c)(t) => q"$tree.toString"
case t if t =:= typeOf[Boolean] => tree
case t if t =:= typeOf[Int] => tree
case t if t =:= typeOf[Long] => tree
case t if t =:= typeOf[Float] => tree
case t if t =:= typeOf[Double] => tree
case t if t =:= typeOf[String] => tree
case t if t =:= typeOf[Long] => q"$tree.toString" // json doesn't support long
case t if t =:= typeOf[Float] => q"$tree.toDouble" // json doesn't support float
case t if t =:= typeOf[Double] => tree
case t if t =:= typeOf[String] => tree

case t if t =:= typeOf[BigDecimal] =>
q"_root_.com.spotify.scio.bigquery.Numeric($tree).toString"
Expand All @@ -412,7 +499,7 @@ private[types] object ConverterProvider {
case t if t =:= typeOf[Geography] =>
q"$tree.wkt"
case t if t =:= typeOf[Json] =>
// for TableRow/json, use JSON to prevent escaping
// for TableRow/json, use parsed JSON to prevent escaping
q"_root_.com.spotify.scio.bigquery.types.Json.parse($tree)"
case t if t =:= typeOf[BigNumeric] =>
// for TableRow/json, use string to avoid precision loss (like numeric)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package com.spotify.scio.bigquery

import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
import com.fasterxml.jackson.databind.{ObjectMapper, SerializationFeature}
import com.fasterxml.jackson.datatype.joda.JodaModule
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
import com.spotify.scio.coders.Coder
import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.LogicalTypes
Expand Down Expand Up @@ -63,10 +65,15 @@ package object types {
*/
case class Json(wkt: String)
object Json {
// Use same mapper as the TableRowJsonCoder
private lazy val mapper = new ObjectMapper()
.registerModule(new JavaTimeModule())
.registerModule(new JodaModule())
.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS)
.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS);

def apply(node: JsonNode): Json = Json(mapper.writeValueAsString(node))
def parse(json: Json): JsonNode = mapper.readTree(json.wkt)
def apply(row: TableRow): Json = Json(mapper.writeValueAsString(row))
def parse(json: Json): TableRow = mapper.readValue(json.wkt, classOf[TableRow])
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ final class ConverterProviderSpec
}
implicit val arbJson: Arbitrary[Json] = Arbitrary(
for {
key <- Gen.alphaStr
// f is a field from TableRow.
// Jackson ObjectMapper will fail with such key
key <- Gen.alphaStr.retryUntil(_ != "f")
Comment on lines +56 to +58
Copy link
Contributor Author

@RustedBones RustedBones Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering what happens if a BQ table has a field named f. It's probable that we can't use the TableRow API

value <- Gen.alphaStr
} yield Json(s"""{"$key":"$value"}""")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

package com.spotify.scio.bigquery.types

import com.fasterxml.jackson.databind.node.{JsonNodeFactory, ObjectNode}
import com.google.protobuf.ByteString
import com.spotify.scio.bigquery._
import org.joda.time.{Instant, LocalDate, LocalDateTime, LocalTime}
import org.scalatest.matchers.should.Matchers
import org.scalatest.flatspec.AnyFlatSpec

Expand All @@ -28,7 +29,7 @@ class ConverterProviderTest extends AnyFlatSpec with Matchers {
"ConverterProvider" should "throw NPE with meaningful message for null in REQUIRED field" in {
the[NullPointerException] thrownBy {
Required.fromTableRow(TableRow())
} should have message """REQUIRED field "a" is null"""
} should have message "REQUIRED field 'a' is null"
}

it should "handle null in NULLABLE field" in {
Expand All @@ -38,7 +39,7 @@ class ConverterProviderTest extends AnyFlatSpec with Matchers {
it should "throw NPE with meaningful message for null in REPEATED field" in {
the[NullPointerException] thrownBy {
Repeated.fromTableRow(TableRow())
} should have message """REPEATED field "a" is null"""
} should have message "REPEATED field 'a' is null"
}

it should "handle required geography type" in {
Expand All @@ -49,14 +50,12 @@ class ConverterProviderTest extends AnyFlatSpec with Matchers {

it should "handle required json type" in {
val wkt = """{"name":"Alice","age":30}"""
val jsNodeFactory = new JsonNodeFactory(false)
val jackson = jsNodeFactory
.objectNode()
.set[ObjectNode]("name", jsNodeFactory.textNode("Alice"))
.set[ObjectNode]("age", jsNodeFactory.numberNode(30))

RequiredJson.fromTableRow(TableRow("a" -> jackson)) shouldBe RequiredJson(Json(wkt))
BigQueryType.toTableRow[RequiredJson](RequiredJson(Json(wkt))) shouldBe TableRow("a" -> jackson)
val parsed = new TableRow()
.set("name", "Alice")
.set("age", 30)

RequiredJson.fromTableRow(TableRow("a" -> parsed)) shouldBe RequiredJson(Json(wkt))
BigQueryType.toTableRow[RequiredJson](RequiredJson(Json(wkt))) shouldBe TableRow("a" -> parsed)
}

it should "handle required big numeric type" in {
Expand All @@ -74,6 +73,13 @@ class ConverterProviderTest extends AnyFlatSpec with Matchers {
RequiredWithMethod.fromTableRow(TableRow("a" -> "")) shouldBe RequiredWithMethod("")
BigQueryType.toTableRow[RequiredWithMethod](RequiredWithMethod("")) shouldBe TableRow("a" -> "")
}

it should "convert to stable types for the coder" in {
import com.spotify.scio.testing.CoderAssertions._
// Coder[TableRow] is destructive
// make sure the target TableRow format chosen by the BigQueryType conversion is stable
AllTypes.toTableRow(AllTypes()) coderShould roundtrip()
}
}

object ConverterProviderTest {
Expand Down Expand Up @@ -102,4 +108,23 @@ object ConverterProviderTest {
def accessorMethod: String = ""
def method(x: String): String = x
}

@BigQueryType.toTable
case class AllTypes(
bool: Boolean = true,
int: Int = 1,
long: Long = 2L,
float: Float = 3.3f,
double: Double = 4.4,
numeric: BigDecimal = BigDecimal(5),
string: String = "6",
byteString: ByteString = ByteString.copyFromUtf8("7"),
timestamp: Instant = Instant.now(),
date: LocalDate = LocalDate.now(),
time: LocalTime = LocalTime.now(),
datetime: LocalDateTime = LocalDateTime.now(),
geography: Geography = Geography("POINT (8 8)"),
json: Json = Json("""{"key": 9,"value": 10}"""),
bigNumeric: BigNumeric = BigNumeric(BigDecimal(11))
)
}
Loading