Skip to content

Commit

Permalink
Track new identifier explicitness through the preserve set
Browse files Browse the repository at this point in the history
  • Loading branch information
umazalakain committed May 7, 2021
1 parent 7f84b7c commit 7ad1001
Showing 1 changed file with 19 additions and 22 deletions.
41 changes: 19 additions & 22 deletions src/main/scala/rise/core/types/Constraints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ object Constraint {
// scalastyle:off method.length
def solveOne(c: Constraint, preserve: Set[String], trace: Seq[Constraint]) (implicit explDep: Flags.ExplicitDependence): Solution = {
implicit val _trace: Seq[Constraint] = trace
def decomposed(cs: Seq[Constraint]) = solve(cs, preserve, c +: trace)
def decomposed(cs: Seq[Constraint], preserve: Set[String]) = solve(cs, preserve, c +: trace)

c match {
case TypeConstraint(a, b) =>
Expand All @@ -93,22 +93,21 @@ object Constraint {
_: IndexType | _: VectorType)
if a =~= b => Solution()
case (IndexType(sa), IndexType(sb)) =>
decomposed(Seq(NatConstraint(sa, sb)))
decomposed(Seq(NatConstraint(sa, sb)), preserve)
case (ArrayType(sa, ea), ArrayType(sb, eb)) =>
decomposed(Seq(NatConstraint(sa, sb), TypeConstraint(ea, eb)))
decomposed(Seq(NatConstraint(sa, sb), TypeConstraint(ea, eb)), preserve)
case (VectorType(sa, ea), VectorType(sb, eb)) =>
decomposed(Seq(NatConstraint(sa, sb), TypeConstraint(ea, eb)))
decomposed(Seq(NatConstraint(sa, sb), TypeConstraint(ea, eb)), preserve)
case (FragmentType(rowsa, columnsa, d3a, dta, fragTypea, layouta), FragmentType(rowsb, columnsb, d3b, dtb, fragTypeb, layoutb)) =>
decomposed(Seq(NatConstraint(rowsa, rowsb), NatConstraint(columnsa, columnsb), NatConstraint(d3a, d3b),
TypeConstraint(dta, dtb), FragmentTypeConstraint(fragTypea, fragTypeb), MatrixLayoutConstraint(layouta, layoutb)))
TypeConstraint(dta, dtb), FragmentTypeConstraint(fragTypea, fragTypeb), MatrixLayoutConstraint(layouta, layoutb)),
preserve)
case (DepArrayType(sa, ea), DepArrayType(sb, eb)) =>
decomposed(Seq(NatConstraint(sa, sb), NatToDataConstraint(ea, eb)))
decomposed(Seq(NatConstraint(sa, sb), NatToDataConstraint(ea, eb)), preserve)
case (PairType(pa1, pa2), PairType(pb1, pb2)) =>
decomposed(Seq(TypeConstraint(pa1, pb1), TypeConstraint(pa2, pb2)))
decomposed(Seq(TypeConstraint(pa1, pb1), TypeConstraint(pa2, pb2)), preserve)
case (FunType(ina, outa), FunType(inb, outb)) =>
decomposed(
Seq(TypeConstraint(ina, inb), TypeConstraint(outa, outb))
)
decomposed( Seq(TypeConstraint(ina, inb), TypeConstraint(outa, outb)), preserve)
case (
DepFunType(na: NatIdentifier, ta),
DepFunType(nb: NatIdentifier, tb)
Expand All @@ -124,24 +123,23 @@ object Constraint {
* initial constrain-types phase?
*/
val (nTa, nTaSub) = dependence.explicitlyDependent(
substitute.natInType(n, `for`=na, ta), n, preserve)
substitute.natInType(n, `for`=na, ta), n, preserve + n.name)
val (nTb, nTbSub) = dependence.explicitlyDependent(
substitute.natInType(n, `for`= nb, tb), n, preserve)
substitute.natInType(n, `for`= nb, tb), n, preserve + n.name)
nTaSub ++ nTbSub ++ decomposed(
Seq(
NatConstraint(n, na),
NatConstraint(n, nb),
TypeConstraint(nTa, nTb)
))
), preserve + n.name - na.name - nb.name)
case ExplicitDependence.Off =>
val n = NatIdentifier(freshName("n"))
decomposed(
Seq(
NatConstraint(n, na),
NatConstraint(n, nb),
TypeConstraint(ta, tb)
)
)
), preserve + n.name - na.name - nb.name)
}
case (
DepFunType(dta: DataTypeIdentifier, ta),
Expand All @@ -153,8 +151,7 @@ object Constraint {
TypeConstraint(dt, dta),
TypeConstraint(dt, dtb),
TypeConstraint(ta, tb)
)
)
), preserve + dt.name - dta.name - dtb.name)
case (
DepFunType(_: AddressSpaceIdentifier, _),
DepFunType(_: AddressSpaceIdentifier, _)
Expand All @@ -171,7 +168,7 @@ object Constraint {
NatConstraint(n, x1),
NatConstraint(n, x2),
TypeConstraint(t1, t2)
))
), preserve + n.name - x1.name - x2.name)

case (
DepPairType(x1: NatCollectionIdentifier, t1),
Expand All @@ -183,7 +180,7 @@ object Constraint {
NatCollectionConstraint(n, x1),
NatCollectionConstraint(n, x2),
TypeConstraint(t1, t2)
))
), preserve + n.name - x1.name - x2.name)

case (
NatToDataApply(f: NatToDataIdentifier, _),
Expand Down Expand Up @@ -219,7 +216,7 @@ object Constraint {
df match {
case _: DepFunType[_, _] =>
val applied = liftDependentFunctionType(df)(arg)
decomposed(Seq(TypeConstraint(applied, t)))
decomposed(Seq(TypeConstraint(applied, t)), preserve)
case _ =>
error(s"expected a dependent function type, but got $df")
}
Expand All @@ -237,7 +234,7 @@ object Constraint {
NatConstraint(n, x1),
NatConstraint(n, x2),
TypeConstraint(dt1, dt2)
))
), preserve + n.name - x1.name - x2.name)

case _ => error(s"cannot unify $a and $b")
}
Expand Down Expand Up @@ -511,7 +508,7 @@ object Constraint {
nat.unify(
substitute.natInNat(n, `for` = x1, body1),
substitute.natInNat(n, `for`=x2, body2),
preserve)
preserve + n.name)
}
}
}
Expand Down

0 comments on commit 7ad1001

Please sign in to comment.