Skip to content

Commit

Permalink
Support database name in table name (#10)
Browse files Browse the repository at this point in the history
* Support schema name

Signed-off-by: Chen Dai <[email protected]>

* Add IT

Signed-off-by: Chen Dai <[email protected]>

* Add more IT

Signed-off-by: Chen Dai <[email protected]>

---------

Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen authored Jul 31, 2023
1 parent 541db1b commit 365188c
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ object FlintSpark {
* Configure which source table the index is based on.
*
* @param tableName
* source table name
* full table name
* @return
* index builder
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan]
Some(table),
false))
if hasNoDisjunction(condition) && !location.isInstanceOf[FlintSparkSkippingFileIndex] =>
val indexName = getSkippingIndexName(table.identifier.table) // TODO: database name
val indexName = getSkippingIndexName(table.identifier.unquotedString)
val index = flint.describeIndex(indexName)
if (index.exists(_.kind == SKIPPING_INDEX_TYPE)) {
val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,13 @@ object FlintSparkSkippingIndex {
* for now.
*
* @param tableName
* source table name
* full table name
* @return
* Flint skipping index name
*/
def getSkippingIndexName(tableName: String): String = s"flint_${tableName}_skipping_index"
def getSkippingIndexName(tableName: String): String = {
require(tableName.contains("."), "Full table name database.table is required")

s"flint_${tableName.replace(".", "_")}_skipping_index"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

package org.opensearch.flint.spark.sql

import org.antlr.v4.runtime.tree.RuleNode
import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.FlintSpark.RefreshMode
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._
Expand All @@ -28,7 +29,7 @@ class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command
// Create skipping index
val indexBuilder = flint
.skippingIndex()
.onTable(ctx.tableName.getText)
.onTable(getFullTableName(flint, ctx.tableName))

ctx.indexColTypeList().indexColType().forEach { colTypeCtx =>
val colName = colTypeCtx.identifier().getText
Expand All @@ -43,7 +44,7 @@ class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command

// Trigger auto refresh if enabled
if (isAutoRefreshEnabled(ctx.propertyList())) {
val indexName = getSkippingIndexName(ctx.tableName.getText)
val indexName = getSkippingIndexName(flint, ctx.tableName)
flint.refreshIndex(indexName, RefreshMode.INCREMENTAL)
}
Seq.empty
Expand All @@ -52,7 +53,7 @@ class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command
override def visitRefreshSkippingIndexStatement(
ctx: RefreshSkippingIndexStatementContext): Command =
FlintSparkSqlCommand() { flint =>
val indexName = getSkippingIndexName(ctx.tableName.getText)
val indexName = getSkippingIndexName(flint, ctx.tableName)
flint.refreshIndex(indexName, RefreshMode.FULL)
Seq.empty
}
Expand All @@ -65,7 +66,7 @@ class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command
AttributeReference("skip_type", StringType, nullable = false)())

FlintSparkSqlCommand(outputSchema) { flint =>
val indexName = getSkippingIndexName(ctx.tableName.getText)
val indexName = getSkippingIndexName(flint, ctx.tableName)
flint
.describeIndex(indexName)
.map { case index: FlintSparkSkippingIndex =>
Expand All @@ -78,8 +79,7 @@ class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command

override def visitDropSkippingIndexStatement(ctx: DropSkippingIndexStatementContext): Command =
FlintSparkSqlCommand() { flint =>
val tableName = ctx.tableName.getText // TODO: handle schema name
val indexName = getSkippingIndexName(tableName)
val indexName = getSkippingIndexName(flint, ctx.tableName)
flint.deleteIndex(indexName)
Seq.empty
}
Expand All @@ -99,6 +99,19 @@ class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command
}
}

private def getSkippingIndexName(flint: FlintSpark, tableNameCtx: RuleNode): String =
FlintSparkSkippingIndex.getSkippingIndexName(getFullTableName(flint, tableNameCtx))

private def getFullTableName(flint: FlintSpark, tableNameCtx: RuleNode): String = {
val tableName = tableNameCtx.getText
if (tableName.contains(".")) {
tableName
} else {
val db = flint.spark.catalog.currentDatabase
s"$db.$tableName"
}
}

override def aggregateResult(aggregate: Command, nextResult: Command): Command =
if (nextResult != null) nextResult else aggregate
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{doAnswer, when}
import org.mockito.invocation.InvocationOnMock
import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind
import org.scalatest.matchers.{Matcher, MatchResult}
import org.scalatest.matchers.should.Matchers
Expand All @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT
class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers {

/** Test table and index */
private val testTable = "apply_skipping_index_test"
private val testTable = "default.apply_skipping_index_test"
private val testIndex = getSkippingIndexName(testTable)
private val testSchema = StructType(
Seq(
Expand Down Expand Up @@ -117,9 +117,9 @@ class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers {
private var relation: LogicalRelation = _
private var plan: LogicalPlan = _

def withSourceTable(name: String, schema: StructType): AssertionHelper = {
def withSourceTable(fullname: String, schema: StructType): AssertionHelper = {
val table = CatalogTable(
identifier = TableIdentifier(name),
identifier = TableIdentifier(fullname.split('.')(1), Some(fullname.split('.')(0))),
tableType = CatalogTableType.EXTERNAL,
storage = CatalogStorageFormat.empty,
schema = null)
Expand All @@ -135,8 +135,11 @@ class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers {
}

def withSkippingIndex(indexName: String, indexCols: String*): AssertionHelper = {
val skippingIndex =
new FlintSparkSkippingIndex(indexName, indexCols.map(FakeSkippingStrategy))
val skippingIndex = mock[FlintSparkSkippingIndex]
when(skippingIndex.kind).thenReturn(SKIPPING_INDEX_TYPE)
when(skippingIndex.name()).thenReturn(indexName)
when(skippingIndex.indexedColumns).thenReturn(indexCols.map(FakeSkippingStrategy))

when(flint.describeIndex(any())).thenReturn(Some(skippingIndex))
this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,30 @@ import org.apache.spark.sql.functions.col
class FlintSparkSkippingIndexSuite extends FlintSuite {

test("get skipping index name") {
val index = new FlintSparkSkippingIndex("test", Seq(mock[FlintSparkSkippingStrategy]))
index.name() shouldBe "flint_test_skipping_index"
val index = new FlintSparkSkippingIndex("default.test", Seq(mock[FlintSparkSkippingStrategy]))
index.name() shouldBe "flint_default_test_skipping_index"
}

test("can build index building job with unique ID column") {
val indexCol = mock[FlintSparkSkippingStrategy]
when(indexCol.outputSchema()).thenReturn(Map("name" -> "string"))
when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("name").expr)))
val index = new FlintSparkSkippingIndex("test", Seq(indexCol))
val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol))

val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age")
val indexDf = index.build(df)
indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN)
}

test("should fail if get index name without full table name") {
assertThrows[IllegalArgumentException] {
FlintSparkSkippingIndex.getSkippingIndexName("test")
}
}

test("should fail if no indexed column given") {
assertThrows[IllegalArgumentException] {
new FlintSparkSkippingIndex("test", Seq.empty)
new FlintSparkSkippingIndex("default.test", Seq.empty)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class FlintSparkSkippingIndexITSuite
}

/** Test table and index name */
private val testTable = "test"
private val testTable = "default.test"
private val testIndex = getSkippingIndexName(testTable)

override def beforeAll(): Unit = {
Expand Down Expand Up @@ -91,7 +91,7 @@ class FlintSparkSkippingIndexITSuite
.addMinMax("age")
.create()

val indexName = s"flint_${testTable}_skipping_index"
val indexName = s"flint_default_test_skipping_index"
val index = flint.describeIndex(indexName)
index shouldBe defined
index.get.metadata().getContent should matchJson(s"""{
Expand All @@ -118,7 +118,7 @@ class FlintSparkSkippingIndexITSuite
| "columnName": "age",
| "columnType": "int"
| }],
| "source": "$testTable"
| "source": "default.test"
| },
| "properties": {
| "year": {
Expand Down Expand Up @@ -152,11 +152,7 @@ class FlintSparkSkippingIndexITSuite
.create()
flint.refreshIndex(testIndex, FULL)

val indexData =
spark.read
.format(FLINT_DATASOURCE)
.options(openSearchOptions)
.load(testIndex)
val indexData = flint.queryIndex(testIndex)
indexData.columns should not contain ID_COLUMN
}

Expand Down Expand Up @@ -313,6 +309,25 @@ class FlintSparkSkippingIndexITSuite
hasIndexFilter(col("MinMax_age_0") <= 25 && col("MinMax_age_1") >= 25))
}

test("should rewrite applicable query with table name without database specified") {
flint
.skippingIndex()
.onTable(testTable)
.addPartitions("year")
.create()

// Table name without database name "default"
val query = sql(s"""
| SELECT name
| FROM test
| WHERE year = 2023
|""".stripMargin)

query.queryExecution.executedPlan should
useFlintSparkSkippingFileIndex(
hasIndexFilter(col("year") === 2023))
}

test("should not rewrite original query if filtering condition has disjunction") {
flint
.skippingIndex()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class FlintSparkSqlITSuite
private lazy val flint: FlintSpark = new FlintSpark(spark)

/** Test table and index name */
private val testTable = "flint_sql_test"
private val testTable = "default.flint_sql_test"
private val testIndex = getSkippingIndexName(testTable)

override def beforeAll(): Unit = {
Expand Down Expand Up @@ -64,17 +64,6 @@ class FlintSparkSqlITSuite
| """.stripMargin)
}

protected override def beforeEach(): Unit = {
super.beforeEach()
flint
.skippingIndex()
.onTable(testTable)
.addPartitions("year")
.addValueSet("name")
.addMinMax("age")
.create()
}

protected override def afterEach(): Unit = {
super.afterEach()
flint.deleteIndex(testIndex)
Expand All @@ -87,7 +76,6 @@ class FlintSparkSqlITSuite
}

test("create skipping index with auto refresh") {
flint.deleteIndex(testIndex)
sql(s"""
| CREATE SKIPPING INDEX ON $testTable
| (
Expand All @@ -111,7 +99,6 @@ class FlintSparkSqlITSuite
}

test("create skipping index with manual refresh") {
flint.deleteIndex(testIndex)
sql(s"""
| CREATE SKIPPING INDEX ON $testTable
| (
Expand All @@ -131,6 +118,14 @@ class FlintSparkSqlITSuite
}

test("describe skipping index") {
flint
.skippingIndex()
.onTable(testTable)
.addPartitions("year")
.addValueSet("name")
.addMinMax("age")
.create()

val result = sql(s"DESC SKIPPING INDEX ON $testTable")

checkAnswer(
Expand All @@ -141,14 +136,41 @@ class FlintSparkSqlITSuite
Row("age", "int", "MIN_MAX")))
}

test("should return empty if no skipping index to describe") {
flint.deleteIndex(testIndex)
test("create skipping index on table without database name") {
sql(s"""
| CREATE SKIPPING INDEX ON flint_sql_test
| (
| year PARTITION,
| name VALUE_SET,
| age MIN_MAX
| )
| """.stripMargin)

flint.describeIndex(testIndex) shouldBe defined
}

test("create skipping index on table in other database") {
sql("CREATE SCHEMA sample")
sql("USE sample")
sql("CREATE TABLE test (name STRING) USING CSV")
sql("CREATE SKIPPING INDEX ON test (name VALUE_SET)")

flint.describeIndex("flint_sample_test_skipping_index") shouldBe defined
}

test("should return empty if no skipping index to describe") {
val result = sql(s"DESC SKIPPING INDEX ON $testTable")

checkAnswer(result, Seq.empty)
}

test("drop skipping index") {
flint
.skippingIndex()
.onTable(testTable)
.addPartitions("year")
.create()

sql(s"DROP SKIPPING INDEX ON $testTable")

flint.describeIndex(testIndex) shouldBe empty
Expand Down

0 comments on commit 365188c

Please sign in to comment.