Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat-svcgen-deserialize #495

Merged
merged 3 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ public static final class Context {
public static final Symbol Context = SmithyGoDependency.CONTEXT.valueSymbol("Context");
}

public static final class Encoding {
public static final class Json {
public static final Symbol NewDecoder = SmithyGoDependency.JSON.valueSymbol("NewDecoder");
public static final Symbol Number = SmithyGoDependency.JSON.valueSymbol("Number");
}

public static final class Base64 {
public static final Symbol StdEncoding = SmithyGoDependency.BASE64.valueSymbol("StdEncoding");
}
}

public static final class Fmt {
public static final Symbol Errorf = SmithyGoDependency.FMT.valueSymbol("Errorf");
public static final Symbol Sprintf = SmithyGoDependency.FMT.valueSymbol("Sprintf");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ public static ChainWritable of(GoWriter.Writable... writables) {
return chain;
}

public static ChainWritable of(List<GoWriter.Writable> writables) {
public static ChainWritable of(Collection<GoWriter.Writable> writables) {
var chain = new ChainWritable();
chain.writables.addAll(writables);
return chain;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,20 @@ public static final class Smithy {
public static final Symbol OperationError = SmithyGoDependency.SMITHY.pointableSymbol("OperationError");
}

public static final class Encoding {
public static final class Json {
public static final Symbol NewEncoder = SmithyGoDependency.SMITHY_JSON.valueSymbol("NewEncoder");
public static final Symbol Value = SmithyGoDependency.SMITHY_JSON.valueSymbol("Value");
}
}

public static final class Ptr {
public static final Symbol String = SmithyGoDependency.SMITHY_PTR.valueSymbol("String");
public static final Symbol Bool = SmithyGoDependency.SMITHY_PTR.valueSymbol("Bool");
public static final Symbol Int8 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Int8");
public static final Symbol Int16 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Int16");
public static final Symbol Int32 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Int32");
public static final Symbol Int64 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Int64");
}

public static final class Middleware {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,14 @@ public static boolean isUniverseType(Symbol symbol) {
.orElse(false);
}

public static boolean isPointable(Symbol symbol) {
return symbol.getProperty(SymbolUtils.POINTABLE, Boolean.class).orElse(false);
}

public static Symbol getReference(Symbol symbol) {
return symbol.getProperty(SymbolUtils.GO_ELEMENT_TYPE, Symbol.class).orElse(null);
}

/**
* Builds a symbol within the context of the package in which codegen is taking place.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.smithy.go.codegen.service;

import static java.util.stream.Collectors.toSet;

import java.util.Set;
import java.util.stream.Stream;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.BlobShape;
import software.amazon.smithy.model.shapes.BooleanShape;
import software.amazon.smithy.model.shapes.ByteShape;
import software.amazon.smithy.model.shapes.DoubleShape;
import software.amazon.smithy.model.shapes.FloatShape;
import software.amazon.smithy.model.shapes.IntegerShape;
import software.amazon.smithy.model.shapes.LongShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShortShape;
import software.amazon.smithy.model.shapes.StringShape;
import software.amazon.smithy.model.shapes.TimestampShape;

public final class Util {
private Util() {}

public static Set<Shape> getShapesToSerde(Model model, Shape shape) {
return Stream.concat(
Stream.of(normalize(shape)),
shape.members().stream()
.map(it -> model.expectShape(it.getTarget()))
.flatMap(it -> getShapesToSerde(model, it).stream())
).filter(it -> !it.getId().toString().equals("smithy.api#Unit")).collect(toSet());
}

public static Shape normalize(Shape shape) {
return switch (shape.getType()) {
// TODO should be marked synthetic and keyed into from there by caller to avoid shape name conflicts
case BLOB -> BlobShape.builder().id("com.amazonaws.synthetic#Blob").build();
case BOOLEAN -> BooleanShape.builder().id("com.amazonaws.synthetic#Bool").build();
case STRING -> StringShape.builder().id("com.amazonaws.synthetic#String").build();
case TIMESTAMP -> TimestampShape.builder().id("com.amazonaws.synthetic#Time").build();
case BYTE -> ByteShape.builder().id("com.amazonaws.synthetic#Int8").build();
case SHORT -> ShortShape.builder().id("com.amazonaws.synthetic#Int16").build();
case INTEGER -> IntegerShape.builder().id("com.amazonaws.synthetic#Int32").build();
case LONG -> LongShape.builder().id("com.amazonaws.synthetic#Int64").build();
case FLOAT -> FloatShape.builder().id("com.amazonaws.synthetic#Float32").build();
case DOUBLE -> DoubleShape.builder().id("com.amazonaws.synthetic#Float64").build();
default -> shape;
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.smithy.go.codegen.service.protocol;

import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
import static software.amazon.smithy.go.codegen.SymbolUtils.getReference;
import static software.amazon.smithy.go.codegen.SymbolUtils.isPointable;
import static software.amazon.smithy.go.codegen.service.Util.normalize;

import java.util.Set;
import software.amazon.smithy.codegen.core.CodegenException;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoStdlibTypes;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoTypes;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.CollectionShape;
import software.amazon.smithy.model.shapes.MapShape;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.utils.MapUtils;
import software.amazon.smithy.utils.SmithyInternalApi;

@SmithyInternalApi
public final class JsonDeserializerGenerator {
private final Model model;
private final SymbolProvider symbolProvider;

public JsonDeserializerGenerator(Model model, SymbolProvider symbolProvider) {
this.model = model;
this.symbolProvider = symbolProvider;
}

public static String getDeserializerName(Shape shape) {
return "deserialize" + shape.getId().getName();
}

public GoWriter.Writable generate(Set<Shape> shapes) {
return GoWriter.ChainWritable.of(
shapes.stream()
.map(this::generateShapeDeserializer)
.toList()
).compose();
}

private GoWriter.Writable generateShapeDeserializer(Shape shape) {
return goTemplate("""
func $name:L(v interface{}) ($shapeType:P, error) {
av, ok := v.($assert:W)
if !ok {
return $zero:W, $error:T("invalid")
}
$deserialize:W
}
""",
MapUtils.of(
"name", getDeserializerName(shape),
"shapeType", symbolProvider.toSymbol(shape),
"assert", generateOpaqueAssert(shape),
"zero", generateZeroValue(shape),
"error", GoStdlibTypes.Fmt.Errorf,
"deserialize", generateDeserializeAssertedValue(shape, "av")
));
}

private GoWriter.Writable generateOpaqueAssert(Shape shape) {
return switch (shape.getType()) {
case BYTE, SHORT, INTEGER, LONG, FLOAT, DOUBLE, INT_ENUM ->
goTemplate("$T", GoStdlibTypes.Encoding.Json.Number);
case STRING, BLOB, TIMESTAMP, ENUM, BIG_DECIMAL, BIG_INTEGER ->
goTemplate("string");
case BOOLEAN ->
goTemplate("bool");
case LIST, SET ->
goTemplate("[]interface{}");
case MAP, STRUCTURE, UNION ->
goTemplate("map[string]interface{}");
case DOCUMENT ->
throw new CodegenException("TODO: document is special");
default ->
throw new CodegenException("? " + shape.getType());
};
}

private GoWriter.Writable generateZeroValue(Shape shape) {
return switch (shape.getType()) {
case BYTE, SHORT, INTEGER, LONG, FLOAT, DOUBLE ->
goTemplate("0");
case STRING ->
goTemplate("\"\"");
case BOOLEAN ->
goTemplate("false");
case BLOB, LIST, SET, MAP, STRUCTURE, UNION ->
goTemplate("nil");
case ENUM ->
goTemplate("$T(\"\")", symbolProvider.toSymbol(shape));
case INT_ENUM ->
goTemplate("$T(0)", symbolProvider.toSymbol(shape));
case DOCUMENT ->
throw new CodegenException("TODO: document is special");
default ->
throw new CodegenException("? " + shape.getType());
};
}

private GoWriter.Writable generateDeserializeAssertedValue(Shape shape, String ident) {
return switch (shape.getType()) {
case BYTE -> generateDeserializeIntegral(ident, "int8", Byte.MIN_VALUE, Byte.MAX_VALUE);
case SHORT -> generateDeserializeIntegral(ident, "int16", Short.MIN_VALUE, Short.MAX_VALUE);
case INTEGER -> generateDeserializeIntegral(ident, "int32", Integer.MIN_VALUE, Integer.MAX_VALUE);
case LONG -> generateDeserializeIntegral(ident, "int64", Long.MIN_VALUE, Long.MAX_VALUE);
case STRING, BOOLEAN -> goTemplate("return $L, nil", ident);
case ENUM -> goTemplate("return $T($L), nil", symbolProvider.toSymbol(shape), ident);
case BLOB -> goTemplate("""
p, err := $b64:T.DecodeString($ident:L)
if err != nil {
return nil, err
}
return p, nil
""",
MapUtils.of(
"ident", ident,
"b64", GoStdlibTypes.Encoding.Base64.StdEncoding
));
case LIST, SET -> {
var target = normalize(model.expectShape(((CollectionShape) shape).getMember().getTarget()));
var symbol = symbolProvider.toSymbol(shape);
var targetSymbol = symbolProvider.toSymbol(target);
yield goTemplate("""
var deserializedList $type:T
for _, serializedItem := range $ident:L {
deserializedItem, err := $deserialize:L(serializedItem)
if err != nil {
return nil, err
}
deserializedList = append(deserializedList, $deref:L)
}
return deserializedList, nil
""",
MapUtils.of(
"type", symbol,
"ident", ident,
"deserialize", getDeserializerName(target),
"deref", isPointable(getReference(symbol)) != isPointable(targetSymbol)
? "*deserializedItem" : "deserializedItem"
));
}
case MAP -> {
var value = normalize(model.expectShape(((MapShape) shape).getValue().getTarget()));
var symbol = symbolProvider.toSymbol(shape);
var valueSymbol = symbolProvider.toSymbol(value);
yield goTemplate("""
deserializedMap := $type:T{}
for key, serializedValue := range $ident:L {
deserializedValue, err := $deserialize:L(serializedValue)
if err != nil {
return nil, err
}
deserializedMap[key] = $deref:L
}
return deserializedMap, nil
""",
MapUtils.of(
"type", symbol,
"ident", ident,
"deserialize", getDeserializerName(value),
"deref", isPointable(getReference(symbol)) != isPointable(valueSymbol)
? "*deserializedValue" : "deserializedValue"
));
}
case STRUCTURE -> goTemplate("""
deserializedStruct := &$type:T{}
for key, serializedValue := range $ident:L {
$deserializeFields:W
}
return deserializedStruct, nil
""",
MapUtils.of(
"type", symbolProvider.toSymbol(shape),
"ident", ident,
"deserializeFields", GoWriter.ChainWritable.of(
shape.getAllMembers().entrySet().stream()
.map(it -> {
var target = model.expectShape(it.getValue().getTarget());
return goTemplate("""
if key == $field:S {
fieldValue, err := $deserialize:L(serializedValue)
if err != nil {
return nil, err
}
deserializedStruct.$fieldName:L = $deref:W
}
""",
MapUtils.of(
"field", it.getKey(),
"fieldName", symbolProvider.toMemberName(it.getValue()),
"deserialize", getDeserializerName(normalize(target)),
"deref", generateStructFieldDeref(
it.getValue(), "fieldValue")
));
})
.toList()
).compose(false)
));
case UNION -> goTemplate("// TODO (union)");
default ->
throw new CodegenException("? " + shape.getType());
};
}

private GoWriter.Writable generateDeserializeIntegral(String ident, String castTo, long min, long max) {
return goTemplate("""
$nextident:L, err := $ident:L.Int64()
if err != nil {
return 0, err
}
if $nextident:L < $min:L || $nextident:L > $max:L {
return 0, $errorf:T("invalid")
}
return $cast:L($nextident:L), nil
""",
MapUtils.of(
"errorf", GoStdlibTypes.Fmt.Errorf,
"ident", ident,
"nextident", ident + "_",
"min", min,
"max", max,
"cast", castTo
));
}

private GoWriter.Writable generateStructFieldDeref(MemberShape member, String ident) {
var symbol = symbolProvider.toSymbol(member);
if (!isPointable(symbol)) {
return goTemplate(ident);
}
return switch (model.expectShape(member.getTarget()).getType()) {
case BYTE -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Int8, ident);
case SHORT -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Int16, ident);
case INTEGER -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Int32, ident);
case LONG -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Int64, ident);
case STRING -> goTemplate("$T($L)", SmithyGoTypes.Ptr.String, ident);
case BOOLEAN -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Bool, ident);
default -> goTemplate(ident);
};
}

}
Loading
Loading