Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile each schema only once #198

Merged
merged 10 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions avrohugger-core/src/main/scala/Generator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@ package avrohugger

import avrohugger.format.abstractions.SourceFormat
import avrohugger.format._
import avrohugger.generators.{FileGenerator, StringGenerator}
import avrohugger.input.parsers.{FileInputParser, StringInputParser}
import avrohugger.generators.{ FileGenerator, StringGenerator }
import avrohugger.input.parsers.{ FileInputParser, StringInputParser }
import avrohugger.matchers.TypeMatcher
import avrohugger.types.AvroScalaTypes
import avrohugger.stores.{ClassStore, SchemaStore}
import org.apache.avro.{Protocol, Schema}
import avrohugger.stores.{ ClassStore, SchemaStore }
import org.apache.avro.{ Protocol, Schema }
import java.io.File

// Unable to overload this class' methods because outDir uses a default value
case class Generator(format: SourceFormat,
avroScalaCustomTypes: Option[AvroScalaTypes] = None,
avroScalaCustomNamespace: Map[String, String] = Map.empty,
restrictedFieldNumber: Boolean = false,
classLoader: ClassLoader = Thread.currentThread.getContextClassLoader,
targetScalaPartialVersion: String = avrohugger.internal.ScalaVersion.version) {
avroScalaCustomTypes: Option[AvroScalaTypes] = None,
avroScalaCustomNamespace: Map[String, String] = Map.empty,
restrictedFieldNumber: Boolean = false,
classLoader: ClassLoader = Thread.currentThread.getContextClassLoader,
targetScalaPartialVersion: String = avrohugger.internal.ScalaVersion.version) {

val avroScalaTypes = avroScalaCustomTypes.getOrElse(format.defaultTypes)
val defaultOutputDir = "target/generated-sources"
Expand All @@ -25,20 +25,22 @@ case class Generator(format: SourceFormat,
lazy val schemaParser = new Schema.Parser
val classStore = new ClassStore
val schemaStore = new SchemaStore
val fileGenerator = new FileGenerator
val stringGenerator = new StringGenerator
val typeMatcher = new TypeMatcher(avroScalaTypes, avroScalaCustomNamespace)

//////////////// methods for writing definitions out to file /////////////////
def schemaToFile(
schema: Schema,
outDir: String = defaultOutputDir): Unit = {
FileGenerator.schemaToFile(
fileGenerator.schemaToFile(
schema, outDir, format, classStore, schemaStore, typeMatcher, restrictedFieldNumber, targetScalaPartialVersion)
}

def protocolToFile(
protocol: Protocol,
outDir: String = defaultOutputDir): Unit = {
FileGenerator.protocolToFile(
fileGenerator.protocolToFile(
protocol,
outDir,
format,
Expand All @@ -52,7 +54,7 @@ case class Generator(format: SourceFormat,
def stringToFile(
schemaStr: String,
outDir: String = defaultOutputDir): Unit = {
FileGenerator.stringToFile(
fileGenerator.stringToFile(
schemaStr,
outDir,
format,
Expand All @@ -67,7 +69,7 @@ case class Generator(format: SourceFormat,
def fileToFile(
inFile: File,
outDir: String = defaultOutputDir): Unit = {
FileGenerator.fileToFile(
fileGenerator.fileToFile(
inFile,
outDir,
format,
Expand All @@ -80,19 +82,35 @@ case class Generator(format: SourceFormat,
targetScalaPartialVersion)
}

def filesToFile(
inFiles: List[File],
outDir: String = defaultOutputDir): Unit = {
fileGenerator.filesToFile(
inFiles,
outDir,
format,
classStore,
schemaStore,
fileParser,
typeMatcher,
classLoader,
restrictedFieldNumber,
targetScalaPartialVersion)
}

//////// methods for writing to a list of definitions in String format ///////
def schemaToStrings(schema: Schema): List[String] = {
StringGenerator.schemaToStrings(
stringGenerator.schemaToStrings(
schema, format, classStore, schemaStore, typeMatcher, restrictedFieldNumber, targetScalaPartialVersion)
}

def protocolToStrings(protocol: Protocol): List[String] = {
StringGenerator.protocolToStrings(
stringGenerator.protocolToStrings(
protocol, format, classStore, schemaStore, typeMatcher, restrictedFieldNumber, targetScalaPartialVersion)
}

def stringToStrings(schemaStr: String): List[String] = {
StringGenerator.stringToStrings(
stringGenerator.stringToStrings(
schemaStr,
format,
classStore,
Expand All @@ -104,7 +122,7 @@ case class Generator(format: SourceFormat,
}

def fileToStrings(inFile: File): List[String] = {
StringGenerator.fileToStrings(
stringGenerator.fileToStrings(
inFile,
format,
classStore,
Expand Down
85 changes: 42 additions & 43 deletions avrohugger-core/src/main/scala/format/abstractions/Importer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,38 +44,38 @@ trait Importer {
// gets enum schemas which may be dependencies
def getEnumSchemas(
topLevelSchemas: List[Schema],
alreadyImported: List[Schema] = List.empty[Schema]): List[Schema] = {
def nextSchemas(s: Schema, us: List[Schema]) = getRecordSchemas(List(s), us)
alreadyImported: Set[Schema] = Set.empty[Schema]): List[Schema] = {
def nextSchemas(s: Schema, us: Set[Schema]) = getRecordSchemas(List(s), us)

topLevelSchemas
.flatMap(schema => {
schema.getType match {
case RECORD =>
val fieldSchemasWithChildSchemas = getFieldSchemas(schema).toSeq
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(s, alreadyImported :+ s))
val fieldSchemasWithChildSchemas = getFieldSchemas(schema).toSet
.intersect(alreadyImported)
.flatMap(s => nextSchemas(s, alreadyImported + s))
Seq(schema) ++ fieldSchemasWithChildSchemas
case ENUM =>
Seq(schema)
case UNION =>
schema.getTypes().asScala
.find(s => s.getType != NULL).toSeq
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(schema, alreadyImported :+ s))
.find(s => s.getType != NULL).toSet
.intersect(alreadyImported)
.flatMap(s => nextSchemas(schema, alreadyImported + s))
case MAP =>
Seq(schema.getValueType)
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(schema, alreadyImported :+ s))
Set(schema.getValueType)
.intersect(alreadyImported)
.flatMap(s => nextSchemas(schema, alreadyImported + s))
case ARRAY =>
Seq(schema.getElementType)
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(schema, alreadyImported :+ s))
Set(schema.getElementType)
.intersect(alreadyImported)
.flatMap(s => nextSchemas(schema, alreadyImported + s))
case _ =>
Seq.empty[Schema]
}
})
.filter(schema => schema.getType == ENUM)
.distinct
.toList
}

def getFixedSchemas(topLevelSchemas: List[Schema]): List[Schema] =
Expand All @@ -88,8 +88,7 @@ trait Importer {
})
.filter(_.getType == FIXED)
.distinct
.toList


def getFieldSchemas(schema: Schema): List[Schema] = {
schema.getFields().asScala.toList.map(field => field.schema)
}
Expand Down Expand Up @@ -126,74 +125,74 @@ trait Importer {

def requiresImportDef(schema: Schema): Boolean = {
(isRecord(schema) || isEnum(schema) || isFixed(schema)) &&
checkNamespace(schema).isDefined &&
checkNamespace(schema) != namespace
checkNamespace(schema).isDefined &&
checkNamespace(schema) != namespace
}

recordSchemas
.filter(schema => requiresImportDef(schema))
.groupBy(schema => checkNamespace(schema).getOrElse(schema.getNamespace))
.toList
.map(group => group match {
case(packageName, fields) => asImportDef(packageName, fields)
})
.map {
case (packageName, fields) => asImportDef(packageName, fields)
}
}

// gets record schemas which may be dependencies
def getRecordSchemas(
topLevelSchemas: List[Schema],
alreadyImported: List[Schema] = List.empty[Schema]): List[Schema] = {
def nextSchemas(s: Schema, us: List[Schema]) = getRecordSchemas(List(s), us)
alreadyImported: Set[Schema] = Set.empty[Schema]): List[Schema] = {
def nextSchemas(s: Schema, us: Set[Schema]) = getRecordSchemas(List(s), us)

topLevelSchemas
.flatMap(schema => {
schema.getType match {
case RECORD =>
val fieldSchemasWithChildSchemas = getFieldSchemas(schema).toSeq
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(s, alreadyImported :+ s))
val fieldSchemasWithChildSchemas = getFieldSchemas(schema).toSet
.intersect(alreadyImported)
.flatMap(s => nextSchemas(s, alreadyImported + s))
Seq(schema) ++ fieldSchemasWithChildSchemas
case ENUM =>
Seq(schema)
case UNION =>
schema.getTypes().asScala
.find(s => s.getType != NULL).toSeq
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(schema, alreadyImported :+ s))
.find(s => s.getType != NULL).toSet
.intersect(alreadyImported)
.flatMap(s => nextSchemas(schema, alreadyImported + s))
case MAP =>
Seq(schema.getValueType)
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(schema, alreadyImported :+ s))
Set(schema.getValueType)
.intersect(alreadyImported)
.flatMap(s => nextSchemas(schema, alreadyImported + s))
case ARRAY =>
Seq(schema.getElementType)
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(schema, alreadyImported :+ s))
Set(schema.getElementType)
.intersect(alreadyImported)
.flatMap(s => nextSchemas(schema, alreadyImported + s))
case _ =>
Seq.empty[Schema]
}
})
.filter(schema => isRecord(schema))
.distinct
.toList
}

def getTopLevelSchemas(
schemaOrProtocol: Either[Schema, Protocol],
schemaOrProtocol: Either[Schema, Protocol],
schemaStore: SchemaStore,
typeMatcher: TypeMatcher): List[Schema] = {
schemaOrProtocol match {
case Left(schema) =>
schema::(NestedSchemaExtractor.getNestedSchemas(schema, schemaStore, typeMatcher))
schema :: (NestedSchemaExtractor.getNestedSchemas(schema, schemaStore, typeMatcher))
case Right(protocol) => protocol.getTypes().asScala.toList.flatMap(schema => {
schema::(NestedSchemaExtractor.getNestedSchemas(schema, schemaStore, typeMatcher))
schema :: (NestedSchemaExtractor.getNestedSchemas(schema, schemaStore, typeMatcher))
})
}

}

def isFixed(schema: Schema): Boolean = ( schema.getType == FIXED )
def isFixed(schema: Schema): Boolean = (schema.getType == FIXED)

def isEnum(schema: Schema): Boolean = ( schema.getType == ENUM )
def isEnum(schema: Schema): Boolean = (schema.getType == ENUM)

def isRecord(schema: Schema): Boolean = ( schema.getType == RECORD )
def isRecord(schema: Schema): Boolean = (schema.getType == RECORD)

}
Loading