Skip to content

Commit

Permalink
[Feature][Transform-V2][SQL] Support case when clause for SQL Transfo…
Browse files Browse the repository at this point in the history
…rm plugin (apache#5013)

Co-authored-by: javalover123 <[email protected]>
  • Loading branch information
javalover123 committed Aug 1, 2023
1 parent eb7378c commit 495f2c8
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ public Object computeForValue(Expression expression, Object[] inputFields) {
}
if (expression instanceof CaseExpression) {
CaseExpression caseExpression = (CaseExpression) expression;
return executeCaseExpr(caseExpression, inputFields);
final Object value = executeCaseExpr(caseExpression, inputFields);
SeaTunnelDataType<?> type = zetaSQLType.getExpressionType(expression);
return SystemFunction.castAs(value, type);
}
if (expression instanceof BinaryExpression) {
return executeBinaryExpr((BinaryExpression) expression, inputFields);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.seatunnel.common.exception.CommonErrorCode;
import org.apache.seatunnel.transform.exception.TransformException;

import org.apache.commons.collections4.CollectionUtils;

import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.CaseExpression;
import net.sf.jsqlparser.expression.CastExpression;
Expand All @@ -38,13 +40,15 @@
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.TimeKeyExpression;
import net.sf.jsqlparser.expression.WhenClause;
import net.sf.jsqlparser.expression.operators.arithmetic.Concat;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.schema.Column;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

public class ZetaSQLType {
public static final String DECIMAL = "DECIMAL";
Expand Down Expand Up @@ -157,8 +161,66 @@ public SeaTunnelDataType<?> getExpressionType(Expression expression) {
String.format("Unsupported SQL Expression: %s ", expression.toString()));
}

public SeaTunnelDataType<?> getMaxType(
SeaTunnelDataType<?> leftType, SeaTunnelDataType<?> rightType) {
if (leftType == null || rightType == null) {
return leftType != null ? leftType : rightType;
}
if (leftType.equals(rightType)) {
return leftType;
}
if (leftType.getSqlType() == SqlType.INT && rightType.getSqlType() == SqlType.INT) {
return BasicType.INT_TYPE;
}
if ((leftType.getSqlType() == SqlType.INT || leftType.getSqlType() == SqlType.BIGINT)
&& (rightType.getSqlType() == SqlType.INT
|| rightType.getSqlType() == SqlType.BIGINT)) {
return BasicType.LONG_TYPE;
}
if (leftType.getSqlType() == SqlType.DECIMAL || rightType.getSqlType() == SqlType.DECIMAL) {
int precision = 0;
int scale = 0;
if (leftType.getSqlType() == SqlType.DECIMAL) {
DecimalType decimalType = (DecimalType) leftType;
precision = decimalType.getPrecision();
scale = decimalType.getScale();
}
if (rightType.getSqlType() == SqlType.DECIMAL) {
DecimalType decimalType = (DecimalType) rightType;
precision = Math.max(decimalType.getPrecision(), precision);
scale = Math.max(decimalType.getScale(), scale);
}
return new DecimalType(precision, scale);
}
if ((leftType.getSqlType() == SqlType.FLOAT || leftType.getSqlType() == SqlType.DOUBLE)
|| (rightType.getSqlType() == SqlType.FLOAT
|| rightType.getSqlType() == SqlType.DOUBLE)) {
return BasicType.DOUBLE_TYPE;
}
throw new TransformException(
CommonErrorCode.UNSUPPORTED_OPERATION, leftType + " type not equals " + rightType);
}

public SeaTunnelDataType<?> getMaxType(List<SeaTunnelDataType<?>> types) {
if (CollectionUtils.isEmpty(types)) {
throw new TransformException(
CommonErrorCode.UNSUPPORTED_OPERATION, "getMaxType parameter is null");
}
SeaTunnelDataType<?> result = types.get(0);
for (int i = 0, j = types.size(); i < j; i++) {
result = getMaxType(result, types.get(i));
}
return result;
}

private SeaTunnelDataType<?> getCaseType(CaseExpression caseExpression) {
return getExpressionType(caseExpression.getElseExpression());
final List<SeaTunnelDataType<?>> types =
caseExpression.getWhenClauses().stream()
.map(WhenClause::getThenExpression)
.map(this::getExpressionType)
.collect(Collectors.toList());
types.add(getExpressionType(caseExpression.getElseExpression()));
return getMaxType(types);
}

private SeaTunnelDataType<?> getCastType(CastExpression castExpression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.seatunnel.transform.sql.zeta.functions;

import org.apache.seatunnel.api.table.type.DecimalType;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.common.exception.CommonErrorCode;
import org.apache.seatunnel.transform.exception.TransformException;

Expand All @@ -25,6 +27,7 @@
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.List;

public class SystemFunction {
Expand Down Expand Up @@ -60,6 +63,18 @@ public static Object nullif(List<Object> args) {
return v1;
}

public static Object castAs(Object arg, SeaTunnelDataType<?> type) {
final ArrayList<Object> args = new ArrayList<>(4);
args.add(arg);
args.add(type.getSqlType().toString());
if (DecimalType.class.equals(type.getClass())) {
final DecimalType decimalType = (DecimalType) type;
args.add(decimalType.getPrecision());
args.add(decimalType.getScale());
}
return castAs(args);
}

public static Object castAs(List<Object> args) {
Object v1 = args.get(0);
String v2 = (String) args.get(1);
Expand Down

0 comments on commit 495f2c8

Please sign in to comment.