Skip to content

Commit

Permalink
add sigv4a (#465)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws authored Nov 6, 2023
1 parent 561f4cd commit 9d77a17
Show file tree
Hide file tree
Showing 18 changed files with 386 additions and 192 deletions.
21 changes: 21 additions & 0 deletions auth/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,24 @@ type IdentityResolver interface {
type IdentityResolverOptions interface {
GetIdentityResolver(schemeID string) IdentityResolver
}

// AnonymousIdentity is a sentinel to indicate no identity.
type AnonymousIdentity struct{}

var _ Identity = (*AnonymousIdentity)(nil)

// Expiration returns the zero value for time, as anonymous identity never
// expires.
func (*AnonymousIdentity) Expiration() time.Time {
return time.Time{}
}

// AnonymousIdentityResolver returns AnonymousIdentity.
type AnonymousIdentityResolver struct{}

var _ IdentityResolver = (*AnonymousIdentityResolver)(nil)

// GetIdentity returns AnonymousIdentity.
func (*AnonymousIdentityResolver) GetIdentity(ctx context.Context, _ smithy.Properties) (Identity, error) {
return &AnonymousIdentity{}, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,8 @@ public static final class Http {
public static final Symbol Response = SmithyGoDependency.NET_HTTP.pointableSymbol("Response");
}
}

public static final class Path {
public static final Symbol Join = SmithyGoDependency.PATH.valueSymbol("Join");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,10 @@ private boolean isTargetDeprecated(Model model, MemberShape member) {
&& !Prelude.isPreludeShape(member.getTarget());
}

public void write(Writable w) {
write("$W", w);
}

@Override
public String toString() {
String contents = super.toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
import software.amazon.smithy.codegen.core.CodegenException;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.auth.GetIdentityMiddlewareGenerator;
import software.amazon.smithy.go.codegen.auth.ResolveAuthSchemeMiddlewareGenerator;
import software.amazon.smithy.go.codegen.auth.SignRequestMiddlewareGenerator;
import software.amazon.smithy.go.codegen.endpoints.EndpointParameterOperationBindingsGenerator;
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
import software.amazon.smithy.go.codegen.integration.ProtocolGenerator;
Expand Down Expand Up @@ -204,8 +201,6 @@ private void generateAddOperationMiddleware() {
MiddlewareRegistrar middlewareRegistrar = runtimeClientPlugin.registerMiddleware().get();
Collection<Symbol> functionArguments = middlewareRegistrar.getFunctionArguments();

// TODO these functions do not all return err like they should. This should be fixed.
// TODO Must be fixed for all public functions.
if (middlewareRegistrar.getInlineRegisterMiddlewareStatement() != null) {
String registerStatement = String.format("if err = stack.%s",
middlewareRegistrar.getInlineRegisterMiddlewareStatement());
Expand All @@ -232,12 +227,6 @@ private void generateAddOperationMiddleware() {
}
});

writer.write("$W", GoWriter.ChainWritable.of(
ResolveAuthSchemeMiddlewareGenerator.generateAddMiddleware(operationSymbol.getName()),
SignRequestMiddlewareGenerator.generateAddMiddleware(),
GetIdentityMiddlewareGenerator.generateAddMiddleware()
).compose());

writer.write("return nil");
});
}
Expand All @@ -251,6 +240,12 @@ private void generateOperationProtocolMiddlewareAdders() {
}
writer.addUseImports(SmithyGoDependency.SMITHY_MIDDLEWARE);

// persist operation input to context for internal build/finalize middleware access
writer.write("""
if err := stack.Serialize.Add(&setOperationInputMiddleware{}, middleware.After); err != nil {
return err
}""");

// Add request serializer middleware
String serializerMiddlewareName = ProtocolGenerator.getSerializeMiddlewareName(
operation.getId(), service, protocolGenerator.getProtocolName());
Expand All @@ -262,6 +257,15 @@ private void generateOperationProtocolMiddlewareAdders() {
operation.getId(), service, protocolGenerator.getProtocolName());
writer.write("err = stack.Deserialize.Add(&$L{}, middleware.After)", deserializerMiddlewareName);
writer.write("if err != nil { return err }");

// FUTURE: retry middleware should be at the front of finalize, right now it's added by the SDK
writer.write("""
if err := addProtocolFinalizerMiddlewares(stack, options, $S); err != nil {
return $T("add protocol finalizers: %v", err)
}""",
operationSymbol.getName(),
GoStdlibTypes.Fmt.Errorf);
writer.write("");
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,26 @@
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.auth.AuthSchemeResolverGenerator;
import software.amazon.smithy.go.codegen.auth.GetIdentityMiddlewareGenerator;
import software.amazon.smithy.go.codegen.auth.ResolveAuthSchemeMiddlewareGenerator;
import software.amazon.smithy.go.codegen.auth.SignRequestMiddlewareGenerator;
import software.amazon.smithy.go.codegen.endpoints.EndpointMiddlewareGenerator;
import software.amazon.smithy.go.codegen.integration.AuthSchemeDefinition;
import software.amazon.smithy.go.codegen.integration.ClientMember;
import software.amazon.smithy.go.codegen.integration.ClientMemberResolver;
import software.amazon.smithy.go.codegen.integration.ConfigField;
import software.amazon.smithy.go.codegen.integration.ConfigFieldResolver;
import software.amazon.smithy.go.codegen.integration.GoIntegration;
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
import software.amazon.smithy.go.codegen.integration.auth.AnonymousDefinition;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.knowledge.ServiceIndex;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.traits.synthetic.NoAuthTrait;
import software.amazon.smithy.utils.MapUtils;

/**
Expand Down Expand Up @@ -83,7 +90,7 @@ final class ServiceGenerator implements Runnable {

@Override
public void run() {
writer.write("$W", generate());
writer.write(generate());
writeProtocolResolverImpls();
}

Expand All @@ -93,7 +100,9 @@ private GoWriter.Writable generate() {
generateClient(),
generateNew(),
generateOptions(),
generateInvokeOperation()
generateInvokeOperation(),
generateInputContextFuncs(),
generateAddProtocolFinalizerMiddleware()
).compose();
}

Expand Down Expand Up @@ -172,11 +181,13 @@ func New(options $options:L, optFns ...func(*$options:L)) *$client:L {
fn(&options)
}
$finalizeResolvers:W
client := &$client:L{
options: options,
}
$finalizeResolvers:W
$finalizeWithClientResolvers:W
$clientMemberResolvers:W
Expand All @@ -188,20 +199,22 @@ func New(options $options:L, optFns ...func(*$options:L)) *$client:L {
"client", serviceSymbol.getName(),
"protocolResolvers", generateProtocolResolvers(),
"initializeResolvers", GoWriter.ChainWritable.of(
plugins.stream()
.flatMap(it -> it.getConfigFieldResolvers().stream())
.filter(it -> it.getLocation().equals(ConfigFieldResolver.Location.CLIENT))
.filter(it -> it.getTarget().equals(ConfigFieldResolver.Target.INITIALIZATION))
.map(this::generateConfigFieldResolver)
.toList()
getConfigResolvers(
ConfigFieldResolver.Location.CLIENT,
ConfigFieldResolver.Target.INITIALIZATION
).map(this::generateConfigFieldResolver).toList()
).compose(),
"finalizeResolvers", GoWriter.ChainWritable.of(
plugins.stream()
.flatMap(it -> it.getConfigFieldResolvers().stream())
.filter(it -> it.getLocation().equals(ConfigFieldResolver.Location.CLIENT))
.filter(it -> it.getTarget().equals(ConfigFieldResolver.Target.FINALIZATION))
.map(this::generateConfigFieldResolver)
.toList()
getConfigResolvers(
ConfigFieldResolver.Location.CLIENT,
ConfigFieldResolver.Target.FINALIZATION
).map(this::generateConfigFieldResolver).toList()
).compose(),
"finalizeWithClientResolvers", GoWriter.ChainWritable.of(
getConfigResolvers(
ConfigFieldResolver.Location.CLIENT,
ConfigFieldResolver.Target.FINALIZATION_WITH_CLIENT
).map(this::generateConfigFieldResolver).toList()
).compose(),
"clientMemberResolvers", GoWriter.ChainWritable.of(
plugins.stream()
Expand Down Expand Up @@ -443,6 +456,7 @@ func resolveAuthSchemes(options *Options) {
private GoWriter.Writable generateOptionsGetIdentityResolver() {
return goTemplate("""
func (o $L) GetIdentityResolver(schemeID string) $T {
$W
$W
return nil
}
Expand All @@ -455,7 +469,8 @@ private GoWriter.Writable generateOptionsGetIdentityResolver() {
.filter(authSchemes::containsKey)
.map(trait -> generateGetIdentityResolverMapping(trait, authSchemes.get(trait)))
.toList()
).compose(false));
).compose(false),
generateGetIdentityResolverMapping(NoAuthTrait.ID, new AnonymousDefinition()));
}

private GoWriter.Writable generateGetIdentityResolverMapping(ShapeId schemeId, AuthSchemeDefinition scheme) {
Expand All @@ -467,9 +482,6 @@ private GoWriter.Writable generateGetIdentityResolverMapping(ShapeId schemeId, A

@SuppressWarnings("checkstyle:LineLength")
private GoWriter.Writable generateInvokeOperation() {
var plugins = runtimePlugins.stream()
.filter(it -> it.matchesService(model, service))
.toList();
return goTemplate("""
func (c *Client) invokeOperation(ctx $context:T, opID string, params interface{}, optFns []func(*Options), stackFns ...func($stack:P, Options) error) (result interface{}, metadata $metadata:T, err error) {
ctx = $clearStackValues:T(ctx)
Expand Down Expand Up @@ -516,20 +528,16 @@ private GoWriter.Writable generateInvokeOperation() {
"newStackHandler", generateNewStackHandler(),
"operationError", SmithyGoTypes.Smithy.OperationError,
"resolvers", GoWriter.ChainWritable.of(
plugins.stream()
.flatMap(it -> it.getConfigFieldResolvers().stream())
.filter(it -> it.getLocation().equals(ConfigFieldResolver.Location.OPERATION))
.filter(it -> it.getTarget().equals(ConfigFieldResolver.Target.INITIALIZATION))
.map(this::generateConfigFieldResolver)
.toList()
getConfigResolvers(
ConfigFieldResolver.Location.OPERATION,
ConfigFieldResolver.Target.INITIALIZATION
).map(this::generateConfigFieldResolver).toList()
).compose(),
"finalizers", GoWriter.ChainWritable.of(
plugins.stream()
.flatMap(it -> it.getConfigFieldResolvers().stream())
.filter(it -> it.getLocation().equals(ConfigFieldResolver.Location.OPERATION))
.filter(it -> it.getTarget().equals(ConfigFieldResolver.Target.FINALIZATION))
.map(this::generateConfigFieldResolver)
.toList()
getConfigResolvers(
ConfigFieldResolver.Location.OPERATION,
ConfigFieldResolver.Target.FINALIZATION
).map(this::generateConfigFieldResolver).toList()
).compose()
));
}
Expand All @@ -552,4 +560,51 @@ private void ensureSupportedProtocol() {
"Protocols other than HTTP are not yet implemented: " + applicationProtocol);
}
}

private Stream<ConfigFieldResolver> getConfigResolvers(
ConfigFieldResolver.Location location, ConfigFieldResolver.Target target
) {
return runtimePlugins.stream()
.filter(it -> it.matchesService(model, service))
.flatMap(it -> it.getConfigFieldResolvers().stream())
.filter(it -> it.getLocation() == location && it.getTarget() == target);
}

private GoWriter.Writable generateInputContextFuncs() {
return goTemplate("""
type operationInputKey struct{}
func setOperationInput(ctx $1T, input interface{}) $1T {
return $2T(ctx, operationInputKey{}, input)
}
func getOperationInput(ctx $1T) interface{} {
return $3T(ctx, operationInputKey{})
}
$4W
""",
GoStdlibTypes.Context.Context,
SmithyGoTypes.Middleware.WithStackValue,
SmithyGoTypes.Middleware.GetStackValue,
new SetOperationInputContextMiddleware().generate());
}

private GoWriter.Writable generateAddProtocolFinalizerMiddleware() {
ensureSupportedProtocol();
return goTemplate("""
func addProtocolFinalizerMiddlewares(stack $P, options $L, operation string) error {
$W
return nil
}
""",
SmithyGoTypes.Middleware.Stack,
CONFIG_NAME,
GoWriter.ChainWritable.of(
ResolveAuthSchemeMiddlewareGenerator.generateAddToProtocolFinalizers(),
GetIdentityMiddlewareGenerator.generateAddToProtocolFinalizers(),
EndpointMiddlewareGenerator.generateAddToProtocolFinalizers(),
SignRequestMiddlewareGenerator.generateAddToProtocolFinalizers()
).compose(false));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.smithy.go.codegen;

import static software.amazon.smithy.go.codegen.GoStackStepMiddlewareGenerator.createSerializeStepMiddleware;
import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate;
import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;

/**
* Middleware to set the final operation input on the context at the start of the serialize step such that protocol
* middlewares in later phases can use it.
*/
public class SetOperationInputContextMiddleware {
public static final String MIDDLEWARE_NAME = "setOperationInputMiddleware";
public static final String MIDDLEWARE_ID = "setOperationInput";

public GoWriter.Writable generate() {
return createSerializeStepMiddleware(MIDDLEWARE_NAME, MiddlewareIdentifier.string(MIDDLEWARE_ID))
.asWritable(generateBody(), emptyGoTemplate());
}

private GoWriter.Writable generateBody() {
return goTemplate("""
ctx = setOperationInput(ctx, in.Parameters)
return next.HandleSerialize(ctx, in)
""");
}
}
Loading

0 comments on commit 9d77a17

Please sign in to comment.