Skip to content

Commit

Permalink
feat: add authentication middleware flow
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws committed Oct 17, 2023
1 parent 1cac3db commit 15a7ba2
Show file tree
Hide file tree
Showing 23 changed files with 737 additions and 267 deletions.
2 changes: 1 addition & 1 deletion auth/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type Identity interface {
// IdentityResolver defines the interface through which an Identity is
// retrieved.
type IdentityResolver interface {
GetIdentity(ctx context.Context, params *smithy.Properties) (Identity, error)
GetIdentity(context.Context, smithy.Properties) (Identity, error)
}

// IdentityResolverOptions defines the interface through which an entity can be
Expand Down
15 changes: 15 additions & 0 deletions auth/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,24 @@ package auth

import "github.com/aws/smithy-go"

type (
authOptionsKey struct{}
)

// Option represents a possible authentication method for an operation.
type Option struct {
SchemeID string
IdentityProperties smithy.Properties
SignerProperties smithy.Properties
}

// GetAuthOptions gets auth Options from Properties.
func GetAuthOptions(p *smithy.Properties) ([]*Option, bool) {
v, ok := p.Get(authOptionsKey{}).([]*Option)
return v, ok
}

// SetAuthOptions sets auth Options on Properties.
func SetAuthOptions(p *smithy.Properties, options []*Option) {
p.Set(authOptionsKey{}, options)
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ public static GoStackStepMiddlewareGenerator createSerializeStepMiddleware(Strin
SymbolUtils.createValueSymbolBuilder("SerializeHandler", SmithyGoDependency.SMITHY_MIDDLEWARE).build());
}

/**
* Create a new FinalizeStep middleware generator with the provided type name.
*
* @param name is the type name to identify the middleware.
* @param id the unique ID for the middleware.
* @return the middleware generator.
*/
public static GoStackStepMiddlewareGenerator createFinalizeStepMiddleware(String name, MiddlewareIdentifier id) {
return createMiddleware(name,
id,
"HandleFinalize",
SmithyGoTypes.Middleware.FinalizeInput,
SmithyGoTypes.Middleware.FinalizeOutput,
SmithyGoTypes.Middleware.FinalizeHandler);
}

/**
* Create a new DeserializeStep middleware generator with the provided type name.
*
Expand Down Expand Up @@ -216,6 +232,20 @@ public void writeMiddleware(
});
}

/**
* Creates a Writable which renders the middleware.
* @param body A Writable that renders the middleware body.
* @param fields A Writable that renders the middleware struct's fields.
* @return the writable.
*/
public GoWriter.Writable asWritable(GoWriter.Writable body, GoWriter.Writable fields) {
return writer -> writeMiddleware(
writer,
(generator, bodyWriter) -> bodyWriter.write("$W", body),
(generator, fieldWriter) -> fieldWriter.write("$W", fields)
);
}

/**
* Returns a new middleware generator builder.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import software.amazon.smithy.model.traits.RequiredTrait;
import software.amazon.smithy.model.traits.StringTrait;
import software.amazon.smithy.utils.AbstractCodeWriter;
import software.amazon.smithy.utils.ListUtils;
import software.amazon.smithy.utils.StringUtils;

/**
Expand Down Expand Up @@ -998,6 +999,18 @@ public ChainWritable() {
writables = new ArrayList<>();
}

public static ChainWritable of(GoWriter.Writable... writables) {
var chain = new ChainWritable();
chain.writables.addAll(ListUtils.of(writables));
return chain;
}

public static ChainWritable of(List<GoWriter.Writable> writables) {
var chain = new ChainWritable();
chain.writables.addAll(writables);
return chain;
}

public boolean isEmpty() {
return writables.isEmpty();
}
Expand All @@ -1019,17 +1032,21 @@ public ChainWritable add(boolean include, GoWriter.Writable writable) {
return this;
}

public GoWriter.Writable compose() {
public GoWriter.Writable compose(boolean writeNewlines) {
return (GoWriter writer) -> {
var hasPrevious = false;
for (GoWriter.Writable writable : writables) {
if (hasPrevious) {
if (hasPrevious && writeNewlines) {
writer.write("");
}
hasPrevious = true;
writer.write("$W", writable);
}
};
}

public GoWriter.Writable compose() {
return compose(true);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import software.amazon.smithy.codegen.core.CodegenException;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.auth.GetIdentityMiddlewareGenerator;
import software.amazon.smithy.go.codegen.auth.ResolveAuthSchemeMiddlewareGenerator;
import software.amazon.smithy.go.codegen.auth.SignRequestMiddlewareGenerator;
import software.amazon.smithy.go.codegen.endpoints.EndpointParameterOperationBindingsGenerator;
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
import software.amazon.smithy.go.codegen.integration.ProtocolGenerator;
Expand Down Expand Up @@ -229,6 +232,12 @@ private void generateAddOperationMiddleware() {
}
});

writer.write("$W", GoWriter.ChainWritable.of(
ResolveAuthSchemeMiddlewareGenerator.generateAddMiddleware(operationSymbol.getName()),
SignRequestMiddlewareGenerator.generateAddMiddleware(),
GetIdentityMiddlewareGenerator.generateAddMiddleware()
).compose());

writer.write("return nil");
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,28 @@

package software.amazon.smithy.go.codegen;

import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.auth.AuthSchemeResolverGenerator;
import software.amazon.smithy.go.codegen.integration.AuthSchemeDefinition;
import software.amazon.smithy.go.codegen.integration.ClientMember;
import software.amazon.smithy.go.codegen.integration.ClientMemberResolver;
import software.amazon.smithy.go.codegen.integration.ConfigField;
import software.amazon.smithy.go.codegen.integration.ConfigFieldResolver;
import software.amazon.smithy.go.codegen.integration.GoIntegration;
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.knowledge.ServiceIndex;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.ShapeId;

/**
* Generates a service client and configuration.
Expand All @@ -46,6 +53,7 @@ final class ServiceGenerator implements Runnable {
private final List<GoIntegration> integrations;
private final List<RuntimeClientPlugin> runtimePlugins;
private final ApplicationProtocol applicationProtocol;
private final Map<ShapeId, AuthSchemeDefinition> authSchemes;

ServiceGenerator(
GoSettings settings,
Expand All @@ -65,6 +73,10 @@ final class ServiceGenerator implements Runnable {
this.integrations = integrations;
this.runtimePlugins = runtimePlugins;
this.applicationProtocol = applicationProtocol;
this.authSchemes = integrations.stream()
.flatMap(it -> it.getClientPlugins(model, service).stream())
.flatMap(it -> it.getAuthSchemeDefinitions().entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

@Override
Expand Down Expand Up @@ -96,6 +108,7 @@ public void run() {
generateConstructor(serviceSymbol);
generateConfig();
generateClientInvokeOperation();
generateProtocolResolvers();
}

private void writeClientMemberResolvers(
Expand Down Expand Up @@ -153,6 +166,7 @@ private void generateConstructor(Symbol serviceSymbol) {
resolver.getLocation() == ConfigFieldResolver.Location.CLIENT
&& resolver.getTarget() == ConfigFieldResolver.Target.INITIALIZATION);
}
writeProtocolResolvers();

writer.openBlock("for _, fn := range optFns {", "}", () -> writer.write("fn(&options)"));
writer.write("");
Expand Down Expand Up @@ -203,6 +217,8 @@ private void generateConfig() {
generateApplicationProtocolConfig();
}).write("");

generateGetIdentityResolver();

writer.writeDocs("WithAPIOptions returns a functional option for setting the Client's APIOptions option.");
writer.openBlock("func WithAPIOptions(optFns ...func(*middleware.Stack) error) func(*Options) {", "}", () -> {
writer.openBlock("return func(o *Options) {", "}", () -> {
Expand Down Expand Up @@ -284,6 +300,12 @@ private void generateApplicationProtocolConfig() {
writer.writeDocs(
"The HTTP client to invoke API calls with. Defaults to client's default HTTP implementation if nil.");
writer.write("HTTPClient HTTPClient").write("");

writer.writeDocs("The auth scheme resolver which determines how to authenticate for each operation.");
writer.write("AuthSchemeResolver $L", AuthSchemeResolverGenerator.INTERFACE_NAME).write("");

writer.writeDocs("The list of auth schemes supported by the client.");
writer.write("AuthSchemes []$T", SmithyGoTypes.Transport.Http.AuthScheme).write("");
}

private void generateApplicationProtocolTypes() {
Expand All @@ -294,6 +316,71 @@ private void generateApplicationProtocolTypes() {
}).write("");
}

private void writeProtocolResolvers() {
ensureSupportedProtocol();

writer.write("""
resolveAuthSchemeResolver(&options)
resolveAuthSchemes(&options)
""");
}

private void generateProtocolResolvers() {
ensureSupportedProtocol();

var schemeMappings = GoWriter.ChainWritable.of(
ServiceIndex.of(model)
.getEffectiveAuthSchemes(service).keySet().stream()
.filter(authSchemes::containsKey)
.map(authSchemes::get)
.map(it -> goTemplate("$W, ", it.generateDefaultAuthScheme()))
.toList()
).compose(false);

writer.write("""
func resolveAuthSchemeResolver(options *Options) {
options.AuthSchemeResolver = &$L{}
}
func resolveAuthSchemes(options *Options) {
options.AuthSchemes = []$T{
$W
}
}
""",
AuthSchemeResolverGenerator.DEFAULT_NAME,
SmithyGoTypes.Transport.Http.AuthScheme,
schemeMappings);
}

private void generateGetIdentityResolver() {
var resolverMappings = GoWriter.ChainWritable.of(
ServiceIndex.of(model)
.getEffectiveAuthSchemes(service).keySet().stream()
.filter(authSchemes::containsKey)
.map(trait -> generateGetIdentityResolverMapping(trait, authSchemes.get(trait)))
.toList()
);

writer.write("""
func (o $L) GetIdentityResolver(schemeID string) $T {
$W
return nil
}
""",
CONFIG_NAME,
SmithyGoTypes.Auth.IdentityResolver,
resolverMappings.compose(false));
}

private GoWriter.Writable generateGetIdentityResolverMapping(ShapeId schemeId, AuthSchemeDefinition scheme) {
return goTemplate("""
if schemeID == $S {
return $W
}""", schemeId.toString(), scheme.generateOptionsIdentityResolver());
}

private void generateClientInvokeOperation() {
writer.addUseImports(SmithyGoDependency.CONTEXT);
writer.addUseImports(SmithyGoDependency.SMITHY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,56 @@ public static final class Ptr {

public static final class Middleware {
public static final Symbol Stack = SmithyGoDependency.SMITHY_MIDDLEWARE.pointableSymbol("Stack");
public static final Symbol Metadata = SmithyGoDependency.SMITHY_MIDDLEWARE.pointableSymbol("Metadata");
public static final Symbol WithStackValue = SmithyGoDependency.SMITHY_MIDDLEWARE.valueSymbol("WithStackValue");
public static final Symbol GetStackValue = SmithyGoDependency.SMITHY_MIDDLEWARE.valueSymbol("GetStackValue");
public static final Symbol After = SmithyGoDependency.SMITHY_MIDDLEWARE.valueSymbol("After");
public static final Symbol Before = SmithyGoDependency.SMITHY_MIDDLEWARE.valueSymbol("Before");

public static final Symbol SerializeInput = SmithyGoDependency.SMITHY_MIDDLEWARE.pointableSymbol("SerializeInput");
public static final Symbol SerializeOutput = SmithyGoDependency.SMITHY_MIDDLEWARE.pointableSymbol("SerializeOutput");
public static final Symbol SerializeHandler = SmithyGoDependency.SMITHY_MIDDLEWARE.pointableSymbol("SerializeHandler");
public static final Symbol Metadata = SmithyGoDependency.SMITHY_MIDDLEWARE.pointableSymbol("Metadata");
public static final Symbol FinalizeInput = SmithyGoDependency.SMITHY_MIDDLEWARE.pointableSymbol("FinalizeInput");
public static final Symbol FinalizeOutput = SmithyGoDependency.SMITHY_MIDDLEWARE.pointableSymbol("FinalizeOutput");
public static final Symbol FinalizeHandler = SmithyGoDependency.SMITHY_MIDDLEWARE.pointableSymbol("FinalizeHandler");
}

public static final class Transport {
public static final class Http {
public static final Symbol Request = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.pointableSymbol("Request");

public static final Symbol AuthScheme = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("AuthScheme");
public static final Symbol SchemeIDAnonymous = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("SchemeIDAnonymous");
public static final Symbol NewAnonymousScheme = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("NewAnonymousScheme");

public static final Symbol NewSigV4Option = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("NewSigV4Option");
public static final Symbol NewSigV4AOption = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("NewSigV4AOption");
public static final Symbol NewBearerOption = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("NewBearerOption");
public static final Symbol NewAnonymousOption = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("NewAnonymousOption");

public static final Symbol SigV4Properties = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.pointableSymbol("SigV4Properties");
public static final Symbol SigV4AProperties = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.pointableSymbol("SigV4AProperties");

public static final Symbol SetSigV4SigningName = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("SetSigV4SigningName");
public static final Symbol SetSigV4SigningRegion = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("SetSigV4SigningRegion");
public static final Symbol SetSigV4ASigningName = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("SetSigV4ASigningName");
public static final Symbol SetSigV4ASigningRegions = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("SetSigV4ASigningRegions");
public static final Symbol SetIsUnsignedPayload = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("SetIsUnsignedPayload");
public static final Symbol SetDisableDoubleEncoding = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("SetDisableDoubleEncoding");
}
}

public static final class Auth {
public static final Symbol Option = SmithyGoDependency.SMITHY_AUTH.pointableSymbol("Option");
public static final Symbol IdentityResolver = SmithyGoDependency.SMITHY_AUTH.valueSymbol("IdentityResolver");
public static final Symbol Identity = SmithyGoDependency.SMITHY_AUTH.valueSymbol("Identity");
public static final Symbol GetAuthOptions = SmithyGoDependency.SMITHY_AUTH.valueSymbol("GetAuthOptions");
public static final Symbol SetAuthOptions = SmithyGoDependency.SMITHY_AUTH.valueSymbol("SetAuthOptions");

public static final class Bearer {
public static final Symbol TokenProvider = SmithyGoDependency.SMITHY_AUTH_BEARER.valueSymbol("TokenProvider");
public static final Symbol Signer = SmithyGoDependency.SMITHY_AUTH_BEARER.valueSymbol("Signer");
public static final Symbol NewSignHTTPSMessage = SmithyGoDependency.SMITHY_AUTH_BEARER.valueSymbol("NewSignHTTPSMessage");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ public void generate() {
}

context.getWriter().get()
.write("$W", new AuthParametersGenerator(context).generate())
.write("")
.write("$W", new AuthParametersResolverGenerator(context).generate())
.write("")
.write("$W", getResolverGenerator().generate());
.write("$W\n", new AuthParametersGenerator(context).generate())
.write("$W\n", new AuthParametersResolverGenerator(context).generate())
.write("$W\n", getResolverGenerator().generate())
.write("$W\n", new ResolveAuthSchemeMiddlewareGenerator(context).generate())
.write("$W\n", new GetIdentityMiddlewareGenerator(context).generate())
.write("$W\n", new SignRequestMiddlewareGenerator(context).generate());
}

// TODO(i&a): allow consuming generators to overwrite
Expand Down
Loading

0 comments on commit 15a7ba2

Please sign in to comment.