diff --git a/aws/sdk/integration-tests/dynamodb/tests/test-error-classification.rs b/aws/sdk/integration-tests/dynamodb/tests/test-error-classification.rs index 244f20c871..0e6ed27cd2 100644 --- a/aws/sdk/integration-tests/dynamodb/tests/test-error-classification.rs +++ b/aws/sdk/integration-tests/dynamodb/tests/test-error-classification.rs @@ -62,6 +62,7 @@ async fn assert_error_not_transient(error: ReplayedEvent) { let client = Client::from_conf(config); let _item = client .get_item() + .table_name("arn:aws:dynamodb:us-east-2:333333333333:table/table_name") .key("foo", AttributeValue::Bool(true)) .send() .await diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt index 09ffd39864..114289c8b5 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt @@ -23,6 +23,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.config.confi import software.amazon.smithy.rust.codegen.client.smithy.generators.config.loadFromConfigBag import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.RustJmespathShapeTraversalGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.TraversalBinding +import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.TraversalContext import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.TraversedShape import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustType @@ -168,16 +169,18 @@ class EndpointParamsInterceptorGenerator( TraversedShape.from(model, operationShape.inputShape(model)), ), ), + TraversalContext(retainOption = false), ) when (pathTraversal.outputType) { is RustType.Vec -> { - rust(".$setterName($getterName(_input))") - } - - else -> { - rust(".$setterName($getterName(_input).cloned())") + if (pathTraversal.outputType.member is RustType.Reference) { + rust(".$setterName($getterName(_input).map(|v| v.into_iter().cloned().collect::>()))") + } else { + rust(".$setterName($getterName(_input))") + } } + else -> rust(".$setterName($getterName(_input).cloned())") } } @@ -211,6 +214,7 @@ class EndpointParamsInterceptorGenerator( TraversedShape.from(model, operationShape.inputShape(model)), ), ), + TraversalContext(retainOption = false), ) rust("// Generated from JMESPath Expression: $pathValue") diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointResolverGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointResolverGenerator.kt index 13b7df06ef..7a2676949c 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointResolverGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointResolverGenerator.kt @@ -154,6 +154,8 @@ internal class EndpointResolverGenerator( "clippy::comparison_to_empty", // we generate `if let Some(_) = ... { ... }` "clippy::redundant_pattern_matching", + // we generate `if (s.as_ref() as &str) == ("arn:") { ... }`, and `s` can be either `String` or `&str` + "clippy::useless_asref", ) private val context = Context(registry, runtimeConfig) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointTestGenerator.kt index 3e5f0ef335..5c1d758cde 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointTestGenerator.kt @@ -18,7 +18,6 @@ import software.amazon.smithy.rulesengine.traits.ExpectedEndpoint import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.endpoint.Types import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName -import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientInstantiator import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.escape @@ -50,8 +49,6 @@ internal class EndpointTestGenerator( "capture_request" to RuntimeType.captureRequest(runtimeConfig), ) - private val instantiator = ClientInstantiator(codegenContext) - private fun EndpointTestCase.docs(): Writable { val self = this return writable { docs(self.documentation.orElse("no docs")) } @@ -134,7 +131,23 @@ internal class EndpointTestGenerator( value.values.map { member -> writable { rustTemplate( - "#{Document}::from(#{value:W})", + /* + * If we wrote "#{Document}::from(#{value:W})" here, we could encounter a + * compile error due to the following type mismatch: + * the trait `From>` is not implemented for `Vec` + * + * given the following method signature: + * fn resource_arn_list(mut self, value: impl Into<::std::vec::Vec<::std::string::String>>) + * + * with a call site like this: + * .resource_arn_list(vec![::aws_smithy_types::Document::from( + * "arn:aws:dynamodb:us-east-1:333333333333:table/table_name".to_string(), + * )]) + * + * For this reason we use `into()` instead to allow types that need to be converted + * to `Document` to continue working as before, and to support the above use case. + */ + "#{value:W}.into()", *codegenScope, "value" to generateValue(member), ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGenerator.kt index 2c544a11f6..80bba425fd 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGenerator.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.client.smithy.endpoint.rulesgen import org.jetbrains.annotations.Contract import software.amazon.smithy.rulesengine.language.evaluation.type.BooleanType import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType +import software.amazon.smithy.rulesengine.language.evaluation.type.StringType import software.amazon.smithy.rulesengine.language.evaluation.type.Type import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression import software.amazon.smithy.rulesengine.language.syntax.expressions.ExpressionVisitor @@ -24,6 +25,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.util.PANIC +import java.lang.RuntimeException /** * Root expression generator. @@ -56,7 +58,18 @@ class ExpressionGenerator( else -> rust("${ref.name.rustName()}.to_owned()") } } else { - rust(ref.name.rustName()) + try { + when (ref.type()) { + // This ensures we obtain a `&str`, regardless of whether `ref.name.rustName()` returns a `String` or a `&str`. + // Typically, we don't know which type will be returned due to code generation. + is StringType -> rust("${ref.name.rustName()}.as_ref() as &str") + else -> rust(ref.name.rustName()) + } + } catch (_: RuntimeException) { + // Because Typechecking was never invoked upon calling `.type()` on Reference for an expression + // like "{ref}: rust". See `generateLiterals2` in ExprGeneratorTest. + rust(ref.name.rustName()) + } } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGenerator.kt index ce46aff229..a7aac0120e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGenerator.kt @@ -36,14 +36,15 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.SafeNamer import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.asRef import software.amazon.smithy.rust.codegen.core.rustlang.plus import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter @@ -253,6 +254,18 @@ sealed class TraversalBinding { typealias TraversalBindings = List +/** + * Bag of metadata accessible from the generate* methods that can affect how the resulting Rust code should be generated. + * + * [retainOption] determines whether `Option`s are preserved in the context of a projected list. + * Specifically, when applying selectors (used in multi-select lists) to each entity on the left-hand side of a + * projection, we want the resulting map function to return `Option>>` (when `retainOption` is true) + * rather than `Option>` (when it is false). + * This distinction is crucial because the latter could incorrectly result in `None` if any of the selectors + * refer to a field with a `None` value. + */ +data class TraversalContext(val retainOption: Boolean) + /** * Indicates a feature that's part of the JmesPath spec, but that we explicitly decided * not to support in smithy-rs due to the complexity of code generating it for Rust. @@ -292,24 +305,25 @@ class RustJmespathShapeTraversalGenerator( fun generate( expr: JmespathExpression, bindings: TraversalBindings, + context: TraversalContext, ): GeneratedExpression { fun String.attachExpression() = this.substringBefore("\nExpression:") + "\nExpression: ${ExpressionSerializer().serialize(expr)}" try { val result = when (expr) { - is ComparatorExpression -> generateComparator(expr, bindings) - is FunctionExpression -> generateFunction(expr, bindings) - is FieldExpression -> generateField(expr, bindings) + is ComparatorExpression -> generateComparator(expr, bindings, context) + is FunctionExpression -> generateFunction(expr, bindings, context) + is FieldExpression -> generateField(expr, bindings, context) is LiteralExpression -> generateLiteral(expr) - is MultiSelectListExpression -> generateMultiSelectList(expr, bindings) - is AndExpression -> generateAnd(expr, bindings) - is OrExpression -> generateOr(expr, bindings) - is NotExpression -> generateNot(expr, bindings) - is ObjectProjectionExpression -> generateObjectProjection(expr, bindings) - is FilterProjectionExpression -> generateFilterProjection(expr, bindings) - is ProjectionExpression -> generateProjection(expr, bindings) - is Subexpression -> generateSubexpression(expr, bindings) + is MultiSelectListExpression -> generateMultiSelectList(expr, bindings, context) + is AndExpression -> generateAnd(expr, bindings, context) + is OrExpression -> generateOr(expr, bindings, context) + is NotExpression -> generateNot(expr, bindings, context) + is ObjectProjectionExpression -> generateObjectProjection(expr, bindings, context) + is FilterProjectionExpression -> generateFilterProjection(expr, bindings, context) + is ProjectionExpression -> generateProjection(expr, bindings, context) + is Subexpression -> generateSubexpression(expr, bindings, context) is CurrentExpression -> throw JmesPathTraversalCodegenBugException("current expression must be handled in each expression type that can have one") is ExpressionTypeExpression -> throw UnsupportedJmesPathException("Expression type expressions are not supported by smithy-rs") is IndexExpression -> throw UnsupportedJmesPathException("Index expressions are not supported by smithy-rs") @@ -338,15 +352,20 @@ class RustJmespathShapeTraversalGenerator( private fun generateComparator( expr: ComparatorExpression, bindings: TraversalBindings, + context: TraversalContext, ): GeneratedExpression { - val left = generate(expr.left, bindings) - val right = generate(expr.right, bindings) + // When applying a comparator to the left and right operands, both must be non-optional types. + // For this, we avoid retaining `Option` values, even when `generateComparator` is invoked + // further down the chain from a projection expression. + val left = generate(expr.left, bindings, context.copy(retainOption = false)) + val right = generate(expr.right, bindings, context.copy(retainOption = false)) return generateCompare(safeNamer, left, right, expr.comparator.toString()) } private fun generateFunction( expr: FunctionExpression, bindings: TraversalBindings, + context: TraversalContext, ): GeneratedExpression { val ident = safeNamer.safeName("_ret") return when (expr.name) { @@ -354,7 +373,7 @@ class RustJmespathShapeTraversalGenerator( if (expr.arguments.size != 1) { throw InvalidJmesPathTraversalException("Length function takes exactly one argument") } - val arg = generate(expr.arguments[0], bindings) + val arg = generate(expr.arguments[0], bindings, context) if (!arg.isArray() && !arg.isString()) { throw InvalidJmesPathTraversalException("Argument to `length` function must be a collection or string type") } @@ -374,14 +393,14 @@ class RustJmespathShapeTraversalGenerator( if (expr.arguments.size != 2) { throw InvalidJmesPathTraversalException("Contains function takes exactly two arguments") } - val left = generate(expr.arguments[0], bindings) + val left = generate(expr.arguments[0], bindings, context) if (!left.isArray() && !left.isString()) { throw InvalidJmesPathTraversalException("First argument to `contains` function must be a collection or string type") } if (expr.arguments[1].isLiteralNull()) { throw UnsupportedJmesPathException("Checking for null with `contains` is not supported in smithy-rs") } - val right = generate(expr.arguments[1], bindings) + val right = generate(expr.arguments[1], bindings, context) if (!right.isBool() && !right.isNumber() && !right.isString() && !right.isEnum()) { throw UnsupportedJmesPathException("Checking for anything other than booleans, numbers, strings, or enums in the `contains` function is not supported in smithy-rs") } @@ -415,7 +434,8 @@ class RustJmespathShapeTraversalGenerator( outputType = RustType.Reference( lifetime = null, - member = left.outputType.collectionValue(), + member = + left.outputType.collectionValue(), ), output = writable {}, ), @@ -435,7 +455,7 @@ class RustJmespathShapeTraversalGenerator( if (expr.arguments.size != 1) { throw InvalidJmesPathTraversalException("Keys function takes exactly one argument") } - val arg = generate(expr.arguments[0], bindings) + val arg = generate(expr.arguments[0], bindings, context) if (!arg.isObject()) { throw InvalidJmesPathTraversalException("Argument to `keys` function must be an object type") } @@ -446,8 +466,7 @@ class RustJmespathShapeTraversalGenerator( output = writable { arg.output(this) - val outputShape = arg.outputShape.shape - when (outputShape) { + when (val outputShape = arg.outputShape.shape) { is StructureShape -> { // Can't iterate a struct in Rust so source the keys from smithy val keys = @@ -473,6 +492,7 @@ class RustJmespathShapeTraversalGenerator( private fun generateField( expr: FieldExpression, bindings: TraversalBindings, + context: TraversalContext, ): GeneratedExpression { val globalBinding = bindings.find { it is TraversalBinding.Global } val namedBinding = bindings.find { it is TraversalBinding.Named && it.jmespathName == expr.name } @@ -495,21 +515,50 @@ class RustJmespathShapeTraversalGenerator( val targetSym = symbolProvider.toSymbol(target) val ident = safeNamer.safeName("_fld") - return GeneratedExpression( - identifier = ident, - outputShape = TraversedShape.from(model, target), - outputType = targetSym.rustType().asRef(), - output = - writable { - rust( - if (memberSym.isOptional()) { - "let $ident = ${globalBinding.rustName}.${memberSym.name}.as_ref()?;" - } else { - "let $ident = &${globalBinding.rustName}.${memberSym.name};" - }, - ) - }, - ) + if (context.retainOption) { + return GeneratedExpression( + identifier = ident, + outputShape = TraversedShape.from(model, target), + outputType = RustType.Option(targetSym.rustType().asRef()), + output = + writable { + rustTemplate( + if (globalBinding.rustName.startsWith("_fld")) { + if (memberSym.isOptional()) { + // This ensures that `ident` has a type with a single level of `Option`, rather than being + // doubly nested as `Option>`. + "let $ident = ${globalBinding.rustName}.and_then(|v| v.${memberSym.name}.as_ref());" + } else { + "let $ident = ${globalBinding.rustName}.map(|v| &v.${memberSym.name});" + } + } else { + if (memberSym.isOptional()) { + "let $ident = ${globalBinding.rustName}.${memberSym.name}.as_ref();" + } else { + "let $ident = #{Some}(&${globalBinding.rustName}.${memberSym.name});" + } + }, + *preludeScope, + ) + }, + ) + } else { + return GeneratedExpression( + identifier = ident, + outputShape = TraversedShape.from(model, target), + outputType = targetSym.rustType().asRef(), + output = + writable { + rust( + if (memberSym.isOptional()) { + "let $ident = ${globalBinding.rustName}.${memberSym.name}.as_ref()?;" + } else { + "let $ident = &${globalBinding.rustName}.${memberSym.name};" + }, + ) + }, + ) + } } else if (namedBinding != null || globalBinding != null) { throw InvalidJmesPathTraversalException("Cannot look up fields in non-struct shapes") } else { @@ -566,10 +615,11 @@ class RustJmespathShapeTraversalGenerator( private fun generateMultiSelectList( expr: MultiSelectListExpression, bindings: TraversalBindings, + context: TraversalContext, ): GeneratedExpression { val expressions = expr.expressions.map { subexpr -> - generate(subexpr, bindings) + generate(subexpr, bindings, context) } // If we wanted to support mixed-types, we would need to use tuples, add tuple support to RustType, // and update supported functions such as `contains` to operate on tuples. @@ -596,20 +646,23 @@ class RustJmespathShapeTraversalGenerator( private fun generateAnd( expr: AndExpression, bindings: TraversalBindings, - ): GeneratedExpression = generateBooleanOp(expr, "&&", bindings) + context: TraversalContext, + ): GeneratedExpression = generateBooleanOp(expr, "&&", bindings, context) private fun generateOr( expr: OrExpression, bindings: TraversalBindings, - ): GeneratedExpression = generateBooleanOp(expr, "||", bindings) + context: TraversalContext, + ): GeneratedExpression = generateBooleanOp(expr, "||", bindings, context) private fun generateBooleanOp( expr: BinaryExpression, op: String, bindings: TraversalBindings, + context: TraversalContext, ): GeneratedExpression { - val left = generate(expr.left, bindings) - val right = generate(expr.right, bindings) + val left = generate(expr.left, bindings, context) + val right = generate(expr.right, bindings, context) if (!left.isBool() || !right.isBool()) { throw UnsupportedJmesPathException("Applying the `$op` operation doesn't support non-boolean types in smithy-rs") } @@ -632,8 +685,9 @@ class RustJmespathShapeTraversalGenerator( private fun generateNot( expr: NotExpression, bindings: TraversalBindings, + context: TraversalContext, ): GeneratedExpression { - val inner = generate(expr.expression, bindings) + val inner = generate(expr.expression, bindings, context) if (!inner.isBool()) { throw UnsupportedJmesPathException("Negation of a non-boolean type is not supported by smithy-rs") } @@ -655,36 +709,35 @@ class RustJmespathShapeTraversalGenerator( private fun generateProjection( expr: ProjectionExpression, bindings: TraversalBindings, + context: TraversalContext, ): GeneratedExpression { val maybeFlatten = expr.left if (maybeFlatten is SliceExpression) { throw UnsupportedJmesPathException("Slice expressions are not supported by smithy-rs") } - if (maybeFlatten !is FlattenExpression) { - throw UnsupportedJmesPathException("Only projection expressions with flattens are supported by smithy-rs") + val left = + when (maybeFlatten) { + is FlattenExpression -> generate(maybeFlatten.expression, bindings, context) + else -> generate(expr.left, bindings, context) + } + + // Short-circuit in the case where the projection is unnecessary + if (left.isArray() && expr.right is CurrentExpression) { + return left } - val left = generate(maybeFlatten.expression, bindings) + val leftTarget = ( left.outputShape as? TraversedShape.Array ?: throw InvalidJmesPathTraversalException("Left side of the flatten projection MUST resolve to a list or set shape") ).member val leftTargetSym: Any = (leftTarget.shape?.let { symbolProvider.toSymbol(it) }) ?: left.outputType + val leftBinding = "_v" - // Short-circuit in the case where the projection is unnecessary - if (left.isArray() && expr.right is CurrentExpression) { - return left - } + val right = + generate(expr.right, listOf(TraversalBinding.Global(leftBinding, leftTarget)), context.copy(retainOption = true)) - val right = generate(expr.right, listOf(TraversalBinding.Global("v", leftTarget))) - - // If the right expression results in a collection type, then the resulting vec will need to get flattened. - // Otherwise, you'll get `Vec<&Vec>` instead of `Vec<&T>`, which causes later projections to fail to compile. - val (projectionType, flattenNeeded) = - when { - right.isArray() -> right.outputType.stripOuter() to true - else -> RustType.Vec(right.outputType.asRef()) to false - } + val (projectionType, flattenNeeded) = projectionType(right) return safeNamer.safeName("_prj").let { ident -> GeneratedExpression( @@ -696,19 +749,15 @@ class RustJmespathShapeTraversalGenerator( writable { rust("let $ident = ${left.identifier}.iter()") withBlock(".flat_map(|v| {", "})") { - rustBlockTemplate( - "fn map(v: &#{Left}) -> #{Option}<#{Right}>", - *preludeScope, - "Left" to leftTargetSym, - "Right" to right.outputType, - ) { - right.output(this) - rustTemplate("#{Some}(${right.identifier})", *preludeScope) - } + renderMapToProject(this, leftBinding, leftTargetSym, right) rust("map(v)") } if (flattenNeeded) { rust(".flatten()") + // Eliminate temporary `Option` introduced by `retainOption = true` above. + if (right.outputType.isCollectionOfOptions()) { + rust(".flatten()") + } } rustTemplate(".collect::<#{Vec}<_>>();", *preludeScope) }, @@ -719,35 +768,40 @@ class RustJmespathShapeTraversalGenerator( private fun generateFilterProjection( expr: FilterProjectionExpression, bindings: TraversalBindings, + context: TraversalContext, ): GeneratedExpression { - val left = generate(expr.left, bindings) + val left = generate(expr.left, bindings, context) if (!left.isArray()) { throw UnsupportedJmesPathException("Filter projections can only be done on lists or sets in smithy-rs") } val leftTarget = (left.outputShape as TraversedShape.Array).member val leftTargetSym = symbolProvider.toSymbol(leftTarget.shape) + val leftBinding = "_v" val right = if (expr.right is CurrentExpression) { left.copy( outputType = left.outputType.collectionValue().asRef(), + outputShape = leftTarget, output = writable {}, ) } else { - generate(expr.right, listOf(TraversalBinding.Global("_v", leftTarget))) + generate(expr.right, listOf(TraversalBinding.Global(leftBinding, leftTarget)), context.copy(retainOption = true)) } - val comparison = generate(expr.comparison, listOf(TraversalBinding.Global("_v", leftTarget))) + val comparison = generate(expr.comparison, listOf(TraversalBinding.Global("_v", leftTarget)), context) if (!comparison.isBool()) { throw InvalidJmesPathTraversalException("The filter expression comparison must result in a boolean") } + val (projectionType, flattenNeeded) = projectionType(right) + return safeNamer.safeName("_fprj").let { ident -> GeneratedExpression( identifier = ident, outputShape = TraversedShape.Array(null, right.outputShape), - outputType = RustType.Vec(right.outputType), + outputType = projectionType, output = left.output + writable { @@ -765,17 +819,16 @@ class RustJmespathShapeTraversalGenerator( } if (expr.right !is CurrentExpression) { withBlock(".flat_map({", "})") { - rustBlockTemplate( - "fn map(_v: &#{Left}) -> #{Option}<#{Right}>", - *preludeScope, - "Left" to leftTargetSym, - "Right" to right.outputType, - ) { - right.output(this) - rustTemplate("#{Some}(${right.identifier})", *preludeScope) - } + renderMapToProject(this, leftBinding, leftTargetSym, right) rust("map") } + if (flattenNeeded) { + rust(".flatten()") + // Eliminate temporary `Option` introduced by `retainOption = true` above. + if (right.outputType.isCollectionOfOptions()) { + rust(".flatten()") + } + } } rustTemplate(".collect::<#{Vec}<_>>();", *preludeScope) }, @@ -786,11 +839,12 @@ class RustJmespathShapeTraversalGenerator( private fun generateObjectProjection( expr: ObjectProjectionExpression, bindings: TraversalBindings, + context: TraversalContext, ): GeneratedExpression { if (expr.left is CurrentExpression) { throw UnsupportedJmesPathException("Object projection cannot be done on computed maps in smithy-rs") } - val left = generate(expr.left, bindings) + val left = generate(expr.left, bindings, context) if (!left.outputType.isMap()) { throw UnsupportedJmesPathException("Object projection is only supported on map types in smithy-rs") } @@ -800,41 +854,45 @@ class RustJmespathShapeTraversalGenerator( val leftTarget = model.expectShape((left.outputShape.shape as MapShape).value.target) val leftTargetSym = symbolProvider.toSymbol(leftTarget) + val leftBinding = "_v" val right = if (expr.right is CurrentExpression) { left.copy( - outputType = left.outputType.collectionValue().asRef(), + outputType = + left.outputType.collectionValue().asRef(), + outputShape = TraversedShape.from(model, leftTarget), output = writable {}, ) } else { - generate(expr.right, listOf(TraversalBinding.Global("_v", TraversedShape.from(model, leftTarget)))) + generate(expr.right, listOf(TraversalBinding.Global(leftBinding, TraversedShape.from(model, leftTarget))), context.copy(retainOption = true)) } + val (projectionType, flattenNeeded) = projectionType(right) + val ident = safeNamer.safeName("_oprj") return GeneratedExpression( identifier = ident, outputShape = TraversedShape.Array(null, right.outputShape), - outputType = RustType.Vec(right.outputType), + outputType = projectionType, output = left.output + writable { if (expr.right is CurrentExpression) { rustTemplate("let $ident = ${left.identifier}.values().collect::<#{Vec}<_>>();", *preludeScope) } else { - rustBlock("let $ident = ${left.identifier}.values().flat_map(") { - rustBlockTemplate( - "fn map(_v: &#{Left}) -> #{Option}<#{Right}>", - *preludeScope, - "Left" to leftTargetSym, - "Right" to right.outputType, - ) { - right.output(this) - rustTemplate("#{Some}(${right.identifier})", *preludeScope) - } + withBlock("let $ident = ${left.identifier}.values().flat_map({", "})") { + renderMapToProject(this, leftBinding, leftTargetSym, right) rust("map") } - rustTemplate(").collect::<#{Vec}<_>>();", *preludeScope) + if (flattenNeeded) { + rust(".flatten()") + if (right.outputType.isCollectionOfOptions()) { + // Eliminate temporary `Option` introduced by `retainOption = true` above. + rust(".flatten()") + } + } + rustTemplate(".collect::<#{Vec}<_>>();", *preludeScope) } }, ) @@ -843,9 +901,10 @@ class RustJmespathShapeTraversalGenerator( private fun generateSubexpression( expr: Subexpression, bindings: TraversalBindings, + context: TraversalContext, ): GeneratedExpression { - val left = generate(expr.left, bindings) - val right = generate(expr.right, listOf(TraversalBinding.Global(left.identifier, left.outputShape))) + val left = generate(expr.left, bindings, context) + val right = generate(expr.right, listOf(TraversalBinding.Global(left.identifier, left.outputShape)), context) return GeneratedExpression( identifier = right.identifier, outputShape = right.outputShape, @@ -899,6 +958,57 @@ internal fun generateCompare( } } +private fun renderMapToProject( + writer: RustWriter, + leftBinding: String, + leftTargetSym: Any, + right: GeneratedExpression, +) { + writer.apply { + Attribute.AllowClippyLetAndReturn.render(this) + rustBlockTemplate( + if (right.outputType is RustType.Option) { + "fn map($leftBinding: &#{Left}) -> #{Right}" + } else { + "fn map($leftBinding: &#{Left}) -> #{Option}<#{Right}>" + }, + *preludeScope, + "Left" to leftTargetSym, + "Right" to right.outputType, + ) { + right.output(this) + if (right.outputType is RustType.Option) { + rust(right.identifier) + } else { + rustTemplate("#{Some}(${right.identifier})", *preludeScope) + } + } + } +} + +/** + * This function takes the `GeneratedExpression` of a projection expression's right-hand side (RHS) + * and returns a pair: + * - A `RustType` representing the final evaluation of the projection expression. + * - A `Boolean` indicating whether the resulting vector needs to be flattened. + * Flattening ensures you get `Vec<&T>` instead of `Vec<&Vec>`, which would otherwise cause + * subsequent projections to fail to compile. + */ +private fun projectionType(right: GeneratedExpression) = + when { + right.isArray() && right.outputType is RustType.Vec -> { + // A case like `lists.structs[].[integer]` where RHS output type (`[integer]`) is `Vec>`, and we want Vec<&T> + RustType.Vec(right.outputType.member.stripOuter()) to true + } + right.isArray() && right.outputType is RustType.Option -> { + // A case like `maps.structs[].strings` where RHS (strings) output type (`[strings]`) is `Option<&Vec>`, and we want Vec<&T> + RustType.Vec(right.outputType.member.stripOuter().stripOuter().asRef()) to true + } + else -> { + RustType.Vec(right.outputType.stripOuter()) to false + } + } + private fun RustType.dereference(): RustType = if (this is RustType.Reference) { this.member.dereference() @@ -916,6 +1026,13 @@ private fun RustType.isNumber(): Boolean = this.dereference().let { it is RustTy private fun RustType.isDoubleReference(): Boolean = this is RustType.Reference && this.member is RustType.Reference +private fun RustType.isCollectionOfOptions(): Boolean = + try { + collectionValue() is RustType.Option + } catch (_: RuntimeException) { + false + } + private fun RustType.collectionValue(): RustType = when (this) { is RustType.Reference -> member.collectionValue() diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustWaiterMatcherGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustWaiterMatcherGenerator.kt index 23123eb6dc..b90f8e0c22 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustWaiterMatcherGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustWaiterMatcherGenerator.kt @@ -109,6 +109,7 @@ class RustWaiterMatcherGenerator( RustJmespathShapeTraversalGenerator(codegenContext).generate( pathExpression, listOf(TraversalBinding.Global("_output", TraversedShape.from(model, outputShape))), + TraversalContext(retainOption = false), ) generatePathTraversalMatcher( @@ -132,6 +133,7 @@ class RustWaiterMatcherGenerator( TraversalBinding.Named("input", "_input", TraversedShape.from(model, inputShape)), TraversalBinding.Named("output", "_output", TraversedShape.from(model, outputShape)), ), + TraversalContext(retainOption = false), ) generatePathTraversalMatcher( diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGeneratorTest.kt index 89e75fc36d..540dbc8755 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGeneratorTest.kt @@ -76,7 +76,7 @@ internal class ExprGeneratorTest { rust("assert_eq!(true, #W);", gen.generate(Expression.of(true))) rust("assert_eq!(false, #W);", gen.generate(Expression.of(false))) rust("""assert_eq!("blah", #W);""", gen.generate(Expression.of("blah"))) - rust("""assert_eq!("helloworld: rust", #W);""", gen.generate(Expression.of("{ref}: rust"))) + rust("""assert_eq!("helloworld: rust", #W);""", gen.generate(Expression.of("{extra}: rust"))) rustTemplate( """ let mut expected = std::collections::HashMap::new(); @@ -94,6 +94,6 @@ internal class ExprGeneratorTest { ), ), ) - } + }.compileAndTest(runClippy = true) } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGeneratorTest.kt index 84b3b8798d..41297ecccb 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGeneratorTest.kt @@ -50,6 +50,8 @@ class RustJmespathShapeTraversalGeneratorTest { val entityMaps = symbolProvider.toSymbol(model.lookup("test#EntityMaps")) val enum = symbolProvider.toSymbol(model.lookup("test#Enum")) val struct = symbolProvider.toSymbol(model.lookup("test#Struct")) + val subStruct = symbolProvider.toSymbol(model.lookup("test#SubStruct")) + val traversalContext = TraversalContext(retainOption = false) val testInputDataFn: RuntimeType val testOutputDataFn: RuntimeType @@ -107,15 +109,33 @@ class RustJmespathShapeTraversalGeneratorTest { .enums(#{Enum}::Two) .int_enums(2) .structs(#{Struct}::builder() + .required_integer(1) .primitives(primitives.clone()) - .build()) + .strings("lists_structs1_strings1") + .sub_structs(#{SubStruct}::builder().sub_struct_primitives(primitives.clone()).build()) + .sub_structs(#{SubStruct}::builder().sub_struct_primitives( + #{EntityPrimitives}::builder() + .required_boolean(false) + .required_string("why") + .build() + .unwrap()) + .build()) + .build() + .unwrap()) + .structs(#{Struct}::builder() + .required_integer(2) + .integer(1) + .string("foo") + .build() + .unwrap()) .build()) .maps(#{EntityMaps}::builder() .strings("foo", "foo_oo") .strings("bar", "bar_ar") .booleans("foo", true) .booleans("bar", false) - .structs("foo", #{Struct}::builder().integer(5).build()) + .structs("foo", #{Struct}::builder().required_integer(2).integer(5).strings("maps_foo_struct_strings1").build().unwrap()) + .structs("bar", #{Struct}::builder().required_integer(3).primitives(primitives).integer(7).build().unwrap()) .build()) .build() } @@ -126,6 +146,7 @@ class RustJmespathShapeTraversalGeneratorTest { "EntityMaps" to entityMaps, "Enum" to enum, "Struct" to struct, + "SubStruct" to subStruct, ) } } @@ -150,7 +171,7 @@ class RustJmespathShapeTraversalGeneratorTest { else -> listOf(TraversalBinding.Global("_output", TraversedShape.from(model, outputShape))) } - val generated = generator.generate(parsed, bindings) + val generated = generator.generate(parsed, bindings, traversalContext) rustCrate.unitTest(testName) { rust("// jmespath: $expression") rust("// jmespath parsed: $parsed") @@ -194,7 +215,7 @@ class RustJmespathShapeTraversalGeneratorTest { else -> listOf(TraversalBinding.Global("_output", TraversedShape.from(model, outputShape))) } - generator.generate(parsed, bindings).output(RustWriter.forModule("unsupported")) + generator.generate(parsed, bindings, traversalContext).output(RustWriter.forModule("unsupported")) fail("expression '$expression' should have thrown InvalidJmesPathTraversalException") } catch (ex: InvalidJmesPathTraversalException) { ex.message shouldContain contains @@ -211,6 +232,7 @@ class RustJmespathShapeTraversalGeneratorTest { generator.generate( parsed, listOf(TraversalBinding.Global("_output", TraversedShape.from(model, outputShape))), + traversalContext, ).output(RustWriter.forModule("unsupported")) fail("expression '$expression' should have thrown UnsupportedJmesPathException") } catch (ex: UnsupportedJmesPathException) { @@ -248,6 +270,7 @@ class RustJmespathShapeTraversalGeneratorTest { filterProjections() booleanOperations() multiSelectLists() + projectionFollowedByMultiSelectLists() complexCombinationsOfFeatures() unsupported("&('foo')", "Expression type expressions") @@ -319,6 +342,35 @@ class RustJmespathShapeTraversalGeneratorTest { ) } + private fun TestCase.projectionFollowedByMultiSelectLists() { + fun test( + name: String, + expression: String, + assertions: RustWriter.() -> Unit, + ) = testCase("traverse_$name", expression, assertions) + + // Each struct in projection sets at least one of the selected fields, e.g. either `string` or `primitives.string` is `Some`. + test("wildcard_projection_followed_by_multiselectlists", "lists.structs[*].[string, primitives.string][]") { + rust("""assert_eq!(vec!["test", "foo"], result);""") + } + + // The `primitives` field is `None` in structs obtained via `lists.structs[?string == 'foo']` + test("filter_projection_followed_by_multiselectlists_empty", "lists.structs[?string == 'foo'].[primitives.string, primitives.requiredString][]") { + rust("assert!(result.is_empty());") + } + + // Unlike the previous, the `integer` field is set in a struct in the projection. + test("filter_projection_followed_by_multiselectlists", "lists.structs[?string == 'foo'].[integer, primitives.integer][]") { + rust("assert_eq!(vec![&1], result);") + } + + test("object_projection_followed_by_multiselectlists", "maps.structs.*.[integer, primitives.integer][]") { + rust("let mut result = result;") + rust("result.sort();") + rust("assert_eq!(vec![&4, &5, &7], result);") + } + } + private fun TestCase.flattenExpressions() { fun test( name: String, @@ -333,6 +385,10 @@ class RustJmespathShapeTraversalGeneratorTest { rust("assert_eq!(1, result.len());") rust("assert_eq!(\"test\", result[0]);") } + test("no_shortcircuit_continued", "lists.structs[].strings") { + rust("assert_eq!(1, result.len());") + rust("assert_eq!(\"lists_structs1_strings1\", result[0]);") + } test("nested_flattens", "lists.structs[].subStructs[].subStructPrimitives.string") { // it should compile } @@ -364,7 +420,7 @@ class RustJmespathShapeTraversalGeneratorTest { assertions: RustWriter.() -> Unit, ) = testCase("traverse_fn_$name", expression, assertions) - test("list_length", "length(lists.structs[])", simple("assert_eq!(1, result);")) + test("list_length", "length(lists.structs[])", simple("assert_eq!(2, result);")) test("string_length", "length(primitives.string)", simple("assert_eq!(4, result);")) test("string_contains_false", "contains(primitives.string, 'foo')", expectFalse) @@ -481,17 +537,28 @@ class RustJmespathShapeTraversalGeneratorTest { assertions: RustWriter.() -> Unit, ) = testCase("traverse_obj_projection_$name", expression, assertions) - test("traverse_obj_projection_simple", "maps.booleans.*") { + test("simple", "maps.booleans.*") { rust("assert_eq!(2, result.len());") // Order is non-deterministic because we're getting the values of a hash map rust("assert_eq!(1, result.iter().filter(|&&&b| b == true).count());") rust("assert_eq!(1, result.iter().filter(|&&&b| b == false).count());") } - test("traverse_obj_projection_continued", "maps.structs.*.integer") { - rust("assert_eq!(1, result.len());") - rust("assert_eq!(5, **result.get(0).unwrap());") + test("continued", "maps.structs.*.integer") { + rust("let mut result = result;") + rust("result.sort();") + rust("assert_eq!(vec![&5, &7], result);") + } + test("followed_by_optional_array", "maps.structs.*.strings") { + rust("assert_eq!(vec![\"maps_foo_struct_strings1\"], result);") + } + test("w_function", "length(maps.structs.*.strings) == `1`", expectTrue) + + // Derived from https://github.com/awslabs/aws-sdk-rust/blob/8848f51e58fead8d230a0c15f0434b2812825c38/aws-models/ses.json#L2985 + test("followed_by_required_field", "maps.structs.*.requiredInteger") { + rust("let mut result = result;") + rust("result.sort();") + rust("assert_eq!(vec![&2, &3], result);") } - test("traverse_obj_projection_complex", "length(maps.structs.*.strings) == `0`", expectTrue) unsupported("primitives.integer.*", "Object projection is only supported on map types") unsupported("lists.structs[?`true`].*", "Object projection cannot be done on computed maps") @@ -505,13 +572,13 @@ class RustJmespathShapeTraversalGeneratorTest { ) = testCase("traverse_filter_projection_$name", expression, assertions) test("boollit", "lists.structs[?`true`]") { - rust("assert_eq!(1, result.len());") + rust("assert_eq!(2, result.len());") } test("intcmp", "lists.structs[?primitives.integer > `0`]") { rust("assert_eq!(1, result.len());") } test("boollit_continued_empty", "lists.structs[?`true`].integer") { - rust("assert_eq!(0, result.len());") + rust("assert_eq!(1, result.len());") } test("boollit_continued", "lists.structs[?`true`].primitives.integer") { rust("assert_eq!(1, result.len());") @@ -580,6 +647,14 @@ class RustJmespathShapeTraversalGeneratorTest { "(length(lists.structs[?!(integer < `0`) && integer >= `0` || `false`]) == `5`) == contains(lists.integers, length(maps.structs.*.strings))", itCompiles, ) + + // Derived from https://github.com/awslabs/aws-sdk-rust/blob/8848f51e58fead8d230a0c15f0434b2812825c38/aws-models/auto-scaling.json#L4202 + // The first argument to `contains` evaluates to `Some([true])` since `length(...)` is 1 and `requiredInteger` in that struct is 1. + test( + "2", + "contains(lists.structs[].[length(subStructs[?subStructPrimitives.requiredString=='why']) >= requiredInteger][], `true`)", + expectTrue, + ) } private fun testModel() = @@ -660,6 +735,8 @@ class RustJmespathShapeTraversalGeneratorTest { } structure Struct { + @required + requiredInteger: Integer, primitives: EntityPrimitives, strings: StringList, integer: Integer,