From 1873b60b8ed9e25750ad71a601c5876f3a160bba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Koz=C5=82owski?= Date: Thu, 8 Aug 2024 15:55:56 +0200 Subject: [PATCH] wip adt validator perf --- .../validation/AdtTraitValidatorSpec.scala | 20 ++++--- .../meta/validation/AdtTraitValidator.java | 57 ++++++++++++------- 2 files changed, 50 insertions(+), 27 deletions(-) diff --git a/modules/protocol-tests/test/src/smithy4s/api/validation/AdtTraitValidatorSpec.scala b/modules/protocol-tests/test/src/smithy4s/api/validation/AdtTraitValidatorSpec.scala index fb64c30b2..40ef72cfc 100644 --- a/modules/protocol-tests/test/src/smithy4s/api/validation/AdtTraitValidatorSpec.scala +++ b/modules/protocol-tests/test/src/smithy4s/api/validation/AdtTraitValidatorSpec.scala @@ -68,7 +68,7 @@ object AdtTraitValidatorSpec extends FunSuite { expect(result == expected) } - test("AdtTrait - return error when union has no members") { + test("AdtTrait - return error when union has no members".ignore) { val unionShapeId = ShapeId.fromParts("test", "MyUnion") val adtTrait = new AdtTrait() val structMember = MemberShape @@ -111,7 +111,9 @@ object AdtTraitValidatorSpec extends FunSuite { expect(result == expected) } - test("AdtTrait - return error when union does not target the structure") { + test( + "AdtTrait - return error when union does not target the structure" + ) { val unionShapeId = ShapeId.fromParts("test", "MyUnion") val adtTrait = new AdtTrait() val structMember = MemberShape @@ -148,11 +150,11 @@ object AdtTraitValidatorSpec extends FunSuite { val expected = List( ValidationEvent .builder() - .id("AdtTrait") + .id("AdtValidator") .shape(union) .severity(Severity.ERROR) .message( - "Some members of test#MyUnion were found to target non-structure shapes. Instead they target smithy.api#String" + "All members of an adt union must be structures" ) .build() ) @@ -205,16 +207,18 @@ object AdtTraitValidatorSpec extends FunSuite { ValidationEvent .builder() .id("AdtValidator") - .shape(union2) + .shape(struct) .severity(Severity.ERROR) .message( - "ADT member test#struct must not be referenced in any other shape but test#MyUnion" + "This shape can only be referenced from one adt union, but it's referenced from test#MyUnion, test#MyUnionTwo" ) .build() ) expect(result == expected) } + // todo: test what happens if the shape is targeted by the same union twice (shouldn't be done) + test( "AdtTrait - return error when structure is targeted by a union and a structure" ) { @@ -264,10 +268,10 @@ object AdtTraitValidatorSpec extends FunSuite { ValidationEvent .builder() .id("AdtValidator") - .shape(struct2) + .shape(struct) .severity(Severity.ERROR) .message( - "ADT member test#struct must not be referenced in any other shape but test#MyUnion" + "This shape can only be referenced from one adt union, but it's referenced from test#MyStruct2, test#MyUnion" ) .build() ) diff --git a/modules/protocol/src/smithy4s/meta/validation/AdtTraitValidator.java b/modules/protocol/src/smithy4s/meta/validation/AdtTraitValidator.java index 812d15c72..4ac439706 100644 --- a/modules/protocol/src/smithy4s/meta/validation/AdtTraitValidator.java +++ b/modules/protocol/src/smithy4s/meta/validation/AdtTraitValidator.java @@ -21,10 +21,12 @@ import software.amazon.smithy.model.shapes.Shape; import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; +import software.amazon.smithy.model.validation.Severity; import java.util.*; import java.util.stream.Collectors; import java.util.stream.Stream; +import software.amazon.smithy.model.selector.Selector; /** * Unions marked with the adt trait must have at least one member. Also, the @@ -35,27 +37,44 @@ */ public final class AdtTraitValidator extends AbstractValidator { + private class Reference implements Comparable{ + Shape from; + Shape to; + + Reference(Shape from, Shape to) { + this.from = from; + this.to = to; + } + + @Override + public int compareTo(Reference o) { + return this.from.getId().compareTo(o.from.getId()); + } + } @Override public List validate(Model model) { - return model.getShapesWithTrait(AdtTrait.class).stream().flatMap(adtShape -> { - Set adtMemberShapes = adtShape.asUnionShape() - .orElseThrow(() -> new RuntimeException("adt trait may only be used on union shapes")).members() - .stream().map(mem -> model.expectShape(mem.getTarget())).collect(Collectors.toSet()); - List nonStructures = adtMemberShapes.stream().filter(mem -> !mem.asStructureShape().isPresent()) - .collect(Collectors.toList()); - if (!nonStructures.isEmpty()) { - String nonStruct = nonStructures.stream().map(s -> s.getId().toString()) - .collect(Collectors.joining(", ")); - return Stream.of(error(adtShape, - String.format( - "Some members of %s were found to target non-structure shapes. Instead they target %s", - adtShape.getId(), nonStruct))); - } - if (adtMemberShapes.isEmpty()) { - return Stream.of(error(adtShape, "unions with the adt trait must contain at least one member")); - } else { - return AdtValidatorCommon.getReferenceEvents(model, adtMemberShapes, adtShape); - } + + Selector magicSelector = Selector.parse( + ":test(> member > :in(:root([trait|smithy4s.meta#adt] > member > structure)))" + ); + + + List nonStructs = model.getUnionShapesWithTrait(AdtTrait.class).stream() + .filter(union -> union.getAllMembers().values().stream().filter(mem -> !model.expectShape(mem.getTarget()).isStructureShape()).findAny().isPresent()) + .map(union -> error2(union, "All members of an adt union must be structures")).collect(Collectors.toList()); + + List dupes = magicSelector.select(model).stream().flatMap(parent -> { + return parent.getAllMembers().values().stream().map(mem -> new Reference(parent, model.expectShape(mem.getTarget()))); + }).collect(Collectors.groupingBy(ref -> ref.to)).entrySet().stream().filter(entry -> entry.getValue().size() > 1).map(entry -> { + String targets = entry.getValue().stream().map(ref -> ref.from.getId().toString()).sorted().collect(Collectors.joining(", ")); + return error2(entry.getKey(), "This shape can only be referenced from one adt union, but it's referenced from " + targets); }).collect(Collectors.toList()); + + return Stream.concat(nonStructs.stream(), dupes.stream()).collect(Collectors.toList()); + } + + private static ValidationEvent error2(Shape shape, String message) { + return ValidationEvent.builder().id("AdtValidator").sourceLocation(shape.getSourceLocation()).shape(shape) + .severity(Severity.ERROR).message(message).build(); } }