Skip to content

Commit

Permalink
[NU-1763] Add Collect aggregation (#6939)
Browse files Browse the repository at this point in the history
add collect aggregation
  • Loading branch information
mslabek authored Sep 27, 2024
1 parent f06ea95 commit a254240
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,20 @@ class TableAggregationTest extends AnyFunSuite with TableDrivenPropertyChecks wi
}

test("aggregations should aggregate by integers") {
val input = List(1, 2)
val input = List(1, 1, 2)
val aggregatorWithExpectedResult: List[AggregateByInputTestData] = List(
"Average" -> 1,
"Count" -> 2,
"Count" -> 3,
"Min" -> 1,
"Max" -> 2,
"First" -> 1,
"Last" -> 2,
"Sum" -> 3,
"Sum" -> 4,
"Population standard deviation" -> 0,
"Sample standard deviation" -> 1,
"Sample standard deviation" -> 0,
"Population variance" -> 0,
"Sample variance" -> 1,
"Sample variance" -> 0,
"Collect" -> Map(1 -> 2, 2 -> 1).asJava
).map(a => AggregateByInputTestData(a._1, a._2))
runMultipleAggregationTest(input, aggregatorWithExpectedResult)
}
Expand All @@ -120,19 +121,21 @@ class TableAggregationTest extends AnyFunSuite with TableDrivenPropertyChecks wi
"Population standard deviation" -> 0.5,
"Sample standard deviation" -> 0.7071067811865476,
"Population variance" -> 0.25,
"Sample variance" -> 0.5
"Sample variance" -> 0.5,
"Collect" -> Map(2.0 -> 1, 1.0 -> 1).asJava
).map(a => AggregateByInputTestData(a._1, a._2))
runMultipleAggregationTest(input, aggregatorWithExpectedResult)
}

test("aggregations should aggregate by strings") {
val input = List("def", "abc")
val aggregatorWithExpectedResult: List[AggregateByInputTestData] = List(
"Count" -> 2,
"Min" -> "abc",
"Max" -> "def",
"First" -> "def",
"Last" -> "abc",
"Count" -> 2,
"Min" -> "abc",
"Max" -> "def",
"First" -> "def",
"Last" -> "abc",
"Collect" -> Map("def" -> 1, "abc" -> 1).asJava
).map(a => AggregateByInputTestData(a._1, a._2))
runMultipleAggregationTest(input, aggregatorWithExpectedResult)
}
Expand All @@ -150,17 +153,22 @@ class TableAggregationTest extends AnyFunSuite with TableDrivenPropertyChecks wi
"Population standard deviation" -> java.math.BigDecimal.valueOf(0.5).setScale(18),
"Sample standard deviation" -> java.math.BigDecimal.valueOf(0.7071067811865476).setScale(18),
"Population variance" -> java.math.BigDecimal.valueOf(0.25).setScale(18),
"Sample variance" -> java.math.BigDecimal.valueOf(0.5).setScale(18)
"Sample variance" -> java.math.BigDecimal.valueOf(0.5).setScale(18),
"Collect" -> Map(
java.math.BigDecimal.valueOf(1).setScale(18) -> 1,
java.math.BigDecimal.valueOf(2).setScale(18) -> 1
).asJava
).map(a => AggregateByInputTestData(a._1, a._2))
runMultipleAggregationTest(input, aggregatorWithExpectedResult)
}

test("max, min and count aggregations should aggregate by date types") {
val input = List(LocalDate.parse("2000-01-01"), LocalDate.parse("2000-01-02"))
val aggregatorWithExpectedResult = List(
"Count" -> 2,
"Min" -> LocalDate.parse("2000-01-01"),
"Max" -> LocalDate.parse("2000-01-02"),
"Count" -> 2,
"Min" -> LocalDate.parse("2000-01-01"),
"Max" -> LocalDate.parse("2000-01-02"),
"Collect" -> Map(LocalDate.parse("2000-01-01") -> 1, LocalDate.parse("2000-01-02") -> 1).asJava
).map(a => AggregateByInputTestData(a._1, a._2))
runMultipleAggregationTest(input, aggregatorWithExpectedResult)
}
Expand Down Expand Up @@ -210,18 +218,17 @@ class TableAggregationTest extends AnyFunSuite with TableDrivenPropertyChecks wi
}
}

test("count aggregation works when aggregating by type aligned to RAW") {
val scenario = buildMultipleAggregationsScenario(
List(
AggregationParameters(aggregator = "'Count'".spel, aggregateBy = "#input".spel, groupBy = "''".spel)
)
)
val result = runner.runWithData(
scenario,
List(OffsetDateTime.now()),
Boundedness.BOUNDED
)
result shouldBe Symbol("valid")
test("count and collect aggregation works when aggregating by type aligned to RAW") {
val input =
List(OffsetDateTime.parse("2024-01-01T23:59:30+03:00"), OffsetDateTime.parse("2024-01-02T23:59:30+04:00"))
val aggregatorWithExpectedResult: List[AggregateByInputTestData] = List(
"Count" -> 2,
"Collect" -> Map(
OffsetDateTime.parse("2024-01-01T23:59:30+03:00") -> 1,
OffsetDateTime.parse("2024-01-02T23:59:30+04:00") -> 1
).asJava
).map(a => AggregateByInputTestData(a._1, a._2))
runMultipleAggregationTest(input, aggregatorWithExpectedResult)
}

test("table aggregation should emit groupBy key and aggregated values as separate variables") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import pl.touk.nussknacker.engine.flink.table.utils.simulateddatatype.{
TODO: add remaining aggregations functions:
- LISTAGG
- ARRAY_AGG
- COLLECT
TODO: unify aggregator function definitions with unbounded-streaming ones. Current duplication may lead to
inconsistency in naming and may be confusing for users
Expand Down Expand Up @@ -94,6 +93,12 @@ object TableAggregator extends Enum[TableAggregator] {
override def inputAllowedTypesConstraint: Option[List[LogicalTypeRoot]] = Some(numericAggregationsAllowedTypes)
}

case object Collect extends TableAggregator {
override val displayName: String = "Collect"
override def flinkFunctionDefinition: BuiltInFunctionDefinition = BuiltInFunctionDefinitions.COLLECT
override def inputAllowedTypesConstraint: Option[List[LogicalTypeRoot]] = None
}

private val minMaxAllowedTypes = List(
TINYINT,
SMALLINT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ object FlinkTestConfiguration {
// which holds all needed jars/classes in case of running from Scala plugin in IDE.
// but in case of running from sbt it contains only sbt-launcher.jar
config.set(PipelineOptions.CLASSPATHS, List("http://dummy-classpath.invalid").asJava)

// This is to prevent memory problem in tests with mutliple Table API based aggregations. An IllegalArgExceptionon
// is thrown with message "The minBucketMemorySize is not valid!" in
// org.apache.flink.table.runtime.util.collections.binary.AbstractBytesHashMap.java:121 where memorySize is set
// inside code-generated operator (like LocalHashAggregateWithKeys).
config.set(TaskManagerOptions.MANAGED_MEMORY_SIZE, MemorySize.parse("100m"))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ case class ToJsonEncoder(
case a: OffsetDateTime => Encoder[OffsetDateTime].apply(a)
case a: UUID => safeString(a.toString)
case a: DisplayJson => a.asJson
case a: scala.collection.Map[String @unchecked, _] => encodeMap(a.toMap)
case a: java.util.Map[String @unchecked, _] => encodeMap(a.asScala.toMap)
case a: Iterable[_] => fromValues(a.map(encode))
case a: Enum[_] => safeString(a.toString)
case a: java.util.Collection[_] => fromValues(a.asScala.map(encode))
case a: Array[_] => fromValues(a.map(encode))
case _ if !failOnUnknown => safeString(any.toString)
case a => throw new IllegalArgumentException(s"Invalid type: ${a.getClass}")
case a: scala.collection.Map[_, _] => encodeMap(a.toMap)
case a: java.util.Map[_, _] => encodeMap(a.asScala.toMap)
case a: Iterable[_] => fromValues(a.map(encode))
case a: Enum[_] => safeString(a.toString)
case a: java.util.Collection[_] => fromValues(a.asScala.map(encode))
case a: Array[_] => fromValues(a.map(encode))
case _ if !failOnUnknown => safeString(any.toString)
case a => throw new IllegalArgumentException(s"Invalid type: ${a.getClass}")
}
)

Expand All @@ -85,8 +85,13 @@ case class ToJsonEncoder(
case None => Null
}

private def encodeMap(map: Map[String, _]) = {
fromFields(map.mapValuesNow(encode))
// TODO: make encoder aware of NU Types to encode things like multiset differently. Right now its handled by calling
// toString on keys.
private def encodeMap(map: Map[_, _]) = {
val mapWithStringKeys = map.view.map { case (k, v) =>
k.toString -> v
}.toMap
fromFields(mapWithStringKeys.mapValuesNow(encode))
}

}

0 comments on commit a254240

Please sign in to comment.