Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove explicitness information #171

Closed
wants to merge 13 commits into from
12 changes: 6 additions & 6 deletions src/main/scala/apps/cameraPipelineRewrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ object cameraPipelineRewrite {
// idx i >> f -> map f >> idx i
def idxAfterF: Strategy[Rise] = {
case expr @ App(f, App(App(p.idx(), i), in)) =>
Success(p.idx(i)(p.map(f)(in)) !: expr.t)
Success(p.idx(i)(p.map(f)(in)) !: expr)
case _ => Failure(idxAfterF)
}

Expand Down Expand Up @@ -383,18 +383,18 @@ object cameraPipelineRewrite {

def letHoist: Strategy[Rise] = {
case expr @ App(f, App(App(p.let(), v), Lambda(x, b))) =>
Success(letf(lambda(eraseType(x), preserveType(f)(b)))(v) !: expr.t)
Success(letf(lambda(eraseType(x), preserveType(f)(b)))(v) !: expr)
// TODO: normal form / non-map specific?
case expr @ App(App(p.map(), Lambda(y,
App(App(p.let(), v), Lambda(x, b))
)), in) if !contains[Rise](y).apply(v) =>
Success(letf(lambda(eraseType(x), p.map(lambda(eraseType(y), b))(in)))(v) !: expr.t)
Success(letf(lambda(eraseType(x), p.map(lambda(eraseType(y), b))(in)))(v) !: expr)
case expr @ App(p.map(), Lambda(y,
App(App(p.let(), v), Lambda(x, b))
)) if !contains[Rise](y).apply(v) =>
Success(fun(in =>
letf(lambda(eraseType(x), p.map(lambda(eraseType(y), b))(in)))(v)
) !: expr.t)
) !: expr)
case _ => Failure(letHoist)
}

Expand All @@ -404,7 +404,7 @@ object cameraPipelineRewrite {
argument(argument({
case expr @ App(Lambda(x, color_correct), matrix) =>
Success(letf(lambda(toBeTyped(x), color_correct))(
p.mapSeq(p.mapSeq(fun(x => x)))(matrix)) !: expr.t)
p.mapSeq(p.mapSeq(fun(x => x)))(matrix)) !: expr)
case _ => Failure(precomputeColorCorrectionMatrix)
})) `;`
normalize.apply(gentleBetaReduction() <+ letHoist)
Expand All @@ -422,7 +422,7 @@ object cameraPipelineRewrite {
argument(function(isEqualTo(p.generate.primitive))) `;`
argument({ curve =>
Success(letf(fun(x => x))(
p.mapSeq(fun(x => x))(curve)) !: curve.t)
p.mapSeq(fun(x => x))(curve)) !: curve)
})
)
))) `;`
Expand Down
1 change: 0 additions & 1 deletion src/main/scala/rise/core/DSL/ToBeTyped.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ final case class ToBeTyped[+T <: Expr](private val e: T) {
case Opaque(x, t) => expr(x)
case tl@TopLevel(x, t) => expr(x)
case TypeAnnotation(e, t) => expr(e)
case TypeAssertion(e, t) => expr(e)
case p => super.`expr`(p.setType(TypePlaceholder))
}
}.expr(e).unwrap
Expand Down
34 changes: 14 additions & 20 deletions src/main/scala/rise/core/DSL/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,28 @@ object Type {
// type level lambdas
object n2dtFun {
def apply(f: NatIdentifier => DataType): NatToDataLambda = {
val x = NatIdentifier(freshName("n"), isExplicit = true)
val x = NatIdentifier(freshName("n"))
NatToDataLambda(x, f(x))
}

def apply(
r: arithexpr.arithmetic.Range
)(f: NatIdentifier => DataType): NatToDataLambda = {
val x = NatIdentifier(freshName("n"), r, isExplicit = true)
def apply(r: arithexpr.arithmetic.Range)(f: NatIdentifier => DataType): NatToDataLambda = {
val x = NatIdentifier(freshName("n"), r)
NatToDataLambda(x, f(x))
}

def apply(
upperBound: Nat
)(f: NatIdentifier => DataType): NatToDataLambda = {
def apply(upperBound: Nat)(f: NatIdentifier => DataType): NatToDataLambda = {
apply(RangeAdd(0, upperBound, 1))(f)
}
}

object n2nFun {
def apply(f: NatIdentifier => Nat): NatToNatLambda = {
val x = NatIdentifier(freshName("n2n"), isExplicit = true)
val x = NatIdentifier(freshName("n2n"))
NatToNatLambda(x, f(x))
}

def apply(
r: arithexpr.arithmetic.Range
)(f: NatIdentifier => Nat): NatToNatLambda = {
val x = NatIdentifier(freshName("n2n"), r, isExplicit = true)
def apply(r: arithexpr.arithmetic.Range)(f: NatIdentifier => Nat): NatToNatLambda = {
val x = NatIdentifier(freshName("n2n"), r)
NatToNatLambda(x, f(x))
}

Expand Down Expand Up @@ -93,27 +87,27 @@ object Type {

object expl {
def apply(w: NatFunctionWrapper[Type]): Type = {
val x = NatIdentifier(freshName("n"), isExplicit = true)
val x = NatIdentifier(freshName("n"))
DepFunType[NatKind, Type](x, w.f(x))
}

def apply(w: DataTypeFunctionWrapper[Type]): Type = {
val x = DataTypeIdentifier(freshName("dt"), isExplicit = true)
val x = DataTypeIdentifier(freshName("dt"))
DepFunType[DataKind, Type](x, w.f(x))
}

def apply(w: NatToDataFunctionWrapper[Type]): Type = {
val x = NatToDataIdentifier(freshName("n2d"), isExplicit = true)
val x = NatToDataIdentifier(freshName("n2d"))
DepFunType[NatToDataKind, Type](x, w.f(x))
}

def apply(w: NatToNatFunctionWrapper[Type]): Type = {
val x = NatToNatIdentifier(freshName("n2n"), isExplicit = true)
val x = NatToNatIdentifier(freshName("n2n"))
DepFunType[NatToNatKind, Type](x, w.f(x))
}

def apply(w: AddressSpaceFunctionWrapper[Type]): Type = {
val x = AddressSpaceIdentifier(freshName("a"), isExplicit = true)
val x = AddressSpaceIdentifier(freshName("a"))
DepFunType[AddressSpaceKind, Type](x, w.f(x))
}
}
Expand Down Expand Up @@ -159,14 +153,14 @@ object Type {
// dependent pairs
object Nat {
def `**`(f: Nat => DataType): Type = {
val x = NatIdentifier(freshName("n"), isExplicit = true)
val x = NatIdentifier(freshName("n"))
DepPairType[NatKind](x, f(x))
}
}

object NatCollection {
def `**`(f: NatCollection => DataType): Type = {
val x = NatCollectionIdentifier(freshName("ns"), isExplicit = true)
val x = NatCollectionIdentifier(freshName("ns"))
DepPairType[NatCollectionKind](x, f(x))
}
}
Expand Down
76 changes: 13 additions & 63 deletions src/main/scala/rise/core/DSL/infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,11 @@ object infer {
traverse(e, Traversal(Set()))._1
}

def preservingWithEnv(e: Expr, env: Map[String, Type], preserve: Set[Kind.Identifier]): Expr = {
val (typed_e, constraints) = constrainTypes(env)(e)
val solution = Constraint.solve(constraints, preserve, Seq())(
Flags.ExplicitDependence.Off)
solution(typed_e)
}

// TODO: Get rid of TypeAssertion and deprecate, instead evaluate !: in place and use `preserving` directly
private [DSL] def apply(e: Expr,
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off,
explDep: Flags.ExplicitDependence = Flags.ExplicitDependence.Off): Expr = {
// Collect FTVs in assertions and opaques; transform assertions into annotations
val (preserve, e_wo_assertions) = traverse(e, collectPreserve)
infer.preserving(e_wo_assertions, preserve, printFlag, explDep)
}

private [DSL] def preserving(wo_assertions: Expr, preserve : Set[Kind.Identifier],
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off,
explDep: Flags.ExplicitDependence = Flags.ExplicitDependence.Off): Expr = {
def apply(e: Expr, env: Map[String, Type] = Map(), preserve : Set[String] = Set(),
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off,
explDep: Flags.ExplicitDependence = Flags.ExplicitDependence.Off): Expr = {
// Collect constraints
val (typed_e, constraints) = constrainTypes(Map())(wo_assertions)
val (typed_e, constraints) = constrainTypes(env)(e)
// Solve constraints while preserving the FTVs in preserve
val solution = Constraint.solve(constraints, preserve, Seq())(explDep)
// Apply the solution
Expand Down Expand Up @@ -83,42 +67,14 @@ object infer {
}
}

val FTVGathering = new PureAccumulatorTraversal[Seq[Kind.Identifier]] {
override val accumulator = SeqMonoid
override def typeIdentifier[I <: Kind.Identifier]: VarType => I => Pair[I] = _ => {
case i: Kind.Explicitness => accumulate(if (!i.isExplicit) Seq(i) else Seq())(i.asInstanceOf[I])
case i => accumulate(Seq(i))(i)
}
override def nat: Nat => Pair[Nat] = ae => {
val ftvs = mutable.ListBuffer[Kind.Identifier]()
val r = ae.visitAndRebuild({
case i: NatIdentifier if !i.isExplicit => ftvs += i; i
case n => n
})
accumulate(ftvs.toSeq)(r)
}
}

def getFTVs(t: Type): Seq[Kind.Identifier] = {
traverse(t, FTVGathering)._1.distinct
}

def getFTVsRec(e: Expr): Seq[Kind.Identifier] = {
traverse(e, FTVGathering)._1.distinct
// TODO: remove, use IsClosedForm.freeVars directly
def getFTVs(t: Type): Set[Kind.Identifier] = {
IsClosedForm.freeVars(t)._2
}

private val collectPreserve = new PureAccumulatorTraversal[Set[Kind.Identifier]] {
override val accumulator = SetMonoid
override def expr: Expr => Pair[Expr] = {
// Transform assertions into annotations, collect FTVs
case TypeAssertion(e, t) =>
val (s, e1) = expr(e).unwrap
accumulate(s ++ getFTVs(t))(TypeAnnotation(e1, t) : Expr)
// Collect FTVs
case Opaque(e, t) =>
accumulate(getFTVs(t).toSet)(Opaque(e, t) : Expr)
case e => super.expr(e)
}
// TODO: remove, use IsClosedForm.freeVars directly
def getFTVsRec(e: Expr): Set[Kind.Identifier] = {
IsClosedForm.freeVars(e)._2
}

private val genType : Expr => Type =
Expand Down Expand Up @@ -175,11 +131,6 @@ object infer {
val c = TypeConstraint(te.t, t)
(te, csE :+ c)

case TypeAssertion(e, t) =>
val (te, csE) = constrainTypes(env)(e)
val c = TypeConstraint(te.t, t)
(te, csE :+ c)

case o: Opaque => (o, Nil)
case l: Literal => (l, Nil)
case p: Primitive => (p.setType(p.typeScheme), Nil)
Expand All @@ -200,8 +151,7 @@ object infer {
}

object inferDependent {
def apply(e: ToBeTyped[Expr],
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off): Expr = infer(e match {
case ToBeTyped(e) => e
}, printFlag, Flags.ExplicitDependence.On)
def apply(e: ToBeTyped[Expr], env : Map[String, Type] = Map(), preserve : Set[String] = Set(),
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off): Expr =
infer(e match { case ToBeTyped(e) => e }, env, preserve, printFlag, Flags.ExplicitDependence.On)
}
62 changes: 24 additions & 38 deletions src/main/scala/rise/core/DSL/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import rise.core.primitives._
import rise.core.semantics._
import rise.core.traverse._
import rise.core.types._
import rise.elevate.rewrite

import scala.language.implicitConversions

Expand All @@ -19,10 +20,7 @@ package object DSL {
x >>= (x => e >>= (e => toBeTyped(Lambda(x, e)(TypePlaceholder))))
def app(f: ToBeTyped[Expr], e: ToBeTyped[Expr]): ToBeTyped[App] =
f >>= (f => e >>= (e => toBeTyped(App(f, e)(TypePlaceholder))))
def depLambda[K <: Kind: KindName](
x: K#I with Kind.Explicitness,
e: ToBeTyped[Expr]
): ToBeTyped[DepLambda[K]] =
def depLambda[K <: Kind: KindName](x: K#I, e: ToBeTyped[Expr]): ToBeTyped[DepLambda[K]] =
e >>= (e => toBeTyped(DepLambda[K](x, e)(TypePlaceholder)))
def depApp[K <: Kind](f: ToBeTyped[Expr], x: K#T): ToBeTyped[DepApp[K]] =
f >>= (f => toBeTyped(DepApp[K](f, x)(TypePlaceholder)))
Expand Down Expand Up @@ -79,9 +77,8 @@ package object DSL {
def `@`(i: ToBeTyped[Expr]): ToBeTyped[App] = idx(i)(e)
}

implicit class TypeAssertionHelper(t: Type) {
def !:[T <: Expr](e: ToBeTyped[T]): ToBeTyped[Expr] =
e >>= (e => toBeTyped(TypeAssertion(e, t)))
implicit class TypeAssertionHelper(lhs: Expr) {
def !:[T <: Expr](e: ToBeTyped[T]): ToBeTyped[Expr] = rewrite(lhs)(e)
}

implicit class TypeAnnotationHelper(t: Type) {
Expand Down Expand Up @@ -403,71 +400,60 @@ package object DSL {
}

object depFun {
def apply(r: arithexpr.arithmetic.Range,
w: NatFunction1Wrapper[ToBeTyped[Expr]]
): ToBeTyped[DepLambda[NatKind]] = {
val n = NatIdentifier(freshName("n"), r, isExplicit = true)
def apply(r: arithexpr.arithmetic.Range, w: NatFunction1Wrapper[ToBeTyped[Expr]]): ToBeTyped[DepLambda[NatKind]] = {
val n = NatIdentifier(freshName("n"), r)
depLambda[NatKind](n, w.f(n))
}

def apply(w: NatFunction1Wrapper[ToBeTyped[Expr]]
): ToBeTyped[DepLambda[NatKind]] = {
def apply(w: NatFunction1Wrapper[ToBeTyped[Expr]]): ToBeTyped[DepLambda[NatKind]] = {
val r = arithexpr.arithmetic.RangeAdd(0, arithexpr.arithmetic.PosInf, 1)
val n = NatIdentifier(freshName("n"), r, isExplicit = true)
val n = NatIdentifier(freshName("n"), r)
depLambda[NatKind](n, w.f(n))
}

def apply(w: NatFunction2Wrapper[ToBeTyped[Expr]]
): ToBeTyped[DepLambda[NatKind]] = {
def apply(w: NatFunction2Wrapper[ToBeTyped[Expr]]): ToBeTyped[DepLambda[NatKind]] = {
val r = arithexpr.arithmetic.RangeAdd(0, arithexpr.arithmetic.PosInf, 1)
val n1 = NatIdentifier(freshName("n"), r, isExplicit = true)
val n1 = NatIdentifier(freshName("n"), r)
depLambda[NatKind](n1, depFun((n2: Nat) => w.f(n1, n2)))
}

def apply(w: NatFunction3Wrapper[ToBeTyped[Expr]]
): ToBeTyped[DepLambda[NatKind]] = {
def apply(w: NatFunction3Wrapper[ToBeTyped[Expr]]): ToBeTyped[DepLambda[NatKind]] = {
val r = arithexpr.arithmetic.RangeAdd(0, arithexpr.arithmetic.PosInf, 1)
val n1 = NatIdentifier(freshName("n"), r, isExplicit = true)
val n1 = NatIdentifier(freshName("n"), r)
depLambda[NatKind](n1, depFun((n2: Nat, n3: Nat) => w.f(n1, n2, n3)))
}

def apply(w: NatFunction4Wrapper[ToBeTyped[Expr]]
): ToBeTyped[DepLambda[NatKind]] = {
def apply(w: NatFunction4Wrapper[ToBeTyped[Expr]]): ToBeTyped[DepLambda[NatKind]] = {
val r = arithexpr.arithmetic.RangeAdd(0, arithexpr.arithmetic.PosInf, 1)
val n1 = NatIdentifier(freshName("n"), r, isExplicit = true)
val n1 = NatIdentifier(freshName("n"), r)
depLambda[NatKind](n1, depFun((n2: Nat, n3: Nat, n4: Nat) =>
w.f(n1, n2, n3, n4)))
}

def apply(w: NatFunction5Wrapper[ToBeTyped[Expr]]
): ToBeTyped[DepLambda[NatKind]] = {
def apply(w: NatFunction5Wrapper[ToBeTyped[Expr]]): ToBeTyped[DepLambda[NatKind]] = {
val r = arithexpr.arithmetic.RangeAdd(0, arithexpr.arithmetic.PosInf, 1)
val n1 = NatIdentifier(freshName("n"), r, isExplicit = true)
val n1 = NatIdentifier(freshName("n"), r)
depLambda[NatKind](n1, depFun((n2: Nat, n3: Nat, n4: Nat, n5: Nat) =>
w.f(n1, n2, n3, n4, n5)))
}

def apply(w: DataTypeFunctionWrapper[ToBeTyped[Expr]]
): ToBeTyped[DepLambda[DataKind]] = {
val x = DataTypeIdentifier(freshName("dt"), isExplicit = true)
def apply(w: DataTypeFunctionWrapper[ToBeTyped[Expr]]): ToBeTyped[DepLambda[DataKind]] = {
val x = DataTypeIdentifier(freshName("dt"))
depLambda[DataKind](x, w.f(x))
}

def apply(w: NatToDataFunctionWrapper[ToBeTyped[Expr]]
): ToBeTyped[DepLambda[NatToDataKind]] = {
val x = NatToDataIdentifier(freshName("n2d"), isExplicit = true)
def apply(w: NatToDataFunctionWrapper[ToBeTyped[Expr]]): ToBeTyped[DepLambda[NatToDataKind]] = {
val x = NatToDataIdentifier(freshName("n2d"))
depLambda[NatToDataKind](x, w.f(x))
}

def apply(w: NatToNatFunctionWrapper[ToBeTyped[Expr]]
): ToBeTyped[DepLambda[NatToNatKind]] = {
val x = NatToNatIdentifier(freshName("n2n"), isExplicit = true)
def apply(w: NatToNatFunctionWrapper[ToBeTyped[Expr]]): ToBeTyped[DepLambda[NatToNatKind]] = {
val x = NatToNatIdentifier(freshName("n2n"))
depLambda[NatToNatKind](x, w.f(x))
}

def apply(w: AddressSpaceFunctionWrapper[ToBeTyped[Expr]]
): ToBeTyped[DepLambda[AddressSpaceKind]] = {
val x = AddressSpaceIdentifier(freshName("a"), isExplicit = true)
def apply(w: AddressSpaceFunctionWrapper[ToBeTyped[Expr]]): ToBeTyped[DepLambda[AddressSpaceKind]] = {
val x = AddressSpaceIdentifier(freshName("a"))
depLambda[AddressSpaceKind](x, w.f(x))
}
}
Expand Down
Loading