Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherrypick][PLUGIN-1816] Added fix for decimal issue not having rounding mode and made escape character configurable. #38

Open
wants to merge 2 commits into
base: release/1.1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

<groupId>io.cdap.plugin</groupId>
<artifactId>snowflake-plugins</artifactId>
<version>1.1.3</version>
<version>1.1.4-SNAPSHOT</version>
<packaging>jar</packaging>
<name>Snowflake plugins</name>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.cdap.plugin.snowflake.common.client.SnowflakeFieldDescriptor;
import io.cdap.plugin.snowflake.common.exception.SchemaParseException;
import io.cdap.plugin.snowflake.source.batch.SnowflakeBatchSourceConfig;
import io.cdap.plugin.snowflake.source.batch.SnowflakeInputFormatProvider;
import io.cdap.plugin.snowflake.source.batch.SnowflakeSourceAccessor;
import java.io.IOException;
import java.sql.Types;
Expand Down Expand Up @@ -62,7 +63,8 @@ public static Schema getSchema(SnowflakeBatchSourceConfig config, FailureCollect
return getParsedSchema(config.getSchema());
}

SnowflakeSourceAccessor snowflakeSourceAccessor = new SnowflakeSourceAccessor(config);
SnowflakeSourceAccessor snowflakeSourceAccessor =
new SnowflakeSourceAccessor(config, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR);
return getSchema(snowflakeSourceAccessor, config.getSchema(), collector, config.getImportQuery());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.cdap.plugin.snowflake.source.batch;

import com.google.common.base.Strings;
import io.cdap.cdap.api.annotation.Description;
import io.cdap.cdap.api.annotation.Name;
import io.cdap.cdap.api.annotation.Plugin;
Expand All @@ -33,6 +34,7 @@
import io.cdap.plugin.snowflake.common.util.SchemaHelper;
import org.apache.hadoop.io.NullWritable;

import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -68,7 +70,11 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) {
public void prepareRun(BatchSourceContext context) {
FailureCollector failureCollector = context.getFailureCollector();
config.validate(failureCollector);

Map<String, String> arguments = new HashMap<>(context.getArguments().asMap());
String escapeChar = arguments.containsKey(SnowflakeInputFormatProvider.PROPERTY_ESCAPE_CHAR) &&
!Strings.isNullOrEmpty(arguments.get(SnowflakeInputFormatProvider.PROPERTY_ESCAPE_CHAR))
? arguments.get(SnowflakeInputFormatProvider.PROPERTY_ESCAPE_CHAR)
: SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR;
Schema schema = SchemaHelper.getSchema(config, failureCollector);
failureCollector.getOrThrowException();

Expand All @@ -81,7 +87,7 @@ public void prepareRun(BatchSourceContext context) {
.collect(Collectors.toList()));
}

context.setInput(Input.of(config.getReferenceName(), new SnowflakeInputFormatProvider(config)));
context.setInput(Input.of(config.getReferenceName(), new SnowflakeInputFormatProvider(config, escapeChar)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ private SnowflakeSourceAccessor getSnowflakeAccessor(Configuration configuration
SnowflakeInputFormatProvider.PROPERTY_CONFIG_JSON);
SnowflakeBatchSourceConfig config = GSON.fromJson(
configJson, SnowflakeBatchSourceConfig.class);
return new SnowflakeSourceAccessor(config);
String escapeChar = configuration.get(SnowflakeInputFormatProvider.PROPERTY_ESCAPE_CHAR,
SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR);
return new SnowflakeSourceAccessor(config, escapeChar);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@
public class SnowflakeInputFormatProvider implements InputFormatProvider {

public static final String PROPERTY_CONFIG_JSON = "cdap.snowflake.source.config";
public static final String PROPERTY_ESCAPE_CHAR = "cdap.snowflake.source.escape";

public static final String PROPERTY_DEFAULT_ESCAPE_CHAR = "\\";

private static final Gson GSON = new Gson();
private final Map<String, String> conf;

public SnowflakeInputFormatProvider(SnowflakeBatchSourceConfig config) {
public SnowflakeInputFormatProvider(SnowflakeBatchSourceConfig config, String escapeChar) {
this.conf = new ImmutableMap.Builder<String, String>()
.put(PROPERTY_CONFIG_JSON, GSON.toJson(config))
.put(PROPERTY_ESCAPE_CHAR, escapeChar)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.slf4j.LoggerFactory;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalTime;
Expand Down Expand Up @@ -83,7 +84,8 @@ private Object convertValue(String fieldName, String value, Schema fieldSchema)
case TIME_MICROS:
return TimeUnit.NANOSECONDS.toMicros(LocalTime.parse(value).toNanoOfDay());
case DECIMAL:
return new BigDecimal(value).setScale(fieldSchema.getScale()).unscaledValue().toByteArray();
return new BigDecimal(value).setScale(fieldSchema.getScale(),
RoundingMode.HALF_EVEN).unscaledValue().toByteArray();
default:
throw new IllegalArgumentException(
String.format("Field '%s' is of unsupported type '%s'", fieldSchema.getDisplayName(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ public class SnowflakeSourceAccessor extends SnowflakeAccessor {
"OVERWRITE=TRUE HEADER=TRUE SINGLE=FALSE";
private static final String COMMAND_MAX_FILE_SIZE = " MAX_FILE_SIZE=%s";
private final SnowflakeBatchSourceConfig config;
private final char escapeChar;

public SnowflakeSourceAccessor(SnowflakeBatchSourceConfig config) {
public SnowflakeSourceAccessor(SnowflakeBatchSourceConfig config, String escapeChar) {
super(config);
this.config = config;
this.escapeChar = escapeChar.charAt(0);
}

/**
Expand Down Expand Up @@ -116,7 +118,7 @@ public CSVReader buildCsvReader(String stageSplit) throws IOException {
InputStream downloadStream = connection.unwrap(SnowflakeConnection.class)
.downloadStream("@~", stageSplit, true);
InputStreamReader inputStreamReader = new InputStreamReader(downloadStream);
return new CSVReader(inputStreamReader);
return new CSVReader(inputStreamReader, ',', '"', escapeChar);
} catch (SQLException e) {
throw new IOException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import io.cdap.plugin.snowflake.Constants;
import io.cdap.plugin.snowflake.common.BaseSnowflakeTest;
import io.cdap.plugin.snowflake.source.batch.SnowflakeInputFormatProvider;
import io.cdap.plugin.snowflake.source.batch.SnowflakeSourceAccessor;
import org.junit.Assert;
import org.junit.Test;
Expand All @@ -44,7 +45,8 @@
*/
public class SnowflakeAccessorTest extends BaseSnowflakeTest {

private SnowflakeSourceAccessor snowflakeAccessor = new SnowflakeSourceAccessor(CONFIG);
private SnowflakeSourceAccessor snowflakeAccessor =
new SnowflakeSourceAccessor(CONFIG, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR);

@Test
public void testDescribeQuery() throws Exception {
Expand Down
Loading