diff --git a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/ConverterProvider.scala b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/ConverterProvider.scala index 18dfe4acb8..1a6b81a315 100644 --- a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/ConverterProvider.scala +++ b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/ConverterProvider.scala @@ -270,50 +270,136 @@ private[types] object ConverterProvider { // Converter helpers // ======================================================================= def cast(tree: Tree, tpe: Type): Tree = { + val msg = Constant(s"Cannot convert to ${tpe.typeSymbol.name}: ") + val fail = q"""throw new _root_.java.lang.IllegalArgumentException($msg + $tree)""" + + def readBase64(term: TermName) = + q"_root_.com.google.common.io.BaseEncoding.base64().decode($term)" + val provider: OverrideTypeProvider = OverrideTypeProviderFinder.getProvider - val s = q"$tree.toString" tpe match { case t if provider.shouldOverrideType(c)(t) => provider.createInstance(c)(t, q"$tree") - case t if t =:= typeOf[Boolean] => q"$s.toBoolean" - case t if t =:= typeOf[Int] => q"$s.toInt" - case t if t =:= typeOf[Long] => q"$s.toLong" - case t if t =:= typeOf[Float] => q"$s.toFloat" - case t if t =:= typeOf[Double] => q"$s.toDouble" - case t if t =:= typeOf[String] => q"$s" + case t if t =:= typeOf[Boolean] => + q"""$tree match { + case b: ${typeOf[java.lang.Boolean]} => _root_.scala.Boolean.unbox(b) + case _ => $fail + }""" + case t if t =:= typeOf[Int] => + q"""$tree match { + case i: ${typeOf[java.lang.Integer]} => _root_.scala.Int.unbox(i) + case s: ${typeOf[String]} => s.toInt + case _ => $fail + }""" + case t if t =:= typeOf[Long] => + q"""$tree match { + case l: ${typeOf[java.lang.Long]} => _root_.scala.Long.unbox(l) + case i: ${typeOf[java.lang.Integer]} => _root_.scala.Long.unbox(i).toLong + case s: ${typeOf[String]} => s.toLong + case _ => $fail + }""" + case t if t =:= typeOf[Float] => + q"""$tree match { + case f: ${typeOf[java.lang.Float]} => _root_.scala.Float.unbox(f) + case d: ${typeOf[java.lang.Double]} => _root_.scala.Double.unbox(d).toFloat + case s: ${typeOf[String]} => s.toFloat + case _ => $fail + }""" + case t if t =:= typeOf[Double] => + q"""$tree match { + case d: ${typeOf[java.lang.Double]} => _root_.scala.Double.unbox(d) + case s: ${typeOf[String]} => s.toDouble + case _ => $fail + }""" + case t if t =:= typeOf[String] => + q"""$tree match { + case s: ${typeOf[String]} => s + case _ => $fail + }""" case t if t =:= typeOf[BigDecimal] => - q"_root_.com.spotify.scio.bigquery.Numeric($s)" - + q"""$tree match { + case bd: ${typeOf[BigDecimal]} => + _root_.com.spotify.scio.bigquery.Numeric(bd) + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.Numeric(s) + case _ => $fail + }""" case t if t =:= typeOf[ByteString] => - val b = - q"_root_.com.google.common.io.BaseEncoding.base64().decode($s)" - q"_root_.com.google.protobuf.ByteString.copyFrom($b)" + val s = TermName("s") + q"""$tree match { + case bs: ${typeOf[ByteString]} => bs + case $s: ${typeOf[String]} => + _root_.com.google.protobuf.ByteString.copyFrom(${readBase64(s)}) + case _ => $fail + }""" case t if t =:= typeOf[Array[Byte]] => - q"_root_.com.google.common.io.BaseEncoding.base64().decode($s)" - + val s = TermName("s") + q"""$tree match { + case bs: ${typeOf[Array[Byte]]} => bs + case $s: ${typeOf[String]} => + ${readBase64(s)} + case _ => $fail + }""" case t if t =:= typeOf[Instant] => - q"_root_.com.spotify.scio.bigquery.Timestamp.parse($s)" + q"""$tree match { + case i: ${typeOf[Instant]} => i + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.Timestamp.parse(s) + case _ => $fail + }""" case t if t =:= typeOf[LocalDate] => - q"_root_.com.spotify.scio.bigquery.Date.parse($s)" + q"""$tree match { + case ld: ${typeOf[LocalDate]} => ld + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.Date.parse(s) + case _ => $fail + }""" case t if t =:= typeOf[LocalTime] => - q"_root_.com.spotify.scio.bigquery.Time.parse($s)" + q"""$tree match { + case lt: ${typeOf[LocalTime]} => lt + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.Time.parse(s) + case _ => $fail + }""" case t if t =:= typeOf[LocalDateTime] => - q"_root_.com.spotify.scio.bigquery.DateTime.parse($s)" - + q"""$tree match { + case ldt: ${typeOf[LocalDateTime]} => ldt + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.DateTime.parse(s) + case _ => $fail + }""" // different than nested record match below, even though those are case classes case t if t =:= typeOf[Geography] => - q"_root_.com.spotify.scio.bigquery.types.Geography($s)" + q"""$tree match { + case g: ${typeOf[Geography]} => g + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.types.Geography(s) + case _ => $fail + }""" case t if t =:= typeOf[Json] => - q"_root_.com.spotify.scio.bigquery.types.Json($s)" + q"""$tree match { + case j: ${typeOf[Json]} => j + case tr: ${typeOf[TableRow]} => + _root_.com.spotify.scio.bigquery.types.Json(tr) + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.types.Json(s) + case _ => $fail + }""" case t if t =:= typeOf[BigNumeric] => - q"_root_.com.spotify.scio.bigquery.types.BigNumeric($s)" - + q"""$tree match { + case bn: ${typeOf[BigNumeric]} => bn + case bd: ${typeOf[BigDecimal]} => + _root_.com.spotify.scio.bigquery.types.BigNumeric(bd) + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.types.BigNumeric(s) + case _ => $fail + }""" case t if isCaseClass(c)(t) => // nested records - val fn = TermName("r" + t.typeSymbol.name) - q"""{ - val $fn = $tree.asInstanceOf[_root_.java.util.Map[String, AnyRef]] - ${constructor(t, fn)} + val r = TermName("record") + q"""$tree match { + case $r: ${typeOf[java.util.Map[String, AnyRef]]} => ${constructor(t, r)} + case _ => $fail } """ case _ => c.abort(c.enclosingPosition, s"Unsupported type: $tpe") @@ -321,26 +407,27 @@ private[types] object ConverterProvider { } def option(tree: Tree, tpe: Type): Tree = - q"if ($tree == null) None else Some(${cast(tree, tpe)})" + q"_root_.scala.Option($tree).map(x => ${cast(q"x", tpe)})" - def list(tree: Tree, tpe: Type): Tree = { - val jl = tq"_root_.java.util.List[AnyRef]" - q"asScala($tree.asInstanceOf[$jl].iterator).map(x => ${cast(q"x", tpe)}).toList" - } + def list(tree: Tree, tpe: Type): Tree = + q"asScala($tree.asInstanceOf[${typeOf[java.util.List[AnyRef]]}].iterator).map(x => ${cast(q"x", tpe)}).toList" def field(symbol: Symbol, fn: TermName): Tree = { val name = symbol.name.toString val tpe = symbol.asMethod.returnType val tree = q"$fn.get($name)" - def nonNullTree(fType: String) = + def nonNullTree(fType: String) = { + val msg = Constant(s"$fType field '$name' is null") q"""{ val v = $fn.get($name) if (v == null) { - throw new NullPointerException($fType + " field \"" + $name + "\" is null") + throw new NullPointerException($msg) } v }""" + } + if (tpe.erasure =:= typeOf[Option[_]].erasure) { option(tree, tpe.typeArgs.head) } else if (tpe.erasure =:= typeOf[List[_]].erasure) { diff --git a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/package.scala b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/package.scala index d7265d5459..687c2f17d5 100644 --- a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/package.scala +++ b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/package.scala @@ -17,10 +17,12 @@ package com.spotify.scio.bigquery +import com.fasterxml.jackson.databind.{ObjectMapper, SerializationFeature} +import com.fasterxml.jackson.datatype.joda.JodaModule +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule import com.spotify.scio.coders.Coder import org.apache.avro.Conversions.DecimalConversion import org.apache.avro.LogicalTypes -import org.apache.beam.sdk.extensions.gcp.util.Transport import org.typelevel.scalaccompat.annotation.nowarn import java.math.MathContext @@ -63,11 +65,15 @@ package object types { */ case class Json(wkt: String) object Json { - @transient - private lazy val jsonFactory = Transport.getJsonFactory + // Use same mapper as the TableRowJsonCoder + private lazy val mapper = new ObjectMapper() + .registerModule(new JavaTimeModule()) + .registerModule(new JodaModule()) + .disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS) + .disable(SerializationFeature.FAIL_ON_EMPTY_BEANS); - def apply(row: TableRow): Json = Json(jsonFactory.toString(row)) - def parse(json: Json): TableRow = jsonFactory.fromString(json.wkt, classOf[TableRow]) + def apply(row: TableRow): Json = Json(mapper.writeValueAsString(row)) + def parse(json: Json): TableRow = mapper.readValue(json.wkt, classOf[TableRow]) } /** diff --git a/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigquery/types/ConverterProviderSpec.scala b/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigquery/types/ConverterProviderSpec.scala index 797837c052..68203d2e74 100644 --- a/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigquery/types/ConverterProviderSpec.scala +++ b/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigquery/types/ConverterProviderSpec.scala @@ -53,7 +53,9 @@ final class ConverterProviderSpec } implicit val arbJson: Arbitrary[Json] = Arbitrary( for { - key <- Gen.alphaStr + // f is a field from TableRow. + // Jackson ObjectMapper will fail with such key + key <- Gen.alphaStr.retryUntil(_ != "f") value <- Gen.alphaStr } yield Json(s"""{"$key":"$value"}""") ) diff --git a/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigquery/types/ConverterProviderTest.scala b/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigquery/types/ConverterProviderTest.scala index 72bbc289d9..d62929bc78 100644 --- a/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigquery/types/ConverterProviderTest.scala +++ b/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigquery/types/ConverterProviderTest.scala @@ -17,7 +17,6 @@ package com.spotify.scio.bigquery.types -import com.fasterxml.jackson.databind.node.{JsonNodeFactory, ObjectNode} import com.google.protobuf.ByteString import com.spotify.scio.bigquery._ import org.joda.time.{Instant, LocalDate, LocalDateTime, LocalTime} @@ -30,7 +29,7 @@ class ConverterProviderTest extends AnyFlatSpec with Matchers { "ConverterProvider" should "throw NPE with meaningful message for null in REQUIRED field" in { the[NullPointerException] thrownBy { Required.fromTableRow(TableRow()) - } should have message """REQUIRED field "a" is null""" + } should have message "REQUIRED field 'a' is null" } it should "handle null in NULLABLE field" in { @@ -40,7 +39,7 @@ class ConverterProviderTest extends AnyFlatSpec with Matchers { it should "throw NPE with meaningful message for null in REPEATED field" in { the[NullPointerException] thrownBy { Repeated.fromTableRow(TableRow()) - } should have message """REPEATED field "a" is null""" + } should have message "REPEATED field 'a' is null" } it should "handle required geography type" in { @@ -51,14 +50,12 @@ class ConverterProviderTest extends AnyFlatSpec with Matchers { it should "handle required json type" in { val wkt = """{"name":"Alice","age":30}""" - val jsNodeFactory = new JsonNodeFactory(false) - val jackson = jsNodeFactory - .objectNode() - .set[ObjectNode]("name", jsNodeFactory.textNode("Alice")) - .set[ObjectNode]("age", jsNodeFactory.numberNode(30)) - - RequiredJson.fromTableRow(TableRow("a" -> jackson)) shouldBe RequiredJson(Json(wkt)) - BigQueryType.toTableRow[RequiredJson](RequiredJson(Json(wkt))) shouldBe TableRow("a" -> jackson) + val parsed = new TableRow() + .set("name", "Alice") + .set("age", 30) + + RequiredJson.fromTableRow(TableRow("a" -> parsed)) shouldBe RequiredJson(Json(wkt)) + BigQueryType.toTableRow[RequiredJson](RequiredJson(Json(wkt))) shouldBe TableRow("a" -> parsed) } it should "handle required big numeric type" in {