Skip to content

Commit

Permalink
restructure some things around ctx
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws committed Feb 26, 2024
1 parent 97e23c5 commit 70a8eaf
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.TriConsumer;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.utils.SmithyUnstableApi;

Expand Down Expand Up @@ -138,9 +137,7 @@ default void writeAdditionalFiles(
*
* @return Returns the list of protocol generators to register.
*/
default List<ServiceProtocolGenerator> getProtocolGenerators(
Model model, ServiceShape service, SymbolProvider symbolProvider
) {
default List<ServiceProtocolGenerator> getProtocolGenerators(GoCodegenContext ctx) {
return Collections.emptyList();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static java.util.stream.Collectors.toSet;
import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
import static software.amazon.smithy.go.codegen.service.ServiceCodegenUtils.getShapesToSerde;
import static software.amazon.smithy.go.codegen.service.ServiceCodegenUtils.isUnit;
import static software.amazon.smithy.go.codegen.service.ServiceCodegenUtils.withUnit;

import java.util.List;
Expand All @@ -31,6 +32,7 @@
import software.amazon.smithy.codegen.core.directed.GenerateEnumDirective;
import software.amazon.smithy.codegen.core.directed.GenerateErrorDirective;
import software.amazon.smithy.codegen.core.directed.GenerateIntEnumDirective;
import software.amazon.smithy.codegen.core.directed.GenerateOperationDirective;
import software.amazon.smithy.codegen.core.directed.GenerateServiceDirective;
import software.amazon.smithy.codegen.core.directed.GenerateStructureDirective;
import software.amazon.smithy.codegen.core.directed.GenerateUnionDirective;
Expand All @@ -48,37 +50,36 @@
import software.amazon.smithy.go.codegen.SymbolVisitor;
import software.amazon.smithy.go.codegen.UnionGenerator;
import software.amazon.smithy.model.knowledge.TopDownIndex;
import software.amazon.smithy.model.shapes.EnumShape;
import software.amazon.smithy.model.shapes.IntEnumShape;
import software.amazon.smithy.model.shapes.StringShape;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.shapes.UnionShape;

public class ServiceDirectedCodegen implements DirectedCodegen<GoCodegenContext, GoSettings, GoServiceIntegration> {
@Override
public SymbolProvider createSymbolProvider(CreateSymbolProviderDirective directive) {
return new SymbolVisitor(withUnit(directive.model()), (GoSettings) directive.settings());
public SymbolProvider createSymbolProvider(CreateSymbolProviderDirective<GoSettings> directive) {
return new SymbolVisitor(withUnit(directive.model()), directive.settings());
}

@Override
public GoCodegenContext createContext(CreateContextDirective directive) {
public GoCodegenContext createContext(CreateContextDirective<GoSettings, GoServiceIntegration> directive) {
return new GoCodegenContext(
withUnit(directive.model()),
(GoSettings) directive.settings(),
directive.settings(),
directive.symbolProvider(),
directive.fileManifest(),
new WriterDelegator<>(directive.fileManifest(), directive.symbolProvider(), (filename, namespace) ->
new GoWriter(namespace)),
new WriterDelegator<>(directive.fileManifest(), directive.symbolProvider(),
(filename, namespace) -> new GoWriter(namespace)),
directive.integrations()
);
}

@Override
public void generateService(GenerateServiceDirective directive) {
var namespace = ((GoSettings) directive.settings()).getModuleName();
public void generateService(GenerateServiceDirective<GoCodegenContext, GoSettings> directive) {
var namespace = directive.settings().getModuleName();
var delegator = directive.context().writerDelegator();
var settings = ((GoSettings) directive.settings());
var settings = directive.settings();

var protocolGenerator = resolveProtocolGenerator(directive);
var protocolGenerator = resolveProtocolGenerator(directive.context());

var model = directive.model();
var service = directive.service();
Expand Down Expand Up @@ -125,11 +126,18 @@ public void generateService(GenerateServiceDirective directive) {
}

@Override
public void generateStructure(GenerateStructureDirective directive) {
public void generateOperation(GenerateOperationDirective<GoCodegenContext, GoSettings> directive) {
var protocolGenerator = resolveProtocolGenerator(directive.context());
directive.context().writerDelegator().useShapeWriter(directive.shape(),
protocolGenerator.generateHandleOperation(directive.shape()));
}

@Override
public void generateStructure(GenerateStructureDirective<GoCodegenContext, GoSettings> directive) {
if (directive.shape().getId().getNamespace().equals(CodegenUtils.getSyntheticTypeNamespace())) {
return;
}
if (directive.shape().getId().toString().equals("smithy.api#Unit")) {
if (isUnit(directive.shape().getId())) {
return;
}

Expand All @@ -138,69 +146,61 @@ public void generateStructure(GenerateStructureDirective directive) {
new StructureGenerator(
directive.model(),
directive.symbolProvider(),
(GoWriter) writer,
writer,
directive.service(),
(StructureShape) directive.shape(),
directive.shape(),
directive.symbolProvider().toSymbol(directive.shape()),
null
).run()
);
}

@Override
public void generateError(GenerateErrorDirective directive) {
public void generateError(GenerateErrorDirective<GoCodegenContext, GoSettings> directive) {
var delegator = directive.context().writerDelegator();
delegator.useShapeWriter(directive.shape(), writer ->
new StructureGenerator(
directive.model(),
directive.symbolProvider(),
(GoWriter) writer,
writer,
directive.service(),
(StructureShape) directive.shape(),
directive.shape(),
directive.symbolProvider().toSymbol(directive.shape()),
null
).run()
);
}

@Override
public void generateUnion(GenerateUnionDirective directive) {
public void generateUnion(GenerateUnionDirective<GoCodegenContext, GoSettings> directive) {
var delegator = directive.context().writerDelegator();
delegator.useShapeWriter(directive.shape(), writer ->
new UnionGenerator(directive.model(), directive.symbolProvider(), (UnionShape) directive.shape())
.generateUnion((GoWriter) writer)
new UnionGenerator(directive.model(), directive.symbolProvider(), directive.shape())
.generateUnion(writer)
);
}

@Override
public void generateEnumShape(GenerateEnumDirective directive) {
public void generateEnumShape(GenerateEnumDirective<GoCodegenContext, GoSettings> directive) {
var delegator = directive.context().writerDelegator();
delegator.useShapeWriter(directive.shape(), writer ->
new EnumGenerator(directive.symbolProvider(), (GoWriter) writer, (StringShape) directive.shape())
.run()
new EnumGenerator(directive.symbolProvider(), writer, (EnumShape) directive.shape()).run()
);
}

@Override
public void generateIntEnumShape(GenerateIntEnumDirective directive) {
public void generateIntEnumShape(GenerateIntEnumDirective<GoCodegenContext, GoSettings> directive) {
directive.context().writerDelegator().useShapeWriter(directive.shape(), writer ->
new IntEnumGenerator(
directive.symbolProvider(),
(GoWriter) writer,
(IntEnumShape) directive.shape()
).run()
new IntEnumGenerator(directive.symbolProvider(), writer, (IntEnumShape) directive.shape()).run()
);
}

private ServiceProtocolGenerator resolveProtocolGenerator(
GenerateServiceDirective<GoCodegenContext, GoSettings> directive
) {
var model = directive.model();
var service = directive.service();
var symbolProvider = directive.symbolProvider();
private ServiceProtocolGenerator resolveProtocolGenerator(GoCodegenContext ctx) {
var model = ctx.model();
var service = ctx.settings().getService(model);

var protocolGenerators = directive.context().integrations().stream()
.flatMap(it -> it.getProtocolGenerators(model, service, symbolProvider).stream())
var protocolGenerators = ctx.integrations().stream()
.flatMap(it -> it.getProtocolGenerators(ctx).stream())
.filter(it -> service.hasTrait(it.getProtocol()))
.toList();
if (protocolGenerators.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Set;
import software.amazon.smithy.go.codegen.ApplicationProtocol;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.utils.SmithyInternalApi;
Expand All @@ -32,6 +33,8 @@ public interface ServiceProtocolGenerator {
// Go
GoWriter.Writable generateHandleRequest();

GoWriter.Writable generateHandleOperation(OperationShape operation);

GoWriter.Writable generateOptions();

GoWriter.Writable generateDeserializers(Set<Shape> shape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,17 @@
package software.amazon.smithy.go.codegen.service.integration;

import java.util.List;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoCodegenContext;
import software.amazon.smithy.go.codegen.service.GoServiceIntegration;
import software.amazon.smithy.go.codegen.service.ServiceProtocolGenerator;
import software.amazon.smithy.go.codegen.service.protocol.aws.AwsJson10ProtocolGenerator;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.utils.ListUtils;

public class DefaultProtocols implements GoServiceIntegration {
@Override
public List<ServiceProtocolGenerator> getProtocolGenerators(
Model model, ServiceShape service, SymbolProvider symbolProvider
) {
public List<ServiceProtocolGenerator> getProtocolGenerators(GoCodegenContext ctx) {
return ListUtils.of(
new AwsJson10ProtocolGenerator(model, service, symbolProvider)
new AwsJson10ProtocolGenerator(ctx)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@

package software.amazon.smithy.go.codegen.service.protocol;

import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate;
import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;

import software.amazon.smithy.go.codegen.ApplicationProtocol;
import software.amazon.smithy.go.codegen.GoCodegenContext;
import software.amazon.smithy.go.codegen.GoStdlibTypes;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoTypes;
import software.amazon.smithy.go.codegen.knowledge.GoValidationIndex;
import software.amazon.smithy.go.codegen.service.RequestHandler;
import software.amazon.smithy.go.codegen.service.ServiceProtocolGenerator;
import software.amazon.smithy.go.codegen.service.ServiceValidationGenerator;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.utils.MapUtils;
import software.amazon.smithy.utils.SmithyInternalApi;

Expand All @@ -31,6 +38,16 @@
*/
@SmithyInternalApi
public abstract class HttpHandlerProtocolGenerator implements ServiceProtocolGenerator {
protected final GoCodegenContext ctx;

private final GoValidationIndex validationIndex;

protected HttpHandlerProtocolGenerator(GoCodegenContext ctx) {
this.ctx = ctx;

this.validationIndex = GoValidationIndex.of(ctx.model());
}

@Override
public ApplicationProtocol getApplicationProtocol() {
return ApplicationProtocol.createDefaultHttpApplicationProtocol();
Expand Down Expand Up @@ -90,8 +107,95 @@ public GoWriter.Writable generateProtocolSource() {
));
}

@Override
public final GoWriter.Writable generateHandleOperation(OperationShape operation) {
var service = ctx.settings().getService(ctx.model());
var input = ctx.model().expectShape(operation.getInputShape());
return goTemplate("""
func (h *$requestHandler:L) $funcName:L(w $rw:T, r $r:P) {
id, err := $newUuid:T($rand:T).GetUUID()
if err != nil {
serializeError(w, err)
return
}
$beforeDeserialize:W
$deserialize:W
$afterDeserialize:W
$validate:W
out, err := h.service.$operation:L(r.Context(), in)
if err != nil {
serializeError(w, err)
return
}
$beforeSerialize:W
$beforeWriteResponse:W
$serialize:W
}
""",
MapUtils.of(
"requestHandler", RequestHandler.NAME,
"funcName", getOperationHandlerName(operation),
"rw", GoStdlibTypes.Net.Http.ResponseWriter,
"r", GoStdlibTypes.Net.Http.Request
),
MapUtils.of(
"newUuid", SmithyGoTypes.Rand.NewUUID,
"rand", GoStdlibTypes.Crypto.Rand.Reader,
"deserialize", generateDeserializeRequest(operation),
"validate", validationIndex.operationRequiresValidation(service, operation)
? generateValidateInput(input)
: emptyGoTemplate(),
"operation", ctx.symbolProvider().toSymbol(operation).getName(),
"serialize", generateSerializeResponse(operation),
"beforeDeserialize", generateInvokeInterceptor("BeforeDeserialize", "r"),
"afterDeserialize", generateInvokeInterceptor("AfterDeserialize", "in"),
"beforeSerialize", generateInvokeInterceptor("BeforeSerialize", "out"),
"beforeWriteResponse", generateInvokeInterceptor("BeforeWriteResponse", "w")
));
}

/**
* Generates the net/http.Handler's ServeHTTP implementation for this protocol.
* Individual operation handlers are generated by generateServeHttpOperation. Implementors should fill in logic here
* to route requests to those methods according to the protocol.
*/
public abstract GoWriter.Writable generateServeHttp();

/**
* Generates a block of logic to convert the input http.Request `r` into the modeled input structure `in`.
*/
public abstract GoWriter.Writable generateDeserializeRequest(OperationShape operation);

/**
* Generates a block of serialize the modeled output structure `out` to the http.ResponseWriter `w`.
*/
public abstract GoWriter.Writable generateSerializeResponse(OperationShape operation);

protected final String getOperationHandlerName(OperationShape operation) {
return "serveHTTP" + operation.getId().getName();
}

private GoWriter.Writable generateValidateInput(Shape input) {
return goTemplate("""
if err := $L(in); err != nil {
serializeError(w, err)
return
}
""", ServiceValidationGenerator.getShapeValidatorName(input));
}

private GoWriter.Writable generateInvokeInterceptor(String type, String args) {
return goTemplate("""
for _, i := range h.options.Interceptors.$1L {
if err := i.$1L(r.Context(), id, $2L); err != nil {
serializeError(w, err)
return
}
}
""", type, args);
}
}
Loading

0 comments on commit 70a8eaf

Please sign in to comment.