Skip to content

Commit

Permalink
Use a ListSet to backup IsClosedForm
Browse files Browse the repository at this point in the history
  • Loading branch information
umazalakain committed May 16, 2021
1 parent 9346a78 commit 69e4d97
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/main/scala/rise/core/DSL/infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object infer {
val (typed_e, constraints) = constrainTypes(exprEnv)(e_wo_assertions)
// Collect free variables both before and after constraint gathering:
// Some types in the collected constraints might not be present after constraint gathering (Fixme: BUG?)
val ftvs = IsClosedForm.freeVars(e_wo_assertions)._2.toSet ++ IsClosedForm.freeVars(typed_e)._2.toSet
val ftvs = IsClosedForm.freeVars(e_wo_assertions)._2 ++ IsClosedForm.freeVars(typed_e)._2
val toSubstitute = (ftvs removedAll e_preserve) removedAll typeEnv
// Solve constraints while preserving the FTVs in preserve
val solution = Constraint.solve(constraints, toSubstitute, Seq())
Expand Down
42 changes: 21 additions & 21 deletions src/main/scala/rise/core/IsClosedForm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ import arithexpr.arithmetic.NamedVar
import rise.core.traverse._
import rise.core.types._

import scala.collection.immutable.ListSet

object IsClosedForm {
case class Visitor(boundV: Seq[Identifier], boundT: Seq[Kind.Identifier])
extends PureAccumulatorTraversal[(Seq[Identifier], Seq[Kind.Identifier])]
// Use a ListSet to keep the accumulator both small and deterministic
case class Visitor(boundV: Set[Identifier], boundT: Set[Kind.Identifier])
extends PureAccumulatorTraversal[(Set[Identifier], Set[Kind.Identifier])]
{
override val accumulator = PairMonoid(SeqMonoid, SeqMonoid)
override val accumulator = PairMonoid(SetMonoid, SetMonoid)

override def identifier[I <: Identifier]: VarType => I => Pair[I] = vt => i => {
for { t2 <- `type`(i.t);
i2 <- if (vt == Reference && !boundV.contains(i)) {
accumulate((Seq(i), Seq()))(i)
accumulate((ListSet(i), ListSet()))(i)
} else {
return_(i)
}}
Expand All @@ -23,68 +26,65 @@ object IsClosedForm {

override def typeIdentifier[I <: Kind.Identifier]: VarType => I => Pair[I] = {
case Reference => i =>
if (boundT.contains(i)) return_(i) else accumulate((Seq(), Seq(i)))(i)
if (boundT.contains(i)) return_(i) else accumulate((ListSet(), ListSet(i)))(i)
case _ => return_
}

override def nat: Nat => Pair[Nat] = n => {
val free = n.varList.foldLeft(Seq[Kind.Identifier]()) {
case (free, v: NamedVar) if !boundT.contains(NatIdentifier(v)) => NatIdentifier(v) +: free
val free = n.varList.foldLeft(ListSet[Kind.Identifier]()) {
case (free, v: NamedVar) if !boundT(NatIdentifier(v)) => free + NatIdentifier(v)
case (free, _) => free
}
accumulate((Seq(), free))(n)
accumulate((ListSet(), free))(n)
}

override def expr: Expr => Pair[Expr] = {
case l@Lambda(x, e) =>
// The binder's type itself might contain free type variables
val ((fVx, fTx), x1) = identifier(Binding)(x).unwrap
val ((fVe, fTe), e1) = this.copy(boundV = x1 +: boundV).expr(e).unwrap
val ((fVe, fTe), e1) = this.copy(boundV = boundV + x1).expr(e).unwrap
val ((fVt, fTt), t1) = `type`(l.t).unwrap
accumulate((fVx ++ fVe ++ fVt, fTx ++ fTe ++ fTt))(Lambda(x1, e1)(t1) : Expr)
case DepLambda(x, b) => this.copy(boundT = x +: boundT).expr(b)
case DepLambda(x, b) => this.copy(boundT = boundT + x).expr(b)
case e => super.expr(e)
}

override def natToData: NatToData => Pair[NatToData] = {
case NatToDataLambda(x, e) =>
for { p <- this.copy(boundT = x +: boundT).`type`(e) }
for { p <- this.copy(boundT = boundT + x).`type`(e) }
yield (p._1, NatToDataLambda(x, e))
case t => super.natToData(t)
}

override def natToNat: NatToNat => Pair[NatToNat] = {
case NatToNatLambda(x, n) =>
for { p <- this.copy(boundT = x +: boundT).nat(n) }
for { p <- this.copy(boundT = boundT + x).nat(n) }
yield (p._1, NatToNatLambda(x, n))
case n => super.natToNat(n)
}

override def `type`[T <: Type]: T => Pair[T] = {
case d@DepFunType(x, t) =>
for { p <- this.copy(boundT = x +: boundT).`type`(t) }
for { p <- this.copy(boundT = boundT + x).`type`(t) }
yield (p._1, d.asInstanceOf[T])
case d@DepPairType(x, dt) =>
for { p <- this.copy(boundT = x +: boundT).datatype(dt) }
for { p <- this.copy(boundT = boundT + x).datatype(dt) }
yield (p._1, d.asInstanceOf[T])
case t => super.`type`(t)
}
}

def freeVars(expr: Expr): (Seq[Identifier], Seq[Kind.Identifier]) =
traverse(expr, Visitor(Seq(), Seq()))._1

def freeVars(t: Type): (Seq[Identifier], Seq[Kind.Identifier]) =
traverse(t, Visitor(Seq(), Seq()))._1
def freeVars(expr: Expr): (Set[Identifier], Set[Kind.Identifier]) =
traverse(expr, Visitor(ListSet(), ListSet()))._1

def varsToClose(expr : Expr): (Seq[Identifier], Seq[Kind.Identifier]) = {
val (fV, fT) = freeVars(expr)
// Exclude matrix layout and fragment kind identifiers, since they cannot currently be bound
(fV, fT.flatMap {
(fV.toSeq, fT.flatMap {
case i : MatrixLayoutIdentifier => Seq()
case i : FragmentKindIdentifier => Seq()
case e => Seq(e)
})
}.toSeq)
}

def apply(expr: Expr): Boolean = {
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/rise/eqsat/NamedRewrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ object NamedRewrite {
}
val typedLhs = infer(lhs, untypedFreeV, Set())
val freeV = infer.collectFreeEnv(typedLhs)
val freeT = rise.core.IsClosedForm.freeVars(typedLhs)._2.toSet
val freeT = rise.core.IsClosedForm.freeVars(typedLhs)._2
val typedRhs = infer(rc.TypeAnnotation(rhs, typedLhs.t), freeV, freeT)

trait PatVarStatus
Expand Down

0 comments on commit 69e4d97

Please sign in to comment.