diff --git a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/aggregate/TableAggregationTest.scala b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/aggregate/TableAggregationTest.scala index 151fa18fe30..2c93a025552 100644 --- a/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/aggregate/TableAggregationTest.scala +++ b/engine/flink/components/base-tests/src/test/scala/pl/touk/nussknacker/engine/flink/table/aggregate/TableAggregationTest.scala @@ -90,19 +90,20 @@ class TableAggregationTest extends AnyFunSuite with TableDrivenPropertyChecks wi } test("aggregations should aggregate by integers") { - val input = List(1, 2) + val input = List(1, 1, 2) val aggregatorWithExpectedResult: List[AggregateByInputTestData] = List( "Average" -> 1, - "Count" -> 2, + "Count" -> 3, "Min" -> 1, "Max" -> 2, "First" -> 1, "Last" -> 2, - "Sum" -> 3, + "Sum" -> 4, "Population standard deviation" -> 0, - "Sample standard deviation" -> 1, + "Sample standard deviation" -> 0, "Population variance" -> 0, - "Sample variance" -> 1, + "Sample variance" -> 0, + "Collect" -> Map(1 -> 2, 2 -> 1).asJava ).map(a => AggregateByInputTestData(a._1, a._2)) runMultipleAggregationTest(input, aggregatorWithExpectedResult) } @@ -120,7 +121,8 @@ class TableAggregationTest extends AnyFunSuite with TableDrivenPropertyChecks wi "Population standard deviation" -> 0.5, "Sample standard deviation" -> 0.7071067811865476, "Population variance" -> 0.25, - "Sample variance" -> 0.5 + "Sample variance" -> 0.5, + "Collect" -> Map(2.0 -> 1, 1.0 -> 1).asJava ).map(a => AggregateByInputTestData(a._1, a._2)) runMultipleAggregationTest(input, aggregatorWithExpectedResult) } @@ -128,11 +130,12 @@ class TableAggregationTest extends AnyFunSuite with TableDrivenPropertyChecks wi test("aggregations should aggregate by strings") { val input = List("def", "abc") val aggregatorWithExpectedResult: List[AggregateByInputTestData] = List( - "Count" -> 2, - "Min" -> "abc", - "Max" -> "def", - "First" -> "def", - "Last" -> "abc", + "Count" -> 2, + "Min" -> "abc", + "Max" -> "def", + "First" -> "def", + "Last" -> "abc", + "Collect" -> Map("def" -> 1, "abc" -> 1).asJava ).map(a => AggregateByInputTestData(a._1, a._2)) runMultipleAggregationTest(input, aggregatorWithExpectedResult) } @@ -150,7 +153,11 @@ class TableAggregationTest extends AnyFunSuite with TableDrivenPropertyChecks wi "Population standard deviation" -> java.math.BigDecimal.valueOf(0.5).setScale(18), "Sample standard deviation" -> java.math.BigDecimal.valueOf(0.7071067811865476).setScale(18), "Population variance" -> java.math.BigDecimal.valueOf(0.25).setScale(18), - "Sample variance" -> java.math.BigDecimal.valueOf(0.5).setScale(18) + "Sample variance" -> java.math.BigDecimal.valueOf(0.5).setScale(18), + "Collect" -> Map( + java.math.BigDecimal.valueOf(1).setScale(18) -> 1, + java.math.BigDecimal.valueOf(2).setScale(18) -> 1 + ).asJava ).map(a => AggregateByInputTestData(a._1, a._2)) runMultipleAggregationTest(input, aggregatorWithExpectedResult) } @@ -158,9 +165,10 @@ class TableAggregationTest extends AnyFunSuite with TableDrivenPropertyChecks wi test("max, min and count aggregations should aggregate by date types") { val input = List(LocalDate.parse("2000-01-01"), LocalDate.parse("2000-01-02")) val aggregatorWithExpectedResult = List( - "Count" -> 2, - "Min" -> LocalDate.parse("2000-01-01"), - "Max" -> LocalDate.parse("2000-01-02"), + "Count" -> 2, + "Min" -> LocalDate.parse("2000-01-01"), + "Max" -> LocalDate.parse("2000-01-02"), + "Collect" -> Map(LocalDate.parse("2000-01-01") -> 1, LocalDate.parse("2000-01-02") -> 1).asJava ).map(a => AggregateByInputTestData(a._1, a._2)) runMultipleAggregationTest(input, aggregatorWithExpectedResult) } @@ -210,18 +218,17 @@ class TableAggregationTest extends AnyFunSuite with TableDrivenPropertyChecks wi } } - test("count aggregation works when aggregating by type aligned to RAW") { - val scenario = buildMultipleAggregationsScenario( - List( - AggregationParameters(aggregator = "'Count'".spel, aggregateBy = "#input".spel, groupBy = "''".spel) - ) - ) - val result = runner.runWithData( - scenario, - List(OffsetDateTime.now()), - Boundedness.BOUNDED - ) - result shouldBe Symbol("valid") + test("count and collect aggregation works when aggregating by type aligned to RAW") { + val input = + List(OffsetDateTime.parse("2024-01-01T23:59:30+03:00"), OffsetDateTime.parse("2024-01-02T23:59:30+04:00")) + val aggregatorWithExpectedResult: List[AggregateByInputTestData] = List( + "Count" -> 2, + "Collect" -> Map( + OffsetDateTime.parse("2024-01-01T23:59:30+03:00") -> 1, + OffsetDateTime.parse("2024-01-02T23:59:30+04:00") -> 1 + ).asJava + ).map(a => AggregateByInputTestData(a._1, a._2)) + runMultipleAggregationTest(input, aggregatorWithExpectedResult) } test("table aggregation should emit groupBy key and aggregated values as separate variables") { diff --git a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/aggregate/TableAggregator.scala b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/aggregate/TableAggregator.scala index bad2de71ca2..953859e87d6 100644 --- a/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/aggregate/TableAggregator.scala +++ b/engine/flink/components/table/src/main/scala/pl/touk/nussknacker/engine/flink/table/aggregate/TableAggregator.scala @@ -18,7 +18,6 @@ import pl.touk.nussknacker.engine.flink.table.utils.simulateddatatype.{ TODO: add remaining aggregations functions: - LISTAGG - ARRAY_AGG - - COLLECT TODO: unify aggregator function definitions with unbounded-streaming ones. Current duplication may lead to inconsistency in naming and may be confusing for users @@ -94,6 +93,12 @@ object TableAggregator extends Enum[TableAggregator] { override def inputAllowedTypesConstraint: Option[List[LogicalTypeRoot]] = Some(numericAggregationsAllowedTypes) } + case object Collect extends TableAggregator { + override val displayName: String = "Collect" + override def flinkFunctionDefinition: BuiltInFunctionDefinition = BuiltInFunctionDefinitions.COLLECT + override def inputAllowedTypesConstraint: Option[List[LogicalTypeRoot]] = None + } + private val minMaxAllowedTypes = List( TINYINT, SMALLINT, diff --git a/engine/flink/test-utils/src/main/scala/pl/touk/nussknacker/engine/flink/test/FlinkTestConfiguration.scala b/engine/flink/test-utils/src/main/scala/pl/touk/nussknacker/engine/flink/test/FlinkTestConfiguration.scala index 32e4ad93904..4335d07293e 100644 --- a/engine/flink/test-utils/src/main/scala/pl/touk/nussknacker/engine/flink/test/FlinkTestConfiguration.scala +++ b/engine/flink/test-utils/src/main/scala/pl/touk/nussknacker/engine/flink/test/FlinkTestConfiguration.scala @@ -22,6 +22,12 @@ object FlinkTestConfiguration { // which holds all needed jars/classes in case of running from Scala plugin in IDE. // but in case of running from sbt it contains only sbt-launcher.jar config.set(PipelineOptions.CLASSPATHS, List("http://dummy-classpath.invalid").asJava) + + // This is to prevent memory problem in tests with mutliple Table API based aggregations. An IllegalArgExceptionon + // is thrown with message "The minBucketMemorySize is not valid!" in + // org.apache.flink.table.runtime.util.collections.binary.AbstractBytesHashMap.java:121 where memorySize is set + // inside code-generated operator (like LocalHashAggregateWithKeys). + config.set(TaskManagerOptions.MANAGED_MEMORY_SIZE, MemorySize.parse("100m")) } } diff --git a/utils/utils/src/main/scala/pl/touk/nussknacker/engine/util/json/ToJsonEncoder.scala b/utils/utils/src/main/scala/pl/touk/nussknacker/engine/util/json/ToJsonEncoder.scala index 3bf0a897425..a57223fdd4c 100644 --- a/utils/utils/src/main/scala/pl/touk/nussknacker/engine/util/json/ToJsonEncoder.scala +++ b/utils/utils/src/main/scala/pl/touk/nussknacker/engine/util/json/ToJsonEncoder.scala @@ -68,14 +68,14 @@ case class ToJsonEncoder( case a: OffsetDateTime => Encoder[OffsetDateTime].apply(a) case a: UUID => safeString(a.toString) case a: DisplayJson => a.asJson - case a: scala.collection.Map[String @unchecked, _] => encodeMap(a.toMap) - case a: java.util.Map[String @unchecked, _] => encodeMap(a.asScala.toMap) - case a: Iterable[_] => fromValues(a.map(encode)) - case a: Enum[_] => safeString(a.toString) - case a: java.util.Collection[_] => fromValues(a.asScala.map(encode)) - case a: Array[_] => fromValues(a.map(encode)) - case _ if !failOnUnknown => safeString(any.toString) - case a => throw new IllegalArgumentException(s"Invalid type: ${a.getClass}") + case a: scala.collection.Map[_, _] => encodeMap(a.toMap) + case a: java.util.Map[_, _] => encodeMap(a.asScala.toMap) + case a: Iterable[_] => fromValues(a.map(encode)) + case a: Enum[_] => safeString(a.toString) + case a: java.util.Collection[_] => fromValues(a.asScala.map(encode)) + case a: Array[_] => fromValues(a.map(encode)) + case _ if !failOnUnknown => safeString(any.toString) + case a => throw new IllegalArgumentException(s"Invalid type: ${a.getClass}") } ) @@ -85,8 +85,13 @@ case class ToJsonEncoder( case None => Null } - private def encodeMap(map: Map[String, _]) = { - fromFields(map.mapValuesNow(encode)) + // TODO: make encoder aware of NU Types to encode things like multiset differently. Right now its handled by calling + // toString on keys. + private def encodeMap(map: Map[_, _]) = { + val mapWithStringKeys = map.view.map { case (k, v) => + k.toString -> v + }.toMap + fromFields(mapWithStringKeys.mapValuesNow(encode)) } }