Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-37180][table] Support running stateless PTFs #26076

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ public List<FunctionTestStep> getSetupFunctionTestSteps() {
.collect(Collectors.toList());
}

/** Convenience method to avoid casting. It assumes that the order of steps is not important. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually faced adj + noun rather than noun + noun... not sure how correct it is...

Suggested change
/** Convenience method to avoid casting. It assumes that the order of steps is not important. */
/** Convenient method to avoid casting. It assumes that the order of steps is not important. */

public List<SqlTestStep> getSetupSqlTestSteps() {
return setupSteps.stream()
.filter(s -> s.getKind() == TestKind.SQL)
.map(SqlTestStep.class::cast)
.collect(Collectors.toList());
}

/** Convenience method to avoid casting. It assumes that the order of steps is not important. */
public List<TemporalFunctionTestStep> getSetupTemporalFunctionTestSteps() {
return setupSteps.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ public enum ArgumentTrait {
* in a way that it can encode changes. In other words: choose a row type that exposes the
* {@link RowKind} change flag.
*
* <p>This trait is intended for advanced use cases. Please note that inputs are always
* insert-only in batch mode. Thus, if the PTF should produce the same results in both batch and
* streaming mode, results should be emitted based on watermarks and event-time. The trait
* {@link #PASS_COLUMNS_THROUGH} is not supported if this trait is declared.
*
* <p>Note: This trait is valid for {@link #TABLE_AS_ROW} and {@link #TABLE_AS_SET} arguments.
*/
SUPPORT_UPDATES(false, StaticArgumentTrait.SUPPORT_UPDATES);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
Expand All @@ -57,10 +56,15 @@
@Internal
public class SystemTypeInference {

private static final List<StaticArgument> PROCESS_TABLE_FUNCTION_SYSTEM_ARGS =
public static final List<StaticArgument> PROCESS_TABLE_FUNCTION_SYSTEM_ARGS =
List.of(StaticArgument.scalar("uid", DataTypes.STRING(), true));

/** Format of unique identifiers for {@link ProcessTableFunction}. */
/**
* Format of unique identifiers for {@link ProcessTableFunction}.
*
* <p>Leading digits are not allowed. This also prevents that a custom PTF uid can infer with
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* <p>Leading digits are not allowed. This also prevents that a custom PTF uid can infer with
* <p>Leading digits are not allowed. This also prevents that a custom PTF uid can interfere with

* {@code ExecutionConfigOptions#TABLE_EXEC_UID_FORMAT}.
*/
private static final Predicate<String> UID_FORMAT =
Pattern.compile("^[a-zA-Z_][a-zA-Z-_0-9]*$").asPredicate();

Expand All @@ -81,6 +85,7 @@ public static TypeInference of(FunctionKind functionKind, TypeInference origin)
return builder.build();
}

@SuppressWarnings("BooleanMethodIsAlwaysInverted")
public static boolean isValidUidForProcessTableFunction(String uid) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe invert it then?

return UID_FORMAT.test(uid);
}
Expand Down Expand Up @@ -114,6 +119,8 @@ private static void checkScalarArgsOnly(List<StaticArgument> defaultArgs) {
}

checkReservedArgs(declaredArgs);
checkMultipleTableArgs(declaredArgs);
checkUpdatingPassThroughColumns(declaredArgs);

final List<StaticArgument> newStaticArgs = new ArrayList<>(declaredArgs);
newStaticArgs.addAll(PROCESS_TABLE_FUNCTION_SYSTEM_ARGS);
Expand All @@ -135,6 +142,25 @@ private static void checkReservedArgs(List<StaticArgument> staticArgs) {
}
}

private static void checkMultipleTableArgs(List<StaticArgument> staticArgs) {
if (staticArgs.stream().filter(arg -> arg.is(StaticArgumentTrait.TABLE)).count() > 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (staticArgs.stream().filter(arg -> arg.is(StaticArgumentTrait.TABLE)).count() > 1) {
if (staticArgs.stream().filter(arg -> arg.is(StaticArgumentTrait.TABLE)).limit(2).count() > 1) {

nit: would it make sense to add limit like?

throw new ValidationException(
"Currently, only signatures with at most one table argument are supported.");
}
}

private static void checkUpdatingPassThroughColumns(List<StaticArgument> staticArgs) {
final Set<StaticArgumentTrait> traits =
staticArgs.stream()
.flatMap(arg -> arg.getTraits().stream())
.collect(Collectors.toSet());
if (traits.contains(StaticArgumentTrait.SUPPORT_UPDATES)
&& traits.contains(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) {
throw new ValidationException(
"Signatures with updating inputs must not pass columns through.");
}
}

private static InputTypeStrategy deriveSystemInputStrategy(
FunctionKind functionKind,
@Nullable List<StaticArgument> staticArgs,
Expand Down Expand Up @@ -288,7 +314,6 @@ public Optional<List<DataType>> inferInputTypes(
}

checkUidArg(callContext);
checkMultipleTableArgs(callContext);
checkTableArgTraits(staticArgs, callContext);

return Optional.of(inferredDataTypes);
Expand Down Expand Up @@ -318,19 +343,6 @@ private static void checkUidArg(CallContext callContext) {
}
}

private static void checkMultipleTableArgs(CallContext callContext) {
final List<DataType> args = callContext.getArgumentDataTypes();

final List<TableSemantics> tableSemantics =
IntStream.range(0, args.size())
.mapToObj(pos -> callContext.getTableSemantics(pos).orElse(null))
.collect(Collectors.toList());
if (tableSemantics.stream().filter(Objects::nonNull).count() > 1) {
throw new ValidationException(
"Currently, only signatures with at most one table argument are supported.");
}
}

private static void checkTableArgTraits(
List<StaticArgument> staticArgs, CallContext callContext) {
IntStream.range(0, staticArgs.size())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
Expand Down Expand Up @@ -157,54 +157,44 @@ private RexNode convertTableArgs(SqlRexContext cx, final SqlCall call) {
call.getOperator() instanceof SqlTableFunction,
"Only table functions can have set semantics arguments.");
final SqlOperator operator = call.getOperator();
final RelDataType returnType = cx.getValidator().getValidatedNodeType(call);

final List<RexNode> rewrittenOperands = new ArrayList<>();
int tableInputCount = 0;
for (int pos = 0; pos < call.getOperandList().size(); pos++) {
final SqlNode operand = call.operand(pos);
if (operand.getKind() == SqlKind.SET_SEMANTICS_TABLE) {
final RelDataType tableType = cx.getValidator().getValidatedNodeType(operand);
final SqlBasicCall setSemanticsCall = (SqlBasicCall) operand;
final SqlNodeList partitionKeys = setSemanticsCall.operand(1);
final SqlNodeList orderKeys = setSemanticsCall.operand(2);
checkArgument(
orderKeys.isEmpty(), "Table functions do not support order keys yet.");
final int[] keys = getPartitionKeyIndices(tableType, partitionKeys);

rewrittenOperands.add(
new RexTableArgCall(
cx.getValidator().getValidatedNodeType(operand),
tableInputCount++,
getPartitionKeyIndices(cx, partitionKeys),
new int[0]));
new RexTableArgCall(tableType, tableInputCount++, keys, new int[0]));
} else if (operand.isA(SqlKind.QUERY)) {
final RelDataType tableType = cx.getValidator().getValidatedNodeType(operand);
rewrittenOperands.add(
new RexTableArgCall(
cx.getValidator().getValidatedNodeType(operand),
tableInputCount++,
new int[0],
new int[0]));
new RexTableArgCall(tableType, tableInputCount++, new int[0], new int[0]));
} else {
rewrittenOperands.add(cx.convertExpression(operand));
}
}

final RelDataType returnType = cx.getValidator().getValidatedNodeType(call);
return cx.getRexBuilder().makeCall(returnType, operator, rewrittenOperands);
}

private static int[] getPartitionKeyIndices(SqlRexContext cx, SqlNodeList partitions) {
private static int[] getPartitionKeyIndices(RelDataType tableType, SqlNodeList partitions) {
// Due to incorrect scoping of SET_SEMANTIC_TABLE, we have to resolve identifiers manually
// See FLINK-37211
final List<String> tableColumns = tableType.getFieldNames();
final int[] result = new int[partitions.size()];
for (int i = 0; i < partitions.getList().size(); i++) {
final RexNode expr = cx.convertExpression(partitions.get(i));
result[i] = parseFieldIdx(expr);
final SqlIdentifier column = (SqlIdentifier) partitions.get(i);
result[i] = tableColumns.indexOf(column.getSimple());
}
return result;
}

private static int parseFieldIdx(RexNode e) {
if (SqlKind.INPUT_REF == e.getKind()) {
final RexInputRef ref = (RexInputRef) e;
return ref.getIndex();
}
// should not happen
throw new IllegalStateException("Unsupported partition key with type: " + e.getKind());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ public Optional<TableSemantics> getTableSemantics(int pos) {
return Optional.empty();
}
return Optional.of(
CallBindingTableSemantics.create(
argumentDataTypes.get(pos), staticArguments.get(pos), sqlNode));
SqlBindingTableSemantics.create(argumentDataTypes.get(pos), staticArg, sqlNode));
}

@Override
Expand All @@ -145,15 +144,15 @@ public Optional<DataType> getOutputDataType() {
// TableSemantics
// --------------------------------------------------------------------------------------------

private static class CallBindingTableSemantics implements TableSemantics {
private static class SqlBindingTableSemantics implements TableSemantics {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I usually don't have an opinion on names, so feel free to ignore this comment.

You have SqlCallContext, SqlCallBinding, shouldn't it be called SqlCallBindingTableSemantics then?


private final DataType dataType;
private final int[] partitionByColumns;

public static CallBindingTableSemantics create(
public static SqlBindingTableSemantics create(
DataType tableDataType, StaticArgument staticArg, SqlNode sqlNode) {
checkNoOrderBy(sqlNode);
return new CallBindingTableSemantics(
return new SqlBindingTableSemantics(
createDataType(tableDataType, staticArg),
createPartitionByColumns(tableDataType, sqlNode));
}
Expand Down Expand Up @@ -211,7 +210,7 @@ private static int[] createPartitionByColumns(DataType tableDataType, SqlNode sq
.toArray();
}

private CallBindingTableSemantics(DataType dataType, int[] partitionByColumns) {
private SqlBindingTableSemantics(DataType dataType, int[] partitionByColumns) {
this.dataType = dataType;
this.partitionByColumns = partitionByColumns;
}
Expand Down
Loading