diff --git a/modules/server/src/main/scala/gql/server/interpreter/SubgraphBatches.scala b/modules/server/src/main/scala/gql/server/interpreter/SubgraphBatches.scala index 339acf59c..a49d15176 100644 --- a/modules/server/src/main/scala/gql/server/interpreter/SubgraphBatches.scala +++ b/modules/server/src/main/scala/gql/server/interpreter/SubgraphBatches.scala @@ -35,11 +35,11 @@ object SubgraphBatches { final case class BatchNodeId(id: NodeId) final case class State( - childBatches: List[BatchNodeId], - accum: Map[MulitplicityNode, List[BatchNodeId]] + childBatches: Set[BatchNodeId], + accum: Map[MulitplicityNode, Set[BatchNodeId]] ) object State { - def empty = State(List.empty, Map.empty) + def empty = State(Set.empty, Map.empty) implicit val monoid: Monoid[State] = new Monoid[State] { def empty = State.empty def combine(x: State, y: State) = State(x.childBatches ++ y.childBatches, x.accum ++ y.accum) @@ -54,16 +54,18 @@ object SubgraphBatches { case Compose(_, l, r) => countStep(state, r).flatMap(countStep(_, l)) case alg: Choose[F, ?, ?, ?, ?] => - val s1F = countStep(state, alg.fac).map { s => - s.copy(accum = s.accum + (MulitplicityNode(alg.fac.nodeId) -> s.childBatches)) - } - val s2F = countStep(state, alg.fbd).map { s => - s.copy(accum = s.accum + (MulitplicityNode(alg.fbd.nodeId) -> s.childBatches)) - } - (s1F, s2F).mapN(_ |+| _) + for { + s1 <- countStep(state, alg.fac) + s2 <- countStep(state, alg.fbd) + s1Unique = s1.childBatches -- s2.childBatches + s1Out = s1.copy(accum = s1.accum + (MulitplicityNode(alg.fac.nodeId) -> s1Unique)) + + s2Unique = s2.childBatches -- s1.childBatches + s2Out = s2.copy(accum = s2.accum + (MulitplicityNode(alg.fbd.nodeId) -> s2Unique)) + } yield s1Out |+| s2Out case alg: First[F, ?, ?, ?] => countStep(state, alg.step) - case alg: Batch[F, ?, ?] => Eval.now(state.copy(childBatches = BatchNodeId(alg.nodeId) :: state.childBatches)) - case alg: InlineBatch[F, ?, ?] => Eval.now(state.copy(childBatches = BatchNodeId(alg.nodeId) :: state.childBatches)) + case alg: Batch[F, ?, ?] => Eval.now(state.copy(childBatches = state.childBatches + BatchNodeId(alg.nodeId))) + case alg: InlineBatch[F, ?, ?] => Eval.now(state.copy(childBatches = state.childBatches + BatchNodeId(alg.nodeId))) } } @@ -80,7 +82,7 @@ object SubgraphBatches { def countPrep[F[_]](prep: Prepared[F, ?]): Eval[State] = Eval.defer { prep match { - case PreparedLeaf(_, _, _) => Eval.now(State(Nil, Map.empty)) + case PreparedLeaf(_, _, _) => Eval.now(State(Set.empty, Map.empty)) case Selection(_, fields, _) => fields.foldMapA(countField(_)) case PreparedList(id, of, _) => countCont(of.edges, of.cont).map(s => s.copy(accum = s.accum + (MulitplicityNode(id) -> s.childBatches))) @@ -130,7 +132,7 @@ object SubgraphBatches { } .map(_.toMap) val inlineBatchIds: F[Map[BatchNodeId, Ref[F, BatchFamily[F, ?, ?]]]] = - (countState.childBatches.toSet -- groups.map { case (_, vs) => vs.toList }.flatten).toList + (countState.childBatches -- groups.map { case (_, vs) => vs.toList }.flatten).toList .traverse { id => F.ref[BatchFamily[F, ?, ?]](BatchFamily(1, Set.empty, Nil, Chain.empty)).tupleLeft(id) } @@ -195,7 +197,7 @@ object SubgraphBatches { val toAdd = n - 1 countState.accum .get(MulitplicityNode(mulId)) - .traverse_(_.traverse_ { id => + .traverse_(_.toList.traverse_ { id => val ref = batches(id) ref.update { case bf => bf.copy(pendingInputs = bf.pendingInputs + toAdd)