Skip to content

Commit

Permalink
feat-svcgen-deserialize (#495)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws committed Feb 26, 2024
1 parent a310e49 commit 3455ea1
Show file tree
Hide file tree
Showing 8 changed files with 603 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ public static final class Context {
public static final Symbol Background = SmithyGoDependency.CONTEXT.valueSymbol("Background");
}

public static final class Encoding {
public static final class Json {
public static final Symbol NewDecoder = SmithyGoDependency.JSON.valueSymbol("NewDecoder");
public static final Symbol Number = SmithyGoDependency.JSON.valueSymbol("Number");
}

public static final class Base64 {
public static final Symbol StdEncoding = SmithyGoDependency.BASE64.valueSymbol("StdEncoding");
}
}

public static final class Fmt {
public static final Symbol Errorf = SmithyGoDependency.FMT.valueSymbol("Errorf");
public static final Symbol Sprintf = SmithyGoDependency.FMT.valueSymbol("Sprintf");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,7 @@ public static ChainWritable of(GoWriter.Writable... writables) {
return chain;
}

public static ChainWritable of(List<GoWriter.Writable> writables) {
public static ChainWritable of(Collection<GoWriter.Writable> writables) {
var chain = new ChainWritable();
chain.writables.addAll(writables);
return chain;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,20 @@ public static final class Smithy {
public static final Symbol OperationError = SmithyGoDependency.SMITHY.pointableSymbol("OperationError");
}

public static final class Encoding {
public static final class Json {
public static final Symbol NewEncoder = SmithyGoDependency.SMITHY_JSON.valueSymbol("NewEncoder");
public static final Symbol Value = SmithyGoDependency.SMITHY_JSON.valueSymbol("Value");
}
}

public static final class Ptr {
public static final Symbol String = SmithyGoDependency.SMITHY_PTR.valueSymbol("String");
public static final Symbol Bool = SmithyGoDependency.SMITHY_PTR.valueSymbol("Bool");
public static final Symbol Int8 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Int8");
public static final Symbol Int16 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Int16");
public static final Symbol Int32 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Int32");
public static final Symbol Int64 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Int64");
}

public static final class Middleware {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,14 @@ public static boolean isUniverseType(Symbol symbol) {
.orElse(false);
}

public static boolean isPointable(Symbol symbol) {
return symbol.getProperty(SymbolUtils.POINTABLE, Boolean.class).orElse(false);
}

public static Symbol getReference(Symbol symbol) {
return symbol.getProperty(SymbolUtils.GO_ELEMENT_TYPE, Symbol.class).orElse(null);
}

/**
* Builds a symbol within the context of the package in which codegen is taking place.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.smithy.go.codegen.service;

import static java.util.stream.Collectors.toSet;

import java.util.Set;
import java.util.stream.Stream;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.BlobShape;
import software.amazon.smithy.model.shapes.BooleanShape;
import software.amazon.smithy.model.shapes.ByteShape;
import software.amazon.smithy.model.shapes.DoubleShape;
import software.amazon.smithy.model.shapes.FloatShape;
import software.amazon.smithy.model.shapes.IntegerShape;
import software.amazon.smithy.model.shapes.LongShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShortShape;
import software.amazon.smithy.model.shapes.StringShape;
import software.amazon.smithy.model.shapes.TimestampShape;

public final class Util {
private Util() {}

public static Set<Shape> getShapesToSerde(Model model, Shape shape) {
return Stream.concat(
Stream.of(normalize(shape)),
shape.members().stream()
.map(it -> model.expectShape(it.getTarget()))
.flatMap(it -> getShapesToSerde(model, it).stream())
).filter(it -> !it.getId().toString().equals("smithy.api#Unit")).collect(toSet());
}

public static Shape normalize(Shape shape) {
return switch (shape.getType()) {
// TODO should be marked synthetic and keyed into from there by caller to avoid shape name conflicts
case BLOB -> BlobShape.builder().id("com.amazonaws.synthetic#Blob").build();
case BOOLEAN -> BooleanShape.builder().id("com.amazonaws.synthetic#Bool").build();
case STRING -> StringShape.builder().id("com.amazonaws.synthetic#String").build();
case TIMESTAMP -> TimestampShape.builder().id("com.amazonaws.synthetic#Time").build();
case BYTE -> ByteShape.builder().id("com.amazonaws.synthetic#Int8").build();
case SHORT -> ShortShape.builder().id("com.amazonaws.synthetic#Int16").build();
case INTEGER -> IntegerShape.builder().id("com.amazonaws.synthetic#Int32").build();
case LONG -> LongShape.builder().id("com.amazonaws.synthetic#Int64").build();
case FLOAT -> FloatShape.builder().id("com.amazonaws.synthetic#Float32").build();
case DOUBLE -> DoubleShape.builder().id("com.amazonaws.synthetic#Float64").build();
default -> shape;
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.smithy.go.codegen.service.protocol;

import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
import static software.amazon.smithy.go.codegen.SymbolUtils.getReference;
import static software.amazon.smithy.go.codegen.SymbolUtils.isPointable;
import static software.amazon.smithy.go.codegen.service.Util.normalize;

import java.util.Set;
import software.amazon.smithy.codegen.core.CodegenException;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoStdlibTypes;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoTypes;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.CollectionShape;
import software.amazon.smithy.model.shapes.MapShape;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.utils.MapUtils;
import software.amazon.smithy.utils.SmithyInternalApi;

@SmithyInternalApi
public final class JsonDeserializerGenerator {
private final Model model;
private final SymbolProvider symbolProvider;

public JsonDeserializerGenerator(Model model, SymbolProvider symbolProvider) {
this.model = model;
this.symbolProvider = symbolProvider;
}

public static String getDeserializerName(Shape shape) {
return "deserialize" + shape.getId().getName();
}

public GoWriter.Writable generate(Set<Shape> shapes) {
return GoWriter.ChainWritable.of(
shapes.stream()
.map(this::generateShapeDeserializer)
.toList()
).compose();
}

private GoWriter.Writable generateShapeDeserializer(Shape shape) {
return goTemplate("""
func $name:L(v interface{}) ($shapeType:P, error) {
av, ok := v.($assert:W)
if !ok {
return $zero:W, $error:T("invalid")
}
$deserialize:W
}
""",
MapUtils.of(
"name", getDeserializerName(shape),
"shapeType", symbolProvider.toSymbol(shape),
"assert", generateOpaqueAssert(shape),
"zero", generateZeroValue(shape),
"error", GoStdlibTypes.Fmt.Errorf,
"deserialize", generateDeserializeAssertedValue(shape, "av")
));
}

private GoWriter.Writable generateOpaqueAssert(Shape shape) {
return switch (shape.getType()) {
case BYTE, SHORT, INTEGER, LONG, FLOAT, DOUBLE, INT_ENUM ->
goTemplate("$T", GoStdlibTypes.Encoding.Json.Number);
case STRING, BLOB, TIMESTAMP, ENUM, BIG_DECIMAL, BIG_INTEGER ->
goTemplate("string");
case BOOLEAN ->
goTemplate("bool");
case LIST, SET ->
goTemplate("[]interface{}");
case MAP, STRUCTURE, UNION ->
goTemplate("map[string]interface{}");
case DOCUMENT ->
throw new CodegenException("TODO: document is special");
default ->
throw new CodegenException("? " + shape.getType());
};
}

private GoWriter.Writable generateZeroValue(Shape shape) {
return switch (shape.getType()) {
case BYTE, SHORT, INTEGER, LONG, FLOAT, DOUBLE ->
goTemplate("0");
case STRING ->
goTemplate("\"\"");
case BOOLEAN ->
goTemplate("false");
case BLOB, LIST, SET, MAP, STRUCTURE, UNION ->
goTemplate("nil");
case ENUM ->
goTemplate("$T(\"\")", symbolProvider.toSymbol(shape));
case INT_ENUM ->
goTemplate("$T(0)", symbolProvider.toSymbol(shape));
case DOCUMENT ->
throw new CodegenException("TODO: document is special");
default ->
throw new CodegenException("? " + shape.getType());
};
}

private GoWriter.Writable generateDeserializeAssertedValue(Shape shape, String ident) {
return switch (shape.getType()) {
case BYTE -> generateDeserializeIntegral(ident, "int8", Byte.MIN_VALUE, Byte.MAX_VALUE);
case SHORT -> generateDeserializeIntegral(ident, "int16", Short.MIN_VALUE, Short.MAX_VALUE);
case INTEGER -> generateDeserializeIntegral(ident, "int32", Integer.MIN_VALUE, Integer.MAX_VALUE);
case LONG -> generateDeserializeIntegral(ident, "int64", Long.MIN_VALUE, Long.MAX_VALUE);
case STRING, BOOLEAN -> goTemplate("return $L, nil", ident);
case ENUM -> goTemplate("return $T($L), nil", symbolProvider.toSymbol(shape), ident);
case BLOB -> goTemplate("""
p, err := $b64:T.DecodeString($ident:L)
if err != nil {
return nil, err
}
return p, nil
""",
MapUtils.of(
"ident", ident,
"b64", GoStdlibTypes.Encoding.Base64.StdEncoding
));
case LIST, SET -> {
var target = normalize(model.expectShape(((CollectionShape) shape).getMember().getTarget()));
var symbol = symbolProvider.toSymbol(shape);
var targetSymbol = symbolProvider.toSymbol(target);
yield goTemplate("""
var deserializedList $type:T
for _, serializedItem := range $ident:L {
deserializedItem, err := $deserialize:L(serializedItem)
if err != nil {
return nil, err
}
deserializedList = append(deserializedList, $deref:L)
}
return deserializedList, nil
""",
MapUtils.of(
"type", symbol,
"ident", ident,
"deserialize", getDeserializerName(target),
"deref", isPointable(getReference(symbol)) != isPointable(targetSymbol)
? "*deserializedItem" : "deserializedItem"
));
}
case MAP -> {
var value = normalize(model.expectShape(((MapShape) shape).getValue().getTarget()));
var symbol = symbolProvider.toSymbol(shape);
var valueSymbol = symbolProvider.toSymbol(value);
yield goTemplate("""
deserializedMap := $type:T{}
for key, serializedValue := range $ident:L {
deserializedValue, err := $deserialize:L(serializedValue)
if err != nil {
return nil, err
}
deserializedMap[key] = $deref:L
}
return deserializedMap, nil
""",
MapUtils.of(
"type", symbol,
"ident", ident,
"deserialize", getDeserializerName(value),
"deref", isPointable(getReference(symbol)) != isPointable(valueSymbol)
? "*deserializedValue" : "deserializedValue"
));
}
case STRUCTURE -> goTemplate("""
deserializedStruct := &$type:T{}
for key, serializedValue := range $ident:L {
$deserializeFields:W
}
return deserializedStruct, nil
""",
MapUtils.of(
"type", symbolProvider.toSymbol(shape),
"ident", ident,
"deserializeFields", GoWriter.ChainWritable.of(
shape.getAllMembers().entrySet().stream()
.map(it -> {
var target = model.expectShape(it.getValue().getTarget());
return goTemplate("""
if key == $field:S {
fieldValue, err := $deserialize:L(serializedValue)
if err != nil {
return nil, err
}
deserializedStruct.$fieldName:L = $deref:W
}
""",
MapUtils.of(
"field", it.getKey(),
"fieldName", symbolProvider.toMemberName(it.getValue()),
"deserialize", getDeserializerName(normalize(target)),
"deref", generateStructFieldDeref(
it.getValue(), "fieldValue")
));
})
.toList()
).compose(false)
));
case UNION -> goTemplate("// TODO (union)");
default ->
throw new CodegenException("? " + shape.getType());
};
}

private GoWriter.Writable generateDeserializeIntegral(String ident, String castTo, long min, long max) {
return goTemplate("""
$nextident:L, err := $ident:L.Int64()
if err != nil {
return 0, err
}
if $nextident:L < $min:L || $nextident:L > $max:L {
return 0, $errorf:T("invalid")
}
return $cast:L($nextident:L), nil
""",
MapUtils.of(
"errorf", GoStdlibTypes.Fmt.Errorf,
"ident", ident,
"nextident", ident + "_",
"min", min,
"max", max,
"cast", castTo
));
}

private GoWriter.Writable generateStructFieldDeref(MemberShape member, String ident) {
var symbol = symbolProvider.toSymbol(member);
if (!isPointable(symbol)) {
return goTemplate(ident);
}
return switch (model.expectShape(member.getTarget()).getType()) {
case BYTE -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Int8, ident);
case SHORT -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Int16, ident);
case INTEGER -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Int32, ident);
case LONG -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Int64, ident);
case STRING -> goTemplate("$T($L)", SmithyGoTypes.Ptr.String, ident);
case BOOLEAN -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Bool, ident);
default -> goTemplate(ident);
};
}

}
Loading

0 comments on commit 3455ea1

Please sign in to comment.