Skip to content

Commit

Permalink
InitConverter refactor (#1886)
Browse files Browse the repository at this point in the history
Moved `initConverter` static method from an interface to a newly created util class.
  • Loading branch information
sfc-gh-astachowski authored Sep 6, 2024
1 parent 21c1e32 commit 75f8087
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 212 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,8 @@
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
import java.util.Map;
import java.util.TimeZone;
import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SFException;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import net.snowflake.client.jdbc.SnowflakeSQLLoggedException;
import net.snowflake.client.jdbc.SnowflakeType;
import net.snowflake.common.core.SqlState;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.complex.FixedSizeListVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.types.Types;

/** Interface to convert from arrow vector values into java data types. */
public interface ArrowVectorConverter {
Expand Down Expand Up @@ -177,201 +163,4 @@ public interface ArrowVectorConverter {
* @param isUTC true or false value of whether NTZ timestamp should be set to UTC
*/
void setTreatNTZAsUTC(boolean isUTC);

/**
* Given an arrow vector (a single column in a single record batch), return an arrow vector
* converter. Note, converter is built on top of arrow vector, so that arrow data can be converted
* back to java data
*
* <p>
*
* <p>Arrow converter mappings for Snowflake fixed-point numbers
* ----------------------------------------------------------------------------------------- Max
* position and scale Converter
* -----------------------------------------------------------------------------------------
* number(3,0) {@link TinyIntToFixedConverter} number(3,2) {@link TinyIntToScaledFixedConverter}
* number(5,0) {@link SmallIntToFixedConverter} number(5,4) {@link SmallIntToScaledFixedConverter}
* number(10,0) {@link IntToFixedConverter} number(10,9) {@link IntToScaledFixedConverter}
* number(19,0) {@link BigIntToFixedConverter} number(19,18) {@link BigIntToFixedConverter}
* number(38,37) {@link DecimalToScaledFixedConverter}
* ------------------------------------------------------------------------------------------
*
* @param vector an arrow vector
* @param context data conversion context
* @param session SFBaseSession for purposes of logging
* @param idx the index of the vector in its batch
* @return A converter on top og the vector
*/
static ArrowVectorConverter initConverter(
ValueVector vector, DataConversionContext context, SFBaseSession session, int idx)
throws SnowflakeSQLException {
// arrow minor type
Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType());

// each column's metadata
Map<String, String> customMeta = vector.getField().getMetadata();
if (type == Types.MinorType.DECIMAL) {
// Note: Decimal vector is different from others
return new DecimalToScaledFixedConverter(vector, idx, context);
} else if (!customMeta.isEmpty()) {
SnowflakeType st = SnowflakeType.valueOf(customMeta.get("logicalType"));
switch (st) {
case ANY:
case CHAR:
case TEXT:
case VARIANT:
return new VarCharConverter(vector, idx, context);

case MAP:
if (vector instanceof MapVector) {
return new MapConverter((MapVector) vector, idx, context);
} else {
return new VarCharConverter(vector, idx, context);
}

case VECTOR:
return new VectorTypeConverter((FixedSizeListVector) vector, idx, context);

case ARRAY:
if (vector instanceof ListVector) {
return new ArrayConverter((ListVector) vector, idx, context);
} else {
return new VarCharConverter(vector, idx, context);
}

case OBJECT:
if (vector instanceof StructVector) {
return new StructConverter((StructVector) vector, idx, context);
} else {
return new VarCharConverter(vector, idx, context);
}

case BINARY:
return new VarBinaryToBinaryConverter(vector, idx, context);

case BOOLEAN:
return new BitToBooleanConverter(vector, idx, context);

case DATE:
boolean getFormatDateWithTimeZone = false;
if (context.getSession() != null) {
getFormatDateWithTimeZone = context.getSession().getFormatDateWithTimezone();
}
return new DateConverter(vector, idx, context, getFormatDateWithTimeZone);

case FIXED:
String scaleStr = vector.getField().getMetadata().get("scale");
int sfScale = Integer.parseInt(scaleStr);
switch (type) {
case TINYINT:
if (sfScale == 0) {
return new TinyIntToFixedConverter(vector, idx, context);
} else {
return new TinyIntToScaledFixedConverter(vector, idx, context, sfScale);
}
case SMALLINT:
if (sfScale == 0) {
return new SmallIntToFixedConverter(vector, idx, context);
} else {
return new SmallIntToScaledFixedConverter(vector, idx, context, sfScale);
}
case INT:
if (sfScale == 0) {
return new IntToFixedConverter(vector, idx, context);
} else {
return new IntToScaledFixedConverter(vector, idx, context, sfScale);
}
case BIGINT:
if (sfScale == 0) {
return new BigIntToFixedConverter(vector, idx, context);
} else {
return new BigIntToScaledFixedConverter(vector, idx, context, sfScale);
}
}
break;

case REAL:
return new DoubleToRealConverter(vector, idx, context);

case TIME:
switch (type) {
case INT:
return new IntToTimeConverter(vector, idx, context);
case BIGINT:
return new BigIntToTimeConverter(vector, idx, context);
default:
throw new SnowflakeSQLLoggedException(
session,
ErrorCode.INTERNAL_ERROR.getMessageCode(),
SqlState.INTERNAL_ERROR,
"Unexpected Arrow Field for ",
st.name());
}

case TIMESTAMP_LTZ:
if (vector.getField().getChildren().isEmpty()) {
// case when the scale of the timestamp is equal or smaller than millisecs since epoch
return new BigIntToTimestampLTZConverter(vector, idx, context);
} else if (vector.getField().getChildren().size() == 2) {
// case when the scale of the timestamp is larger than millisecs since epoch, e.g.,
// nanosecs
return new TwoFieldStructToTimestampLTZConverter(vector, idx, context);
} else {
throw new SnowflakeSQLLoggedException(
session,
ErrorCode.INTERNAL_ERROR.getMessageCode(),
SqlState.INTERNAL_ERROR,
"Unexpected Arrow Field for ",
st.name());
}

case TIMESTAMP_NTZ:
if (vector.getField().getChildren().isEmpty()) {
// case when the scale of the timestamp is equal or smaller than 7
return new BigIntToTimestampNTZConverter(vector, idx, context);
} else if (vector.getField().getChildren().size() == 2) {
// when the timestamp is represent in two-field struct
return new TwoFieldStructToTimestampNTZConverter(vector, idx, context);
} else {
throw new SnowflakeSQLLoggedException(
session,
ErrorCode.INTERNAL_ERROR.getMessageCode(),
SqlState.INTERNAL_ERROR,
"Unexpected Arrow Field for ",
st.name());
}

case TIMESTAMP_TZ:
if (vector.getField().getChildren().size() == 2) {
// case when the scale of the timestamp is equal or smaller than millisecs since epoch
return new TwoFieldStructToTimestampTZConverter(vector, idx, context);
} else if (vector.getField().getChildren().size() == 3) {
// case when the scale of the timestamp is larger than millisecs since epoch, e.g.,
// nanosecs
return new ThreeFieldStructToTimestampTZConverter(vector, idx, context);
} else {
throw new SnowflakeSQLLoggedException(
session,
ErrorCode.INTERNAL_ERROR.getMessageCode(),
SqlState.INTERNAL_ERROR,
"Unexpected SnowflakeType ",
st.name());
}

default:
throw new SnowflakeSQLLoggedException(
session,
ErrorCode.INTERNAL_ERROR.getMessageCode(),
SqlState.INTERNAL_ERROR,
"Unexpected Arrow Field for ",
st.name());
}
}
throw new SnowflakeSQLLoggedException(
session,
ErrorCode.INTERNAL_ERROR.getMessageCode(),
SqlState.INTERNAL_ERROR,
"Unexpected Arrow Field for ",
type.toString());
}
}
Loading

0 comments on commit 75f8087

Please sign in to comment.