From fdd4fbbb85f3ef5dd55fb1815b4bbdcadb0d0818 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Thu, 19 Dec 2024 16:02:10 +0100 Subject: [PATCH] Avoid toString conversion --- .../bigquery/types/ConverterProvider.scala | 149 ++++++++++++++---- .../spotify/scio/bigquery/types/package.scala | 16 +- .../types/ConverterProviderTest.scala | 15 +- 3 files changed, 136 insertions(+), 44 deletions(-) diff --git a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/ConverterProvider.scala b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/ConverterProvider.scala index 18dfe4acb8..a39a1579ef 100644 --- a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/ConverterProvider.scala +++ b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/ConverterProvider.scala @@ -270,50 +270,136 @@ private[types] object ConverterProvider { // Converter helpers // ======================================================================= def cast(tree: Tree, tpe: Type): Tree = { + val msg = Constant(s"Cannot convert to ${tpe.typeSymbol.name}: ") + val fail = q"""throw new _root_.java.lang.IllegalArgumentException($msg + $tree)""" + + def readBase64(term: TermName) = + s"_root_.com.google.common.io.BaseEncoding.base64().decode($term)" + val provider: OverrideTypeProvider = OverrideTypeProviderFinder.getProvider - val s = q"$tree.toString" tpe match { case t if provider.shouldOverrideType(c)(t) => provider.createInstance(c)(t, q"$tree") - case t if t =:= typeOf[Boolean] => q"$s.toBoolean" - case t if t =:= typeOf[Int] => q"$s.toInt" - case t if t =:= typeOf[Long] => q"$s.toLong" - case t if t =:= typeOf[Float] => q"$s.toFloat" - case t if t =:= typeOf[Double] => q"$s.toDouble" - case t if t =:= typeOf[String] => q"$s" + case t if t =:= typeOf[Boolean] => + q"""$tree match { + case b: ${typeOf[java.lang.Boolean]} => _root_.scala.Boolean.unbox(b) + case _ => $fail + }""" + case t if t =:= typeOf[Int] => + q"""$tree match { + case i: ${typeOf[java.lang.Integer]} => _root_.scala.Int.unbox(i) + case s: ${typeOf[String]} => s.toInt + case _ => $fail + }""" + case t if t =:= typeOf[Long] => + q"""$tree match { + case l: ${typeOf[java.lang.Long]} => _root_.scala.Long.unbox(l) + case i: ${typeOf[java.lang.Integer]} => _root_.scala.Long.unbox(i).toLong + case s: ${typeOf[String]} => s.toLong + case _ => $fail + }""" + case t if t =:= typeOf[Float] => + q"""$tree match { + case f: ${typeOf[java.lang.Float]} => _root_.scala.Float.unbox(f) + case d: ${typeOf[java.lang.Double]} => _root_.scala.Double.unbox(d).toFloat + case s: ${typeOf[String]} => s.toFloat + case _ => $fail + }""" + case t if t =:= typeOf[Double] => + q"""$tree match { + case d: ${typeOf[java.lang.Double]} => _root_.scala.Double.unbox(d) + case s: ${typeOf[String]} => s.toDouble + case _ => $fail + }""" + case t if t =:= typeOf[String] => + q"""$tree match { + case s: ${typeOf[String]} => s + case _ => $fail + }""" case t if t =:= typeOf[BigDecimal] => - q"_root_.com.spotify.scio.bigquery.Numeric($s)" - + q"""$tree match { + case bd: ${typeOf[BigDecimal]} => + _root_.com.spotify.scio.bigquery.Numeric(bd) + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.Numeric(s) + case _ => $fail + }""" case t if t =:= typeOf[ByteString] => - val b = - q"_root_.com.google.common.io.BaseEncoding.base64().decode($s)" - q"_root_.com.google.protobuf.ByteString.copyFrom($b)" + val s = TermName("s") + q"""$tree match { + case bs: ${typeOf[ByteString]} => bs + case $s: ${typeOf[String]} => + _root_.com.google.protobuf.ByteString.copyFrom(${readBase64(s)}) + case _ => $fail + }""" case t if t =:= typeOf[Array[Byte]] => - q"_root_.com.google.common.io.BaseEncoding.base64().decode($s)" - + val s = TermName("s") + q"""$tree match { + case bs: ${typeOf[Array[Byte]]} => bs + case $s: ${typeOf[String]} => + ${readBase64(s)} + case _ => $fail + }""" case t if t =:= typeOf[Instant] => - q"_root_.com.spotify.scio.bigquery.Timestamp.parse($s)" + q"""$tree match { + case i: ${typeOf[Instant]} => i + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.Timestamp.parse(s) + case _ => $fail + }""" case t if t =:= typeOf[LocalDate] => - q"_root_.com.spotify.scio.bigquery.Date.parse($s)" + q"""$tree match { + case ld: ${typeOf[LocalDate]} => ld + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.Date.parse(s) + case _ => $fail + }""" case t if t =:= typeOf[LocalTime] => - q"_root_.com.spotify.scio.bigquery.Time.parse($s)" + q"""$tree match { + case lt: ${typeOf[LocalTime]} => lt + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.Time.parse(s) + case _ => $fail + }""" case t if t =:= typeOf[LocalDateTime] => - q"_root_.com.spotify.scio.bigquery.DateTime.parse($s)" - + q"""$tree match { + case ldt: ${typeOf[LocalDateTime]} => ldt + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.DateTime.parse(s) + case _ => $fail + }""" // different than nested record match below, even though those are case classes case t if t =:= typeOf[Geography] => - q"_root_.com.spotify.scio.bigquery.types.Geography($s)" + q"""$tree match { + case g: ${typeOf[Geography]} => g + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.types.Geography(s) + case _ => $fail + }""" case t if t =:= typeOf[Json] => - q"_root_.com.spotify.scio.bigquery.types.Json($s)" + q"""$tree match { + case j: ${typeOf[Json]} => j + case tr: ${typeOf[TableRow]} => + _root_.com.spotify.scio.bigquery.types.Json(tr) + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.types.Json(s) + case _ => $fail + }""" case t if t =:= typeOf[BigNumeric] => - q"_root_.com.spotify.scio.bigquery.types.BigNumeric($s)" - + q"""$tree match { + case bn: ${typeOf[BigNumeric]} => bn + case bd: ${typeOf[BigDecimal]} => + _root_.com.spotify.scio.bigquery.types.BigNumeric(bd) + case s: ${typeOf[String]} => + _root_.com.spotify.scio.bigquery.types.BigNumeric(s) + case _ => $fail + }""" case t if isCaseClass(c)(t) => // nested records - val fn = TermName("r" + t.typeSymbol.name) - q"""{ - val $fn = $tree.asInstanceOf[_root_.java.util.Map[String, AnyRef]] - ${constructor(t, fn)} + val r = TermName("record") + q"""$tree match { + case $r: ${typeOf[java.util.Map[String, AnyRef]]} => ${constructor(t, r)} + case _ => $fail } """ case _ => c.abort(c.enclosingPosition, s"Unsupported type: $tpe") @@ -321,7 +407,7 @@ private[types] object ConverterProvider { } def option(tree: Tree, tpe: Type): Tree = - q"if ($tree == null) None else Some(${cast(tree, tpe)})" + q"_root_.scala.Option(${cast(tree, tpe)})" def list(tree: Tree, tpe: Type): Tree = { val jl = tq"_root_.java.util.List[AnyRef]" @@ -333,14 +419,17 @@ private[types] object ConverterProvider { val tpe = symbol.asMethod.returnType val tree = q"$fn.get($name)" - def nonNullTree(fType: String) = + def nonNullTree(fType: String) = { + val msg = Constant(s"$fType field $name is null") q"""{ val v = $fn.get($name) if (v == null) { - throw new NullPointerException($fType + " field \"" + $name + "\" is null") + throw new NullPointerException($msg) } v }""" + } + if (tpe.erasure =:= typeOf[Option[_]].erasure) { option(tree, tpe.typeArgs.head) } else if (tpe.erasure =:= typeOf[List[_]].erasure) { diff --git a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/package.scala b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/package.scala index d7265d5459..687c2f17d5 100644 --- a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/package.scala +++ b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/bigquery/types/package.scala @@ -17,10 +17,12 @@ package com.spotify.scio.bigquery +import com.fasterxml.jackson.databind.{ObjectMapper, SerializationFeature} +import com.fasterxml.jackson.datatype.joda.JodaModule +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule import com.spotify.scio.coders.Coder import org.apache.avro.Conversions.DecimalConversion import org.apache.avro.LogicalTypes -import org.apache.beam.sdk.extensions.gcp.util.Transport import org.typelevel.scalaccompat.annotation.nowarn import java.math.MathContext @@ -63,11 +65,15 @@ package object types { */ case class Json(wkt: String) object Json { - @transient - private lazy val jsonFactory = Transport.getJsonFactory + // Use same mapper as the TableRowJsonCoder + private lazy val mapper = new ObjectMapper() + .registerModule(new JavaTimeModule()) + .registerModule(new JodaModule()) + .disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS) + .disable(SerializationFeature.FAIL_ON_EMPTY_BEANS); - def apply(row: TableRow): Json = Json(jsonFactory.toString(row)) - def parse(json: Json): TableRow = jsonFactory.fromString(json.wkt, classOf[TableRow]) + def apply(row: TableRow): Json = Json(mapper.writeValueAsString(row)) + def parse(json: Json): TableRow = mapper.readValue(json.wkt, classOf[TableRow]) } /** diff --git a/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigquery/types/ConverterProviderTest.scala b/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigquery/types/ConverterProviderTest.scala index 72bbc289d9..e5873d8dfc 100644 --- a/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigquery/types/ConverterProviderTest.scala +++ b/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigquery/types/ConverterProviderTest.scala @@ -17,7 +17,6 @@ package com.spotify.scio.bigquery.types -import com.fasterxml.jackson.databind.node.{JsonNodeFactory, ObjectNode} import com.google.protobuf.ByteString import com.spotify.scio.bigquery._ import org.joda.time.{Instant, LocalDate, LocalDateTime, LocalTime} @@ -51,14 +50,12 @@ class ConverterProviderTest extends AnyFlatSpec with Matchers { it should "handle required json type" in { val wkt = """{"name":"Alice","age":30}""" - val jsNodeFactory = new JsonNodeFactory(false) - val jackson = jsNodeFactory - .objectNode() - .set[ObjectNode]("name", jsNodeFactory.textNode("Alice")) - .set[ObjectNode]("age", jsNodeFactory.numberNode(30)) - - RequiredJson.fromTableRow(TableRow("a" -> jackson)) shouldBe RequiredJson(Json(wkt)) - BigQueryType.toTableRow[RequiredJson](RequiredJson(Json(wkt))) shouldBe TableRow("a" -> jackson) + val parsed = new TableRow() + .set("name", "Alice") + .set("age", 30) + + RequiredJson.fromTableRow(TableRow("a" -> parsed)) shouldBe RequiredJson(Json(wkt)) + BigQueryType.toTableRow[RequiredJson](RequiredJson(Json(wkt))) shouldBe TableRow("a" -> parsed) } it should "handle required big numeric type" in {