diff --git a/src/main/java/com/intuit/graphql/orchestrator/batch/AuthDownstreamQueryModifier.java b/src/main/java/com/intuit/graphql/orchestrator/batch/AuthDownstreamQueryModifier.java index 70d7fdee..91e4cfca 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/batch/AuthDownstreamQueryModifier.java +++ b/src/main/java/com/intuit/graphql/orchestrator/batch/AuthDownstreamQueryModifier.java @@ -1,17 +1,5 @@ package com.intuit.graphql.orchestrator.batch; -import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.hasResolverDirective; -import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.getNodesAsPathList; -import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.pathListToFQN; -import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.convertGraphqlFieldWithOriginalName; -import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.getRenameKey; -import static graphql.introspection.Introspection.TypeNameMetaFieldDef; -import static graphql.schema.FieldCoordinates.coordinates; -import static graphql.util.TreeTransformerUtil.changeNode; -import static graphql.util.TreeTransformerUtil.deleteNode; -import static java.util.Objects.nonNull; -import static java.util.Objects.requireNonNull; - import com.intuit.graphql.orchestrator.authorization.FieldAuthorization; import com.intuit.graphql.orchestrator.authorization.FieldAuthorizationEnvironment; import com.intuit.graphql.orchestrator.authorization.FieldAuthorizationResult; @@ -40,6 +28,11 @@ import graphql.schema.GraphQLUnionType; import graphql.util.TraversalControl; import graphql.util.TraverserContext; +import lombok.Builder; +import lombok.NonNull; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.collections4.MapUtils; + import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -47,10 +40,19 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; -import lombok.Builder; -import lombok.NonNull; -import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.collections4.MapUtils; + +import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.hasResolverDirective; +import static com.intuit.graphql.orchestrator.utils.QueryDirectivesUtil.shouldIgnoreNode; +import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.getNodesAsPathList; +import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.pathListToFQN; +import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.convertGraphqlFieldWithOriginalName; +import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.getRenameKey; +import static graphql.introspection.Introspection.TypeNameMetaFieldDef; +import static graphql.schema.FieldCoordinates.coordinates; +import static graphql.util.TreeTransformerUtil.changeNode; +import static graphql.util.TreeTransformerUtil.deleteNode; +import static java.util.Objects.nonNull; +import static java.util.Objects.requireNonNull; /** * This class modifies for query for a downstream provider. @@ -91,6 +93,12 @@ public TraversalControl visitField(Field node, TraverserContext context) { requireNonNull(fieldDefinition, "Failed to get Field Definition for " + node.getName()); context.setVar(GraphQLType.class, fieldDefinition.getType()); + + if(shouldIgnoreNode(node, this.queryVariables)) { + decreaseParentSelectionSetCount(context.getParentContext()); + return deleteNode(context); + } + FieldAuthorizationResult fieldAuthorizationResult = authorize(node, fieldDefinition, parentType, context); if (!fieldAuthorizationResult.isAllowed()) { decreaseParentSelectionSetCount(context.getParentContext()); @@ -112,8 +120,10 @@ public TraversalControl visitField(Field node, TraverserContext context) { GraphQLFieldDefinition fieldDefinition = getFieldDefinition(node.getName(), parentType); requireNonNull(fieldDefinition, "Failed to get Field Definition for " + node.getName()); - if (serviceMetadata.shouldModifyDownStreamQuery() && (hasResolverDirective(fieldDefinition) - || isExternalField(parentType.getName(), node.getName()))) { + boolean shouldRemoveNode = (serviceMetadata.shouldModifyDownStreamQuery() && (hasResolverDirective(fieldDefinition) + || isExternalField(parentType.getName(), node.getName()))) + || shouldIgnoreNode(node, this.queryVariables); + if (shouldRemoveNode) { decreaseParentSelectionSetCount(context.getParentContext()); return deleteNode(context); } else { diff --git a/src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java b/src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java index 14ffcf2d..51c9c8a1 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java +++ b/src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java @@ -6,25 +6,20 @@ import graphql.analysis.QueryVisitorInlineFragmentEnvironment; import graphql.analysis.QueryVisitorStub; import graphql.language.Argument; -import graphql.language.AstTransformer; import graphql.language.Document; import graphql.language.Field; import graphql.language.FragmentDefinition; import graphql.language.FragmentSpread; import graphql.language.InlineFragment; import graphql.language.Node; -import graphql.language.NodeVisitorStub; import graphql.language.OperationDefinition; import graphql.language.Value; import graphql.language.VariableReference; import graphql.schema.GraphQLObjectType; import graphql.schema.GraphQLSchema; -import graphql.util.TraversalControl; -import graphql.util.TraverserContext; import lombok.Getter; import java.util.ArrayList; -import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Set; @@ -36,8 +31,6 @@ */ public class VariableDefinitionFilter { - private static AstTransformer astTransformer = new AstTransformer(); - /** * Traverses a GraphQL Node and returns all VariableReference names used in all nodes in the graph. * @@ -50,17 +43,17 @@ public class VariableDefinitionFilter { * reference indicator prefix '$' will be excluded in the result. */ public Set getVariableReferencesFromNode(GraphQLSchema graphQLSchema, GraphQLObjectType rootType, - Map fragmentsByName, Map variables, Node rootNode) { + Map fragmentsByName, Map variables, Node rootNode) { final VariableReferenceVisitor variableReferenceVisitor = new VariableReferenceVisitor(); //need to utilize a better pattern for creating mockable QueryTraverser/QueryTransformer QueryTraverser queryTraverser = QueryTraverser.newQueryTraverser() - .schema(graphQLSchema) - .rootParentType(rootType) //need to support also for subscription - .fragmentsByName(fragmentsByName) - .variables(variables) - .root(rootNode) - .build(); + .schema(graphQLSchema) + .rootParentType(rootType) //need to support also for subscription + .fragmentsByName(fragmentsByName) + .variables(variables) + .root(rootNode) + .build(); queryTraverser.visitPreOrder(variableReferenceVisitor); @@ -75,28 +68,16 @@ public Set getVariableReferencesFromNode(GraphQLSchema graphQLSchema, Gr Set additionalReferences = operationDirectiveVariableReferences(operationDefinitions); - Stream variableReferenceStream; - if((variableReferenceVisitor.getVariableReferences().size() + additionalReferences.size()) != variables.size()) { - NodeTraverser nodeTraverser = new NodeTraverser(); - astTransformer.transform(rootNode, nodeTraverser); - - variableReferenceStream = Stream.of(variableReferenceVisitor.getVariableReferences(), - additionalReferences, - nodeTraverser.getVariableReferenceExtractor().getVariableReferences()) - .flatMap(Collection::stream); - } else { - variableReferenceStream = Stream.concat(variableReferenceVisitor.getVariableReferences().stream(), additionalReferences.stream()); - } - return variableReferenceStream.map(VariableReference::getName).collect(Collectors.toSet()); - + return Stream.concat(variableReferenceVisitor.getVariableReferences().stream(), additionalReferences.stream()) + .map(VariableReference::getName).collect(Collectors.toSet()); } private Set operationDirectiveVariableReferences(List operationDefinitions) { final List values = operationDefinitions.stream() - .flatMap(operationDefinition -> operationDefinition.getDirectives().stream()) - .flatMap(directive -> directive.getArguments().stream()) - .map(Argument::getValue) - .collect(Collectors.toList()); + .flatMap(operationDefinition -> operationDefinition.getDirectives().stream()) + .flatMap(directive -> directive.getArguments().stream()) + .map(Argument::getValue) + .collect(Collectors.toList()); VariableReferenceExtractor extractor = new VariableReferenceExtractor(); extractor.captureVariableReferences(values); @@ -138,7 +119,7 @@ public void visitField(final QueryVisitorFieldEnvironment env) { } final Stream directiveArgumentStream = field.getDirectives().stream() - .flatMap(directive -> directive.getArguments().stream()); + .flatMap(directive -> directive.getArguments().stream()); final Stream fieldArgumentStream = field.getArguments().stream(); @@ -154,7 +135,7 @@ public void visitInlineFragment(final QueryVisitorInlineFragmentEnvironment env) } Stream arguments = env.getInlineFragment().getDirectives().stream() - .flatMap(directive -> directive.getArguments().stream()); + .flatMap(directive -> directive.getArguments().stream()); captureVariableReferences(arguments); } @@ -169,8 +150,8 @@ public void visitFragmentSpread(final QueryVisitorFragmentSpreadEnvironment env) } final Stream allArguments = Stream.concat( - fragmentDefinition.getDirectives().stream(), - fragmentSpread.getDirectives().stream() + fragmentDefinition.getDirectives().stream(), + fragmentSpread.getDirectives().stream() ).flatMap(directive -> directive.getArguments().stream()); captureVariableReferences(allArguments); @@ -178,24 +159,9 @@ public void visitFragmentSpread(final QueryVisitorFragmentSpreadEnvironment env) private void captureVariableReferences(Stream arguments) { final List values = arguments.map(Argument::getValue) - .collect(Collectors.toList()); + .collect(Collectors.toList()); variableReferenceExtractor.captureVariableReferences(values); } } - - static class NodeTraverser extends NodeVisitorStub { - - @Getter - private final VariableReferenceExtractor variableReferenceExtractor = new VariableReferenceExtractor(); - - public TraversalControl visitArgument(Argument node, TraverserContext context) { - return this.visitNode(node, context); - } - - public TraversalControl visitVariableReference(VariableReference node, TraverserContext context) { - variableReferenceExtractor.captureVariableReference(node); - return this.visitValue(node, context); - } - } } diff --git a/src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java b/src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java index ca20637c..218530ae 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java +++ b/src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java @@ -19,14 +19,10 @@ public Set getVariableReferences() { public void captureVariableReferences(List values) { for (final Value value : values) { - captureVariableReference(value); + doSwitch(value); } } - public void captureVariableReference(Value value) { - doSwitch(value); - } - private void doSwitch(Value value) { if (value instanceof ArrayValue) { handleArrayValue((ArrayValue) value); diff --git a/src/main/java/com/intuit/graphql/orchestrator/utils/QueryDirectivesUtil.java b/src/main/java/com/intuit/graphql/orchestrator/utils/QueryDirectivesUtil.java new file mode 100644 index 00000000..ae825c03 --- /dev/null +++ b/src/main/java/com/intuit/graphql/orchestrator/utils/QueryDirectivesUtil.java @@ -0,0 +1,42 @@ +package com.intuit.graphql.orchestrator.utils; + +import graphql.language.Argument; +import graphql.language.BooleanValue; +import graphql.language.Directive; +import graphql.language.Field; +import graphql.language.Value; +import graphql.language.VariableReference; + +import java.util.Map; +import java.util.Optional; + +public class QueryDirectivesUtil { + + public static boolean shouldIgnoreNode(Field node, Map queryVariables) { + Optional optionalIncludesDir = node.getDirectives("include").stream().findFirst(); + Optional optionalSkipDir = node.getDirectives("skip").stream().findFirst(); + if(optionalIncludesDir.isPresent() || optionalSkipDir.isPresent()) { + if(optionalIncludesDir.isPresent() && (!getIfValue(optionalIncludesDir.get(), queryVariables))) { + return true; + } + return optionalSkipDir.isPresent() && (getIfValue(optionalSkipDir.get(), queryVariables)); + } + + return false; + } + + private static boolean getIfValue(Directive directive, Map queryVariables){ + Argument ifArg = directive.getArgument("if"); + Value ifValue = ifArg.getValue(); + + boolean defaultValue = directive.getName().equals("skip"); + + if(ifValue instanceof VariableReference) { + String variableRefName = ((VariableReference) ifValue).getName(); + return (boolean) queryVariables.getOrDefault(variableRefName, defaultValue); + } else if(ifValue instanceof BooleanValue) { + return ((BooleanValue) ifValue).isValue(); + } + return false; + } +} diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/authorization/AuthDownstreamQueryRedactorVisitorSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/authorization/AuthDownstreamQueryRedactorVisitorSpec.groovy index 7ed843ed..f8b1367a 100644 --- a/src/test/groovy/com/intuit/graphql/orchestrator/authorization/AuthDownstreamQueryRedactorVisitorSpec.groovy +++ b/src/test/groovy/com/intuit/graphql/orchestrator/authorization/AuthDownstreamQueryRedactorVisitorSpec.groovy @@ -54,6 +54,34 @@ class AuthDownstreamQueryRedactorVisitorSpec extends Specification { } """ + def skipQuery = """ query skipQuery(\$shouldSkip: Boolean) { + a { + b1 @skip(if: \$shouldSkip) { + c1 { + s1 + } + } + b2 { + i1 + } + } + } + """ + + def includesQuery = """ query includesQuery(\$shouldInclude: Boolean) { + a { + b1 { + c1 { + s1 + } + } + b2 @include(if: \$shouldInclude) { + i1 + } + } + } + """ + static final Object TEST_AUTH_DATA = "TestAuthDataCanBeAnyObject" Field mockField = Mock() @@ -134,6 +162,171 @@ class AuthDownstreamQueryRedactorVisitorSpec extends Specification { argumentValueResolver.resolve(_, _, _) >> Collections.emptyMap() } + def "skip query with skip directive true removes selection set"() { + given: + Document document = new Parser().parseDocument(skipQuery) + OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0) + Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet()) + GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query") + + Map queryVariables = new HashMap<>() + queryVariables.put("shouldSkip", true) + + AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder() + .rootParentType((GraphQLFieldsContainer) rootFieldParentType) + .fieldAuthorization(mockFieldAuthorization) + .graphQLContext(mockGraphQLContext) + .queryVariables(queryVariables) + .graphQLSchema(testGraphQLSchema) + .selectionCollector(new SelectionCollector(fragmentsByName)) + .serviceMetadata(mockServiceMetadata) + .authData(TEST_AUTH_DATA) + .build() + + when: + Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest) + + then: + transformedField.getName() == "a" + Object[] selectionSet = transformedField.getSelectionSet() + .getSelections() + .asList() + selectionSet.size() == 1 + ((Field)selectionSet.first()).getName() == ("b2") + + 1 * mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(aB2) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(b2i1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap() + mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata + } + + def "skip query with skip directive false keeps selection set"() { + given: + Document document = new Parser().parseDocument(skipQuery) + OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0) + Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet()) + GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query") + + Map queryVariables = new HashMap<>() + queryVariables.put("shouldSkip", false) + + AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder() + .rootParentType((GraphQLFieldsContainer) rootFieldParentType) + .fieldAuthorization(mockFieldAuthorization) + .graphQLContext(mockGraphQLContext) + .queryVariables(queryVariables) + .graphQLSchema(testGraphQLSchema) + .selectionCollector(new SelectionCollector(fragmentsByName)) + .serviceMetadata(mockServiceMetadata) + .authData(TEST_AUTH_DATA) + .build() + + when: + Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest) + + then: + transformedField.getName() == "a" + Object[] selectionSet = transformedField.getSelectionSet() + .getSelections() + .asList() + selectionSet.size() == 2 + ((Field)selectionSet[0]).getName() == "b1" + ((Field)selectionSet[1]).getName() == "b2" + + 1 * mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(aB1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(b1C1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(c1S1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(aB2) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(b2i1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap() + mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata + } + + def "includes query with include directive true keeps selection set"() { + given: + Document document = new Parser().parseDocument(includesQuery) + OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0) + Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet()) + GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query") + + Map queryVariables = new HashMap<>() + queryVariables.put("shouldInclude", true) + + AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder() + .rootParentType((GraphQLFieldsContainer) rootFieldParentType) + .fieldAuthorization(mockFieldAuthorization) + .graphQLContext(mockGraphQLContext) + .queryVariables(queryVariables) + .graphQLSchema(testGraphQLSchema) + .selectionCollector(new SelectionCollector(fragmentsByName)) + .serviceMetadata(mockServiceMetadata) + .authData(TEST_AUTH_DATA) + .build() + + when: + Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest) + + then: + transformedField.getName() == "a" + Object[] selectionSet = transformedField.getSelectionSet() + .getSelections() + .asList() + selectionSet.size() == 2 + ((Field)selectionSet[0]).getName() == "b1" + ((Field)selectionSet[1]).getName() == "b2" + + 1 * mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(aB1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(b1C1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(c1S1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(aB2) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(b2i1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap() + mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata + } + + def "includes query with include directive false removes selection set"() { + given: + Document document = new Parser().parseDocument(includesQuery) + OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0) + Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet()) + GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query") + + Map queryVariables = new HashMap<>() + queryVariables.put("shouldInclude", false) + + AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder() + .rootParentType((GraphQLFieldsContainer) rootFieldParentType) + .fieldAuthorization(mockFieldAuthorization) + .graphQLContext(mockGraphQLContext) + .queryVariables(queryVariables) + .graphQLSchema(testGraphQLSchema) + .selectionCollector(new SelectionCollector(fragmentsByName)) + .serviceMetadata(mockServiceMetadata) + .authData(TEST_AUTH_DATA) + .build() + + when: + Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest) + + then: + transformedField.getName() == "a" + Object[] selectionSet = transformedField.getSelectionSet() + .getSelections() + .asList() + selectionSet.size() == 1 + ((Field)selectionSet[0]).getName() == "b1" + + 1 * mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(aB1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(b1C1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + 1 * mockFieldAuthorization.authorize(c1S1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT + mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap() + mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata + } + def "redact query, results to empty selection set"() { given: diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy index 67e11fc5..d4258f94 100644 --- a/src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy +++ b/src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy @@ -47,17 +47,6 @@ class VariableDefinitionFilterSpec extends Specification { directive @field_directive_argument(arg: InputObject) on FIELD_DEFINITION ''' - private String schema2 = ''' - type Query { person: Person } - - type Person { - address : Address - id: String - } - - type Address { city: String state: String zip: String } - ''' - private GraphQLSchema graphQLSchema private VariableDefinitionFilter variableDefinitionFilter @@ -74,12 +63,6 @@ class VariableDefinitionFilterSpec extends Specification { RuntimeWiring.newRuntimeWiring().build()) } - private GraphQLSchema getSchema2() { - return new SchemaGenerator() - .makeExecutableSchema(new SchemaParser().parse(schema2), - RuntimeWiring.newRuntimeWiring().build()) - } - private Map getFragmentsByName(Document document) { return document.getDefinitionsOfType(FragmentDefinition.class).stream() .inject([:]) {map, it -> map << [(it.getName()): it]} @@ -196,62 +179,6 @@ class VariableDefinitionFilterSpec extends Specification { results.containsAll("int_arg", "string_arg") } - def "variable References In Built in Query Directive includes"() { - given: - String query = ''' - query($includeContext: Boolean!) { - consumer { - liabilities(arg: 1) @include(if: $includeContext) { - totalDebt(arg: 1) - } - income - } - } - ''' - - Document document = parser.parseDocument(query) - HashMap variables = new HashMap<>() - variables.put("includeContext", false) - - when: - final Set results = variableDefinitionFilter - .getVariableReferencesFromNode(graphQLSchema, graphQLSchema.getQueryType(), Collections.emptyMap(), - variables, document) - - then: - results.size() == 1 - - results.containsAll("includeContext") - } - - def "variable References In Built in Query Directive skip"() { - given: - String query = ''' - query($includeContext: Boolean!) { - consumer { - liabilities(arg: 1) @skip(if: $includeContext) { - totalDebt(arg: 1) - } - income - } - } - ''' - - Document document = parser.parseDocument(query) - HashMap variables = new HashMap<>() - variables.put("includeContext", true) - - when: - final Set results = variableDefinitionFilter - .getVariableReferencesFromNode(graphQLSchema, graphQLSchema.getQueryType(), Collections.emptyMap(), - variables, document) - - then: - results.size() == 1 - - results.containsAll("includeContext") - } - def "test Negative Cases"() { given: final String negativeTestCaseQuery = "query { consumer { liabilities { totalDebt(arg: 1234) } } }" diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/utils/QueryDirectivesUtilSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/utils/QueryDirectivesUtilSpec.groovy new file mode 100644 index 00000000..f32cc226 --- /dev/null +++ b/src/test/groovy/com/intuit/graphql/orchestrator/utils/QueryDirectivesUtilSpec.groovy @@ -0,0 +1,133 @@ +package com.intuit.graphql.orchestrator.utils + + +import graphql.language.* +import spock.lang.Specification + +class QueryDirectivesUtilSpec extends Specification { + BooleanValue booleanTrue = BooleanValue.newBooleanValue(true).build() + BooleanValue booleanFalse = BooleanValue.newBooleanValue(false).build() + VariableReference variableReference = VariableReference.newVariableReference().name("variable").build() + + Argument ifTrueArg = Argument.newArgument("if", booleanTrue).build() + Argument ifFalseArg = Argument.newArgument("if", booleanFalse).build() + Argument ifRefArg = Argument.newArgument("if", variableReference).build() + + Directive skipTrueDirective = Directive.newDirective().name("skip").argument(ifTrueArg).build() + Directive skipFalseDirective = Directive.newDirective().name("skip").argument(ifFalseArg).build() + Directive skipRefDirective = Directive.newDirective().name("skip").argument(ifRefArg).build() + + Directive includeTrueDirective = Directive.newDirective().name("include").argument(ifTrueArg).build() + Directive includeFalseDirective = Directive.newDirective().name("include").argument(ifFalseArg).build() + Directive includeRefDirective = Directive.newDirective().name("include").argument(ifRefArg).build() + + def "shouldIgnoreNode node without skip and includes returns false"(){ + given: + Field node = Field.newField("test").build() + + when: + boolean result = QueryDirectivesUtil.shouldIgnoreNode(node, new HashMap<>()) + then: + !result + } + def "shouldIgnoreNode node with skip and if as true returns true"(){ + given: + Field node = Field.newField("test").directive(skipTrueDirective).build() + + when: + boolean result = QueryDirectivesUtil.shouldIgnoreNode(node, new HashMap<>()) + then: + result + } + def "shouldIgnoreNode node with skip and if as false returns false"(){ + given: + Field node = Field.newField("test").directive(skipFalseDirective).build() + + when: + boolean result = QueryDirectivesUtil.shouldIgnoreNode(node, new HashMap<>()) + then: + !result + } + def "shouldIgnoreNode node with skip and if as ref as true returns true"(){ + given: + Field node = Field.newField("test").directive(skipRefDirective).build() + + when: + HashMap variableMap = new HashMap<>() + variableMap.put("variable", true) + boolean result = QueryDirectivesUtil.shouldIgnoreNode(node, variableMap) + then: + result + } + def "shouldIgnoreNode node with skip and if as ref as false returns false"(){ + given: + Field node = Field.newField("test").directive(skipRefDirective).build() + + when: + HashMap variableMap = new HashMap<>() + variableMap.put("variable", false) + boolean result = QueryDirectivesUtil.shouldIgnoreNode(node, variableMap) + then: + !result + } + def "shouldIgnoreNode node with skip and if as ref as null returns true"(){ + given: + Field node = Field.newField("test").directive(skipRefDirective).build() + + when: + HashMap variableMap = new HashMap<>() + boolean result = QueryDirectivesUtil.shouldIgnoreNode(node, variableMap) + then: + result + } + def "shouldIgnoreNode node with include and if as true returns false"(){ + given: + Field node = Field.newField("test").directive(includeTrueDirective).build() + + when: + boolean result = QueryDirectivesUtil.shouldIgnoreNode(node, new HashMap<>()) + then: + !result + } + def "shouldIgnoreNode node with include and if as false returns true"(){ + given: + Field node = Field.newField("test").directive(includeFalseDirective).build() + + when: + boolean result = QueryDirectivesUtil.shouldIgnoreNode(node, new HashMap<>()) + then: + result + } + def "shouldIgnoreNode node with include and if as ref as true returns false"(){ + given: + Field node = Field.newField("test").directive(includeRefDirective).build() + + when: + HashMap variableMap = new HashMap<>() + variableMap.put("variable", true) + boolean result = QueryDirectivesUtil.shouldIgnoreNode(node, variableMap) + then: + !result + } + def "shouldIgnoreNode node with include and if as ref as false returns true"(){ + given: + Field node = Field.newField("test").directive(includeRefDirective).build() + + when: + HashMap variableMap = new HashMap<>() + variableMap.put("variable", false) + boolean result = QueryDirectivesUtil.shouldIgnoreNode(node, variableMap) + then: + result + } + def "shouldIgnoreNode node with include and if as ref as null returns true"(){ + given: + Field node = Field.newField("test").directive(includeRefDirective).build() + + when: + HashMap variableMap = new HashMap<>() + boolean result = QueryDirectivesUtil.shouldIgnoreNode(node, variableMap) + then: + result + } +}