Skip to content

Commit

Permalink
[scio-core](feature) Add readFiles and readFilesWithPath apis (#5350)
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones authored May 17, 2024
1 parent e9a564a commit db08aac
Show file tree
Hide file tree
Showing 9 changed files with 651 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ import com.spotify.scio.avro._
import com.spotify.scio.avro.types.AvroType.HasAvroAnnotation
import com.spotify.scio.coders.Coder
import com.spotify.scio.io.ClosedTap
import com.spotify.scio.util.FilenamePolicySupplier
import com.spotify.scio.util.{FilenamePolicySupplier, ScioUtil}
import com.spotify.scio.values._
import org.apache.avro.Schema
import org.apache.avro.file.CodecFactory
import org.apache.avro.specific.SpecificRecord
import org.apache.avro.specific.{SpecificData, SpecificRecord}
import org.apache.avro.generic.GenericRecord
import org.apache.beam.sdk.extensions.avro.io.AvroDatumFactory
import org.apache.beam.sdk.extensions.avro.io.{AvroDatumFactory, AvroIO => BAvroIO, AvroSource}

import scala.reflect.ClassTag
import scala.reflect.runtime.universe._
Expand Down Expand Up @@ -207,6 +207,63 @@ final class ProtobufSCollectionOps[T <: Message](private val self: SCollection[T
}
}

final class FilesSCollectionOps(private val self: SCollection[String]) extends AnyVal {

def readAvroGenericFiles(
schema: Schema,
datumFactory: AvroDatumFactory[GenericRecord] = GenericRecordIO.ReadParam.DefaultDatumFactory
): SCollection[GenericRecord] = {
val df = Option(datumFactory).getOrElse(GenericRecordDatumFactory)
implicit val coder: Coder[GenericRecord] = avroCoder(df, schema)
val transform = BAvroIO
.readFilesGenericRecords(schema)
.withDatumReaderFactory(df)
self.readFiles(filesTransform = transform)
}

def readAvroSpecificFiles[T <: SpecificRecord: ClassTag](
datumFactory: AvroDatumFactory[T] = SpecificRecordIO.ReadParam.DefaultDatumFactory
): SCollection[T] = {
val recordClass = ScioUtil.classOf[T]
val schema = SpecificData.get().getSchema(recordClass)
val df = Option(datumFactory).getOrElse(new SpecificRecordDatumFactory(recordClass))
implicit val coder: Coder[T] = avroCoder(df, schema)
val transform = BAvroIO
.readFiles(recordClass)
.withDatumReaderFactory(df)
self.readFiles(filesTransform = transform)
}

def readAvroGenericFilesWithPath(
schema: Schema,
datumFactory: AvroDatumFactory[GenericRecord] = GenericRecordIO.ReadParam.DefaultDatumFactory
): SCollection[(String, GenericRecord)] = {
val df = Option(datumFactory).getOrElse(GenericRecordDatumFactory)
implicit val coder: Coder[GenericRecord] = avroCoder(df, schema)
self.readFilesWithPath() { f =>
AvroSource
.from(f)
.withSchema(schema)
.withDatumReaderFactory(df)
}
}

def readAvroSpecificFilesWithPath[T <: SpecificRecord: ClassTag](
datumFactory: AvroDatumFactory[T] = SpecificRecordIO.ReadParam.DefaultDatumFactory
): SCollection[(String, T)] = {
val recordClass = ScioUtil.classOf[T]
val schema = SpecificData.get().getSchema(recordClass)
val df = Option(datumFactory).getOrElse(new SpecificRecordDatumFactory(recordClass))
implicit val coder: Coder[T] = avroCoder(df, schema)
self.readFilesWithPath() { f =>
AvroSource
.from(f)
.withSchema(recordClass)
.withDatumReaderFactory(df)
}
}
}

/** Enhanced with Avro methods. */
trait SCollectionSyntax {
implicit def avroGenericRecordSCollectionOps(
Expand All @@ -228,4 +285,9 @@ trait SCollectionSyntax {
implicit def avroProtobufSCollectionOps[T <: Message](
c: SCollection[T]
): ProtobufSCollectionOps[T] = new ProtobufSCollectionOps[T](c)

implicit def avroFilesSCollectionOps[T](
c: SCollection[T]
)(implicit ev: T <:< String): FilesSCollectionOps =
new FilesSCollectionOps(c.covary_)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package com.spotify.scio.avro
import com.spotify.scio._
import com.spotify.scio.avro.AvroUtils._
import com.spotify.scio.coders.Coder
import com.spotify.scio.io.ReadIO
import com.spotify.scio.testing.PipelineSpec
import org.apache.avro.generic.GenericRecord

Expand Down Expand Up @@ -54,6 +55,29 @@ object GenericAvroFileJob {
}
}

object ReadGenericAvroFilesJob {
def main(cmdlineArgs: Array[String]): Unit = {
val (sc, args) = ContextAndArgs(cmdlineArgs)
sc.parallelize(args.list("input"))
.readAvroGenericFiles(AvroUtils.schema)
.saveAsAvroFile(args("output"), schema = AvroUtils.schema)
sc.run()
()
}
}

object ReadSpecificAvroFilesWithPathJob {
def main(cmdlineArgs: Array[String]): Unit = {
val (sc, args) = ContextAndArgs(cmdlineArgs)
sc.parallelize(args.list("input"))
.readAvroSpecificFilesWithPath[TestRecord]()
.map { case (f, r) => TestRecord.newBuilder(r).setStringField(f).build() }
.saveAsAvroFile(args("output"))
sc.run()
()
}
}

object GenericParseFnAvroFileJob {

implicit val coder: Coder[GenericRecord] = avroGenericRecordCoder(AvroUtils.schema)
Expand Down Expand Up @@ -114,7 +138,8 @@ class AvroJobTestTest extends PipelineSpec {
}

def testGenericAvroFileJob(xs: Seq[GenericRecord]): Unit = {
implicit val coder = avroGenericRecordCoder
implicit val coder: Coder[GenericRecord] =
avroGenericRecordCoder(AvroUtils.schema)
JobTest[GenericAvroFileJob.type]
.args("--input=in.avro", "--output=out.avro")
.input(AvroIO[GenericRecord]("in.avro"), (1 to 3).map(newGenericRecord))
Expand All @@ -137,7 +162,8 @@ class AvroJobTestTest extends PipelineSpec {

def testGenericParseAvroFileJob(xs: Seq[GenericRecord]): Unit = {
import GenericParseFnAvroFileJob.PartialFieldsAvro
implicit val coder: Coder[GenericRecord] = avroGenericRecordCoder
implicit val coder: Coder[GenericRecord] =
avroGenericRecordCoder(AvroUtils.schema)
JobTest[GenericParseFnAvroFileJob.type]
.args("--input=in.avro", "--output=out.avro")
.input(AvroIO[PartialFieldsAvro]("in.avro"), (1 to 3).map(PartialFieldsAvro))
Expand All @@ -148,7 +174,7 @@ class AvroJobTestTest extends PipelineSpec {
.run()
}

it should "pass when correct generic parsed records" in {
it should "pass when correct parsed generic records" in {
testGenericParseAvroFileJob((1 to 3).map(newGenericRecord))
}

Expand All @@ -160,4 +186,34 @@ class AvroJobTestTest extends PipelineSpec {
testGenericParseAvroFileJob((1 to 4).map(newGenericRecord))
}
}

"Read avro files" should "pass when correct specific records" in {
implicit val coder: Coder[GenericRecord] =
avroGenericRecordCoder(AvroUtils.schema)
val expected = (1 to 6).map(newGenericRecord)
val (part1, part2) = expected.splitAt(3)
JobTest[ReadGenericAvroFilesJob.type]
.args("--input=in1.avro", "--input=in2.avro", "--output=out.avro")
.input(ReadIO[GenericRecord]("in1.avro"), part1)
.input(ReadIO[GenericRecord]("in2.avro"), part2)
.output(AvroIO[GenericRecord]("out.avro"))(coll => coll should containInAnyOrder(expected))
.run()
}

"Read avro files with path" should "pass when correct specific records" in {
val input = (1 to 6).map(newSpecificRecord)
val (part1, part2) = input.splitAt(3)
val expected = (1 to 6).map { i =>
val r = newSpecificRecord(i)
r.setStringField(if (i <= 3) "in1.avro" else "in2.avro")
r
}

JobTest[ReadSpecificAvroFilesWithPathJob.type]
.args("--input=in1.avro", "--input=in2.avro", "--output=out.avro")
.input(ReadIO[TestRecord]("in1.avro"), part1)
.input(ReadIO[TestRecord]("in2.avro"), part2)
.output(AvroIO[TestRecord]("out.avro"))(coll => coll should containInAnyOrder(expected))
.run()
}
}
Loading

0 comments on commit db08aac

Please sign in to comment.