Skip to content

Commit

Permalink
Add a special case for how S3 requests should be signed (#1605)
Browse files Browse the repository at this point in the history
* Add a special case for how S3 requests should be signed

* Changelog

* Re-enable tests for non-S3 signatures

* Avoid un-necessary flatten

* Regenerated

* jvm17 ordering

* Regen copyright headers

* Fix changelog
  • Loading branch information
Baccata authored Jan 2, 2025
1 parent f2639a8 commit 1d128fa
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 91 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Thank you!
# 0.18.28

* Better support for timestamps before Linux Epoch and trimming the Timestamp nanosecond part (see [#1623](https://github.com/disneystreaming/smithy4s/pull/1623))
* Adds a special for AWS request signing when S3 is being used (see see [#1605](https://github.com/disneystreaming/smithy4s/pull/1605))

# 0.18.27

Expand Down
176 changes: 105 additions & 71 deletions modules/aws-http4s/src/smithy4s/aws/internals/AwsSigning.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ import java.nio.charset.StandardCharsets
*/
private[aws] object AwsSigning {

// see https://raw.githubusercontent.com/awslabs/aws-sdk-kotlin/main/codegen/sdk/aws-models/s3.json
val S3 = "AmazonS3"

def middleware[F[_]: Concurrent](
awsEnvironment: AwsEnvironment[F]
): Endpoint.Middleware[Client[F]] = new Endpoint.Middleware[Client[F]] {
Expand Down Expand Up @@ -88,6 +91,12 @@ private[aws] object AwsSigning {
credentials: F[AwsCredentials],
region: F[AwsRegion]
): Request[F] => F[Request[F]] = {

// S3 has special rules, in that it expects the X-Amz-Content-SHA256 to be set.
val preSign: PreSigner[F] =
if (serviceName == S3) new PreSigner.S3InMemorySigned[F]
else new PreSigner.Standard[F]

val contentType = org.http4s.headers.`Content-Type`.headerInstance
val `Content-Type` = contentType.name

Expand All @@ -106,79 +115,78 @@ private[aws] object AwsSigning {
}

// scalafmt: { align.preset = most, danglingParentheses.preset = false, maxColumn = 240, align.tokens = [{code = ":"}]}
(request: Request[F]) => {

val bodyF = request.body.chunks.compile.to(Chunk).map(_.flatten)
val awsHeadersF = (bodyF, timestamp, credentials, region).mapN { case (body, timestamp, credentials, region) =>
val credentialsScope = s"${timestamp.conciseDate}/$region/$endpointPrefix/aws4_request"
val queryParams: Vector[(String, String)] =
request.uri.query.toVector.sorted.map { case (k, v) => k -> v.getOrElse("") }
val canonicalQueryString =
if (queryParams.isEmpty) ""
else
queryParams
.map { case (k, v) =>
URLEncoder.encode(k, StandardCharsets.UTF_8.name()) + "=" + URLEncoder.encode(v, StandardCharsets.UTF_8.name())
}
.mkString("&")

// // !\ Important: these must remain in the same order
val baseHeadersList = List(
`Content-Type` -> request.contentType.map(contentType.value(_)).orNull,
`Host` -> request.uri.host.map(_.renderString).orNull,
`X-Amz-Date` -> timestamp.conciseDateTime,
`X-Amz-Security-Token` -> credentials.sessionToken.orNull,
`X-Amz-Target` -> (serviceName + "." + operationName)
).filterNot(_._2 == null)

val canonicalHeadersString = baseHeadersList
.map { case (key, value) =>
key.toString.toLowerCase + ":" + value.trim
}
.mkString(newline)
lazy val signedHeadersString = baseHeadersList.map(_._1).map(_.toString.toLowerCase()).mkString(";")

val payloadHash = sha256HexDigest(body.toArray)
val pathString = request.uri.path.toAbsolute.renderString
val canonicalRequest = new StringBuilder()
.append(request.method.name.toUpperCase())
.append(newline)
.append(pathString)
.append(newline)
.append(canonicalQueryString)
.append(newline)
.append(canonicalHeadersString)
.append(newline)
.append(newline)
.append(signedHeadersString)
.append(newline)
.append(payloadHash)
.result()

val canonicalRequestHash = sha256HexDigest(canonicalRequest)
val signatureKey = getSignatureKey(
credentials.secretAccessKey,
timestamp.conciseDate,
region.value,
endpointPrefix
)
val stringToSign = List[String](
algorithm,
timestamp.conciseDateTime,
credentialsScope,
canonicalRequestHash
).mkString(newline)
val signature = toHexString(hmacSha256(stringToSign, signatureKey))
val authHeaderValue = s"${algorithm} Credential=${credentials.accessKeyId}/$credentialsScope, SignedHeaders=$signedHeadersString, Signature=$signature"
val authHeader = Headers("Authorization" -> authHeaderValue)
val baseHeaders = Headers(baseHeadersList.map { case (k, v) => Header.Raw(k, v) })
authHeader ++ baseHeaders
}
(request: Request[F]) =>
preSign(request).flatMap { case (payloadHash, preparedRequest) =>
val awsHeadersF = (timestamp, credentials, region).mapN { case (timestamp, credentials, region) =>
val credentialsScope = s"${timestamp.conciseDate}/$region/$endpointPrefix/aws4_request"
val queryParams: Vector[(String, String)] =
request.uri.query.toVector.sorted.map { case (k, v) => k -> v.getOrElse("") }
val canonicalQueryString =
if (queryParams.isEmpty) ""
else
queryParams
.map { case (k, v) =>
URLEncoder.encode(k, StandardCharsets.UTF_8.name()) + "=" + URLEncoder.encode(v, StandardCharsets.UTF_8.name())
}
.mkString("&")

// // !\ Important: these must remain in the same order
val baseHeadersList = List(
`Content-Type` -> preparedRequest.contentType.map(contentType.value(_)).orNull,
`Host` -> preparedRequest.uri.host.map(_.renderString).orNull,
`X-Amz-Content-SHA256` -> preparedRequest.headers.get(`X-Amz-Content-SHA256`).map(_.head.value).orNull,
`X-Amz-Date` -> timestamp.conciseDateTime,
`X-Amz-Security-Token` -> credentials.sessionToken.orNull,
`X-Amz-Target` -> (serviceName + "." + operationName)
).filterNot(_._2 == null)

val canonicalHeadersString = baseHeadersList
.map { case (key, value) =>
key.toString.toLowerCase + ":" + value.trim
}
.mkString(newline)
lazy val signedHeadersString = baseHeadersList.map(_._1).map(_.toString.toLowerCase()).mkString(";")

val pathString = preparedRequest.uri.path.toAbsolute.renderString
val canonicalRequest = new StringBuilder()
.append(request.method.name.toUpperCase())
.append(newline)
.append(pathString)
.append(newline)
.append(canonicalQueryString)
.append(newline)
.append(canonicalHeadersString)
.append(newline)
.append(newline)
.append(signedHeadersString)
.append(newline)
.append(payloadHash)
.result()

val canonicalRequestHash = sha256HexDigest(canonicalRequest)
val signatureKey = getSignatureKey(
credentials.secretAccessKey,
timestamp.conciseDate,
region.value,
endpointPrefix
)
val stringToSign = List[String](
algorithm,
timestamp.conciseDateTime,
credentialsScope,
canonicalRequestHash
).mkString(newline)
val signature = toHexString(hmacSha256(stringToSign, signatureKey))
val authHeaderValue = s"${algorithm} Credential=${credentials.accessKeyId}/$credentialsScope, SignedHeaders=$signedHeadersString, Signature=$signature"
val authHeader = Headers("Authorization" -> authHeaderValue)
val baseHeaders = Headers(baseHeadersList.map { case (k, v) => Header.Raw(k, v) })
authHeader ++ baseHeaders
}

awsHeadersF.map { headers =>
request.transformHeaders(_ ++ headers)
awsHeadersF.map { headers =>
preparedRequest.transformHeaders(_ ++ headers)
}
}
}
}

private val newline = System.lineSeparator()
Expand All @@ -187,5 +195,31 @@ private[aws] object AwsSigning {
private val `X-Amz-Security-Token` = CIString("X-Amz-Security-Token")
private val `X-Amz-Target` = CIString("X-Amz-Target")
private val algorithm = "AWS4-HMAC-SHA256"
private val `X-Amz-Content-SHA256` = CIString("X-Amz-Content-SHA256")

private sealed trait PreSigner[F[_]] {
def apply(request: Request[F]): F[(String, Request[F])]
}
private object PreSigner {
class Standard[F[_]](implicit F: Concurrent[F]) extends PreSigner[F] {
def apply(request: Request[F]): F[(String, Request[F])] = {
request.body.compile.to(Chunk).map { inMemoryBody =>
val payloadHash = sha256HexDigest(inMemoryBody.toArray)
val newRequest = request.withBodyStream(fs2.Stream.chunk(inMemoryBody))
(payloadHash, newRequest)
}
}
}

class S3InMemorySigned[F[_]](implicit F: Concurrent[F]) extends PreSigner[F] {
def apply(request: Request[F]): F[(String, Request[F])] = {
request.body.compile.to(Chunk).map { inMemoryBody =>
val payloadHash = sha256HexDigest(inMemoryBody.toArray)
val newRequest = request.withBodyStream(fs2.Stream.chunk(inMemoryBody)).transformHeaders(_.put(Header.Raw(`X-Amz-Content-SHA256`, payloadHash)))
(payloadHash, newRequest)
}
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ import java.time.Clock
import java.time.ZoneId
import scala.jdk.CollectionConverters._
import scala.jdk.OptionConverters._
import software.amazon.awssdk.auth.signer.AwsS3V4Signer
import software.amazon.awssdk.auth.signer.params.AwsS3V4SignerParams

/**
* This suite verifies our implementation of the AWS signature algorithm against
Expand Down Expand Up @@ -92,9 +94,9 @@ object AwsSignatureTest extends SimpleIOSuite with Checkers {
val genAwsRequest = for {
httpMethod <- Gen.oneOf(SdkHttpMethod.values().toList)
host <- Gen.identifier
path <- Gen.listOf(Gen.identifier).map(_.mkString("/"))
path <- Gen.listOfN(3, Gen.identifier).map(_.mkString("/"))
content <- Gen.asciiStr
queryParams <- Gen.listOf(Gen.zip(Gen.identifier, Gen.alphaNumStr))
queryParams <- Gen.listOfN(3, Gen.zip(Gen.identifier, Gen.alphaNumStr))
} yield {
val builder = SdkHttpFullRequest
.builder()
Expand All @@ -112,7 +114,7 @@ object AwsSignatureTest extends SimpleIOSuite with Checkers {
}

val gen: Gen[TestInput] = for {
serviceName <- Gen.identifier
serviceName <- Gen.oneOf(Gen.const("AmazonS3"), Gen.identifier)
operationName <- Gen.identifier
timestamp <- Gen.chooseNum(0L, 4102444800L).map(Timestamp.fromEpochSecond)
accessKeyId <- Gen.identifier
Expand Down Expand Up @@ -154,23 +156,53 @@ object AwsSignatureTest extends SimpleIOSuite with Checkers {
}

val region = Region.of(smithy4sRegion.value)
val signedAwsRequest = if (testInput.serviceName == AwsSigning.S3) {

val params = Aws4SignerParams
.builder()
.awsCredentials(creds)
.signingRegion(region)
.signingClockOverride(fixedClock)
.signingName(serviceName)
.build()

val awsSigner = Aws4Signer.create()
// Amending the AWS Request to force the AMZ target as it's added automatically
// by our implementation
val amendedAwsRequest = awsRequest
.toBuilder()
.appendHeader("X-Amz-Target", serviceName + "." + operationName)
.build()
val signedAwsRequest = awsSigner.sign(amendedAwsRequest, params)
val signerParams = AwsS3V4SignerParams
.builder()
.awsCredentials(creds)
.signingRegion(region)
.signingClockOverride(fixedClock)
.enablePayloadSigning(true)
.signingName(serviceName)
.build()

// yes, this is an S3-specific signer.
val awsSigner = AwsS3V4Signer.create()

// Amending the AWS Request to force the AMZ target as it's added automatically
// by our implementation
//
// The hardcoded "required" header value is understood by the S3 signer as a signal that the `X-Amz-Content-SHA256` header
// should be replaced by the hash of the request payload, and that the same hash should be used in the signature.
val amendedAwsRequest = awsRequest
.toBuilder()
.appendHeader("X-Amz-Target", serviceName + "." + operationName)
.appendHeader(
"X-Amz-Content-SHA256",
"required"
) // this is a magic addition that is understood by the S3 signer
.build()

awsSigner.sign(amendedAwsRequest, signerParams)
} else {
val params = Aws4SignerParams
.builder()
.awsCredentials(creds)
.signingRegion(region)
.signingClockOverride(fixedClock)
.signingName(serviceName)
.build()

val awsSigner = Aws4Signer.create()
// Amending the AWS Request to force the AMZ target as it's added automatically
// by our implementation
val amendedAwsRequest = awsRequest
.toBuilder()
.appendHeader("X-Amz-Target", serviceName + "." + operationName)
.build()
awsSigner.sign(amendedAwsRequest, params)
}

val smithy4sSigner = AwsSigning.signingFunction[IO](
serviceName,
Expand Down
3 changes: 2 additions & 1 deletion modules/core/src/smithy4s/Blob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,12 @@ sealed trait Blob {

final def ++(other: Blob) = concat(other)

override def equals(other: Any): Boolean =
override def equals(other: Any): Boolean = {
other match {
case otherBlob: Blob => sameBytesAs(otherBlob)
case _ => false
}
}

override def hashCode(): Int = {
import util.hashing.MurmurHash3
Expand Down

0 comments on commit 1d128fa

Please sign in to comment.