From 20cbe1dc986e41ead848b8cbfa3803a110603012 Mon Sep 17 00:00:00 2001 From: Luc Talatinian Date: Wed, 22 May 2024 16:27:08 -0400 Subject: [PATCH] add shape deserializer overrides --- .../HttpBindingProtocolGenerator.java | 27 ++++++++++++++++--- .../integration/RuntimeClientPlugin.java | 25 +++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java index a5c345028..ef1fb545c 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java @@ -19,7 +19,9 @@ import java.util.Collection; import java.util.Comparator; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.TreeSet; @@ -75,6 +77,7 @@ public abstract class HttpBindingProtocolGenerator implements ProtocolGenerator private final Set serializeDocumentBindingShapes = new TreeSet<>(); private final Set deserializeDocumentBindingShapes = new TreeSet<>(); private final Set deserializingErrorShapes = new TreeSet<>(); + private final Map deserializerOverrides = new HashMap<>(); /** * Creates a Http binding protocol generator. @@ -1082,6 +1085,13 @@ private String conditionallyBase64Encode( @Override public void generateResponseDeserializers(GenerationContext context) { + deserializerOverrides.putAll( + context.getIntegrations().stream() + .flatMap(it -> it.getClientPlugins(context.getModel(), context.getService()).stream()) + .flatMap(it -> it.getShapeDeserializers().entrySet().stream()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ); + EventStreamIndex streamIndex = EventStreamIndex.of(context.getModel()); for (OperationShape operation : getHttpBindingOperations(context)) { @@ -1347,13 +1357,24 @@ private void writeHeaderDeserializerFunction( ) { writer.openBlock("if headerValues := response.Header.Values($S); len(headerValues) != 0 {", "}", binding.getLocationName(), () -> { - Shape targetShape = context.getModel().expectShape(memberShape.getTarget()); + var target = memberShape.getTarget(); + Shape targetShape = context.getModel().expectShape(target); String operand = "headerValues"; operand = writeHeaderValueAccessor(context, writer, targetShape, binding, operand); - String value = generateHttpHeaderValue(context, writer, memberShape, binding, - operand); + if (deserializerOverrides.containsKey(target)) { + writer.write(""" + deserOverride, err := $T($L) + if err != nil { + return err + } + v.$L = deserOverride + """, deserializerOverrides.get(target), operand, memberName); + return; + } + + var value = generateHttpHeaderValue(context, writer, memberShape, binding, operand); writer.write("v.$L = $L", memberName, CodegenUtils.getAsPointerIfPointable(context.getModel(), writer, GoPointableIndex.of(context.getModel()), memberShape, value)); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/RuntimeClientPlugin.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/RuntimeClientPlugin.java index 94f499649..add1ccdae 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/RuntimeClientPlugin.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/RuntimeClientPlugin.java @@ -23,6 +23,7 @@ import java.util.Optional; import java.util.Set; import java.util.function.BiPredicate; +import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.go.codegen.GoWriter; import software.amazon.smithy.go.codegen.auth.AuthParameter; import software.amazon.smithy.go.codegen.auth.AuthParametersResolver; @@ -54,6 +55,7 @@ public final class RuntimeClientPlugin implements ToSmithyBuilder endpointBuiltinBindings; private final Map authSchemeDefinitions; + private final Map shapeDeserializers; private RuntimeClientPlugin(Builder builder) { operationPredicate = builder.operationPredicate; @@ -67,6 +69,7 @@ private RuntimeClientPlugin(Builder builder) { configFieldResolvers = builder.configFieldResolvers; endpointBuiltinBindings = builder.endpointBuiltinBindings; authSchemeDefinitions = builder.authSchemeDefinitions; + shapeDeserializers = builder.shapeDeserializers; } @FunctionalInterface @@ -130,6 +133,14 @@ public Map getAuthSchemeDefinitions() { return authSchemeDefinitions; } + /** + * Gets the registered shape deserializers. + * @return the deserializers. + */ + public Map getShapeDeserializers() { + return shapeDeserializers; + } + /** * Gets the optionally present middleware registrar object that resolves to middleware registering function. * @@ -242,6 +253,7 @@ public static final class Builder implements SmithyBuilder private Map endpointBuiltinBindings = new HashMap<>(); private MiddlewareRegistrar registerMiddleware; private Map authSchemeDefinitions = new HashMap<>(); + private Map shapeDeserializers = new HashMap<>(); @Override public RuntimeClientPlugin build() { @@ -496,5 +508,18 @@ public Builder addAuthSchemeDefinition(ShapeId schemeId, AuthSchemeDefinition de this.authSchemeDefinitions.put(schemeId, definition); return this; } + + /** + * Registers a codegen definition for a custom shape deserializer. This feature is currently only supported for + * overriding deserialization in HTTP bindings. + * @param id The shape id. + * @param deserializer The deserializer symbol. The written code MUST be a function which accepts the + * corresponding type for the shape and returns (*type, error) accordingly. + * @return Returns the builder. + */ + public Builder addShapeDeserializer(ShapeId id, Symbol deserializer) { + this.shapeDeserializers.put(id, deserializer); + return this; + } } }