Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update test file format to support aggregate functions #736

Merged
merged 3 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ repos:
- id: flake8
- repo: local
hooks:
- id: check-substrait-extensions
name: Check Substrait extensions
entry: pytest tests/test_extensions.py::test_read_substrait_extensions
- id: check-substrait-extensions_coverage
name: Check Substrait extensions and test coverage
entry: pytest tests/test_extensions.py::test_substrait_extension_coverage
language: python
pass_filenames: false

11 changes: 11 additions & 0 deletions grammar/FuncTestCaseLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Whitespace : [ \t\n\r]+ -> channel(HIDDEN) ;

TripleHash: '###';
SubstraitScalarTest: 'SUBSTRAIT_SCALAR_TEST';
SubstraitAggregateTest: 'SUBSTRAIT_AGGREGATE_TEST';
SubstraitInclude: 'SUBSTRAIT_INCLUDE';

FormatVersion
Expand All @@ -20,6 +21,7 @@ DescriptionLine
: '# ' ~[\r\n]* '\r'? '\n'
;

Define: 'DEFINE';
ErrorResult: '<!ERROR>';
UndefineResult: '<!UNDEFINED>';
Overflow: 'OVERFLOW';
Expand All @@ -29,6 +31,11 @@ Saturate: 'SATURATE';
Silent: 'SILENT';
TieToEven: 'TIE_TO_EVEN';
NaN: 'NAN';
AcceptNulls: 'ACCEPT_NULLS';
IgnoreNulls: 'IGNORE_NULLS';
NullHandling: 'NULL_HANDLING';
SpacesOnly: 'SPACES_ONLY';
Truncate: 'TRUNCATE';

IntegerLiteral
: [+-]? Int
Expand Down Expand Up @@ -102,3 +109,7 @@ NullLiteral: 'null';
StringLiteral
: '\'' ('\\' . | '\'\'' | ~['\\])* '\''
;

ColumnName
: 'COL' Int
;
187 changes: 150 additions & 37 deletions grammar/FuncTestCaseParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ header
;

version
: TripleHash SubstraitScalarTest Colon FormatVersion
: TripleHash (SubstraitScalarTest | SubstraitAggregateTest) Colon FormatVersion
;

include
Expand All @@ -27,11 +27,12 @@ testGroupDescription
;

testCase
: functionName=Identifier OParen arguments CParen ( OBracket func_options CBracket )? Eq result
: functionName=identifier OParen arguments CParen ( OBracket func_options CBracket )? Eq result
;

testGroup
: testGroupDescription (testCase)+
: testGroupDescription (testCase)+ #scalarFuncTestGroup
| testGroupDescription (aggFuncTestCase)+ #aggregateFuncTestGroup
;

arguments
Expand All @@ -56,6 +57,64 @@ argument
| timestampTzArg
| intervalYearArg
| intervalDayArg
| listArg
;

aggFuncTestCase
: aggFuncCall ( OBracket func_options CBracket )? Eq result
;

aggFuncCall
: tableData funcName=identifier OParen qualifiedAggregateFuncArgs CParen #multiArgAggregateFuncCall
| tableRows functName=identifier OParen aggregateFuncArgs CParen #compactAggregateFuncCall
| functName=identifier OParen dataColumn CParen #singleArgAggregateFuncCall
;

tableData
: Define tableName=Identifier OParen dataType (Comma dataType)* CParen Eq tableRows
;

tableRows
: OParen (columnValues (Comma columnValues)*)? CParen
;

dataColumn
: columnValues DoubleColon dataType
;

columnValues
: OParen (literal (Comma literal)*)? CParen
;

literal
: NullLiteral
| numericLiteral
| BooleanLiteral
| StringLiteral
| DateLiteral
| TimeLiteral
| TimestampLiteral
| TimestampTzLiteral
| IntervalYearLiteral
| IntervalDayLiteral
;

qualifiedAggregateFuncArgs
: qualifiedAggregateFuncArg (Comma qualifiedAggregateFuncArg)*
;

aggregateFuncArgs
: aggregateFuncArg (Comma aggregateFuncArg)*
;

qualifiedAggregateFuncArg
: tableName=Identifier Dot ColumnName
| argument
;

aggregateFuncArg
: ColumnName DoubleColon dataType
| argument
;

numericLiteral
Expand All @@ -66,7 +125,7 @@ floatLiteral
: FloatLiteral | NaN
;

nullArg: NullLiteral DoubleColon datatype;
nullArg: NullLiteral DoubleColon dataType;

intArg: IntegerLiteral DoubleColon (I8 | I16 | I32 | I64);

Expand All @@ -77,11 +136,11 @@ decimalArg
;

booleanArg
: BooleanLiteral DoubleColon Bool
: BooleanLiteral DoubleColon booleanType
;

stringArg
: StringLiteral DoubleColon Str
: StringLiteral DoubleColon stringType
;

dateArg
Expand All @@ -93,19 +152,27 @@ timeArg
;

timestampArg
: TimestampLiteral DoubleColon Ts
: TimestampLiteral DoubleColon timestampType
;

timestampTzArg
: TimestampTzLiteral DoubleColon TsTZ
: TimestampTzLiteral DoubleColon timestampTZType
;

intervalYearArg
: IntervalYearLiteral DoubleColon IYear
: IntervalYearLiteral DoubleColon intervalYearType
;

intervalDayArg
: IntervalDayLiteral DoubleColon IDay
: IntervalDayLiteral DoubleColon intervalDayType
;

listArg
: literalList DoubleColon listType
;

literalList
: OBracket (literal (Comma literal)*)? CBracket
;

intervalYearLiteral
Expand All @@ -126,53 +193,88 @@ timeInterval
| fractionalSeconds=IntegerLiteral FractionalSecondSuffix
;

datatype
dataType
: scalarType
| parameterizedType
;

scalarType
: Bool #Boolean
| I8 #i8
| I16 #i16
| I32 #i32
| I64 #i64
| FP32 #fp32
| FP64 #fp64
| Str #string
| Binary #binary
| Ts #timestamp
| TsTZ #timestampTz
| Date #date
| Time #time
| IDay #intervalDay
| IYear #intervalYear
| UUID #uuid
| UserDefined Identifier #userDefined
: booleanType #boolean
| I8 #i8
| I16 #i16
| I32 #i32
| I64 #i64
| FP32 #fp32
| FP64 #fp64
| stringType #string
| binaryType #binary
| timestampType #timestamp
| timestampTZType #timestampTz
| Date #date
| Time #time
| intervalDayType #intervalDay
| intervalYearType #intervalYear
| UUID #uuid
| UserDefined Identifier #userDefined
;

booleanType
: (Bool | Boolean)
;

stringType
: (Str | String)
;

binaryType
: (Binary | VBin)
;

timestampType
: (Ts | Timestamp)
;

timestampTZType
: (TsTZ | Timestamp_TZ)
;

intervalYearType
: (IYear | Interval_Year)
;

intervalDayType
: (IDay | Interval_Day)
;

fixedCharType
: FChar isnull=QMark? OAngleBracket len=numericParameter CAngleBracket #fixedChar
: (FChar | FixedChar) isnull=QMark? OAngleBracket len=numericParameter CAngleBracket #fixedChar
;

varCharType
: VChar isnull=QMark? OAngleBracket len=numericParameter CAngleBracket #varChar
: (VChar | VarChar) isnull=QMark? OAngleBracket len=numericParameter CAngleBracket #varChar
;

fixedBinaryType
: FBin isnull=QMark? OAngleBracket len=numericParameter CAngleBracket #fixedBinary
: (FBin | FixedBinary) isnull=QMark? OAngleBracket len=numericParameter CAngleBracket #fixedBinary
;

decimalType
: Dec isnull=QMark? (OAngleBracket precision=numericParameter Comma scale=numericParameter CAngleBracket)? #decimal
: (Dec | Decimal) isnull=QMark?
(OAngleBracket precision=numericParameter Comma scale=numericParameter CAngleBracket)? #decimal
;

precisionTimestampType
: PTs isnull=QMark? OAngleBracket precision=numericParameter CAngleBracket #precisionTimestamp
: (PTs | Precision_Timestamp) isnull=QMark?
OAngleBracket precision=numericParameter CAngleBracket #precisionTimestamp
;

precisionTimestampTZType
: PTsTZ isnull=QMark? OAngleBracket precision=numericParameter CAngleBracket #precisionTimestampTZ
: (PTsTZ | Precision_Timestamp_TZ) isnull=QMark?
OAngleBracket precision=numericParameter CAngleBracket #precisionTimestampTZ
;

listType
: List isnull=QMark? OAngleBracket elemType=dataType CAngleBracket #list
;

parameterizedType
Expand All @@ -185,7 +287,6 @@ parameterizedType
// TODO implement the rest of the parameterized types
// | Struct isnull='?'? Lt expr (Comma expr)* Gt #struct
// | NStruct isnull='?'? Lt Identifier expr (Comma Identifier expr)* Gt #nStruct
// | List isnull='?'? Lt expr Gt #list
// | Map isnull='?'? Lt key=expr Comma value=expr Gt #map
;

Expand All @@ -202,14 +303,26 @@ func_option
;

option_name
: Overflow | Rounding
: Overflow | Rounding | NullHandling | SpacesOnly
| Identifier
;

option_value
: Error | Saturate | Silent | TieToEven | NaN
: Error | Saturate | Silent | TieToEven | NaN | Truncate | AcceptNulls | IgnoreNulls
| BooleanLiteral
| NullLiteral
| Identifier
;

func_options
: func_option (Comma func_option)*
;

nonReserved // IMPORTANT: this rule must only contain tokens
: And | Or | Truncate
;

identifier
: nonReserved
| Identifier
;
18 changes: 18 additions & 0 deletions tests/cases/arithmetic/max.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
### SUBSTRAIT_AGGREGATE_TEST: v1.0
### SUBSTRAIT_INCLUDE: '/extensions/functions_arithmetic.yaml'

# basic: Basic examples without any special cases
max((20, -3, 1, -10, 0, 5)::i8) = 20::i8
max((-32768, 32767, 20000, -30000)::i16) = 32767::i16
max((-214748648, 214748647, 21470048, 4000000)::i32) = 214748647::i32
max((2000000000, -3217908979, 629000000, -100000000, 0, 987654321)::i64) = 2000000000::i64
max((2.5, 0, 5.0, -2.5, -7.5)::fp32) = 5.0::fp32
max((1.5e+308, 1.5e+10, -1.5e+8, -1.5e+7, -1.5e+70)::fp64) = 1.5e+308::fp64

# null_handling: Examples with null as input or output
max((Null, Null, Null)::i16) = Null::i16
max(()::i16) = Null::i16
max((2000000000, Null, 629000000, -100000000, Null, 987654321)::i64) = 2000000000::i64
max((Null, inf)::fp64) = inf::fp64
max((Null, -inf, -1.5e+8, -1.5e+7, -1.5e+70)::fp64) = -1.5e+7::fp64
max((1.5e+308, 1.5e+10, Null, -1.5e+7, Null)::fp64) = 1.5e+308::fp64
Loading
Loading