diff --git a/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/ObjectMetadata.java b/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/ObjectMetadata.java index 95a3a51ea..fe4132732 100644 --- a/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/ObjectMetadata.java +++ b/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/ObjectMetadata.java @@ -25,6 +25,7 @@ * Container for S3 Object Metadata. For information about each field look at {@link PutObjectRequest} Javadocs. * * @author Maciej Walkowiak + * @author Hardik Singh Behl * @since 3.0 */ public class ObjectMetadata { @@ -116,6 +117,9 @@ public class ObjectMetadata { @Nullable private final String checksumAlgorithm; + @Nullable + private final String contentMD5; + public static Builder builder() { return new Builder(); } @@ -130,7 +134,7 @@ public static Builder builder() { @Nullable String ssekmsKeyId, @Nullable String ssekmsEncryptionContext, @Nullable Boolean bucketKeyEnabled, @Nullable String requestPayer, @Nullable String tagging, @Nullable String objectLockMode, @Nullable Instant objectLockRetainUntilDate, @Nullable String objectLockLegalHoldStatus, - @Nullable String expectedBucketOwner, @Nullable String checksumAlgorithm) { + @Nullable String expectedBucketOwner, @Nullable String checksumAlgorithm, @Nullable String contentMD5) { this.acl = acl; this.cacheControl = cacheControl; this.contentDisposition = contentDisposition; @@ -160,6 +164,7 @@ public static Builder builder() { this.objectLockLegalHoldStatus = objectLockLegalHoldStatus; this.expectedBucketOwner = expectedBucketOwner; this.checksumAlgorithm = checksumAlgorithm; + this.contentMD5 = contentMD5; } void apply(PutObjectRequest.Builder builder) { @@ -250,6 +255,9 @@ void apply(PutObjectRequest.Builder builder) { if (checksumAlgorithm != null) { builder.checksumAlgorithm(checksumAlgorithm); } + if (contentMD5 != null) { + builder.contentMD5(contentMD5); + } } void apply(CreateMultipartUploadRequest.Builder builder) { @@ -523,6 +531,11 @@ public String getChecksumAlgorithm() { return checksumAlgorithm; } + @Nullable + public String getContentMD5() { + return contentMD5; + } + public static class Builder { private final Map metadata = new HashMap<>(); @@ -611,6 +624,9 @@ public static class Builder { @Nullable private String checksumAlgorithm; + @Nullable + private String contentMD5; + public Builder acl(@Nullable String acl) { this.acl = acl; return this; @@ -785,13 +801,18 @@ public Builder checksumAlgorithm(@Nullable ChecksumAlgorithm checksumAlgorithm) return checksumAlgorithm(checksumAlgorithm != null ? checksumAlgorithm.toString() : null); } + public Builder contentMD5(@Nullable String contentMD5) { + this.contentMD5 = contentMD5; + return this; + } + public ObjectMetadata build() { return new ObjectMetadata(acl, cacheControl, contentDisposition, contentEncoding, contentLanguage, contentType, contentLength, expires, grantFullControl, grantRead, grantReadACP, grantWriteACP, metadata, serverSideEncryption, storageClass, websiteRedirectLocation, sseCustomerAlgorithm, sseCustomerKey, sseCustomerKeyMD5, ssekmsKeyId, ssekmsEncryptionContext, bucketKeyEnabled, requestPayer, tagging, objectLockMode, objectLockRetainUntilDate, objectLockLegalHoldStatus, - expectedBucketOwner, checksumAlgorithm); + expectedBucketOwner, checksumAlgorithm, contentMD5); } } diff --git a/spring-cloud-aws-s3/src/test/java/io/awspring/cloud/s3/S3TemplateIntegrationTests.java b/spring-cloud-aws-s3/src/test/java/io/awspring/cloud/s3/S3TemplateIntegrationTests.java index 8194f3dea..3517132f7 100644 --- a/spring-cloud-aws-s3/src/test/java/io/awspring/cloud/s3/S3TemplateIntegrationTests.java +++ b/spring-cloud-aws-s3/src/test/java/io/awspring/cloud/s3/S3TemplateIntegrationTests.java @@ -20,12 +20,17 @@ import static org.assertj.core.api.Assertions.assertThatNoException; import com.fasterxml.jackson.databind.ObjectMapper; + +import net.bytebuddy.utility.RandomString; + import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.net.URL; import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; import java.time.Duration; +import java.util.Base64; import java.util.List; import org.apache.http.HttpEntity; import org.apache.http.HttpResponse; @@ -47,6 +52,7 @@ import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.http.HttpStatusCode; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.GetObjectResponse; @@ -62,6 +68,7 @@ * @author Maciej Walkowiak * @author Yuki Yoshida * @author Ziemowit Stolarczyk + * @author Hardik Singh Behl */ @Testcontainers class S3TemplateIntegrationTests { @@ -70,7 +77,7 @@ class S3TemplateIntegrationTests { @Container static LocalStackContainer localstack = new LocalStackContainer( - DockerImageName.parse("localstack/localstack:3.8.1")); + DockerImageName.parse("localstack/localstack:3.8.1")).withEnv("S3_SKIP_SIGNATURE_VALIDATION", "0"); private static S3Client client; @@ -268,7 +275,12 @@ void createsWorkingSignedGetURL() throws IOException { @Test void createsWorkingSignedPutURL() throws IOException { - ObjectMetadata metadata = ObjectMetadata.builder().metadata("testkey", "testvalue").build(); + String fileContent = RandomString.make(); + long contentLength = fileContent.length(); + String contentMD5 = calculateContentMD5(fileContent); + + ObjectMetadata metadata = ObjectMetadata.builder().metadata("testkey", "testvalue").contentLength(contentLength) + .contentMD5(contentMD5).build(); URL signedPutUrl = s3Template.createSignedPutURL(BUCKET_NAME, "file.txt", Duration.ofMinutes(1), metadata, "text/plain"); @@ -276,7 +288,8 @@ void createsWorkingSignedPutURL() throws IOException { HttpPut httpPut = new HttpPut(signedPutUrl.toString()); httpPut.setHeader("x-amz-meta-testkey", "testvalue"); httpPut.setHeader("Content-Type", "text/plain"); - HttpEntity body = new StringEntity("hello"); + httpPut.setHeader("Content-MD5", contentMD5); + HttpEntity body = new StringEntity(fileContent); httpPut.setEntity(body); HttpResponse response = httpClient.execute(httpPut); @@ -285,11 +298,36 @@ void createsWorkingSignedPutURL() throws IOException { HeadObjectResponse headObjectResponse = client .headObject(HeadObjectRequest.builder().bucket(BUCKET_NAME).key("file.txt").build()); - assertThat(headObjectResponse.contentLength()).isEqualTo(5); + assertThat(response.getStatusLine().getStatusCode()).isEqualTo(HttpStatusCode.OK); + assertThat(headObjectResponse.contentLength()).isEqualTo(contentLength); assertThat(headObjectResponse.metadata().containsKey("testkey")).isTrue(); assertThat(headObjectResponse.metadata().get("testkey")).isEqualTo("testvalue"); } + @Test + void signedPutURLFailsForNonMatchingSignature() throws IOException { + String fileContent = RandomString.make(); + long contentLength = fileContent.length(); + String contentMD5 = calculateContentMD5(fileContent); + String maliciousContent = RandomString.make(); + + ObjectMetadata metadata = ObjectMetadata.builder().contentLength(contentLength).contentMD5(contentMD5).build(); + URL signedPutUrl = s3Template.createSignedPutURL(BUCKET_NAME, "file.txt", Duration.ofMinutes(1), metadata, + "text/plain"); + + CloseableHttpClient httpClient = HttpClients.createDefault(); + HttpPut httpPut = new HttpPut(signedPutUrl.toString()); + httpPut.setHeader("Content-Type", "text/plain"); + httpPut.setHeader("Content-MD5", contentMD5); + HttpEntity body = new StringEntity(fileContent + maliciousContent); + httpPut.setEntity(body); + + HttpResponse response = httpClient.execute(httpPut); + httpClient.close(); + + assertThat(response.getStatusLine().getStatusCode()).isEqualTo(HttpStatusCode.FORBIDDEN); + } + private void bucketDoesNotExist(ListBucketsResponse r, String bucketName) { assertThat(r.buckets().stream().filter(b -> b.name().equals(bucketName)).findAny()).isEmpty(); } @@ -298,6 +336,17 @@ private void bucketExists(ListBucketsResponse r, String bucketName) { assertThat(r.buckets().stream().filter(b -> b.name().equals(bucketName)).findAny()).isPresent(); } + private String calculateContentMD5(String content) { + try { + MessageDigest md = MessageDigest.getInstance("MD5"); + byte[] contentBytes = content.getBytes(StandardCharsets.UTF_8); + byte[] mdBytes = md.digest(contentBytes); + return Base64.getEncoder().encodeToString(mdBytes); + } catch (Exception exception) { + throw new RuntimeException("Failed to calculate Content-MD5", exception); + } + } + static class Person { private String firstName; private String lastName;