From 502f0ecc9ad5d4b89912ebe310d82366a9727105 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 6 Nov 2024 08:35:26 +0800 Subject: [PATCH 1/4] [followup] Refactor JSON function and add TO_JSON_STRING, ARRAY_LENGHT functions (#870) Signed-off-by: Lantao Jin --- docs/ppl-lang/functions/ppl-json.md | 73 +++++++++++++++---- .../FlintSparkPPLJsonFunctionITSuite.scala | 38 +++++----- .../src/main/antlr4/OpenSearchPPLLexer.g4 | 2 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 2 + .../function/BuiltinFunctionName.java | 2 + .../ppl/utils/BuiltinFunctionTransformer.java | 22 ++---- ...PlanJsonFunctionsTranslatorTestSuite.scala | 20 +++-- 7 files changed, 100 insertions(+), 59 deletions(-) diff --git a/docs/ppl-lang/functions/ppl-json.md b/docs/ppl-lang/functions/ppl-json.md index 1953e8c70..5b26ee427 100644 --- a/docs/ppl-lang/functions/ppl-json.md +++ b/docs/ppl-lang/functions/ppl-json.md @@ -4,11 +4,11 @@ **Description** -`json(value)` Evaluates whether a value can be parsed as JSON. Returns the json string if valid, null otherwise. +`json(value)` Evaluates whether a string can be parsed as JSON format. Returns the string value if valid, null otherwise. -**Argument type:** STRING/JSON_ARRAY/JSON_OBJECT +**Argument type:** STRING -**Return type:** STRING +**Return type:** STRING/NULL A STRING expression of a valid JSON object format. @@ -47,7 +47,7 @@ A StructType expression of a valid JSON object. Example: - os> source=people | eval result = json(json_object('key', 123.45)) | fields result + os> source=people | eval result = json_object('key', 123.45) | fields result fetched rows / total rows = 1/1 +------------------+ | result | @@ -55,7 +55,7 @@ Example: | {"key":123.45} | +------------------+ - os> source=people | eval result = json(json_object('outer', json_object('inner', 123.45))) | fields result + os> source=people | eval result = json_object('outer', json_object('inner', 123.45)) | fields result fetched rows / total rows = 1/1 +------------------------------+ | result | @@ -81,13 +81,13 @@ Example: os> source=people | eval `json_array` = json_array(1, 2, 0, -1, 1.1, -0.11) fetched rows / total rows = 1/1 - +----------------------------+ - | json_array | - +----------------------------+ - | 1.0,2.0,0.0,-1.0,1.1,-0.11 | - +----------------------------+ + +------------------------------+ + | json_array | + +------------------------------+ + | [1.0,2.0,0.0,-1.0,1.1,-0.11] | + +------------------------------+ - os> source=people | eval `json_array_object` = json(json_object("array", json_array(1, 2, 0, -1, 1.1, -0.11))) + os> source=people | eval `json_array_object` = json_object("array", json_array(1, 2, 0, -1, 1.1, -0.11)) fetched rows / total rows = 1/1 +----------------------------------------+ | json_array_object | @@ -95,15 +95,44 @@ Example: | {"array":[1.0,2.0,0.0,-1.0,1.1,-0.11]} | +----------------------------------------+ +### `TO_JSON_STRING` + +**Description** + +`to_json_string(jsonObject)` Returns a JSON string with a given json object value. + +**Argument type:** JSON_OBJECT (Spark StructType/ArrayType) + +**Return type:** STRING + +Example: + + os> source=people | eval `json_string` = to_json_string(json_array(1, 2, 0, -1, 1.1, -0.11)) | fields json_string + fetched rows / total rows = 1/1 + +--------------------------------+ + | json_string | + +--------------------------------+ + | [1.0,2.0,0.0,-1.0,1.1,-0.11] | + +--------------------------------+ + + os> source=people | eval `json_string` = to_json_string(json_object('key', 123.45)) | fields json_string + fetched rows / total rows = 1/1 + +-----------------+ + | json_string | + +-----------------+ + | {'key', 123.45} | + +-----------------+ + + ### `JSON_ARRAY_LENGTH` **Description** -`json_array_length(jsonArray)` Returns the number of elements in the outermost JSON array. +`json_array_length(jsonArrayString)` Returns the number of elements in the outermost JSON array string. -**Argument type:** STRING/JSON_ARRAY +**Argument type:** STRING -A STRING expression of a valid JSON array format, or JSON_ARRAY object. +A STRING expression of a valid JSON array format. **Return type:** INTEGER @@ -119,6 +148,21 @@ Example: | 4 | 5 | null | +-----------+-----------+-------------+ + +### `ARRAY_LENGTH` + +**Description** + +`array_length(jsonArray)` Returns the number of elements in the outermost array. + +**Argument type:** ARRAY + +ARRAY or JSON_ARRAY object. + +**Return type:** INTEGER + +Example: + os> source=people | eval `json_array` = json_array_length(json_array(1,2,3,4)), `empty_array` = json_array_length(json_array()) fetched rows / total rows = 1/1 +--------------+---------------+ @@ -127,6 +171,7 @@ Example: | 4 | 0 | +--------------+---------------+ + ### `JSON_EXTRACT` **Description** diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala index 7cc0a221d..fca758101 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala @@ -163,30 +163,32 @@ class FlintSparkPPLJsonFunctionITSuite assert(ex.getMessage().contains("should all be the same type")) } - test("test json_array() with json()") { + test("test json_array() with to_json_tring()") { val frame = sql(s""" - | source = $testTable | eval result = json(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields result + | source = $testTable | eval result = to_json_string(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""[1.0,2.0,0.0,-1.0,1.1,-0.11]""")), frame) } - test("test json_array_length()") { + test("test array_length()") { var frame = sql(s""" - | source = $testTable | eval result = json_array_length(json_array('this', 'is', 'a', 'string', 'array')) | head 1 | fields result - | """.stripMargin) + | source = $testTable| eval result = array_length(json_array('this', 'is', 'a', 'string', 'array')) | head 1 | fields result + | """.stripMargin) assertSameRows(Seq(Row(5)), frame) frame = sql(s""" - | source = $testTable | eval result = json_array_length(json_array(1, 2, 0, -1, 1.1, -0.11)) | head 1 | fields result - | """.stripMargin) + | source = $testTable| eval result = array_length(json_array(1, 2, 0, -1, 1.1, -0.11)) | head 1 | fields result + | """.stripMargin) assertSameRows(Seq(Row(6)), frame) frame = sql(s""" - | source = $testTable | eval result = json_array_length(json_array()) | head 1 | fields result - | """.stripMargin) + | source = $testTable| eval result = array_length(json_array()) | head 1 | fields result + | """.stripMargin) assertSameRows(Seq(Row(0)), frame) + } - frame = sql(s""" + test("test json_array_length()") { + var frame = sql(s""" | source = $testTable | eval result = json_array_length('[]') | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row(0)), frame) @@ -211,24 +213,24 @@ class FlintSparkPPLJsonFunctionITSuite test("test json_object()") { // test value is a string var frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', 'string_value')) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', 'string_value')) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":"string_value"}""")), frame) // test value is a number frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', 123.45)) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', 123.45)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":123.45}""")), frame) // test value is a boolean frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', true)) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', true)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":true}""")), frame) frame = sql(s""" - | source = $testTable| eval result = json(json_object("a", 1, "b", 2, "c", 3)) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object("a", 1, "b", 2, "c", 3)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"a":1,"b":2,"c":3}""")), frame) } @@ -236,13 +238,13 @@ class FlintSparkPPLJsonFunctionITSuite test("test json_object() and json_array()") { // test value is an empty array var frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', array())) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', array())) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":[]}""")), frame) // test value is an array frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', array(1, 2, 3))) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', array(1, 2, 3))) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":[1,2,3]}""")), frame) @@ -272,14 +274,14 @@ class FlintSparkPPLJsonFunctionITSuite test("test json_object() nested") { val frame = sql(s""" - | source = $testTable | eval result = json(json_object('outer', json_object('inner', 123.45))) | head 1 | fields result + | source = $testTable | eval result = to_json_string(json_object('outer', json_object('inner', 123.45))) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"outer":{"inner":123.45}}""")), frame) } test("test json_object(), json_array() and json()") { val frame = sql(s""" - | source = $testTable | eval result = json(json_object("array", json_array(1,2,0,-1,1.1,-0.11))) | head 1 | fields result + | source = $testTable | eval result = to_json_string(json_object("array", json_array(1,2,0,-1,1.1,-0.11))) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"array":[1.0,2.0,0.0,-1.0,1.1,-0.11]}""")), frame) } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index fcec4d13f..93efb2df1 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -378,6 +378,7 @@ JSON: 'JSON'; JSON_OBJECT: 'JSON_OBJECT'; JSON_ARRAY: 'JSON_ARRAY'; JSON_ARRAY_LENGTH: 'JSON_ARRAY_LENGTH'; +TO_JSON_STRING: 'TO_JSON_STRING'; JSON_EXTRACT: 'JSON_EXTRACT'; JSON_KEYS: 'JSON_KEYS'; JSON_VALID: 'JSON_VALID'; @@ -393,6 +394,7 @@ JSON_VALID: 'JSON_VALID'; // COLLECTION FUNCTIONS ARRAY: 'ARRAY'; +ARRAY_LENGTH: 'ARRAY_LENGTH'; // LAMBDA FUNCTIONS //EXISTS: 'EXISTS'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index b7f293a4a..06dffa55c 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -860,6 +860,7 @@ jsonFunctionName | JSON_OBJECT | JSON_ARRAY | JSON_ARRAY_LENGTH + | TO_JSON_STRING | JSON_EXTRACT | JSON_KEYS | JSON_VALID @@ -876,6 +877,7 @@ jsonFunctionName collectionFunctionName : ARRAY + | ARRAY_LENGTH ; lambdaFunctionName diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 13b5c20ef..1959d0f6d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -213,6 +213,7 @@ public enum BuiltinFunctionName { JSON_OBJECT(FunctionName.of("json_object")), JSON_ARRAY(FunctionName.of("json_array")), JSON_ARRAY_LENGTH(FunctionName.of("json_array_length")), + TO_JSON_STRING(FunctionName.of("to_json_string")), JSON_EXTRACT(FunctionName.of("json_extract")), JSON_KEYS(FunctionName.of("json_keys")), JSON_VALID(FunctionName.of("json_valid")), @@ -228,6 +229,7 @@ public enum BuiltinFunctionName { /** COLLECTION Functions **/ ARRAY(FunctionName.of("array")), + ARRAY_LENGTH(FunctionName.of("array_length")), /** LAMBDA Functions **/ ARRAY_FORALL(FunctionName.of("forall")), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java index e39c9ab38..0b0fb8314 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java @@ -28,6 +28,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_LENGTH; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_SUB; @@ -58,6 +59,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.SYSDATE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPDIFF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TO_JSON_STRING; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TRIM; import static org.opensearch.sql.expression.function.BuiltinFunctionName.UTC_TIMESTAMP; import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK; @@ -102,7 +104,9 @@ public interface BuiltinFunctionTransformer { .put(COALESCE, "coalesce") .put(LENGTH, "length") .put(TRIM, "trim") + .put(ARRAY_LENGTH, "array_size") // json functions + .put(TO_JSON_STRING, "to_json") .put(JSON_KEYS, "json_object_keys") .put(JSON_EXTRACT, "get_json_object") .build(); @@ -126,26 +130,12 @@ public interface BuiltinFunctionTransformer { .put( JSON_ARRAY_LENGTH, args -> { - // Check if the input is an array (from json_array()) or a JSON string - if (args.get(0) instanceof UnresolvedFunction) { - // Input is a JSON array - return UnresolvedFunction$.MODULE$.apply("json_array_length", - seq(UnresolvedFunction$.MODULE$.apply("to_json", seq(args), false)), false); - } else { - // Input is a JSON string - return UnresolvedFunction$.MODULE$.apply("json_array_length", seq(args.get(0)), false); - } + return UnresolvedFunction$.MODULE$.apply("json_array_length", seq(args.get(0)), false); }) .put( JSON, args -> { - // Check if the input is a named_struct (from json_object()) or a JSON string - if (args.get(0) instanceof UnresolvedFunction) { - return UnresolvedFunction$.MODULE$.apply("to_json", seq(args.get(0)), false); - } else { - return UnresolvedFunction$.MODULE$.apply("get_json_object", - seq(args.get(0), Literal$.MODULE$.apply("$")), false); - } + return UnresolvedFunction$.MODULE$.apply("get_json_object", seq(args.get(0), Literal$.MODULE$.apply("$")), false); }) .put( JSON_VALID, diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala index 216c0f232..6193bc43f 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala @@ -48,7 +48,7 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, """source=t a = json(json_object('key', array(1, 2, 3)))"""), + plan(pplParser, """source=t a = to_json_string(json_object('key', array(1, 2, 3)))"""), context) val table = UnresolvedRelation(Seq("t")) @@ -97,7 +97,9 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, """source=t a = json(json_object('key', json_array(1, 2, 3)))"""), + plan( + pplParser, + """source=t a = to_json_string(json_object('key', json_array(1, 2, 3)))"""), context) val table = UnresolvedRelation(Seq("t")) @@ -139,25 +141,21 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } - test("test json_array_length(json_array())") { + test("test array_length(json_array())") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, """source=t a = json_array_length(json_array(1,2,3))"""), + plan(pplParser, """source=t a = array_length(json_array(1,2,3))"""), context) val table = UnresolvedRelation(Seq("t")) val jsonFunc = UnresolvedFunction( - "json_array_length", + "array_size", Seq( UnresolvedFunction( - "to_json", - Seq( - UnresolvedFunction( - "array", - Seq(Literal(1), Literal(2), Literal(3)), - isDistinct = false)), + "array", + Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false)), isDistinct = false) val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) From cfd41a36a853f5b413fca7d03584bd1ba95e64bf Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 6 Nov 2024 09:19:56 +0800 Subject: [PATCH 2/4] Join side aliases should be optional (#862) * Join side aliases should be optional Signed-off-by: Lantao Jin * address comments Signed-off-by: Lantao Jin * typo Signed-off-by: Lantao Jin --------- Signed-off-by: Lantao Jin --- docs/ppl-lang/PPL-Example-Commands.md | 6 +- docs/ppl-lang/ppl-join-command.md | 23 +- .../spark/ppl/FlintSparkPPLBasicITSuite.scala | 64 ++++ .../spark/ppl/FlintSparkPPLJoinITSuite.scala | 269 +++++++++++++++- .../src/main/antlr4/OpenSearchPPLParser.g4 | 2 +- .../sql/ast/tree/DescribeRelation.java | 4 +- .../org/opensearch/sql/ast/tree/Join.java | 6 +- .../org/opensearch/sql/ast/tree/Relation.java | 41 +-- .../sql/ast/tree/SubqueryAlias.java | 10 +- .../sql/ppl/CatalystExpressionVisitor.java | 10 + .../sql/ppl/CatalystPlanContext.java | 60 ++-- .../sql/ppl/CatalystQueryPlanVisitor.java | 3 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 28 +- ...lPlanBasicQueriesTranslatorTestSuite.scala | 40 ++- ...PLLogicalPlanJoinTranslatorTestSuite.scala | 292 +++++++++++++++++- 15 files changed, 755 insertions(+), 103 deletions(-) diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index e780f688d..e80f8c906 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -306,7 +306,11 @@ source = table | where ispresent(a) | - `source = table1 | left semi join left = l right = r on l.a = r.a table2` - `source = table1 | left anti join left = l right = r on l.a = r.a table2` - `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` - +- `source = table1 | inner join on table1.a = table2.a table2 | fields table1.a, table2.a, table1.b, table1.c` (directly refer table name) +- `source = table1 | inner join on a = c table2 | fields a, b, c, d` (ignore side aliases as long as no ambiguous) +- `source = table1 as t1 | join left = l right = r on l.a = r.a table2 as t2 | fields l.a, r.a` (side alias overrides table alias) +- `source = table1 as t1 | join left = l right = r on l.a = r.a table2 as t2 | fields t1.a, t2.a` (error, side alias overrides table alias) +- `source = table1 | join left = l right = r on l.a = r.a [ source = table2 ] as s | fields l.a, s.a` (error, side alias overrides subquery alias) #### **Lookup** [See additional command details](ppl-lookup-command.md) diff --git a/docs/ppl-lang/ppl-join-command.md b/docs/ppl-lang/ppl-join-command.md index 525373f7c..b374bce5f 100644 --- a/docs/ppl-lang/ppl-join-command.md +++ b/docs/ppl-lang/ppl-join-command.md @@ -65,8 +65,8 @@ WHERE t1.serviceName = `order` SEARCH source= | | [joinType] JOIN - leftAlias - rightAlias + [leftAlias] + [rightAlias] [joinHints] ON joinCriteria @@ -79,12 +79,12 @@ SEARCH source= **leftAlias** - Syntax: `left = ` -- Required +- Optional - Description: The subquery alias to use with the left join side, to avoid ambiguous naming. **rightAlias** - Syntax: `right = ` -- Required +- Optional - Description: The subquery alias to use with the right join side, to avoid ambiguous naming. **joinHints** @@ -138,11 +138,11 @@ Rewritten by PPL Join query: ```sql SEARCH source=customer | FIELDS c_custkey -| LEFT OUTER JOIN left = c, right = o - ON c.c_custkey = o.o_custkey AND o_comment NOT LIKE '%unusual%packages%' +| LEFT OUTER JOIN + ON c_custkey = o_custkey AND o_comment NOT LIKE '%unusual%packages%' orders -| STATS count(o_orderkey) AS c_count BY c.c_custkey -| STATS count(1) AS custdist BY c_count +| STATS count(o_orderkey) AS c_count BY c_custkey +| STATS count() AS custdist BY c_count | SORT - custdist, - c_count ``` _- **Limitation: sub-searches is unsupported in join right side**_ @@ -151,14 +151,15 @@ If sub-searches is supported, above ppl query could be rewritten as: ```sql SEARCH source=customer | FIELDS c_custkey -| LEFT OUTER JOIN left = c, right = o ON c.c_custkey = o.o_custkey +| LEFT OUTER JOIN + ON c_custkey = o_custkey [ SEARCH source=orders | WHERE o_comment NOT LIKE '%unusual%packages%' | FIELDS o_orderkey, o_custkey ] -| STATS count(o_orderkey) AS c_count BY c.c_custkey -| STATS count(1) AS custdist BY c_count +| STATS count(o_orderkey) AS c_count BY c_custkey +| STATS count() AS custdist BY c_count | SORT - custdist, - c_count ``` diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index cbc4308b0..3bd98edf1 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -597,4 +597,68 @@ class FlintSparkPPLBasicITSuite | """.stripMargin)) assert(ex.getMessage().contains("Invalid table name")) } + + test("Search multiple tables - translated into union call with fields") { + val frame = sql(s""" + | source = $t1, $t2 + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4)), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + val allFields1 = UnresolvedStar(None) + val allFields2 = UnresolvedStar(None) + + val projectedTable1 = Project(Seq(allFields1), table1) + val projectedTable2 = Project(Seq(allFields2), table2) + + val expectedPlan = + Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("Search multiple tables - with table alias") { + val frame = sql(s""" + | source = $t1, $t2 as t | where t.country = "USA" + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + val plan1 = Filter( + EqualTo(UnresolvedAttribute("t.country"), Literal("USA")), + SubqueryAlias("t", table1)) + val plan2 = Filter( + EqualTo(UnresolvedAttribute("t.country"), Literal("USA")), + SubqueryAlias("t", table2)) + + val projectedTable1 = Project(Seq(UnresolvedStar(None)), plan1) + val projectedTable2 = Project(Seq(UnresolvedStar(None)), plan2) + + val expectedPlan = + Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala index 00e55d50a..3127325c8 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark.ppl -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Divide, EqualTo, Floor, GreaterThan, LessThan, Literal, Multiply, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} @@ -924,4 +924,271 @@ class FlintSparkPPLJoinITSuite s }.size == 13) } + + test("test multiple joins without table aliases") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN ON $testTable1.name = $testTable2.name $testTable2 + | | JOIN ON $testTable2.name = $testTable3.name $testTable3 + | | fields $testTable1.name, $testTable2.name, $testTable3.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello"), + Row("John", "John", "John"), + Row("David", "David", "David"), + Row("David", "David", "David"), + Row("Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + table1, + table2, + Inner, + Some( + EqualTo( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + table3, + Inner, + Some( + EqualTo( + UnresolvedAttribute(s"$testTable2.name"), + UnresolvedAttribute(s"$testTable3.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name"), + UnresolvedAttribute(s"$testTable3.name")), + joinPlan2) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with part subquery aliases") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | fields t1.name, t2.name, t3.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello"), + Row("John", "John", "John"), + Row("David", "David", "David"), + Row("David", "David", "David"), + Row("Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name")), + joinPlan2) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 1") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN right = t4 ON t1.name = t4.name $testTable1 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello", "Hello"), + Row("John", "John", "John", "John"), + Row("David", "David", "David", "David"), + Row("David", "David", "David", "David"), + Row("Jane", "Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 2") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN ON t1.name = t4.name + | [ + | source = $testTable1 + | ] as t4 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello", "Hello"), + Row("John", "John", "John", "John"), + Row("David", "David", "David", "David"), + Row("David", "David", "David", "David"), + Row("Jane", "Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("check access the reference by aliases") { + var frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name $testTable2 as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 as t1 + | | JOIN ON t1.name = t2.name $testTable2 as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 ] as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 as t2 ] + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + } + + test("access the reference by override aliases should throw exception") { + var ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 as tt + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 as tt + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as tt ] + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 as tt ] as t2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 ] as tt + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 as tt + | | JOIN left = t1 ON t1.name = t2.name $testTable2 as t2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + } } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 06dffa55c..123d1e15a 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -339,7 +339,7 @@ joinType ; sideAlias - : LEFT EQUAL leftAlias = ident COMMA? RIGHT EQUAL rightAlias = ident + : (LEFT EQUAL leftAlias = ident)? COMMA? (RIGHT EQUAL rightAlias = ident)? ; joinCriteria diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java index b513d01bf..dd9947329 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java @@ -8,12 +8,14 @@ import lombok.ToString; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import java.util.Collections; + /** * Extend Relation to describe the table itself */ @ToString public class DescribeRelation extends Relation{ public DescribeRelation(UnresolvedExpression tableName) { - super(tableName); + super(Collections.singletonList(tableName)); } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java index 89f787d34..176902911 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java @@ -25,15 +25,15 @@ public class Join extends UnresolvedPlan { private UnresolvedPlan left; private final UnresolvedPlan right; - private final String leftAlias; - private final String rightAlias; + private final Optional leftAlias; + private final Optional rightAlias; private final JoinType joinType; private final Optional joinCondition; private final JoinHint joinHint; @Override public UnresolvedPlan attach(UnresolvedPlan child) { - this.left = new SubqueryAlias(leftAlias, child); + this.left = leftAlias.isEmpty() ? child : new SubqueryAlias(leftAlias.get(), child); return this; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java index 1b30a7998..d8ea104a4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java @@ -6,53 +6,34 @@ package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; -import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; -import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; -import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; /** Logical plan node of Relation, the interface for building the searching sources. */ -@AllArgsConstructor @ToString +@Getter @EqualsAndHashCode(callSuper = false) @RequiredArgsConstructor public class Relation extends UnresolvedPlan { private static final String COMMA = ","; - private final List tableName; - - public Relation(UnresolvedExpression tableName) { - this(tableName, null); - } - - public Relation(UnresolvedExpression tableName, String alias) { - this.tableName = Arrays.asList(tableName); - this.alias = alias; - } - - /** Optional alias name for the relation. */ - @Setter @Getter private String alias; - - /** - * Return table name. - * - * @return table name - */ - public List getTableName() { - return tableName.stream().map(Object::toString).collect(Collectors.toList()); - } + // A relation could contain more than one table/index names, such as + // source=account1, account2 + // source=`account1`,`account2` + // source=`account*` + // They translated into union call with fields. + private final List tableNames; public List getQualifiedNames() { - return tableName.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); + return tableNames.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); } /** @@ -63,11 +44,11 @@ public List getQualifiedNames() { * @return TableQualifiedName. */ public QualifiedName getTableQualifiedName() { - if (tableName.size() == 1) { - return (QualifiedName) tableName.get(0); + if (tableNames.size() == 1) { + return (QualifiedName) tableNames.get(0); } else { return new QualifiedName( - tableName.stream() + tableNames.stream() .map(UnresolvedExpression::toString) .collect(Collectors.joining(COMMA))); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java index 29c3d4b90..ba66cca80 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java @@ -6,19 +6,14 @@ package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; -import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; -import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import java.util.List; -import java.util.Objects; -@AllArgsConstructor @EqualsAndHashCode(callSuper = false) -@RequiredArgsConstructor @ToString public class SubqueryAlias extends UnresolvedPlan { @Getter private final String alias; @@ -32,6 +27,11 @@ public SubqueryAlias(UnresolvedPlan child, String suffix) { this.child = child; } + public SubqueryAlias(String alias, UnresolvedPlan child) { + this.alias = alias; + this.child = child; + } + public List getChild() { return ImmutableList.of(child); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java index a0506ceee..4c8d117b3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -110,6 +110,11 @@ public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext c return unresolved.accept(this, context); } + /** This method is only for analyze the join condition expression */ + public Expression analyzeJoinCondition(UnresolvedExpression unresolved, CatalystPlanContext context) { + return context.resolveJoinCondition(unresolved, this::analyze); + } + @Override public Expression visitLiteral(Literal node, CatalystPlanContext context) { return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( @@ -197,6 +202,11 @@ public Expression visitCompare(Compare node, CatalystPlanContext context) { @Override public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { + // When the qualified name is part of join condition, for example: table1.id = table2.id + // findRelation(context.traversalContext() only returns relation table1 which cause table2.id fail to resolve + if (context.isResolvingJoinCondition()) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } List relation = findRelation(context.traversalContext()); if (!relation.isEmpty()) { Optional resolveField = resolveField(relation, node, context.getRelations()); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 61762f616..53dc17576 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -5,6 +5,7 @@ package org.opensearch.sql.ppl; +import lombok.Getter; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.expressions.AttributeReference; import org.apache.spark.sql.catalyst.expressions.Expression; @@ -39,19 +40,19 @@ public class CatalystPlanContext { /** * Catalyst relations list **/ - private List projectedFields = new ArrayList<>(); + @Getter private List projectedFields = new ArrayList<>(); /** * Catalyst relations list **/ - private List relations = new ArrayList<>(); + @Getter private List relations = new ArrayList<>(); /** * Catalyst SubqueryAlias list **/ - private List subqueryAlias = new ArrayList<>(); + @Getter private List subqueryAlias = new ArrayList<>(); /** * Catalyst evolving logical plan **/ - private Stack planBranches = new Stack<>(); + @Getter private Stack planBranches = new Stack<>(); /** * The current traversal context the visitor is going threw */ @@ -60,28 +61,12 @@ public class CatalystPlanContext { /** * NamedExpression contextual parameters **/ - private final Stack namedParseExpressions = new Stack<>(); + @Getter private final Stack namedParseExpressions = new Stack<>(); /** * Grouping NamedExpression contextual parameters **/ - private final Stack groupingParseExpressions = new Stack<>(); - - public Stack getPlanBranches() { - return planBranches; - } - - public List getRelations() { - return relations; - } - - public List getSubqueryAlias() { - return subqueryAlias; - } - - public List getProjectedFields() { - return projectedFields; - } + @Getter private final Stack groupingParseExpressions = new Stack<>(); public LogicalPlan getPlan() { if (this.planBranches.isEmpty()) return null; @@ -101,10 +86,6 @@ public Stack traversalContext() { return planTraversalContext; } - public Stack getNamedParseExpressions() { - return namedParseExpressions; - } - public void setNamedParseExpressions(Stack namedParseExpressions) { this.namedParseExpressions.clear(); this.namedParseExpressions.addAll(namedParseExpressions); @@ -114,10 +95,6 @@ public Optional popNamedParseExpressions() { return namedParseExpressions.isEmpty() ? Optional.empty() : Optional.of(namedParseExpressions.pop()); } - public Stack getGroupingParseExpressions() { - return groupingParseExpressions; - } - /** * define new field * @@ -154,13 +131,13 @@ public LogicalPlan withProjectedFields(List projectedField this.projectedFields.addAll(projectedFields); return getPlan(); } - + public LogicalPlan applyBranches(List> plans) { plans.forEach(plan -> with(plan.apply(planBranches.get(0)))); planBranches.remove(0); return getPlan(); - } - + } + /** * append plan with evolving plans branches * @@ -288,4 +265,21 @@ public static Optional findRelation(LogicalPlan plan) { return Optional.empty(); } + @Getter private boolean isResolvingJoinCondition = false; + + /** + * Resolve the join condition with the given function. + * A flag will be set to true ahead expression resolving, then false after resolving. + * @param expr + * @param transformFunction + * @return + */ + public Expression resolveJoinCondition( + UnresolvedExpression expr, + BiFunction transformFunction) { + isResolvingJoinCondition = true; + Expression result = transformFunction.apply(expr, this); + isResolvingJoinCondition = false; + return result; + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 669459fba..a43378480 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -273,7 +273,8 @@ public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); return context.apply(left -> { LogicalPlan right = node.getRight().accept(this, context); - Optional joinCondition = node.getJoinCondition().map(c -> visitExpression(c, context)); + Optional joinCondition = node.getJoinCondition() + .map(c -> expressionAnalyzer.analyzeJoinCondition(c, context)); context.retainAllNamedParseExpressions(p -> p); context.retainAllPlans(p -> p); return join(left, right, node.getJoinType(), joinCondition, node.getJoinHint()); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 4e6b1f131..36a34cd06 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -155,14 +155,25 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct joinType = Join.JoinType.CROSS; } Join.JoinHint joinHint = getJoinHint(ctx.joinHintList()); - String leftAlias = ctx.sideAlias().leftAlias.getText(); - String rightAlias = ctx.sideAlias().rightAlias.getText(); + Optional leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(ctx.sideAlias().leftAlias.getText()) : Optional.empty(); + Optional rightAlias = Optional.empty(); if (ctx.tableOrSubqueryClause().alias != null) { - // left and right aliases are required in join syntax. Setting by 'AS' causes ambiguous - throw new SyntaxCheckException("'AS' is not allowed in right subquery, use right= instead"); + rightAlias = Optional.of(ctx.tableOrSubqueryClause().alias.getText()); } + if (ctx.sideAlias().rightAlias != null) { + rightAlias = Optional.of(ctx.sideAlias().rightAlias.getText()); + } + UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause()); - UnresolvedPlan right = new SubqueryAlias(rightAlias, rightRelation); + // Add a SubqueryAlias to the right plan when the right alias is present and no duplicated alias existing in right. + UnresolvedPlan right; + if (rightAlias.isEmpty() || + (rightRelation instanceof SubqueryAlias && + rightAlias.get().equals(((SubqueryAlias) rightRelation).getAlias()))) { + right = rightRelation; + } else { + right = new SubqueryAlias(rightAlias.get(), rightRelation); + } Optional joinCondition = ctx.joinCriteria() == null ? Optional.empty() : Optional.of(expressionBuilder.visitJoinCriteria(ctx.joinCriteria())); @@ -370,7 +381,7 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo /** Lookup command */ @Override public UnresolvedPlan visitLookupCommand(OpenSearchPPLParser.LookupCommandContext ctx) { - Relation lookupRelation = new Relation(this.internalVisitExpression(ctx.tableSource())); + Relation lookupRelation = new Relation(Collections.singletonList(this.internalVisitExpression(ctx.tableSource()))); Lookup.OutputStrategy strategy = ctx.APPEND() != null ? Lookup.OutputStrategy.APPEND : Lookup.OutputStrategy.REPLACE; java.util.Map lookupMappingList = buildLookupPair(ctx.lookupMappingList().lookupPair()); @@ -509,9 +520,8 @@ public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubq @Override public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) { - return ctx.alias == null - ? new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())) - : new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList()), ctx.alias.getText()); + Relation relation = new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())); + return ctx.alias != null ? new SubqueryAlias(ctx.alias.getText(), relation) : relation; } @Override diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 2a569dbdf..50ef985d6 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -13,7 +13,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, GreaterThan, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, EqualTo, GreaterThan, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.DescribeTableCommand @@ -292,6 +292,44 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("Search multiple tables - with table alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """ + | source=table1, table2, table3 as t + | | where t.name = 'Molly' + |""".stripMargin), + context) + + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val star = UnresolvedStar(None) + val plan1 = Project( + Seq(star), + Filter( + EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")), + SubqueryAlias("t", table1))) + val plan2 = Project( + Seq(star), + Filter( + EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")), + SubqueryAlias("t", table2))) + val plan3 = Project( + Seq(star), + Filter( + EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")), + SubqueryAlias("t", table3))) + + val expectedPlan = + Union(Seq(plan1, plan2, plan3), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logPlan, false) + } + test("test fields + field list") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala index 3ceff7735..f4ed397e3 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala @@ -271,9 +271,9 @@ class PPLLogicalPlanJoinTranslatorTestSuite pplParser, s""" | source = $testTable1 - | | inner JOIN left = l,right = r ON l.id = r.id $testTable2 - | | left JOIN left = l,right = r ON l.name = r.name $testTable3 - | | cross JOIN left = l,right = r $testTable4 + | | inner JOIN left = l right = r ON l.id = r.id $testTable2 + | | left JOIN left = l right = r ON l.name = r.name $testTable3 + | | cross JOIN left = l right = r $testTable4 | """.stripMargin) val logicalPlan = planTransformer.visit(logPlan, context) val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) @@ -443,17 +443,17 @@ class PPLLogicalPlanJoinTranslatorTestSuite s""" | source = $testTable1 | | head 10 - | | inner JOIN left = l,right = r ON l.id = r.id + | | inner JOIN left = l right = r ON l.id = r.id | [ | source = $testTable2 | | where id > 10 | ] - | | left JOIN left = l,right = r ON l.name = r.name + | | left JOIN left = l right = r ON l.name = r.name | [ | source = $testTable3 | | fields id | ] - | | cross JOIN left = l,right = r + | | cross JOIN left = l right = r | [ | source = $testTable4 | | sort id @@ -565,4 +565,284 @@ class PPLLogicalPlanJoinTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), sort) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("test multiple joins with table alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 as t1 + | | JOIN ON t1.id = t2.id + | [ + | source = table2 as t2 + | ] + | | JOIN ON t2.id = t3.id + | [ + | source = table3 as t3 + | ] + | | JOIN ON t3.id = t4.id + | [ + | source = table4 as t4 + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.id"), UnresolvedAttribute("t2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.id"), UnresolvedAttribute("t3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table4), + Inner, + Some(EqualTo(UnresolvedAttribute("t3.id"), UnresolvedAttribute("t4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with table and subquery alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 as t1 + | | JOIN left = l right = r ON t1.id = t2.id + | [ + | source = table2 as t2 + | ] + | | JOIN left = l right = r ON t2.id = t3.id + | [ + | source = table3 as t3 + | ] + | | JOIN left = l right = r ON t3.id = t4.id + | [ + | source = table4 as t4 + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("l", SubqueryAlias("t1", table1)), + SubqueryAlias("r", SubqueryAlias("t2", table2)), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.id"), UnresolvedAttribute("t2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + SubqueryAlias("l", joinPlan1), + SubqueryAlias("r", SubqueryAlias("t3", table3)), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.id"), UnresolvedAttribute("t3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + SubqueryAlias("l", joinPlan2), + SubqueryAlias("r", SubqueryAlias("t4", table4)), + Inner, + Some(EqualTo(UnresolvedAttribute("t3.id"), UnresolvedAttribute("t4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins without table aliases") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 + | | JOIN ON table1.id = table2.id table2 + | | JOIN ON table1.id = table3.id table3 + | | JOIN ON table2.id = table4.id table4 + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + table1, + table2, + Inner, + Some(EqualTo(UnresolvedAttribute("table1.id"), UnresolvedAttribute("table2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + table3, + Inner, + Some(EqualTo(UnresolvedAttribute("table1.id"), UnresolvedAttribute("table3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + table4, + Inner, + Some(EqualTo(UnresolvedAttribute("table2.id"), UnresolvedAttribute("table4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with part subquery aliases") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name table2 + | | JOIN right = t3 ON t1.name = t3.name table3 + | | JOIN right = t4 ON t2.name = t4.name table4 + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table4), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 1") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN right = t4 ON t1.name = t4.name $testTable1 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 2") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN ON t1.name = t4.name + | [ + | source = $testTable1 + | ] as t4 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test side alias will override the subquery alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as ttt ] as tt + | | fields t1.name, t2.name + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", SubqueryAlias("tt", SubqueryAlias("ttt", table2))), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val expectedPlan = + Project(Seq(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name")), joinPlan1) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } From 48be5cc1224ea39a38dc55ba465577d24dc99d7f Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 7 Nov 2024 10:15:04 +0800 Subject: [PATCH 3/4] Add TPC-H PPL query suite (#830) * Add TPC-H PPL query suite Signed-off-by: Lantao Jin * fix failure of loading resources Signed-off-by: Lantao Jin * fix data_add() Signed-off-by: Lantao Jin * enable q21 and add docs Signed-off-by: Lantao Jin --------- Signed-off-by: Lantao Jin --- build.sbt | 3 +- docs/ppl-lang/README.md | 4 + docs/ppl-lang/ppl-tpch.md | 102 ++++++++++ .../src/integration/resources/tpch/q1.ppl | 35 ++++ .../src/integration/resources/tpch/q10.ppl | 45 +++++ .../src/integration/resources/tpch/q11.ppl | 45 +++++ .../src/integration/resources/tpch/q12.ppl | 42 +++++ .../src/integration/resources/tpch/q13.ppl | 31 +++ .../src/integration/resources/tpch/q14.ppl | 25 +++ .../src/integration/resources/tpch/q15.ppl | 52 +++++ .../src/integration/resources/tpch/q16.ppl | 45 +++++ .../src/integration/resources/tpch/q17.ppl | 34 ++++ .../src/integration/resources/tpch/q18.ppl | 48 +++++ .../src/integration/resources/tpch/q19.ppl | 61 ++++++ .../src/integration/resources/tpch/q2.ppl | 62 ++++++ .../src/integration/resources/tpch/q20.ppl | 62 ++++++ .../src/integration/resources/tpch/q21.ppl | 64 +++++++ .../src/integration/resources/tpch/q22.ppl | 58 ++++++ .../src/integration/resources/tpch/q3.ppl | 33 ++++ .../src/integration/resources/tpch/q4.ppl | 33 ++++ .../src/integration/resources/tpch/q5.ppl | 36 ++++ .../src/integration/resources/tpch/q6.ppl | 18 ++ .../src/integration/resources/tpch/q7.ppl | 56 ++++++ .../src/integration/resources/tpch/q8.ppl | 60 ++++++ .../src/integration/resources/tpch/q9.ppl | 50 +++++ .../flint/spark/ppl/tpch/TPCHQueryBase.scala | 177 ++++++++++++++++++ .../spark/ppl/tpch/TPCHQueryITSuite.scala | 43 +++++ 27 files changed, 1323 insertions(+), 1 deletion(-) create mode 100644 docs/ppl-lang/ppl-tpch.md create mode 100644 integ-test/src/integration/resources/tpch/q1.ppl create mode 100644 integ-test/src/integration/resources/tpch/q10.ppl create mode 100644 integ-test/src/integration/resources/tpch/q11.ppl create mode 100644 integ-test/src/integration/resources/tpch/q12.ppl create mode 100644 integ-test/src/integration/resources/tpch/q13.ppl create mode 100644 integ-test/src/integration/resources/tpch/q14.ppl create mode 100644 integ-test/src/integration/resources/tpch/q15.ppl create mode 100644 integ-test/src/integration/resources/tpch/q16.ppl create mode 100644 integ-test/src/integration/resources/tpch/q17.ppl create mode 100644 integ-test/src/integration/resources/tpch/q18.ppl create mode 100644 integ-test/src/integration/resources/tpch/q19.ppl create mode 100644 integ-test/src/integration/resources/tpch/q2.ppl create mode 100644 integ-test/src/integration/resources/tpch/q20.ppl create mode 100644 integ-test/src/integration/resources/tpch/q21.ppl create mode 100644 integ-test/src/integration/resources/tpch/q22.ppl create mode 100644 integ-test/src/integration/resources/tpch/q3.ppl create mode 100644 integ-test/src/integration/resources/tpch/q4.ppl create mode 100644 integ-test/src/integration/resources/tpch/q5.ppl create mode 100644 integ-test/src/integration/resources/tpch/q6.ppl create mode 100644 integ-test/src/integration/resources/tpch/q7.ppl create mode 100644 integ-test/src/integration/resources/tpch/q8.ppl create mode 100644 integ-test/src/integration/resources/tpch/q9.ppl create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryBase.scala create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryITSuite.scala diff --git a/build.sbt b/build.sbt index 66b06d6be..1300f68a0 100644 --- a/build.sbt +++ b/build.sbt @@ -238,7 +238,8 @@ lazy val integtest = (project in file("integ-test")) inConfig(IntegrationTest)(Defaults.testSettings ++ Seq( IntegrationTest / javaSource := baseDirectory.value / "src/integration/java", IntegrationTest / scalaSource := baseDirectory.value / "src/integration/scala", - IntegrationTest / parallelExecution := false, + IntegrationTest / resourceDirectory := baseDirectory.value / "src/integration/resources", + IntegrationTest / parallelExecution := false, IntegrationTest / fork := true, )), inConfig(AwsIntegrationTest)(Defaults.testSettings ++ Seq( diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index d78f4c030..ef186e5f2 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -104,6 +104,10 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). ### Example PPL Queries See samples of [PPL queries](PPL-Example-Commands.md) +--- +### TPC-H PPL Query Rewriting +See samples of [TPC-H PPL query rewriting](ppl-tpch.md) + --- ### Planned PPL Commands diff --git a/docs/ppl-lang/ppl-tpch.md b/docs/ppl-lang/ppl-tpch.md new file mode 100644 index 000000000..ef5846ce0 --- /dev/null +++ b/docs/ppl-lang/ppl-tpch.md @@ -0,0 +1,102 @@ +## TPC-H Benchmark + +TPC-H is a decision support benchmark designed to evaluate the performance of database systems in handling complex business-oriented queries and concurrent data modifications. The benchmark utilizes a dataset that is broadly representative of various industries, making it widely applicable. TPC-H simulates a decision support environment where large volumes of data are analyzed, intricate queries are executed, and critical business questions are answered. + +### Test PPL Queries + +TPC-H 22 test query statements: [TPCH-Query-PPL](https://github.com/opensearch-project/opensearch-spark/blob/main/integ-test/src/integration/resources/tpch) + +### Data Preparation + +#### Option 1 - from PyPi + +``` +# Create the virtual environment +python3 -m venv .venv + +# Activate the virtual environment +. .venv/bin/activate + +pip install tpch-datagen +``` + +#### Option 2 - from source + +``` +git clone https://github.com/gizmodata/tpch-datagen + +cd tpch-datagen + +# Create the virtual environment +python3 -m venv .venv + +# Activate the virtual environment +. .venv/bin/activate + +# Upgrade pip, setuptools, and wheel +pip install --upgrade pip setuptools wheel + +# Install TPC-H Datagen - in editable mode with client and dev dependencies +pip install --editable .[dev] +``` + +#### Usage + +Here are the options for the tpch-datagen command: +``` +tpch-datagen --help +Usage: tpch-datagen [OPTIONS] + +Options: + --version / --no-version Prints the TPC-H Datagen package version and + exits. [required] + --scale-factor INTEGER The TPC-H Scale Factor to use for data + generation. + --data-directory TEXT The target output data directory to put the + files into [default: data; required] + --work-directory TEXT The work directory to use for data + generation. [default: /tmp; required] + --overwrite / --no-overwrite Can we overwrite the target directory if it + already exists... [default: no-overwrite; + required] + --num-chunks INTEGER The number of chunks that will be generated + - more chunks equals smaller memory + requirements, but more files generated. + [default: 10; required] + --num-processes INTEGER The maximum number of processes for the + multi-processing pool to use for data + generation. [default: 10; required] + --duckdb-threads INTEGER The number of DuckDB threads to use for data + generation (within each job process). + [default: 1; required] + --per-thread-output / --no-per-thread-output + Controls whether to write the output to a + single file or multiple files (for each + process). [default: per-thread-output; + required] + --compression-method [none|snappy|gzip|zstd] + The compression method to use for the + parquet files generated. [default: zstd; + required] + --file-size-bytes TEXT The target file size for the parquet files + generated. [default: 100m; required] + --help Show this message and exit. +``` + +### Generate 1 GB data with zstd (by default) compression + +``` +tpch-datagen --scale-factor 1 +``` + +### Generate 10 GB data with snappy compression + +``` +tpch-datagen --scale-factor 10 --compression-method snappy +``` + +### Query Test + +All TPC-H PPL Queries located in `integ-test/src/integration/resources/tpch` folder. + +To test all queries, run `org.opensearch.flint.spark.ppl.tpch.TPCHQueryITSuite`. \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q1.ppl b/integ-test/src/integration/resources/tpch/q1.ppl new file mode 100644 index 000000000..885ce35c6 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q1.ppl @@ -0,0 +1,35 @@ +/* +select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order +from + lineitem +where + l_shipdate <= date '1998-12-01' - interval '90' day +group by + l_returnflag, + l_linestatus +order by + l_returnflag, + l_linestatus +*/ + +source = lineitem +| where l_shipdate <= subdate(date('1998-12-01'), 90) +| stats sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count() as count_order + by l_returnflag, l_linestatus +| sort l_returnflag, l_linestatus \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q10.ppl b/integ-test/src/integration/resources/tpch/q10.ppl new file mode 100644 index 000000000..10a050785 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q10.ppl @@ -0,0 +1,45 @@ +/* +select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +from + customer, + orders, + lineitem, + nation +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey +group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +order by + revenue desc +limit 20 +*/ + +source = customer +| join ON c_custkey = o_custkey orders +| join ON l_orderkey = o_orderkey lineitem +| join ON c_nationkey = n_nationkey nation +| where o_orderdate >= date('1993-10-01') + AND o_orderdate < date_add(date('1993-10-01'), interval 3 month) + AND l_returnflag = 'R' +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by c_custkey, c_name, c_acctbal, c_phone, n_name, c_address, c_comment +| sort - revenue +| head 20 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q11.ppl b/integ-test/src/integration/resources/tpch/q11.ppl new file mode 100644 index 000000000..3a55d986e --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q11.ppl @@ -0,0 +1,45 @@ +/* +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc +*/ + +source = partsupp +| join ON ps_suppkey = s_suppkey supplier +| join ON s_nationkey = n_nationkey nation +| where n_name = 'GERMANY' +| stats sum(ps_supplycost * ps_availqty) as value by ps_partkey +| where value > [ + source = partsupp + | join ON ps_suppkey = s_suppkey supplier + | join ON s_nationkey = n_nationkey nation + | where n_name = 'GERMANY' + | stats sum(ps_supplycost * ps_availqty) as check + | eval threshold = check * 0.0001000000 + | fields threshold + ] +| sort - value \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q12.ppl b/integ-test/src/integration/resources/tpch/q12.ppl new file mode 100644 index 000000000..79672d844 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q12.ppl @@ -0,0 +1,42 @@ +/* +select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count +from + orders, + lineitem +where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year +group by + l_shipmode +order by + l_shipmode +*/ + +source = orders +| join ON o_orderkey = l_orderkey lineitem +| where l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_shipmode in ('MAIL', 'SHIP') + and l_receiptdate >= date('1994-01-01') + and l_receiptdate < date_add(date('1994-01-01'), interval 1 year) +| stats sum(case(o_orderpriority = '1-URGENT' or o_orderpriority = '2-HIGH', 1 else 0)) as high_line_count, + sum(case(o_orderpriority != '1-URGENT' and o_orderpriority != '2-HIGH', 1 else 0)) as low_line_countby + by l_shipmode +| sort l_shipmode \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q13.ppl b/integ-test/src/integration/resources/tpch/q13.ppl new file mode 100644 index 000000000..6e77c9b0a --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q13.ppl @@ -0,0 +1,31 @@ +/* +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) as c_count + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders +group by + c_count +order by + custdist desc, + c_count desc +*/ + +source = [ + source = customer + | left outer join ON c_custkey = o_custkey AND not like(o_comment, '%special%requests%') + orders + | stats count(o_orderkey) as c_count by c_custkey + ] as c_orders +| stats count() as custdist by c_count +| sort - custdist, - c_count \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q14.ppl b/integ-test/src/integration/resources/tpch/q14.ppl new file mode 100644 index 000000000..553f1e549 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q14.ppl @@ -0,0 +1,25 @@ +/* +select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from + lineitem, + part +where + l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month +*/ + +source = lineitem +| join ON l_partkey = p_partkey + AND l_shipdate >= date('1995-09-01') + AND l_shipdate < date_add(date('1995-09-01'), interval 1 month) + part +| stats sum(case(like(p_type, 'PROMO%'), l_extendedprice * (1 - l_discount) else 0)) as sum1, + sum(l_extendedprice * (1 - l_discount)) as sum2 +| eval promo_revenue = 100.00 * sum1 / sum2 // Stats and Eval commands can combine when issues/819 resolved +| fields promo_revenue \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q15.ppl b/integ-test/src/integration/resources/tpch/q15.ppl new file mode 100644 index 000000000..96f5ecea2 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q15.ppl @@ -0,0 +1,52 @@ +/* +with revenue0 as + (select + l_suppkey as supplier_no, + sum(l_extendedprice * (1 - l_discount)) as total_revenue + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey) +select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue +from + supplier, + revenue0 +where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) +order by + s_suppkey +*/ + +// CTE is unsupported in PPL +source = supplier +| join right = revenue0 ON s_suppkey = supplier_no [ + source = lineitem + | where l_shipdate >= date('1996-01-01') AND l_shipdate < date_add(date('1996-01-01'), interval 3 month) + | eval supplier_no = l_suppkey + | stats sum(l_extendedprice * (1 - l_discount)) as total_revenue by supplier_no + ] +| where total_revenue = [ + source = [ + source = lineitem + | where l_shipdate >= date('1996-01-01') AND l_shipdate < date_add(date('1996-01-01'), interval 3 month) + | eval supplier_no = l_suppkey + | stats sum(l_extendedprice * (1 - l_discount)) as total_revenue by supplier_no + ] + | stats max(total_revenue) + ] +| sort s_suppkey +| fields s_suppkey, s_name, s_address, s_phone, total_revenue diff --git a/integ-test/src/integration/resources/tpch/q16.ppl b/integ-test/src/integration/resources/tpch/q16.ppl new file mode 100644 index 000000000..4c5765f04 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q16.ppl @@ -0,0 +1,45 @@ +/* +select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt +from + partsupp, + part +where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) +group by + p_brand, + p_type, + p_size +order by + supplier_cnt desc, + p_brand, + p_type, + p_size +*/ + +source = partsupp +| join ON p_partkey = ps_partkey part +| where p_brand != 'Brand#45' + and not like(p_type, 'MEDIUM POLISHED%') + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in [ + source = supplier + | where like(s_comment, '%Customer%Complaints%') + | fields s_suppkey + ] +| stats distinct_count(ps_suppkey) as supplier_cnt by p_brand, p_type, p_size +| sort - supplier_cnt, p_brand, p_type, p_size \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q17.ppl b/integ-test/src/integration/resources/tpch/q17.ppl new file mode 100644 index 000000000..994b7ee18 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q17.ppl @@ -0,0 +1,34 @@ +/* +select + sum(l_extendedprice) / 7.0 as avg_yearly +from + lineitem, + part +where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ) +*/ + +source = lineitem +| join ON p_partkey = l_partkey part +| where p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < [ + source = lineitem + | where l_partkey = p_partkey + | stats avg(l_quantity) as avg + | eval `0.2 * avg` = 0.2 * avg // Stats and Eval commands can combine when issues/819 resolved + | fields `0.2 * avg` + ] +| stats sum(l_extendedprice) as sum +| eval avg_yearly = sum / 7.0 // Stats and Eval commands can combine when issues/819 resolved +| fields avg_yearly \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q18.ppl b/integ-test/src/integration/resources/tpch/q18.ppl new file mode 100644 index 000000000..1dab3d473 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q18.ppl @@ -0,0 +1,48 @@ +/* +select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +from + customer, + orders, + lineitem +where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey +group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +order by + o_totalprice desc, + o_orderdate +limit 100 +*/ + +source = customer +| join ON c_custkey = o_custkey orders +| join ON o_orderkey = l_orderkey lineitem +| where o_orderkey in [ + source = lineitem + | stats sum(l_quantity) as sum by l_orderkey + | where sum > 300 + | fields l_orderkey + ] +| stats sum(l_quantity) by c_name, c_custkey, o_orderkey, o_orderdate, o_totalprice +| sort - o_totalprice, o_orderdate +| head 100 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q19.ppl b/integ-test/src/integration/resources/tpch/q19.ppl new file mode 100644 index 000000000..630d63bcc --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q19.ppl @@ -0,0 +1,61 @@ +/* +select + sum(l_extendedprice* (1 - l_discount)) as revenue +from + lineitem, + part +where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) +*/ + +source = lineitem +| join ON p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + OR p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + OR p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + part \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q2.ppl b/integ-test/src/integration/resources/tpch/q2.ppl new file mode 100644 index 000000000..aa95d9d14 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q2.ppl @@ -0,0 +1,62 @@ +/* +select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment +from + part, + supplier, + partsupp, + nation, + region +where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + ) +order by + s_acctbal desc, + n_name, + s_name, + p_partkey +limit 100 +*/ + +source = part +| join ON p_partkey = ps_partkey partsupp +| join ON s_suppkey = ps_suppkey supplier +| join ON s_nationkey = n_nationkey nation +| join ON n_regionkey = r_regionkey region +| where p_size = 15 AND like(p_type, '%BRASS') AND r_name = 'EUROPE' AND ps_supplycost = [ + source = partsupp + | join ON s_suppkey = ps_suppkey supplier + | join ON s_nationkey = n_nationkey nation + | join ON n_regionkey = r_regionkey region + | where r_name = 'EUROPE' + | stats MIN(ps_supplycost) + ] +| sort - s_acctbal, n_name, s_name, p_partkey +| head 100 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q20.ppl b/integ-test/src/integration/resources/tpch/q20.ppl new file mode 100644 index 000000000..08bd21277 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q20.ppl @@ -0,0 +1,62 @@ +/* +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name +*/ + +source = supplier +| join ON s_nationkey = n_nationkey nation +| where n_name = 'CANADA' + and s_suppkey in [ + source = partsupp + | where ps_partkey in [ + source = part + | where like(p_name, 'forest%') + | fields p_partkey + ] + and ps_availqty > [ + source = lineitem + | where l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date('1994-01-01') + and l_shipdate < date_add(date('1994-01-01'), interval 1 year) + | stats sum(l_quantity) as sum_l_quantity + | eval half_sum_l_quantity = 0.5 * sum_l_quantity // Stats and Eval commands can combine when issues/819 resolved + | fields half_sum_l_quantity + ] + | fields ps_suppkey + ] \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q21.ppl b/integ-test/src/integration/resources/tpch/q21.ppl new file mode 100644 index 000000000..0eb7149f6 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q21.ppl @@ -0,0 +1,64 @@ +/* +select + s_name, + count(*) as numwait +from + supplier, + lineitem l1, + orders, + nation +where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' +group by + s_name +order by + numwait desc, + s_name +limit 100 +*/ + +source = supplier +| join ON s_suppkey = l1.l_suppkey lineitem as l1 +| join ON o_orderkey = l1.l_orderkey orders +| join ON s_nationkey = n_nationkey nation +| where o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists [ + source = lineitem as l2 + | where l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey != l1.l_suppkey + ] + and not exists [ + source = lineitem as l3 + | where l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey != l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ] + and n_name = 'SAUDI ARABIA' +| stats count() as numwait by s_name +| sort - numwait, s_name +| head 100 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q22.ppl b/integ-test/src/integration/resources/tpch/q22.ppl new file mode 100644 index 000000000..811308cb0 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q22.ppl @@ -0,0 +1,58 @@ +/* +select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + substring(c_phone, 1, 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode +*/ + +source = [ + source = customer + | where substring(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > [ + source = customer + | where c_acctbal > 0.00 + and substring(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + | stats avg(c_acctbal) + ] + and not exists [ + source = orders + | where o_custkey = c_custkey + ] + | eval cntrycode = substring(c_phone, 1, 2) + | fields cntrycode, c_acctbal + ] as custsale +| stats count() as numcust, sum(c_acctbal) as totacctbal by cntrycode +| sort cntrycode \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q3.ppl b/integ-test/src/integration/resources/tpch/q3.ppl new file mode 100644 index 000000000..0ece358ab --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q3.ppl @@ -0,0 +1,33 @@ +/* +select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority +from + customer, + orders, + lineitem +where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' +group by + l_orderkey, + o_orderdate, + o_shippriority +order by + revenue desc, + o_orderdate +limit 10 +*/ + +source = customer +| join ON c_custkey = o_custkey orders +| join ON l_orderkey = o_orderkey lineitem +| where c_mktsegment = 'BUILDING' AND o_orderdate < date('1995-03-15') AND l_shipdate > date('1995-03-15') +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by l_orderkey, o_orderdate, o_shippriority +| sort - revenue, o_orderdate +| head 10 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q4.ppl b/integ-test/src/integration/resources/tpch/q4.ppl new file mode 100644 index 000000000..cc01bda7d --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q4.ppl @@ -0,0 +1,33 @@ +/* +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority +*/ + +source = orders +| where o_orderdate >= date('1993-07-01') + and o_orderdate < date_add(date('1993-07-01'), interval 3 month) + and exists [ + source = lineitem + | where l_orderkey = o_orderkey and l_commitdate < l_receiptdate + ] +| stats count() as order_count by o_orderpriority +| sort o_orderpriority \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q5.ppl b/integ-test/src/integration/resources/tpch/q5.ppl new file mode 100644 index 000000000..4761b0365 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q5.ppl @@ -0,0 +1,36 @@ +/* +select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue +from + customer, + orders, + lineitem, + supplier, + nation, + region +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year +group by + n_name +order by + revenue desc +*/ + +source = customer +| join ON c_custkey = o_custkey orders +| join ON l_orderkey = o_orderkey lineitem +| join ON l_suppkey = s_suppkey AND c_nationkey = s_nationkey supplier +| join ON s_nationkey = n_nationkey nation +| join ON n_regionkey = r_regionkey region +| where r_name = 'ASIA' AND o_orderdate >= date('1994-01-01') AND o_orderdate < date_add(date('1994-01-01'), interval 1 year) +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by n_name +| sort - revenue \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q6.ppl b/integ-test/src/integration/resources/tpch/q6.ppl new file mode 100644 index 000000000..6a77877c3 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q6.ppl @@ -0,0 +1,18 @@ +/* +select + sum(l_extendedprice * l_discount) as revenue +from + lineitem +where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between .06 - 0.01 and .06 + 0.01 + and l_quantity < 24 +*/ + +source = lineitem +| where l_shipdate >= date('1994-01-01') + and l_shipdate < adddate(date('1994-01-01'), 365) + and l_discount between .06 - 0.01 and .06 + 0.01 + and l_quantity < 24 +| stats sum(l_extendedprice * l_discount) as revenue \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q7.ppl b/integ-test/src/integration/resources/tpch/q7.ppl new file mode 100644 index 000000000..ceda602b3 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q7.ppl @@ -0,0 +1,56 @@ +/* +select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue +from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + year(l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping +group by + supp_nation, + cust_nation, + l_year +order by + supp_nation, + cust_nation, + l_year +*/ + +source = [ + source = supplier + | join ON s_suppkey = l_suppkey lineitem + | join ON o_orderkey = l_orderkey orders + | join ON c_custkey = o_custkey customer + | join ON s_nationkey = n1.n_nationkey nation as n1 + | join ON c_nationkey = n2.n_nationkey nation as n2 + | where l_shipdate between date('1995-01-01') and date('1996-12-31') + and n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY' or n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE' + | eval supp_nation = n1.n_name, cust_nation = n2.n_name, l_year = year(l_shipdate), volume = l_extendedprice * (1 - l_discount) + | fields supp_nation, cust_nation, l_year, volume + ] as shipping +| stats sum(volume) as revenue by supp_nation, cust_nation, l_year +| sort supp_nation, cust_nation, l_year \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q8.ppl b/integ-test/src/integration/resources/tpch/q8.ppl new file mode 100644 index 000000000..a73c7f7c3 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q8.ppl @@ -0,0 +1,60 @@ +/* +select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share +from + ( + select + year(o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations +group by + o_year +order by + o_year +*/ + +source = [ + source = part + | join ON p_partkey = l_partkey lineitem + | join ON s_suppkey = l_suppkey supplier + | join ON l_orderkey = o_orderkey orders + | join ON o_custkey = c_custkey customer + | join ON c_nationkey = n1.n_nationkey nation as n1 + | join ON s_nationkey = n2.n_nationkey nation as n2 + | join ON n1.n_regionkey = r_regionkey region + | where r_name = 'AMERICA' AND p_type = 'ECONOMY ANODIZED STEEL' + and o_orderdate between date('1995-01-01') and date('1996-12-31') + | eval o_year = year(o_orderdate) + | eval volume = l_extendedprice * (1 - l_discount) + | eval nation = n2.n_name + | fields o_year, volume, nation + ] as all_nations +| stats sum(case(nation = 'BRAZIL', volume else 0)) as sum_case, sum(volume) as sum_volume by o_year +| eval mkt_share = sum_case / sum_volume +| fields mkt_share, o_year +| sort o_year \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q9.ppl b/integ-test/src/integration/resources/tpch/q9.ppl new file mode 100644 index 000000000..7692afd74 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q9.ppl @@ -0,0 +1,50 @@ +/* +select + nation, + o_year, + sum(amount) as sum_profit +from + ( + select + n_name as nation, + year(o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit +group by + nation, + o_year +order by + nation, + o_year desc +*/ + +source = [ + source = part + | join ON p_partkey = l_partkey lineitem + | join ON s_suppkey = l_suppkey supplier + | join ON ps_partkey = l_partkey and ps_suppkey = l_suppkey partsupp + | join ON o_orderkey = l_orderkey orders + | join ON s_nationkey = n_nationkey nation + | where like(p_name, '%green%') + | eval nation = n_name + | eval o_year = year(o_orderdate) + | eval amount = l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity + | fields nation, o_year, amount + ] as profit +| stats sum(amount) as sum_profit by nation, o_year +| sort nation, - o_year \ No newline at end of file diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryBase.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryBase.scala new file mode 100644 index 000000000..fb14210e9 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryBase.scala @@ -0,0 +1,177 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl.tpch + +import org.opensearch.flint.spark.ppl.FlintPPLSuite + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_SECOND +import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.internal.SQLConf + +trait TPCHQueryBase extends FlintPPLSuite { + + override protected def sparkConf: SparkConf = { + super.sparkConf.set(SQLConf.MAX_TO_STRING_FIELDS.key, Int.MaxValue.toString) + } + + override def beforeAll(): Unit = { + super.beforeAll() + RuleExecutor.resetMetrics() + CodeGenerator.resetCompileTime() + WholeStageCodegenExec.resetCodeGenTime() + tpchCreateTable.values.foreach { ppl => + sql(ppl) + } + } + + override def afterAll(): Unit = { + try { + tpchCreateTable.keys.foreach { tableName => + spark.sessionState.catalog.dropTable(TableIdentifier(tableName), true, true) + } + // For debugging dump some statistics about how much time was spent in various optimizer rules + // code generation, and compilation. + logWarning(RuleExecutor.dumpTimeSpent()) + val codeGenTime = WholeStageCodegenExec.codeGenTime.toDouble / NANOS_PER_SECOND + val compileTime = CodeGenerator.compileTime.toDouble / NANOS_PER_SECOND + val codegenInfo = + s""" + |=== Metrics of Whole-stage Codegen === + |Total code generation time: $codeGenTime seconds + |Total compile time: $compileTime seconds + """.stripMargin + logWarning(codegenInfo) + spark.sessionState.catalog.reset() + } finally { + super.afterAll() + } + } + + def checkGeneratedCode(plan: SparkPlan, checkMethodCodeSize: Boolean = true): Unit = { + val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() + + def findSubtrees(plan: SparkPlan): Unit = { + plan foreach { + case s: WholeStageCodegenExec => + codegenSubtrees += s + case s => + s.subqueries.foreach(findSubtrees) + } + } + + findSubtrees(plan) + codegenSubtrees.toSeq.foreach { subtree => + val code = subtree.doCodeGen()._2 + val (_, ByteCodeStats(maxMethodCodeSize, _, _)) = + try { + // Just check the generated code can be properly compiled + CodeGenerator.compile(code) + } catch { + case e: Exception => + val msg = + s""" + |failed to compile: + |Subtree: + |$subtree + |Generated code: + |${CodeFormatter.format(code)} + """.stripMargin + throw new Exception(msg, e) + } + + assert( + !checkMethodCodeSize || + maxMethodCodeSize <= CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT, + s"too long generated codes found in the WholeStageCodegenExec subtree (id=${subtree.id}) " + + s"and JIT optimization might not work:\n${subtree.treeString}") + } + } + + val tpchCreateTable = Map( + "orders" -> + """ + |CREATE TABLE `orders` ( + |`o_orderkey` BIGINT, `o_custkey` BIGINT, `o_orderstatus` STRING, + |`o_totalprice` DECIMAL(10,0), `o_orderdate` DATE, `o_orderpriority` STRING, + |`o_clerk` STRING, `o_shippriority` INT, `o_comment` STRING) + |USING parquet + """.stripMargin, + "nation" -> + """ + |CREATE TABLE `nation` ( + |`n_nationkey` BIGINT, `n_name` STRING, `n_regionkey` BIGINT, `n_comment` STRING) + |USING parquet + """.stripMargin, + "region" -> + """ + |CREATE TABLE `region` ( + |`r_regionkey` BIGINT, `r_name` STRING, `r_comment` STRING) + |USING parquet + """.stripMargin, + "part" -> + """ + |CREATE TABLE `part` (`p_partkey` BIGINT, `p_name` STRING, `p_mfgr` STRING, + |`p_brand` STRING, `p_type` STRING, `p_size` INT, `p_container` STRING, + |`p_retailprice` DECIMAL(10,0), `p_comment` STRING) + |USING parquet + """.stripMargin, + "partsupp" -> + """ + |CREATE TABLE `partsupp` (`ps_partkey` BIGINT, `ps_suppkey` BIGINT, + |`ps_availqty` INT, `ps_supplycost` DECIMAL(10,0), `ps_comment` STRING) + |USING parquet + """.stripMargin, + "customer" -> + """ + |CREATE TABLE `customer` (`c_custkey` BIGINT, `c_name` STRING, `c_address` STRING, + |`c_nationkey` BIGINT, `c_phone` STRING, `c_acctbal` DECIMAL(10,0), + |`c_mktsegment` STRING, `c_comment` STRING) + |USING parquet + """.stripMargin, + "supplier" -> + """ + |CREATE TABLE `supplier` (`s_suppkey` BIGINT, `s_name` STRING, `s_address` STRING, + |`s_nationkey` BIGINT, `s_phone` STRING, `s_acctbal` DECIMAL(10,0), `s_comment` STRING) + |USING parquet + """.stripMargin, + "lineitem" -> + """ + |CREATE TABLE `lineitem` (`l_orderkey` BIGINT, `l_partkey` BIGINT, `l_suppkey` BIGINT, + |`l_linenumber` INT, `l_quantity` DECIMAL(10,0), `l_extendedprice` DECIMAL(10,0), + |`l_discount` DECIMAL(10,0), `l_tax` DECIMAL(10,0), `l_returnflag` STRING, + |`l_linestatus` STRING, `l_shipdate` DATE, `l_commitdate` DATE, `l_receiptdate` DATE, + |`l_shipinstruct` STRING, `l_shipmode` STRING, `l_comment` STRING) + |USING parquet + """.stripMargin) + + val tpchQueries = Seq( + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "q16", + "q17", + "q18", + "q19", + "q20", + "q21", + "q22") +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryITSuite.scala new file mode 100644 index 000000000..1b9681618 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryITSuite.scala @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl.tpch + +import org.opensearch.flint.spark.ppl.LogicalPlanTestUtils + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.sql.streaming.StreamTest + +class TPCHQueryITSuite + extends QueryTest + with LogicalPlanTestUtils + with TPCHQueryBase + with StreamTest { + + override def beforeAll(): Unit = { + super.beforeAll() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + tpchQueries.foreach { name => + val queryString = resourceToString( + s"tpch/$name.ppl", + classLoader = Thread.currentThread().getContextClassLoader) + test(name) { + // check the plans can be properly generated + val plan = sql(queryString).queryExecution.executedPlan + checkGeneratedCode(plan) + } + } +} From 4303057aad2c0edd0ae2c75ef48bee81cd4bb7af Mon Sep 17 00:00:00 2001 From: YANGDB Date: Wed, 6 Nov 2024 18:25:25 -0800 Subject: [PATCH 4/4] Expand ppl command (#868) * add expand command Signed-off-by: YANGDB * add expand command with visitor Signed-off-by: YANGDB * create unit / integration tests Signed-off-by: YANGDB * update expand tests Signed-off-by: YANGDB * add tests Signed-off-by: YANGDB * update doc Signed-off-by: YANGDB * update docs with examples Signed-off-by: YANGDB * update scala style Signed-off-by: YANGDB * update with additional test case remove outer generator Signed-off-by: YANGDB * update with additional test case remove outer generator Signed-off-by: YANGDB * update documentation Signed-off-by: YANGDB --------- Signed-off-by: YANGDB --- docs/ppl-lang/PPL-Example-Commands.md | 37 ++- docs/ppl-lang/README.md | 2 + docs/ppl-lang/ppl-expand-command.md | 45 +++ .../flint/spark/FlintSparkSuite.scala | 22 ++ .../ppl/FlintSparkPPLExpandITSuite.scala | 255 ++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 1 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 6 + .../sql/ast/AbstractNodeVisitor.java | 4 + .../org/opensearch/sql/ast/tree/Expand.java | 44 +++ .../org/opensearch/sql/ast/tree/Flatten.java | 4 +- .../sql/ppl/CatalystQueryPlanVisitor.java | 25 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 6 + ...PlanExpandCommandTranslatorTestSuite.scala | 281 ++++++++++++++++++ 13 files changed, 716 insertions(+), 16 deletions(-) create mode 100644 docs/ppl-lang/ppl-expand-command.md create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index e80f8c906..4ea564111 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -441,8 +441,30 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in _- **Limitation: another command usage of (relation) subquery is in `appendcols` commands which is unsupported**_ ---- -#### Experimental Commands: + +#### **fillnull** +[See additional command details](ppl-fillnull-command.md) +```sql + - `source=accounts | fillnull fields status_code=101` + - `source=accounts | fillnull fields request_path='/not_found', timestamp='*'` + - `source=accounts | fillnull using field1=101` + - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5` + - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5, field6 = 'N/A'` +``` + +#### **expand** +[See additional command details](ppl-expand-command.md) +```sql + - `source = table | expand field_with_array as array_list` + - `source = table | expand employee | stats max(salary) as max by state, company` + - `source = table | expand employee as worker | stats max(salary) as max by state, company` + - `source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus` + - `source = table | expand employee | parse description '(?.+@.+)' | fields employee, email` + - `source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid` + - `source = table | expand multi_valueA as multiA | expand multi_valueB as multiB` +``` + +#### Correlation Commands: [See additional command details](ppl-correlation-command.md) ```sql @@ -454,14 +476,3 @@ _- **Limitation: another command usage of (relation) subquery is in `appendcols` > ppl-correlation-command is an experimental command - it may be removed in future versions --- -### Planned Commands: - -#### **fillnull** -[See additional command details](ppl-fillnull-command.md) -```sql - - `source=accounts | fillnull fields status_code=101` - - `source=accounts | fillnull fields request_path='/not_found', timestamp='*'` - - `source=accounts | fillnull using field1=101` - - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5` - - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5, field6 = 'N/A'` -``` diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index ef186e5f2..d72c973be 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -71,6 +71,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`correlation commands`](ppl-correlation-command.md) - [`trendline commands`](ppl-trendline-command.md) + + - [`expand commands`](ppl-expand-command.md) * **Functions** diff --git a/docs/ppl-lang/ppl-expand-command.md b/docs/ppl-lang/ppl-expand-command.md new file mode 100644 index 000000000..144c0aafa --- /dev/null +++ b/docs/ppl-lang/ppl-expand-command.md @@ -0,0 +1,45 @@ +## PPL `expand` command + +### Description +Using `expand` command to flatten a field of type: +- `Array` +- `Map` + + +### Syntax +`expand [As alias]` + +* field: to be expanded (exploded). The field must be of supported type. +* alias: Optional to be expanded as the name to be used instead of the original field name + +### Usage Guidelines +The expand command produces a row for each element in the specified array or map field, where: +- Array elements become individual rows. +- Map key-value pairs are broken into separate rows, with each key-value represented as a row. + +- When an alias is provided, the exploded values are represented under the alias instead of the original field name. +- This can be used in combination with other commands, such as stats, eval, and parse to manipulate or extract data post-expansion. + +### Examples: +- `source = table | expand employee | stats max(salary) as max by state, company` +- `source = table | expand employee as worker | stats max(salary) as max by state, company` +- `source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus` +- `source = table | expand employee | parse description '(?.+@.+)' | fields employee, email` +- `source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid` +- `source = table | expand multi_valueA as multiA | expand multi_valueB as multiB` + +- Expand command can be used in combination with other commands such as `eval`, `stats` and more +- Using multiple expand commands will create a cartesian product of all the internal elements within each composite array or map + +### Effective SQL push-down query +The expand command is translated into an equivalent SQL operation using LATERAL VIEW explode, allowing for efficient exploding of arrays or maps at the SQL query level. + +```sql +SELECT customer exploded_productId +FROM table +LATERAL VIEW explode(productId) AS exploded_productId +``` +Where the `explode` command offers the following functionality: +- it is a column operation that returns a new column +- it creates a new row for every element in the exploded column +- internal `null`s are ignored as part of the exploded field (no row is created/exploded for null) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index c53eee548..68d370791 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -559,6 +559,28 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit |""".stripMargin) } + protected def createMultiColumnArrayTable(testTable: String): Unit = { + // CSV doesn't support struct field + sql(s""" + | CREATE TABLE $testTable + | ( + | int_col INT, + | multi_valueA Array>, + | multi_valueB Array> + | ) + | USING JSON + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES + | ( 1, array(STRUCT("1_one", 1), STRUCT(null, 11), STRUCT("1_three", null)), array(STRUCT("2_Monday", 2), null) ), + | ( 2, array(STRUCT("2_Monday", 2), null) , array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) ), + | ( 3, array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) , array(STRUCT("1_one", 1))), + | ( 4, null, array(STRUCT("1_one", 1))) + |""".stripMargin) + } + protected def createTableIssue112(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable ( diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala new file mode 100644 index 000000000..f0404bf7b --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala @@ -0,0 +1,255 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import java.nio.file.Files + +import scala.collection.mutable + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Explode, GeneratorOuter, Literal, Or} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLExpandITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "flint_ppl_test" + private val occupationTable = "spark_catalog.default.flint_ppl_flat_table_test" + private val structNestedTable = "spark_catalog.default.flint_ppl_struct_nested_test" + private val structTable = "spark_catalog.default.flint_ppl_struct_test" + private val multiValueTable = "spark_catalog.default.flint_ppl_multi_value_test" + private val multiArraysTable = "spark_catalog.default.flint_ppl_multi_array_test" + private val tempFile = Files.createTempFile("jsonTestData", ".json") + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNestedJsonContentTable(tempFile, testTable) + createOccupationTable(occupationTable) + createStructNestedTable(structNestedTable) + createStructTable(structTable) + createMultiValueStructTable(multiValueTable) + createMultiColumnArrayTable(multiArraysTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + override def afterAll(): Unit = { + super.afterAll() + Files.deleteIfExists(tempFile) + } + + test("expand for eval field of an array") { + val frame = sql( + s""" source = $occupationTable | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid + """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Jake", "Engineer", 1), + Row("Jake", "Engineer", 2), + Row("Jake", "Engineer", 3), + Row("Hello", "Artist", 1), + Row("Hello", "Artist", 2), + Row("Hello", "Artist", 3), + Row("John", "Doctor", 1), + Row("John", "Doctor", 2), + Row("John", "Doctor", 3), + Row("David", "Doctor", 1), + Row("David", "Doctor", 2), + Row("David", "Doctor", 3), + Row("David", "Unemployed", 1), + Row("David", "Unemployed", 2), + Row("David", "Unemployed", 3), + Row("Jane", "Scientist", 1), + Row("Jane", "Scientist", 2), + Row("Jane", "Scientist", 3)) + + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // expected plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_flat_table_test")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "array")() + val project = Project(seq(UnresolvedStar(None), aliasA), table) + val generate = Generate( + Explode(UnresolvedAttribute("array")), + seq(), + false, + None, + seq(UnresolvedAttribute("uid")), + project) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("array")), generate) + val expectedPlan = Project( + seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("occupation"), + UnresolvedAttribute("uid")), + dropSourceColumn) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("expand for structs") { + val frame = sql( + s""" source = $multiValueTable | expand multi_value AS exploded_multi_value | fields exploded_multi_value + """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(Row("1_one", 1)), + Row(Row(null, 11)), + Row(Row("1_three", null)), + Row(Row("2_Monday", 2)), + Row(null), + Row(Row("3_third", 3)), + Row(Row("3_4th", 4)), + Row(null)) + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // expected plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) + val generate = Generate( + Explode(UnresolvedAttribute("multi_value")), + seq(), + outer = false, + None, + seq(UnresolvedAttribute("exploded_multi_value")), + table) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_value")), generate) + val expectedPlan = Project(Seq(UnresolvedAttribute("exploded_multi_value")), dropSourceColumn) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("expand for array of structs") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' or country = 'Poland' + | | expand bridges + | | fields bridges + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(mutable.WrappedArray.make(Array(Row(801, "Tower Bridge"), Row(928, "London Bridge")))), + Row(mutable.WrappedArray.make(Array(Row(801, "Tower Bridge"), Row(928, "London Bridge")))) + // Row(null)) -> in case of outerGenerator = GeneratorOuter(Explode(UnresolvedAttribute("bridges"))) it will include the `null` row + ) + + // Compare the results + assert(results.toSet == expectedResults.toSet) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter( + Or( + EqualTo(UnresolvedAttribute("country"), Literal("England")), + EqualTo(UnresolvedAttribute("country"), Literal("Poland"))), + table) + val generate = + Generate(Explode(UnresolvedAttribute("bridges")), seq(), outer = false, None, seq(), filter) + val expectedPlan = Project(Seq(UnresolvedAttribute("bridges")), generate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("expand for array of structs with alias") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' + | | expand bridges as britishBridges + | | fields britishBridges + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(Row(801, "Tower Bridge")), + Row(Row(928, "London Bridge")), + Row(Row(801, "Tower Bridge")), + Row(Row(928, "London Bridge"))) + // Compare the results + assert(results.toSet == expectedResults.toSet) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter(EqualTo(UnresolvedAttribute("country"), Literal("England")), table) + val generate = Generate( + Explode(UnresolvedAttribute("bridges")), + seq(), + outer = false, + None, + seq(UnresolvedAttribute("britishBridges")), + filter) + val dropColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("bridges")), generate) + val expectedPlan = Project(Seq(UnresolvedAttribute("britishBridges")), dropColumn) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("expand multi columns array table") { + val frame = sql(s""" + | source = $multiArraysTable + | | expand multi_valueA as multiA + | | expand multi_valueB as multiB + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1, Row("1_one", 1), Row("2_Monday", 2)), + Row(1, Row("1_one", 1), null), + Row(1, Row(null, 11), Row("2_Monday", 2)), + Row(1, Row(null, 11), null), + Row(1, Row("1_three", null), Row("2_Monday", 2)), + Row(1, Row("1_three", null), null), + Row(2, Row("2_Monday", 2), Row("3_third", 3)), + Row(2, Row("2_Monday", 2), Row("3_4th", 4)), + Row(2, null, Row("3_third", 3)), + Row(2, null, Row("3_4th", 4)), + Row(3, Row("3_third", 3), Row("1_one", 1)), + Row(3, Row("3_4th", 4), Row("1_one", 1))) + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_array_test")) + val generatorA = Explode(UnresolvedAttribute("multi_valueA")) + val generateA = + Generate(generatorA, seq(), false, None, seq(UnresolvedAttribute("multiA")), table) + val dropSourceColumnA = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueA")), generateA) + val generatorB = Explode(UnresolvedAttribute("multi_valueB")) + val generateB = Generate( + generatorB, + seq(), + false, + None, + seq(UnresolvedAttribute("multiB")), + dropSourceColumnA) + val dropSourceColumnB = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueB")), generateB) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumnB) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 93efb2df1..2c3344b3c 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -37,6 +37,7 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; +EXPAND: 'EXPAND'; FLATTEN: 'FLATTEN'; TRENDLINE: 'TRENDLINE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 123d1e15a..1cfd172f7 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -54,6 +54,7 @@ commands | fillnullCommand | fieldsummaryCommand | flattenCommand + | expandCommand | trendlineCommand ; @@ -82,6 +83,7 @@ commandName | PATTERNS | LOOKUP | RENAME + | EXPAND | FILLNULL | FIELDSUMMARY | FLATTEN @@ -250,6 +252,10 @@ fillnullCommand : expression ; +expandCommand + : EXPAND fieldExpression (AS alias = qualifiedName)? + ; + flattenCommand : FLATTEN fieldExpression ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 189d9084a..54e1205cb 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -108,6 +108,10 @@ public T visitFilter(Filter node, C context) { return visitChildren(node, context); } + public T visitExpand(Expand node, C context) { + return visitChildren(node, context); + } + public T visitLookup(Lookup node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java new file mode 100644 index 000000000..0e164ccd7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedAttribute; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; +import java.util.Optional; + +/** Logical plan node of Expand */ +@RequiredArgsConstructor +public class Expand extends UnresolvedPlan { + private UnresolvedPlan child; + + @Getter + private final Field field; + @Getter + private final Optional alias; + + @Override + public Expand attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitExpand(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java index e31fbb6e3..9c57d2adf 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java @@ -14,7 +14,7 @@ public class Flatten extends UnresolvedPlan { private UnresolvedPlan child; @Getter - private final Field fieldToBeFlattened; + private final Field field; @Override public UnresolvedPlan attach(UnresolvedPlan child) { @@ -26,7 +26,7 @@ public UnresolvedPlan attach(UnresolvedPlan child) { public List getChild() { return child == null ? List.of() : List.of(child); } - + @Override public T accept(AbstractNodeVisitor nodeVisitor, C context) { return nodeVisitor.visitFlatten(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index a43378480..d2ee46ae6 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -11,6 +11,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Ascending$; import org.apache.spark.sql.catalyst.expressions.Descending$; +import org.apache.spark.sql.catalyst.expressions.Explode; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; import org.apache.spark.sql.catalyst.expressions.In$; @@ -93,6 +94,7 @@ import static java.util.Collections.emptyList; import static java.util.List.of; +import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEvents; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEventsAndKeepEmpty; @@ -460,13 +462,34 @@ public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); } - Expression field = visitExpression(flatten.getFieldToBeFlattened(), context); + Expression field = visitExpression(flatten.getField(), context); context.retainAllNamedParseExpressions(p -> (NamedExpression) p); FlattenGenerator flattenGenerator = new FlattenGenerator(field); context.apply(p -> new Generate(new GeneratorOuter(flattenGenerator), seq(), true, (Option) None$.MODULE$, seq(), p)); return context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(field), logicalPlan)); } + @Override + public LogicalPlan visitExpand(org.opensearch.sql.ast.tree.Expand node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + Expression field = visitExpression(node.getField(), context); + Optional alias = node.getAlias().map(aliasNode -> visitExpression(aliasNode, context)); + context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + Explode explodeGenerator = new Explode(field); + scala.collection.mutable.Seq outputs = alias.isEmpty() ? seq() : seq(alias.get()); + if(alias.isEmpty()) + return context.apply(p -> new Generate(explodeGenerator, seq(), false, (Option) None$.MODULE$, outputs, p)); + else { + //in case an alias does appear - remove the original field from the returning columns + context.apply(p -> new Generate(explodeGenerator, seq(), false, (Option) None$.MODULE$, outputs, p)); + return context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(field), logicalPlan)); + } + } + private void visitFieldList(List fieldList, CatalystPlanContext context) { fieldList.forEach(field -> visitExpression(field, context)); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 36a34cd06..f6581016f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -131,6 +131,12 @@ public UnresolvedPlan visitWhereCommand(OpenSearchPPLParser.WhereCommandContext return new Filter(internalVisitExpression(ctx.logicalExpression())); } + @Override + public UnresolvedPlan visitExpandCommand(OpenSearchPPLParser.ExpandCommandContext ctx) { + return new Expand((Field) internalVisitExpression(ctx.fieldExpression()), + ctx.alias!=null ? Optional.of(internalVisitExpression(ctx.alias)) : Optional.empty()); + } + @Override public UnresolvedPlan visitCorrelateCommand(OpenSearchPPLParser.CorrelateCommandContext ctx) { return new Correlation(ctx.correlationType().getText(), diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..2acaac529 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala @@ -0,0 +1,281 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Explode, GeneratorOuter, Literal, RegExpExtract} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Generate, Project} +import org.apache.spark.sql.types.IntegerType + +class PPLLogicalPlanExpandCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test expand only field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=relation | expand field_with_array"), context) + + val relation = UnresolvedRelation(Seq("relation")) + val generator = Explode(UnresolvedAttribute("field_with_array")) + val generate = Generate(generator, seq(), false, None, seq(), relation) + val expectedPlan = Project(seq(UnresolvedStar(None)), generate) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("expand multi columns array table") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + s""" + | source = table + | | expand multi_valueA as multiA + | | expand multi_valueB as multiB + | """.stripMargin), + context) + + val relation = UnresolvedRelation(Seq("table")) + val generatorA = Explode(UnresolvedAttribute("multi_valueA")) + val generateA = + Generate(generatorA, seq(), false, None, seq(UnresolvedAttribute("multiA")), relation) + val dropSourceColumnA = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueA")), generateA) + val generatorB = Explode(UnresolvedAttribute("multi_valueB")) + val generateB = Generate( + generatorB, + seq(), + false, + None, + seq(UnresolvedAttribute("multiB")), + dropSourceColumnA) + val dropSourceColumnB = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueB")), generateB) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumnB) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand on array field which is eval array=json_array") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields uid"), + context) + + val relation = UnresolvedRelation(Seq("table")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "array")() + val project = Project(seq(UnresolvedStar(None), aliasA), relation) + val generate = Generate( + Explode(UnresolvedAttribute("array")), + seq(), + false, + None, + seq(UnresolvedAttribute("uid")), + project) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("array")), generate) + val expectedPlan = Project(seq(UnresolvedAttribute("uid")), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand only field with alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | expand field_with_array as array_list "), + context) + + val relation = UnresolvedRelation(Seq("relation")) + val generate = Generate( + Explode(UnresolvedAttribute("field_with_array")), + seq(), + false, + None, + seq(UnresolvedAttribute("array_list")), + relation) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("field_with_array")), generate) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and stats") { + val context = new CatalystPlanContext + val query = + "source = table | expand employee | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = + Aggregate(Seq(groupingState, groupingCompany), Seq(average, state, company), generate) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and stats with alias") { + val context = new CatalystPlanContext + val query = + "source = table | expand employee as workers | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = Generate( + Explode(UnresolvedAttribute("employee")), + seq(), + false, + None, + seq(UnresolvedAttribute("workers")), + table) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = Aggregate( + Seq(groupingState, groupingCompany), + Seq(average, state, company), + dropSourceColumn) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and eval") { + val context = new CatalystPlanContext + val query = "source = table | expand employee | eval bonus = salary * 3" + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + generate) + val expectedPlan = Project(Seq(UnresolvedStar(None)), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and eval with fields and alias") { + val context = new CatalystPlanContext + val query = + "source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus " + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = Generate( + Explode(UnresolvedAttribute("employee")), + seq(), + false, + None, + seq(UnresolvedAttribute("worker")), + table) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + dropSourceColumn) + val expectedPlan = + Project(Seq(UnresolvedAttribute("worker"), UnresolvedAttribute("bonus")), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and parse and fields") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=table | expand employee | parse description '(?.+@.+)' | fields employee, email"), + context) + val table = UnresolvedRelation(Seq("table")) + val generator = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + generator) + val expectedPlan = + Project(Seq(UnresolvedAttribute("employee"), UnresolvedAttribute("email")), parseProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and parse and flatten ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | expand employee | parse description '(?.+@.+)' | flatten roles "), + context) + val table = UnresolvedRelation(Seq("relation")) + val generateEmployee = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + generateEmployee) + val generateRoles = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("roles"))), + seq(), + true, + None, + seq(), + parseProject) + val dropSourceColumnRoles = + DataFrameDropColumns(Seq(UnresolvedAttribute("roles")), generateRoles) + val expectedPlan = Project(Seq(UnresolvedStar(None)), dropSourceColumnRoles) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + +}