Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for include and skip directive with references #188

Merged
merged 2 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
package com.intuit.graphql.orchestrator.batch;

import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.hasResolverDirective;
import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.getNodesAsPathList;
import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.pathListToFQN;
import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.convertGraphqlFieldWithOriginalName;
import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.getRenameKey;
import static graphql.introspection.Introspection.TypeNameMetaFieldDef;
import static graphql.schema.FieldCoordinates.coordinates;
import static graphql.util.TreeTransformerUtil.changeNode;
import static graphql.util.TreeTransformerUtil.deleteNode;
import static java.util.Objects.nonNull;
import static java.util.Objects.requireNonNull;

import com.intuit.graphql.orchestrator.authorization.FieldAuthorization;
import com.intuit.graphql.orchestrator.authorization.FieldAuthorizationEnvironment;
import com.intuit.graphql.orchestrator.authorization.FieldAuthorizationResult;
Expand Down Expand Up @@ -40,17 +28,31 @@
import graphql.schema.GraphQLUnionType;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;
import lombok.Builder;
import lombok.NonNull;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Builder;
import lombok.NonNull;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;

import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.hasResolverDirective;
import static com.intuit.graphql.orchestrator.utils.QueryDirectivesUtil.shouldIgnoreNode;
import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.getNodesAsPathList;
import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.pathListToFQN;
import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.convertGraphqlFieldWithOriginalName;
import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.getRenameKey;
import static graphql.introspection.Introspection.TypeNameMetaFieldDef;
import static graphql.schema.FieldCoordinates.coordinates;
import static graphql.util.TreeTransformerUtil.changeNode;
import static graphql.util.TreeTransformerUtil.deleteNode;
import static java.util.Objects.nonNull;
import static java.util.Objects.requireNonNull;

/**
* This class modifies for query for a downstream provider.
Expand Down Expand Up @@ -91,6 +93,12 @@ public TraversalControl visitField(Field node, TraverserContext<Node> context) {
requireNonNull(fieldDefinition, "Failed to get Field Definition for " + node.getName());

context.setVar(GraphQLType.class, fieldDefinition.getType());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add integration test i.e. send query to graphql orchestrator and validate the outgoing query.

if(shouldIgnoreNode(node, this.queryVariables)) {
decreaseParentSelectionSetCount(context.getParentContext());
return deleteNode(context);
}

FieldAuthorizationResult fieldAuthorizationResult = authorize(node, fieldDefinition, parentType, context);
if (!fieldAuthorizationResult.isAllowed()) {
decreaseParentSelectionSetCount(context.getParentContext());
Expand All @@ -112,8 +120,10 @@ public TraversalControl visitField(Field node, TraverserContext<Node> context) {
GraphQLFieldDefinition fieldDefinition = getFieldDefinition(node.getName(), parentType);
requireNonNull(fieldDefinition, "Failed to get Field Definition for " + node.getName());

if (serviceMetadata.shouldModifyDownStreamQuery() && (hasResolverDirective(fieldDefinition)
|| isExternalField(parentType.getName(), node.getName()))) {
boolean shouldRemoveNode = (serviceMetadata.shouldModifyDownStreamQuery() && (hasResolverDirective(fieldDefinition)
|| isExternalField(parentType.getName(), node.getName())))
|| shouldIgnoreNode(node, this.queryVariables);
if (shouldRemoveNode) {
decreaseParentSelectionSetCount(context.getParentContext());
return deleteNode(context);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,20 @@
import graphql.analysis.QueryVisitorInlineFragmentEnvironment;
import graphql.analysis.QueryVisitorStub;
import graphql.language.Argument;
import graphql.language.AstTransformer;
import graphql.language.Document;
import graphql.language.Field;
import graphql.language.FragmentDefinition;
import graphql.language.FragmentSpread;
import graphql.language.InlineFragment;
import graphql.language.Node;
import graphql.language.NodeVisitorStub;
import graphql.language.OperationDefinition;
import graphql.language.Value;
import graphql.language.VariableReference;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLSchema;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;
import lombok.Getter;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -36,8 +31,6 @@
*/
public class VariableDefinitionFilter {

private static AstTransformer astTransformer = new AstTransformer();

/**
* Traverses a GraphQL Node and returns all VariableReference names used in all nodes in the graph.
*
Expand All @@ -50,17 +43,17 @@ public class VariableDefinitionFilter {
* reference indicator prefix '$' will be <b>excluded</b> in the result.
*/
public Set<String> getVariableReferencesFromNode(GraphQLSchema graphQLSchema, GraphQLObjectType rootType,
Map<String, FragmentDefinition> fragmentsByName, Map<String, Object> variables, Node<?> rootNode) {
Map<String, FragmentDefinition> fragmentsByName, Map<String, Object> variables, Node<?> rootNode) {
final VariableReferenceVisitor variableReferenceVisitor = new VariableReferenceVisitor();

//need to utilize a better pattern for creating mockable QueryTraverser/QueryTransformer
QueryTraverser queryTraverser = QueryTraverser.newQueryTraverser()
.schema(graphQLSchema)
.rootParentType(rootType) //need to support also for subscription
.fragmentsByName(fragmentsByName)
.variables(variables)
.root(rootNode)
.build();
.schema(graphQLSchema)
.rootParentType(rootType) //need to support also for subscription
.fragmentsByName(fragmentsByName)
.variables(variables)
.root(rootNode)
.build();

queryTraverser.visitPreOrder(variableReferenceVisitor);

Expand All @@ -75,28 +68,16 @@ public Set<String> getVariableReferencesFromNode(GraphQLSchema graphQLSchema, Gr

Set<VariableReference> additionalReferences = operationDirectiveVariableReferences(operationDefinitions);

Stream<VariableReference> variableReferenceStream;
if((variableReferenceVisitor.getVariableReferences().size() + additionalReferences.size()) != variables.size()) {
NodeTraverser nodeTraverser = new NodeTraverser();
astTransformer.transform(rootNode, nodeTraverser);

variableReferenceStream = Stream.of(variableReferenceVisitor.getVariableReferences(),
additionalReferences,
nodeTraverser.getVariableReferenceExtractor().getVariableReferences())
.flatMap(Collection::stream);
} else {
variableReferenceStream = Stream.concat(variableReferenceVisitor.getVariableReferences().stream(), additionalReferences.stream());
}
return variableReferenceStream.map(VariableReference::getName).collect(Collectors.toSet());

return Stream.concat(variableReferenceVisitor.getVariableReferences().stream(), additionalReferences.stream())
.map(VariableReference::getName).collect(Collectors.toSet());
}

private Set<VariableReference> operationDirectiveVariableReferences(List<OperationDefinition> operationDefinitions) {
final List<Value> values = operationDefinitions.stream()
.flatMap(operationDefinition -> operationDefinition.getDirectives().stream())
.flatMap(directive -> directive.getArguments().stream())
.map(Argument::getValue)
.collect(Collectors.toList());
.flatMap(operationDefinition -> operationDefinition.getDirectives().stream())
.flatMap(directive -> directive.getArguments().stream())
.map(Argument::getValue)
.collect(Collectors.toList());

VariableReferenceExtractor extractor = new VariableReferenceExtractor();
extractor.captureVariableReferences(values);
Expand Down Expand Up @@ -138,7 +119,7 @@ public void visitField(final QueryVisitorFieldEnvironment env) {
}

final Stream<Argument> directiveArgumentStream = field.getDirectives().stream()
.flatMap(directive -> directive.getArguments().stream());
.flatMap(directive -> directive.getArguments().stream());

final Stream<Argument> fieldArgumentStream = field.getArguments().stream();

Expand All @@ -154,7 +135,7 @@ public void visitInlineFragment(final QueryVisitorInlineFragmentEnvironment env)
}

Stream<Argument> arguments = env.getInlineFragment().getDirectives().stream()
.flatMap(directive -> directive.getArguments().stream());
.flatMap(directive -> directive.getArguments().stream());

captureVariableReferences(arguments);
}
Expand All @@ -169,33 +150,18 @@ public void visitFragmentSpread(final QueryVisitorFragmentSpreadEnvironment env)
}

final Stream<Argument> allArguments = Stream.concat(
fragmentDefinition.getDirectives().stream(),
fragmentSpread.getDirectives().stream()
fragmentDefinition.getDirectives().stream(),
fragmentSpread.getDirectives().stream()
).flatMap(directive -> directive.getArguments().stream());

captureVariableReferences(allArguments);
}

private void captureVariableReferences(Stream<Argument> arguments) {
final List<Value> values = arguments.map(Argument::getValue)
.collect(Collectors.toList());
.collect(Collectors.toList());

variableReferenceExtractor.captureVariableReferences(values);
}
}

static class NodeTraverser extends NodeVisitorStub {

@Getter
private final VariableReferenceExtractor variableReferenceExtractor = new VariableReferenceExtractor();

public TraversalControl visitArgument(Argument node, TraverserContext<Node> context) {
return this.visitNode(node, context);
}

public TraversalControl visitVariableReference(VariableReference node, TraverserContext<Node> context) {
variableReferenceExtractor.captureVariableReference(node);
return this.visitValue(node, context);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@ public Set<VariableReference> getVariableReferences() {

public void captureVariableReferences(List<Value> values) {
for (final Value value : values) {
captureVariableReference(value);
doSwitch(value);
}
}

public void captureVariableReference(Value value) {
doSwitch(value);
}

private void doSwitch(Value value) {
if (value instanceof ArrayValue) {
handleArrayValue((ArrayValue) value);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.intuit.graphql.orchestrator.utils;

import graphql.language.Argument;
import graphql.language.BooleanValue;
import graphql.language.Directive;
import graphql.language.Field;
import graphql.language.Value;
import graphql.language.VariableReference;

import java.util.Map;
import java.util.Optional;

public class QueryDirectivesUtil {

public static boolean shouldIgnoreNode(Field node, Map<String, Object> queryVariables) {
Optional<Directive> optionalIncludesDir = node.getDirectives("include").stream().findFirst();
Optional<Directive> optionalSkipDir = node.getDirectives("skip").stream().findFirst();
if(optionalIncludesDir.isPresent() || optionalSkipDir.isPresent()) {
if(optionalIncludesDir.isPresent() && (!getIfValue(optionalIncludesDir.get(), queryVariables))) {
return true;
}
return optionalSkipDir.isPresent() && (getIfValue(optionalSkipDir.get(), queryVariables));
}

return false;
}

private static boolean getIfValue(Directive directive, Map<String, Object> queryVariables){
Argument ifArg = directive.getArgument("if");
Value ifValue = ifArg.getValue();

boolean defaultValue = directive.getName().equals("skip");

if(ifValue instanceof VariableReference) {
String variableRefName = ((VariableReference) ifValue).getName();
return (boolean) queryVariables.getOrDefault(variableRefName, defaultValue);
} else if(ifValue instanceof BooleanValue) {
return ((BooleanValue) ifValue).isValue();
}
return false;
}
}
Loading
Loading