diff --git a/src/main/java/net/snowflake/client/jdbc/CompressedStreamFactory.java b/src/main/java/net/snowflake/client/jdbc/CompressedStreamFactory.java index d678145c8..ebb376db9 100644 --- a/src/main/java/net/snowflake/client/jdbc/CompressedStreamFactory.java +++ b/src/main/java/net/snowflake/client/jdbc/CompressedStreamFactory.java @@ -7,7 +7,6 @@ import com.github.luben.zstd.ZstdInputStream; import java.io.IOException; import java.io.InputStream; -import java.io.PushbackInputStream; import java.util.zip.GZIPInputStream; import net.snowflake.common.core.SqlState; import org.apache.http.Header; @@ -16,16 +15,16 @@ class CompressedStreamFactory { private static final int STREAM_BUFFER_SIZE = MB; + /** + * Determine the format of the response, if it is not either plain text or gzip, raise an error. + */ public InputStream createBasedOnEncodingHeader(InputStream is, Header encoding) throws IOException, SnowflakeSQLException { - InputStream inputStream = is; // Determine the format of the response, if it is not - // either plain text or gzip, raise an error. if (encoding != null) { if (GZIP.name().equalsIgnoreCase(encoding.getValue())) { - /* specify buffer size for GZIPInputStream */ - inputStream = new GZIPInputStream(is, STREAM_BUFFER_SIZE); + return new GZIPInputStream(is, STREAM_BUFFER_SIZE); } else if (ZSTD.name().equalsIgnoreCase(encoding.getValue())) { - inputStream = new ZstdInputStream(is); + return new ZstdInputStream(is); } else { throw new SnowflakeSQLException( SqlState.INTERNAL_ERROR, @@ -33,22 +32,7 @@ public InputStream createBasedOnEncodingHeader(InputStream is, Header encoding) "Exception: unexpected compression got " + encoding.getValue()); } } else { - inputStream = detectGzipAndGetStream(is); - } - - return inputStream; - } - - private InputStream detectGzipAndGetStream(InputStream is) throws IOException { - PushbackInputStream pb = new PushbackInputStream(is, 2); - byte[] signature = new byte[2]; - int len = pb.read(signature); - pb.unread(signature, 0, len); - // https://tools.ietf.org/html/rfc1952 - if (signature[0] == (byte) 0x1f && signature[1] == (byte) 0x8b) { - return new GZIPInputStream(pb); - } else { - return pb; + return DefaultResultStreamProvider.detectGzipAndGetStream(is); } } } diff --git a/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java b/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java index cb4d2e0ed..e7a1e8a0c 100644 --- a/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java +++ b/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java @@ -2,8 +2,10 @@ import java.io.IOException; import java.io.InputStream; +import java.io.PushbackInputStream; import java.net.URISyntaxException; import java.util.Map; +import java.util.zip.GZIPInputStream; import net.snowflake.client.core.ExecTimeTelemetryData; import net.snowflake.client.core.HttpUtil; import net.snowflake.client.log.ArgSupplier; @@ -145,4 +147,17 @@ else if (context.getQrmk() != null) { response); return response; } + + public static InputStream detectGzipAndGetStream(InputStream is) throws IOException { + PushbackInputStream pb = new PushbackInputStream(is, 2); + byte[] signature = new byte[2]; + int len = pb.read(signature); + pb.unread(signature, 0, len); + // https://tools.ietf.org/html/rfc1952 + if (signature[0] == (byte) 0x1f && signature[1] == (byte) 0x8b) { + return new GZIPInputStream(pb); + } else { + return pb; + } + } } diff --git a/src/test/java/net/snowflake/client/jdbc/CompressedStreamFactoryTest.java b/src/test/java/net/snowflake/client/jdbc/CompressedStreamFactoryTest.java index e34528cc1..86eb5764a 100644 --- a/src/test/java/net/snowflake/client/jdbc/CompressedStreamFactoryTest.java +++ b/src/test/java/net/snowflake/client/jdbc/CompressedStreamFactoryTest.java @@ -1,6 +1,7 @@ package net.snowflake.client.jdbc; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import com.github.luben.zstd.ZstdInputStream; import com.github.luben.zstd.ZstdOutputStream; @@ -10,6 +11,7 @@ import java.nio.charset.StandardCharsets; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; +import org.apache.commons.io.IOUtils; import org.apache.http.Header; import org.apache.http.message.BasicHeader; import org.junit.Test; @@ -42,16 +44,8 @@ public void testDetectContentEncodingAndGetInputStream_Gzip() throws Exception { InputStream resultStream = factory.createBasedOnEncodingHeader(gzipStream, encodingHeader); // Decompress and validate the data matches original - ByteArrayOutputStream decompressedOutput = new ByteArrayOutputStream(); - byte[] buffer = new byte[1024]; - int bytesRead; - try (GZIPInputStream gzipInputStream = (GZIPInputStream) resultStream) { - while ((bytesRead = gzipInputStream.read(buffer)) != -1) { - decompressedOutput.write(buffer, 0, bytesRead); - } - } - String decompressedData = new String(decompressedOutput.toByteArray(), StandardCharsets.UTF_8); - + assertTrue(resultStream instanceof GZIPInputStream); + String decompressedData = IOUtils.toString(resultStream, StandardCharsets.UTF_8); assertEquals(originalData, decompressedData); } @@ -79,16 +73,8 @@ public void testDetectContentEncodingAndGetInputStream_Zstd() throws Exception { InputStream resultStream = factory.createBasedOnEncodingHeader(zstdStream, encodingHeader); // Decompress and validate the data matches original - ByteArrayOutputStream decompressedOutput = new ByteArrayOutputStream(); - byte[] buffer = new byte[1024]; - int bytesRead; - try (ZstdInputStream zstdInputStream = (ZstdInputStream) resultStream) { - while ((bytesRead = zstdInputStream.read(buffer)) != -1) { - decompressedOutput.write(buffer, 0, bytesRead); - } - } - String decompressedData = new String(decompressedOutput.toByteArray(), StandardCharsets.UTF_8); - + assertTrue(resultStream instanceof ZstdInputStream); + String decompressedData = IOUtils.toString(resultStream, StandardCharsets.UTF_8); assertEquals(originalData, decompressedData); } }