From 288eb1e1d6e2229399e9f66729c4da0a6f9bbba4 Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Fri, 28 May 2021 15:33:23 +0100 Subject: [PATCH 1/3] Expose OrderedSet in freeVars --- src/main/scala/rise/core/IsClosedForm.scala | 24 ++++++++++---------- src/main/scala/rise/eqsat/NamedRewrite.scala | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/main/scala/rise/core/IsClosedForm.scala b/src/main/scala/rise/core/IsClosedForm.scala index ceb9f09a3..462d2f6cb 100644 --- a/src/main/scala/rise/core/IsClosedForm.scala +++ b/src/main/scala/rise/core/IsClosedForm.scala @@ -6,15 +6,15 @@ import rise.core.traverse._ import rise.core.types._ object IsClosedForm { - case class OrderedSet[T](ordered : Seq[T], unique : Set[T]) + case class OrderedSet[T](seq : Seq[T], set : Set[T]) object OrderedSet { - def one[T] : T => OrderedSet[T] = t => OrderedSet(Seq(t), Set(t)) - def add[T] : T => OrderedSet[T] => OrderedSet[T] = t => ts => - if (ts.unique.contains(t)) ts else OrderedSet(t +: ts.ordered, ts.unique + t) def empty[T] : OrderedSet[T] = OrderedSet(Seq(), Set()) + def add[T] : T => OrderedSet[T] => OrderedSet[T] = t => ts => + if (ts.set.contains(t)) ts else OrderedSet(t +: ts.seq, ts.set + t) + def one[T] : T => OrderedSet[T] = add(_)(empty) def append[T] : OrderedSet[T] => OrderedSet[T] => OrderedSet[T] = x => y => { - val ordered = x.ordered.filter(!y.unique.contains(_)) ++ y.ordered - val unique = x.unique ++ y.unique + val ordered = x.seq.filter(!y.set.contains(_)) ++ y.seq + val unique = x.set ++ y.set OrderedSet(ordered, unique) } } @@ -90,14 +90,14 @@ object IsClosedForm { } } - def freeVars(expr: Expr): (Seq[Identifier], Seq[Kind.Identifier]) = { + def freeVars(expr: Expr): (OrderedSet[Identifier], OrderedSet[Kind.Identifier]) = { val ((fV, fT), _) = traverse(expr, Visitor(Set(), Set())) - (fV.ordered, fT.ordered) + (fV, fT) } - def freeVars(t: Type): Seq[Kind.Identifier] = { + def freeVars(t: Type): OrderedSet[Kind.Identifier] = { val ((_, ftv), _) = traverse(t, Visitor(Set(), Set())) - ftv.ordered + ftv } // Exclude matrix layout and fragment kind identifiers, since they cannot currently be bound @@ -109,10 +109,10 @@ object IsClosedForm { def varsToClose(expr : Expr): (Seq[Identifier], Seq[Kind.Identifier]) = { val (fV, fT) = freeVars(expr) - (fV, needsClosing(fT)) + (fV.seq, needsClosing(fT.seq)) } - def varsToClose(t : Type): Seq[Kind.Identifier] = needsClosing(freeVars(t)) + def varsToClose(t : Type): Seq[Kind.Identifier] = needsClosing(freeVars(t).seq) def apply(expr: Expr): Boolean = { val (freeV, freeT) = varsToClose(expr) diff --git a/src/main/scala/rise/eqsat/NamedRewrite.scala b/src/main/scala/rise/eqsat/NamedRewrite.scala index 4ecc3e636..804cff60e 100644 --- a/src/main/scala/rise/eqsat/NamedRewrite.scala +++ b/src/main/scala/rise/eqsat/NamedRewrite.scala @@ -19,7 +19,7 @@ object NamedRewrite { } val typedLhs = infer(lhs, untypedFreeV, Set()) val freeV = infer.collectFreeEnv(typedLhs) - val freeT = rise.core.IsClosedForm.freeVars(typedLhs)._2.toSet + val freeT = rise.core.IsClosedForm.freeVars(typedLhs)._2.set val typedRhs = infer(rc.TypeAnnotation(rhs, typedLhs.t), freeV, freeT) trait PatVarStatus From 264abd57f63f866d3805e9d41aa559419bf880ec Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Fri, 28 May 2021 15:36:06 +0100 Subject: [PATCH 2/3] Do not skip matrix layout and fragment kind identifiers in collectPreserve --- src/main/scala/rise/core/DSL/infer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/scala/rise/core/DSL/infer.scala b/src/main/scala/rise/core/DSL/infer.scala index 76f5e702e..bcce60d96 100644 --- a/src/main/scala/rise/core/DSL/infer.scala +++ b/src/main/scala/rise/core/DSL/infer.scala @@ -83,10 +83,10 @@ object infer { // Transform assertions into annotations, collect FTVs case TypeAssertion(e, t) => val (s1, e1) = expr(e).unwrap - accumulate(s1 ++ IsClosedForm.varsToClose(t))(TypeAnnotation(e1, t) : Expr) + accumulate(s1 ++ IsClosedForm.freeVars(t).set)(TypeAnnotation(e1, t) : Expr) // Collect FTVs case Opaque(e, t) => - accumulate(IsClosedForm.varsToClose(t).toSet)(Opaque(e, t) : Expr) + accumulate(IsClosedForm.freeVars(t).set)(Opaque(e, t) : Expr) case e => super.expr(e) } } From ddc21fa04271040d4281a417eaceb84e7c10033f Mon Sep 17 00:00:00 2001 From: Michel Steuwer Date: Sat, 15 May 2021 13:29:58 +0100 Subject: [PATCH 3/3] Removed type projection from Kinds --- .../scala/meta/generator/DPIAPrimitives.scala | 40 ++++--- .../scala/meta/generator/RisePrimitives.scala | 14 +-- .../scala/apps/cameraPipelineRewrite.scala | 2 +- src/main/scala/rise/core/Builder.scala | 12 +- src/main/scala/rise/core/DSL/Type.scala | 24 ++-- src/main/scala/rise/core/DSL/infer.scala | 19 +--- src/main/scala/rise/core/DSL/package.scala | 73 ++++++------ src/main/scala/rise/core/Expr.scala | 17 ++- src/main/scala/rise/core/IsClosedForm.scala | 6 +- src/main/scala/rise/core/dotPrinter.scala | 8 +- src/main/scala/rise/core/equality.scala | 86 +++++++------- src/main/scala/rise/core/lifting.scala | 42 ++----- src/main/scala/rise/core/makeClosed.scala | 10 +- src/main/scala/rise/core/package.scala | 4 +- .../core/primitives/foreignFunction.scala | 2 +- src/main/scala/rise/core/showRise.scala | 8 +- src/main/scala/rise/core/showScala.scala | 16 +-- src/main/scala/rise/core/substitute.scala | 42 ++++--- src/main/scala/rise/core/traverse.scala | 69 ++++------- src/main/scala/rise/core/typedLifting.scala | 42 ++----- .../scala/rise/core/types/Constraints.scala | 28 ++--- src/main/scala/rise/core/types/Kinds.scala | 71 +++--------- src/main/scala/rise/core/types/Solution.scala | 28 ++--- src/main/scala/rise/core/types/Type.scala | 29 ++--- src/main/scala/rise/core/types/check.scala | 18 +-- src/main/scala/rise/core/types/package.scala | 8 +- src/main/scala/rise/core/uniqueNames.scala | 24 ++-- .../rise/elevate/rules/algorithmic.scala | 30 ++--- .../scala/rise/elevate/rules/lowering.scala | 10 +- .../scala/rise/elevate/rules/movement.scala | 62 +++++----- .../scala/rise/elevate/rules/package.scala | 8 +- .../scala/rise/elevate/rules/traversal.scala | 34 +----- .../scala/rise/elevate/rules/vectorize.scala | 30 ++--- .../rise/elevate/strategies/lowering.scala | 8 +- .../rise/elevate/strategies/predicate.scala | 2 +- src/main/scala/rise/eqsat/Expr.scala | 20 ++-- src/main/scala/rise/eqsat/NamedRewrite.scala | 24 ++-- src/main/scala/rise/eqsat/TypeNode.scala | 12 +- src/main/scala/rise/openCL/DSL.scala | 8 +- .../shine/C/Compilation/CodeGenerator.scala | 16 +-- .../C/Compilation/TranslationContext.scala | 2 +- .../Compilation/AcceptorTranslation.scala | 22 ++-- .../Compilation/ContinuationTranslation.scala | 12 +- .../DPIA/Compilation/FedeTranslation.scala | 10 +- .../scala/shine/DPIA/Compilation/FunDef.scala | 4 +- .../DPIA/Compilation/Passes/UnrollLoops.scala | 2 +- .../DPIA/Compilation/StreamTranslation.scala | 14 +-- src/main/scala/shine/DPIA/DSL/Core.scala | 16 ++- .../shine/DPIA/DSL/ImperativePrimitives.scala | 2 +- src/main/scala/shine/DPIA/DSL/package.scala | 4 +- .../shine/DPIA/InferAccessAnnotation.scala | 80 ++++++------- src/main/scala/shine/DPIA/Lifting.scala | 26 ++--- .../scala/shine/DPIA/Phrases/Phrase.scala | 29 +++-- .../DPIA/Phrases/PrettyPhrasePrinter.scala | 4 +- .../shine/DPIA/Phrases/VisitAndRebuild.scala | 54 ++++----- src/main/scala/shine/DPIA/Types/Kind.scala | 101 ++++------------- .../scala/shine/DPIA/Types/PhraseType.scala | 47 ++++---- .../scala/shine/DPIA/Types/TypeCheck.scala | 8 +- src/main/scala/shine/DPIA/Types/package.scala | 24 ++-- src/main/scala/shine/DPIA/fromRise.scala | 107 +++++++++--------- src/main/scala/shine/DPIA/package.scala | 41 ++++--- .../primitives/functional/DepMapSeq.scala | 6 +- .../DPIA/primitives/functional/Iterate.scala | 4 +- .../DPIA/primitives/imperative/ForNat.scala | 6 +- .../AdjustArraySizesForAllocations.scala | 4 +- .../Compilation/KernelCodeGenerator.scala | 2 +- .../Passes/FlagPrivateArrayLoops.scala | 10 +- .../Passes/InsertMemoryBarriers.scala | 12 +- .../SeparateHostAndKernelCode.scala | 16 +-- src/main/scala/shine/OpenCL/DSL/package.scala | 4 +- .../OpenCL/primitives/functional/DepMap.scala | 6 +- .../primitives/functional/Iterate.scala | 4 +- .../primitives/imperative/ParForNat.scala | 6 +- .../scala/shine/OpenMP/CodeGenerator.scala | 4 +- .../primitives/functional/DepMapPar.scala | 4 +- .../primitives/imperative/ParForNat.scala | 4 +- src/test/scala/apps/asum.scala | 2 +- src/test/scala/apps/dot.scala | 2 +- src/test/scala/apps/gemvCheck.scala | 4 +- .../separableConvolution2DNaiveEqsat.scala | 10 +- src/test/scala/rise/core/showRise.scala | 16 +-- src/test/scala/rise/elevate/algorithmic.scala | 16 +-- src/test/scala/rise/elevate/tiling.scala | 2 +- src/test/scala/rise/elevate/traversals.scala | 6 +- .../scala/rise/elevate/util/package.scala | 2 +- .../scala/shine/DPIA/InferAccessTypes.scala | 2 +- src/test/scala/shine/cuda/MMTest.scala | 8 +- src/test/scala/shine/cuda/basic.scala | 5 +- 88 files changed, 781 insertions(+), 1001 deletions(-) diff --git a/meta/src/main/scala/meta/generator/DPIAPrimitives.scala b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala index 8c77dcafb..83327e594 100644 --- a/meta/src/main/scala/meta/generator/DPIAPrimitives.scala +++ b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala @@ -140,23 +140,22 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param case DPIA.Type.AST.CommType => t"CommType" case DPIA.Type.AST.PairType(lhs, rhs) => t"PhrasePairType[${generatePhraseType(lhs)}, ${generatePhraseType(rhs)}]" case DPIA.Type.AST.FunType(inT, outT) => t"FunType[${generatePhraseType(inT)}, ${generatePhraseType(outT)}]" - case DPIA.Type.AST.DepFunType(id, kind, t) => t"DepFunType[${generateKindType(kind)}, ${generatePhraseType(t)}]" + case DPIA.Type.AST.DepFunType(id, kind, t) => t"DepFunType[${generateKindIdentifierType(kind)}, ${generatePhraseType(t)}]" case DPIA.Type.AST.Identifier(name) => Type.Name(name) case DPIA.Type.AST.VariadicType(_, _) => throw new Exception("Can not generate Phrase Type for Variadic Type") } - // generate Scala type for representing the DPIA/rise kinds themselves - def generateKindType(kindAST: DPIA.Kind.AST): scala.meta.Type = kindAST match { + def generateKindIdentifierType(kindAST: DPIA.Kind.AST): scala.meta.Type = kindAST match { case DPIA.Kind.AST.RiseKind(riseKind) => riseKind match { - case rise.Kind.AST.Data => Type.Name("DataKind") - case rise.Kind.AST.Address => Type.Name("AddressSpaceKind") - case rise.Kind.AST.Nat2Nat => Type.Name("NatToNatKind") - case rise.Kind.AST.Nat2Data => Type.Name("NatToDataKind") - case rise.Kind.AST.Nat => Type.Name("NatKind") - case rise.Kind.AST.Fragment => throw new Exception("Can not generate Kind for Fragment") - case rise.Kind.AST.MatrixLayout => throw new Exception("Can not generate Kind for Matrix Layout") - } - case DPIA.Kind.AST.Access => Type.Name("AccessKind") + case rise.Kind.AST.Data => Type.Name("DataTypeIdentifier") + case rise.Kind.AST.Address => Type.Name("AddressSpaceIdentifier") + case rise.Kind.AST.Nat2Nat => Type.Name("NatToNatIdentifier") + case rise.Kind.AST.Nat2Data => Type.Name("NatToDataIdentifier") + case rise.Kind.AST.Nat => Type.Name("NatIdentifier") + case rise.Kind.AST.Fragment => throw new Exception("Can not generate Kind for Fragment") + case rise.Kind.AST.MatrixLayout => throw new Exception("Can not generate Kind for Matrix Layout") + } + case DPIA.Kind.AST.Access => Type.Name("AccessTypeIdentifier") case DPIA.Kind.AST.VariadicKind(_, _) => throw new Exception("Can not generate Kind for Variadic Kind") } @@ -287,7 +286,7 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param case DPIA.Type.AST.FunType(inT, outT) => q"FunType(${generateTerm(inT)}, ${generateTerm(outT)})" case DPIA.Type.AST.DepFunType(id, kind, t) => - q"DepFunType[${generateKindType(kind)}, PhraseType](${Term.Name(id.name)}, ${generateTerm(t)})" + q"DepFunType(${generateKindType(kind)}, ${Term.Name(id.name)}, ${generateTerm(t)})" case DPIA.Type.AST.VariadicType(_, _) => throw new Exception("Can not generate Term for Variadic Type") } @@ -297,6 +296,21 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param case DPIA.Type.Access.AST.Write =>Term.Name("write") } + // generate Scala type for representing the DPIA/rise kinds themselves + def generateKindType(kindAST: DPIA.Kind.AST): scala.meta.Term = kindAST match { + case DPIA.Kind.AST.RiseKind(riseKind) => riseKind match { + case rise.Kind.AST.Data => Term.Name("DataKind") + case rise.Kind.AST.Address => Term.Name("AddressSpaceKind") + case rise.Kind.AST.Nat2Nat => Term.Name("NatToNatKind") + case rise.Kind.AST.Nat2Data => Term.Name("NatToDataKind") + case rise.Kind.AST.Nat => Term.Name("NatKind") + case rise.Kind.AST.Fragment => throw new Exception("Can not generate Kind for Fragment") + case rise.Kind.AST.MatrixLayout => throw new Exception("Can not generate Kind for Matrix Layout") + } + case DPIA.Kind.AST.Access => Term.Name("AccessKind") + case DPIA.Kind.AST.VariadicKind(_, _) => throw new Exception("Can not generate Kind for Variadic Kind") + } + def generateVisitAndRebuild(name: scala.meta.Type.Name, paramLists: List[List[scala.meta.Term.Param]]): scala.meta.Defn.Def = { // little pattern matching helper that ignores if a type name is written with a package prefix diff --git a/meta/src/main/scala/meta/generator/RisePrimitives.scala b/meta/src/main/scala/meta/generator/RisePrimitives.scala index eeafb6f54..df1249d69 100644 --- a/meta/src/main/scala/meta/generator/RisePrimitives.scala +++ b/meta/src/main/scala/meta/generator/RisePrimitives.scala @@ -180,24 +180,24 @@ import arithexpr.arithmetic._ // val ids = Seq.fill(n)(DataTypeIdentifier(freshName("dt"), isExplicit = true)) // ids.foldRight(t){ case (id, t) => DepFunType[DataKind](id, t) } // to represent n-many dependent function types: (id0: kind) -> (id1: kind) -> ... -> t - val (createIds, typeName) = kind match { + val (createIds, kindName) = kind match { case AST.Data => - (q"""DataTypeIdentifier(freshName("dt"), isExplicit = true)""", Type.Name("DataKind")) + (q"""DataTypeIdentifier(freshName("dt"), isExplicit = true)""", Term.Name("DataKind")) case AST.Address => - (q"""AddressSpaceIdentifier(freshName("a"), isExplicit = true)""", Type.Name("AddressSpaceKind")) + (q"""AddressSpaceIdentifier(freshName("a"), isExplicit = true)""", Term.Name("AddressSpaceKind")) case AST.Nat2Nat => - (q"""NatToNatIdentifier(freshName("n2n"), isExplicit = true)""", Type.Name("NatToNatKind")) + (q"""NatToNatIdentifier(freshName("n2n"), isExplicit = true)""", Term.Name("NatToNatKind")) case AST.Nat2Data => - (q"""NatToDataIdentifier(freshName("n2d"), isExplicit = true)""", Type.Name("NatToDataKind")) + (q"""NatToDataIdentifier(freshName("n2d"), isExplicit = true)""", Term.Name("NatToDataKind")) case AST.Nat => - (q"""NatIdentifier(freshName("n"), isExplicit = true)""", Type.Name("NatKind")) + (q"""NatIdentifier(freshName("n"), isExplicit = true)""", Term.Name("NatKind")) case AST.Fragment => throw new Exception("No support for Fragment Kind yet") case AST.MatrixLayout => throw new Exception("No support for Matrix Layout Kind yet") } q"""{ val ${Pat.Var(Term.Name(ids.name))} = Seq.fill(${Term.Name(n.name)})($createIds) ${Term.Name(ids.name)}.foldRight(${generateTypeScheme(t)}: Type) { - case (id, t) => DepFunType[$typeName, Type](id, t) + case (id, t) => DepFunType($kindName, id, t) } }""" case _ => generateDataType(typeAST) diff --git a/src/main/scala/apps/cameraPipelineRewrite.scala b/src/main/scala/apps/cameraPipelineRewrite.scala index 1e6768aef..e27aad62a 100644 --- a/src/main/scala/apps/cameraPipelineRewrite.scala +++ b/src/main/scala/apps/cameraPipelineRewrite.scala @@ -29,7 +29,7 @@ object cameraPipelineRewrite { case class depFunction(s: Strategy[Rise]) extends Strategy[Rise] { def apply(e: Rise): RewriteResult[Rise] = e match { - case ap @ DepApp(f, x) => s(f).mapSuccess(DepApp(_, x)(ap.t)) + case ap @ DepApp(kind, f, x) => s(f).mapSuccess(DepApp(kind, _, x)(ap.t)) case _ => Failure(s) } override def toString: String = s"depFunction($s)" diff --git a/src/main/scala/rise/core/Builder.scala b/src/main/scala/rise/core/Builder.scala index 660baf88e..25ef2047c 100644 --- a/src/main/scala/rise/core/Builder.scala +++ b/src/main/scala/rise/core/Builder.scala @@ -7,12 +7,12 @@ trait Builder { throw new Exception("apply method must be overridden") def apply(e: DSL.ToBeTyped[Expr]): DSL.ToBeTyped[App] = DSL.app(DSL.toBeTyped(apply), e) - def apply(n: Nat): DSL.ToBeTyped[DepApp[NatKind]] = - DSL.depApp[NatKind](DSL.toBeTyped(apply), n) - def apply(dt: DataType): DSL.ToBeTyped[DepApp[DataKind]] = - DSL.depApp[DataKind](DSL.toBeTyped(apply), dt) - def apply(a: AddressSpace): DSL.ToBeTyped[DepApp[AddressSpaceKind]] = - DSL.depApp[AddressSpaceKind](DSL.toBeTyped(apply), a) + def apply(n: Nat): DSL.ToBeTyped[DepApp[Nat]] = + DSL.depApp(NatKind, DSL.toBeTyped(apply), n) + def apply(dt: DataType): DSL.ToBeTyped[DepApp[DataType]] = + DSL.depApp(DataKind, DSL.toBeTyped(apply), dt) + def apply(a: AddressSpace): DSL.ToBeTyped[DepApp[AddressSpace]] = + DSL.depApp(AddressSpaceKind, DSL.toBeTyped(apply), a) def unapply(arg: Expr): Boolean = throw new Exception("unapply method must be overridden") diff --git a/src/main/scala/rise/core/DSL/Type.scala b/src/main/scala/rise/core/DSL/Type.scala index eda277348..c21390792 100644 --- a/src/main/scala/rise/core/DSL/Type.scala +++ b/src/main/scala/rise/core/DSL/Type.scala @@ -94,27 +94,27 @@ object Type { object expl { def apply(w: NatFunctionWrapper[Type]): Type = { val x = NatIdentifier(freshName("n"), isExplicit = true) - DepFunType[NatKind, Type](x, w.f(x)) + DepFunType(NatKind, x, w.f(x)) } def apply(w: DataTypeFunctionWrapper[Type]): Type = { val x = DataTypeIdentifier(freshName("dt"), isExplicit = true) - DepFunType[DataKind, Type](x, w.f(x)) + DepFunType(DataKind, x, w.f(x)) } def apply(w: NatToDataFunctionWrapper[Type]): Type = { val x = NatToDataIdentifier(freshName("n2d"), isExplicit = true) - DepFunType[NatToDataKind, Type](x, w.f(x)) + DepFunType(NatToDataKind, x, w.f(x)) } def apply(w: NatToNatFunctionWrapper[Type]): Type = { val x = NatToNatIdentifier(freshName("n2n"), isExplicit = true) - DepFunType[NatToNatKind, Type](x, w.f(x)) + DepFunType(NatToNatKind, x, w.f(x)) } def apply(w: AddressSpaceFunctionWrapper[Type]): Type = { val x = AddressSpaceIdentifier(freshName("a"), isExplicit = true) - DepFunType[AddressSpaceKind, Type](x, w.f(x)) + DepFunType(AddressSpaceKind, x, w.f(x)) } } @@ -160,24 +160,24 @@ object Type { object Nat { def `**`(f: Nat => DataType): Type = { val x = NatIdentifier(freshName("n"), isExplicit = true) - DepPairType[NatKind](x, f(x)) + DepPairType(NatKind, x, f(x)) } } object NatCollection { def `**`(f: NatCollection => DataType): Type = { val x = NatCollectionIdentifier(freshName("ns"), isExplicit = true) - DepPairType[NatCollectionKind](x, f(x)) + DepPairType(NatCollectionKind, x, f(x)) } } object `:Nat **` { - def unapply(arg: DepPairType[NatKind]): Option[(NatIdentifier, DataType)] = + def unapply(arg: DepPairType[Nat, NatIdentifier]): Option[(NatIdentifier, DataType)] = Some(arg.x, arg.t) } object `:NatCollection **` { - def unapply(arg: DepPairType[NatCollectionKind]): Option[(NatCollectionIdentifier, DataType)] = + def unapply(arg: DepPairType[NatCollection, NatCollectionIdentifier]): Option[(NatCollectionIdentifier, DataType)] = Some(arg.x, arg.t) } @@ -195,7 +195,7 @@ object Type { } object `(Addr)->:` { - def unapply[K <: Kind, T <: Type](funType: DepFunType[K, T]): Option[(AddressSpaceIdentifier, T)] = { + def unapply[T, I <: Kind.Identifier, U <: Type](funType: DepFunType[T, I, U]): Option[(AddressSpaceIdentifier, U)] = { funType.x match { case a: AddressSpaceIdentifier => Some((a, funType.t)) case _ => throw new Exception("Expected AddressSpace DepFunType") @@ -204,7 +204,7 @@ object Type { } object `(Nat)->:` { - def unapply[K <: Kind, T <: Type](funType: DepFunType[K, T]): Option[(NatIdentifier, T)] = { + def unapply[T, I <: Kind.Identifier, U <: Type](funType: DepFunType[T, I, U]): Option[(NatIdentifier, U)] = { funType.x match { case n: NatIdentifier => Some((n, funType.t)) case _ => throw new Exception("Expected Nat DepFunType") @@ -213,7 +213,7 @@ object Type { } object `(NatToNat)->:` { - def unapply[K <: Kind, T <: Type](funType: DepFunType[K, T]): Option[(NatToNatIdentifier, T)] = { + def unapply[T, I <: Kind.Identifier, U <: Type](funType: DepFunType[T, I, U]): Option[(NatToNatIdentifier, U)] = { funType.x match { case n: NatToNatIdentifier => Some((n, funType.t)) case _ => throw new Exception("Expected NatToNat DepFunType") diff --git a/src/main/scala/rise/core/DSL/infer.scala b/src/main/scala/rise/core/DSL/infer.scala index bcce60d96..df77ebe08 100644 --- a/src/main/scala/rise/core/DSL/infer.scala +++ b/src/main/scala/rise/core/DSL/infer.scala @@ -118,26 +118,17 @@ object infer { val c = TypeConstraint(tf.t, FunType(te.t, exprT)) (App(tf, te)(exprT), csF ++ csE :+ c) - case expr@DepLambda(x, e) => + case expr@DepLambda(kind, x, e) => val (te, csE) = constrainTypes(exprEnv)(e) - val tf = x match { - case n: NatIdentifier => - DepLambda[NatKind](n, te)(DepFunType[NatKind, Type](n, te.t)) - case dt: DataTypeIdentifier => - DepLambda[DataKind](dt, te)(DepFunType[DataKind, Type](dt, te.t)) - case ad: AddressSpaceIdentifier => - DepLambda[AddressSpaceKind](ad, te)(DepFunType[AddressSpaceKind, Type](ad, te.t)) - case n2n: NatToNatIdentifier => - DepLambda[NatToNatKind](n2n, te)(DepFunType[NatToNatKind, Type](n2n, te.t)) - } + val tf = DepLambda(kind, x, te)(DepFunType(kind, x, te.t)) val csE1 = ifTyped(expr.t)(TypeConstraint(expr.t, tf.t)) (tf, csE ++ csE1) - case expr@DepApp(f, x) => + case expr@DepApp(kind, f, x) => val (tf, csF) = constrainTypes(exprEnv)(f) val exprT = genType(expr) - val c = DepConstraint(tf.t, x, exprT) - (DepApp(tf, x)(exprT), csF :+ c) + val c = DepConstraint(kind, tf.t, x, exprT) + (DepApp(kind, tf, x)(exprT), csF :+ c) case TypeAnnotation(e, t) => val (te, csE) = constrainTypes(exprEnv)(e) diff --git a/src/main/scala/rise/core/DSL/package.scala b/src/main/scala/rise/core/DSL/package.scala index b629b4c84..de272e884 100644 --- a/src/main/scala/rise/core/DSL/package.scala +++ b/src/main/scala/rise/core/DSL/package.scala @@ -19,13 +19,12 @@ package object DSL { x >>= (x => e >>= (e => toBeTyped(Lambda(x, e)(TypePlaceholder)))) def app(f: ToBeTyped[Expr], e: ToBeTyped[Expr]): ToBeTyped[App] = f >>= (f => e >>= (e => toBeTyped(App(f, e)(TypePlaceholder)))) - def depLambda[K <: Kind: KindName]( - x: K#I with Kind.Explicitness, - e: ToBeTyped[Expr] - ): ToBeTyped[DepLambda[K]] = - e >>= (e => toBeTyped(DepLambda[K](x, e)(TypePlaceholder))) - def depApp[K <: Kind](f: ToBeTyped[Expr], x: K#T): ToBeTyped[DepApp[K]] = - f >>= (f => toBeTyped(DepApp[K](f, x)(TypePlaceholder))) + def depLambda[T, I <: Kind.Identifier](kind: Kind[T, I], + x: I with Kind.Explicitness, + e: ToBeTyped[Expr]): ToBeTyped[DepLambda[T, I]] = + e >>= (e => toBeTyped(DepLambda(kind, x, e)(TypePlaceholder))) + def depApp[T](kind: Kind[T, _ <: Kind.Identifier], f: ToBeTyped[Expr], x: T): ToBeTyped[DepApp[T]] = + f >>= (f => toBeTyped(DepApp(kind, f, x)(TypePlaceholder))) def literal(d: semantics.Data): ToBeTyped[Literal] = toBeTyped(Literal(d)) def store(cont: ToBeTyped[Expr] => ToBeTyped[Expr]): ToBeTyped[Expr] = @@ -109,16 +108,16 @@ package object DSL { e5: ToBeTyped[Expr]): ToBeTyped[App] = f(e1)(e2)(e3)(e4)(e5) - def apply(n: Nat): ToBeTyped[DepApp[NatKind]] = - depApp[NatKind](f, n) - def apply(dt: DataType): ToBeTyped[DepApp[DataKind]] = - depApp[DataKind](f, dt) - def apply(a: AddressSpace): ToBeTyped[DepApp[AddressSpaceKind]] = - depApp[AddressSpaceKind](f, a) - def apply(n2n: NatToNat): ToBeTyped[DepApp[NatToNatKind]] = - depApp[NatToNatKind](f, n2n) - def apply(n2d: NatToData): ToBeTyped[DepApp[NatToDataKind]] = - depApp[NatToDataKind](f, n2d) + def apply(n: Nat): ToBeTyped[DepApp[Nat]] = + depApp(NatKind, f, n) + def apply(dt: DataType): ToBeTyped[DepApp[DataType]] = + depApp(DataKind, f, dt) + def apply(a: AddressSpace): ToBeTyped[DepApp[AddressSpace]] = + depApp(AddressSpaceKind, f, a) + def apply(n2n: NatToNat): ToBeTyped[DepApp[NatToNat]] = + depApp(NatToNatKind, f, n2n) + def apply(n2d: NatToData): ToBeTyped[DepApp[NatToData]] = + depApp(NatToDataKind, f, n2d) } implicit class FunPipe(e: ToBeTyped[Expr]) { @@ -405,70 +404,70 @@ package object DSL { object depFun { def apply(r: arithexpr.arithmetic.Range, w: NatFunction1Wrapper[ToBeTyped[Expr]] - ): ToBeTyped[DepLambda[NatKind]] = { + ): ToBeTyped[DepLambda[Nat, NatIdentifier]] = { val n = NatIdentifier(freshName("n"), r, isExplicit = true) - depLambda[NatKind](n, w.f(n)) + depLambda(NatKind, n, w.f(n)) } def apply(w: NatFunction1Wrapper[ToBeTyped[Expr]] - ): ToBeTyped[DepLambda[NatKind]] = { + ): ToBeTyped[DepLambda[Nat, NatIdentifier]] = { val r = arithexpr.arithmetic.RangeAdd(0, arithexpr.arithmetic.PosInf, 1) val n = NatIdentifier(freshName("n"), r, isExplicit = true) - depLambda[NatKind](n, w.f(n)) + depLambda(NatKind, n, w.f(n)) } def apply(w: NatFunction2Wrapper[ToBeTyped[Expr]] - ): ToBeTyped[DepLambda[NatKind]] = { + ): ToBeTyped[DepLambda[Nat, NatIdentifier]] = { val r = arithexpr.arithmetic.RangeAdd(0, arithexpr.arithmetic.PosInf, 1) val n1 = NatIdentifier(freshName("n"), r, isExplicit = true) - depLambda[NatKind](n1, depFun((n2: Nat) => w.f(n1, n2))) + depLambda(NatKind, n1, depFun((n2: Nat) => w.f(n1, n2))) } def apply(w: NatFunction3Wrapper[ToBeTyped[Expr]] - ): ToBeTyped[DepLambda[NatKind]] = { + ): ToBeTyped[DepLambda[Nat, NatIdentifier]] = { val r = arithexpr.arithmetic.RangeAdd(0, arithexpr.arithmetic.PosInf, 1) val n1 = NatIdentifier(freshName("n"), r, isExplicit = true) - depLambda[NatKind](n1, depFun((n2: Nat, n3: Nat) => w.f(n1, n2, n3))) + depLambda(NatKind, n1, depFun((n2: Nat, n3: Nat) => w.f(n1, n2, n3))) } def apply(w: NatFunction4Wrapper[ToBeTyped[Expr]] - ): ToBeTyped[DepLambda[NatKind]] = { + ): ToBeTyped[DepLambda[Nat, NatIdentifier]] = { val r = arithexpr.arithmetic.RangeAdd(0, arithexpr.arithmetic.PosInf, 1) val n1 = NatIdentifier(freshName("n"), r, isExplicit = true) - depLambda[NatKind](n1, depFun((n2: Nat, n3: Nat, n4: Nat) => + depLambda(NatKind, n1, depFun((n2: Nat, n3: Nat, n4: Nat) => w.f(n1, n2, n3, n4))) } def apply(w: NatFunction5Wrapper[ToBeTyped[Expr]] - ): ToBeTyped[DepLambda[NatKind]] = { + ): ToBeTyped[DepLambda[Nat, NatIdentifier]] = { val r = arithexpr.arithmetic.RangeAdd(0, arithexpr.arithmetic.PosInf, 1) val n1 = NatIdentifier(freshName("n"), r, isExplicit = true) - depLambda[NatKind](n1, depFun((n2: Nat, n3: Nat, n4: Nat, n5: Nat) => + depLambda(NatKind, n1, depFun((n2: Nat, n3: Nat, n4: Nat, n5: Nat) => w.f(n1, n2, n3, n4, n5))) } def apply(w: DataTypeFunctionWrapper[ToBeTyped[Expr]] - ): ToBeTyped[DepLambda[DataKind]] = { + ): ToBeTyped[DepLambda[DataType, DataTypeIdentifier]] = { val x = DataTypeIdentifier(freshName("dt"), isExplicit = true) - depLambda[DataKind](x, w.f(x)) + depLambda(DataKind, x, w.f(x)) } def apply(w: NatToDataFunctionWrapper[ToBeTyped[Expr]] - ): ToBeTyped[DepLambda[NatToDataKind]] = { + ): ToBeTyped[DepLambda[NatToData, NatToDataIdentifier]] = { val x = NatToDataIdentifier(freshName("n2d"), isExplicit = true) - depLambda[NatToDataKind](x, w.f(x)) + depLambda(NatToDataKind, x, w.f(x)) } def apply(w: NatToNatFunctionWrapper[ToBeTyped[Expr]] - ): ToBeTyped[DepLambda[NatToNatKind]] = { + ): ToBeTyped[DepLambda[NatToNat, NatToNatIdentifier]] = { val x = NatToNatIdentifier(freshName("n2n"), isExplicit = true) - depLambda[NatToNatKind](x, w.f(x)) + depLambda(NatToNatKind, x, w.f(x)) } def apply(w: AddressSpaceFunctionWrapper[ToBeTyped[Expr]] - ): ToBeTyped[DepLambda[AddressSpaceKind]] = { + ): ToBeTyped[DepLambda[AddressSpace, AddressSpaceIdentifier]] = { val x = AddressSpaceIdentifier(freshName("a"), isExplicit = true) - depLambda[AddressSpaceKind](x, w.f(x)) + depLambda(AddressSpaceKind, x, w.f(x)) } } diff --git a/src/main/scala/rise/core/Expr.scala b/src/main/scala/rise/core/Expr.scala index 951034084..2fd0183d7 100644 --- a/src/main/scala/rise/core/Expr.scala +++ b/src/main/scala/rise/core/Expr.scala @@ -30,19 +30,16 @@ final case class App(f: Expr, e: Expr)(override val t: Type) override def setType(t: Type): App = this.copy(f, e)(t) } -final case class DepLambda[K <: Kind: KindName]( - x: K#I with Kind.Explicitness, - e: Expr -)(override val t: Type) +final case class DepLambda[T, I <: Kind.Identifier](kind: Kind[T, I], + x: I with Kind.Explicitness, + e: Expr)(override val t: Type) extends Expr { - val kindName: String = implicitly[KindName[K]].get - override def setType(t: Type): DepLambda[K] = this.copy(x, e)(t) + val kindName: String = kind.name + override def setType(t: Type): DepLambda[T, I] = this.copy(kind, x, e)(t) } -final case class DepApp[K <: Kind](f: Expr, x: K#T)( - override val t: Type -) extends Expr { - override def setType(t: Type): DepApp[K] = this.copy(f, x)(t) +final case class DepApp[T](kind: Kind[T, _ <: Kind.Identifier], f: Expr, x: T)(override val t: Type) extends Expr { + override def setType(t: Type): DepApp[T] = this.copy(kind, f, x)(t) } final case class Literal(d: semantics.Data) extends Expr { diff --git a/src/main/scala/rise/core/IsClosedForm.scala b/src/main/scala/rise/core/IsClosedForm.scala index 462d2f6cb..fb8bf5ed5 100644 --- a/src/main/scala/rise/core/IsClosedForm.scala +++ b/src/main/scala/rise/core/IsClosedForm.scala @@ -61,7 +61,7 @@ object IsClosedForm { val fV = OrderedSet.append(OrderedSet.append(fVx)(fVe))(fVt) val fT = OrderedSet.append(OrderedSet.append(fTx)(fTe))(fTt) accumulate((fV, fT))(Lambda(x1, e1)(t1): Expr) - case DepLambda(x, b) => this.copy(boundT = boundT + x).expr(b) + case DepLambda(_, x, b) => this.copy(boundT = boundT + x).expr(b) case e => super.expr(e) } @@ -80,10 +80,10 @@ object IsClosedForm { } override def `type`[T <: Type]: T => Pair[T] = { - case d@DepFunType(x, t) => + case d@DepFunType(_, x, t) => for { p <- this.copy(boundT = boundT + x).`type`(t) } yield (p._1, d.asInstanceOf[T]) - case d@DepPairType(x, dt) => + case d@DepPairType(_, x, dt) => for { p <- this.copy(boundT = boundT + x).datatype(dt) } yield (p._1, d.asInstanceOf[T]) case t => super.`type`(t) diff --git a/src/main/scala/rise/core/dotPrinter.scala b/src/main/scala/rise/core/dotPrinter.scala index 3bb40a1e8..062b36126 100644 --- a/src/main/scala/rise/core/dotPrinter.scala +++ b/src/main/scala/rise/core/dotPrinter.scala @@ -114,7 +114,7 @@ case object dotPrinter { |${recurse(e, eID)} |$parent -> $eID ${edgeLabel("arg")};""".stripMargin - case DepLambda(x, e) if !inlineLambdaIdentifier => + case DepLambda(kind, x, e) if !inlineLambdaIdentifier => val id = getID(x) val expr = getID(e) s"""$parent ${attr(fillWhite + Label("Λ").bold.toString)} @@ -123,13 +123,13 @@ case object dotPrinter { |$id ${attr(fillWhite + Label(x.name).orange.toString)} |${recurse(e, expr)}""".stripMargin - case DepLambda(x, e) if inlineLambdaIdentifier => + case DepLambda(_, x, e) if inlineLambdaIdentifier => val expr = getID(e) s"""$parent ${attr(fillWhite + Label(s"Λ.${x.name}").toString)} |$parent -> $expr ${edgeLabel("body")}; |${recurse(e, expr)}""".stripMargin - case DepApp(f, e) if applyNodes => + case DepApp(_, f, e) if applyNodes => val fun = getID(f) val arg = getID(e) s""" @@ -139,7 +139,7 @@ case object dotPrinter { |$arg ${attr(fillWhite + Label(e.toString).toString)} |${recurse(f, fun)}""".stripMargin - case DepApp(f, e) if !applyNodes => + case DepApp(_, f, e) if !applyNodes => val eID = getID(e) s""" |${recurse(f, parent)} diff --git a/src/main/scala/rise/core/equality.scala b/src/main/scala/rise/core/equality.scala index f2aea4dd8..814d67b46 100644 --- a/src/main/scala/rise/core/equality.scala +++ b/src/main/scala/rise/core/equality.scala @@ -21,24 +21,24 @@ object equality { } trait TypeEq { - final type Eq[K <: Kind] = K#T => K#T => Boolean - def apply[K <: Kind] : Eq[K] = equiv[K](Env()) - def equiv[K <: Kind] : Env[Kind.Identifier] => Eq[K] - def hash[K <: Kind] : K#T => Int + final type Eq[T] = T => T => Boolean + def apply[T]: Eq[T] = equiv[T](Env()) + def equiv[T]: Env[Kind.Identifier] => Eq[T] + def hash[T]: T => Int } object typeErasure extends TypeEq { - override def hash[K <: Kind]: K#T => Int = _ => 0 - override def equiv[K <: Kind]: Env[Kind.Identifier] => Eq[K] = _ => _ => _ => true + override def hash[T]: T => Int = _ => 0 + override def equiv[T]: Env[Kind.Identifier] => Eq[T] = _ => _ => _ => true } object typePartialAlphaEq extends TypeEq { - override def hash[K <: Kind]: K#T => Int = _ => 0 - override def equiv[K <: Kind]: Env[Kind.Identifier] => Eq[K] = env => a => b => (a, b) match { + override def hash[T]: T => Int = _ => 0 + override def equiv[T]: Env[Kind.Identifier] => Eq[T] = env => a => b => (a, b) match { case (a : Type, b : Type) => (a, b) match { case (TypePlaceholder, _) => true case (_, TypePlaceholder) => true - case _ => typeAlphaEq.equiv[TypeKind](env)(a)(b) + case _ => typeAlphaEq.equiv[Type](env)(a)(b) } case _ => typeAlphaEq.equiv(env)(a)(b) } @@ -48,7 +48,7 @@ object equality { /** Alpha equivalence on types. * Kind equality is checked on dependent functions and pairs. */ - override def equiv[K <: Kind]: Env[Kind.Identifier] => Eq[K] = env => a => b => { + override def equiv[T]: Env[Kind.Identifier] => Eq[T] = env => a => b => { val and = PatternMatching.matchWithDefault(b, false) a match { case a : Nat => and {case b : Nat => equivNat(env)(a)(b)} @@ -56,7 +56,7 @@ object equality { case ia: Kind.Identifier => and { case ib: Kind.Identifier => env.check(ia, ib) } case a: AddressSpace => and { case b: AddressSpace => (a : AddressSpace) == (b : AddressSpace) } case NatToNatLambda(na, ba) => and { case NatToNatLambda(nb, bb) => equivNat(env.add(na, nb))(ba)(bb) } - case NatToDataLambda(na, ba) => and { case NatToDataLambda(nb, bb) => equiv[DataKind](env.add(na, nb))(ba)(bb) } + case NatToDataLambda(na, ba) => and { case NatToDataLambda(nb, bb) => equiv[DataType](env.add(na, nb))(ba)(bb) } case NatCollectionFromArray(a) => and { case NatCollectionFromArray(b) => a == b } // FIXME: should use exprEq } } @@ -76,7 +76,7 @@ object equality { // Base cases -> identifier lookup in nat expressions case IndexType(sa) => and { case IndexType(sb) => equivNat(env)(sa)(sb) } case DepArrayType(sa, da) => and { case DepArrayType(sb, db) => - equivNat(env)(sa)(sb) && equiv[NatToDataKind](env)(da)(db) } + equivNat(env)(sa)(sb) && equiv[NatToData](env)(da)(db) } // Should we move this into its own equality check? case NatToDataApply(fa, na) => and { case NatToDataApply(fb, nb) => @@ -98,11 +98,11 @@ object equality { } // Recursive cases -> binding tracking - case DepFunType(xa, ta) => and { case DepFunType(xb, tb) => - xa.getClass == xb.getClass && equivType(env.add(xa, xb))(ta)(tb) + case DepFunType(ka, xa, ta) => and { case DepFunType(kb, xb, tb) => + ka == kb && equivType(env.add(xa, xb))(ta)(tb) } - case DepPairType(xa, ta) => and { case DepPairType(xb, tb) => - xa.getClass == xb.getClass && equivType(env.add(xa, xb))(ta)(tb) + case DepPairType(ka, xa, ta) => and { case DepPairType(kb, xb, tb) => + ka == kb && equivType(env.add(xa, xb))(ta)(tb) } } } @@ -110,11 +110,11 @@ object equality { /** Alpha renaming respecting hash function on types. * All identifiers are considered equal and therefore ignored. */ - override def hash[K <: Kind]: K#T => Int = { + override def hash[T]: T => Int = { case t: Type => hashType(t) case _: Kind.Identifier => 7 case a: AddressSpace => a.hashCode() - case NatToNatLambda(na, ba) => hash[NatKind](ba) + case NatToNatLambda(na, ba) => hash[Nat](ba) case NatToDataLambda(na, ba) => hashType(ba) case NatCollectionFromArray(a) => 17 } @@ -123,17 +123,17 @@ object equality { case TypeIdentifier(_) => 7 case DataTypeIdentifier(_, _) => 11 case FunType(inT, outT) => 13 * hashType(inT) + 17 * hashType(outT) - case DepFunType(_, t) => 19 * hashType(t) + case DepFunType(_, _, t) => 19 * hashType(t) case st: ScalarType => 29 * st.hashCode() case NatType => 23 - case VectorType(size, elemType) => 31 * hash[NatKind](size) + 37 * hashType(elemType) - case IndexType(size) => 41 * hash[NatKind](size) + case VectorType(size, elemType) => 31 * hash[Nat](size) + 37 * hashType(elemType) + case IndexType(size) => 41 * hash[Nat](size) case PairType(dt1, dt2) => 43 * hashType(dt1) + 47 * hashType(dt2) - case DepPairType(x, t) => 53 * hashType(t) - case NatToDataApply(f, n) => 59 * hash[NatToDataKind](f) + 61 * hash[NatKind](n) - case ArrayType(size, elemType) => 67 * hash[NatKind](size) + 71 * hashType(elemType) - case DepArrayType(size, fdt) => 73 * hash[NatKind](size) + 79 * hash[NatToDataKind](fdt) - case FragmentType(r, c, d, dt, fk, ml) => 83 * hash[NatKind](r) + 89 * hash[NatKind](c) + 97 * hash[NatKind](d) + + case DepPairType(_, _, t) => 53 * hashType(t) + case NatToDataApply(f, n) => 59 * hash[NatToData](f) + 61 * hash[Nat](n) + case ArrayType(size, elemType) => 67 * hash[Nat](size) + 71 * hashType(elemType) + case DepArrayType(size, fdt) => 73 * hash[Nat](size) + 79 * hash[NatToData](fdt) + case FragmentType(r, c, d, dt, fk, ml) => 83 * hash[Nat](r) + 89 * hash[Nat](c) + 97 * hash[Nat](d) + 101 * hashType(dt) + 103*fk.hashCode() } } @@ -153,23 +153,23 @@ object equality { */ override val equiv: Env[Kind.Identifier] => Env[String] => Eq = typeEnv => exprEnv => a => b => { val and = PatternMatching.matchWithDefault(b, false) // Make the match exhaustive - typeEq.equiv[TypeKind](typeEnv)(a.t)(b.t) && (a match { + typeEq.equiv[Type](typeEnv)(a.t)(b.t) && (a match { case Identifier(na) => and { case Identifier(nb) => exprEnv.check(na, nb)} case Literal(da) => and { case Literal(db) => equivData(typeEnv)(da)(db) } case App(fa, ea) => and { case App(fb, eb) => equiv(typeEnv)(exprEnv)(fa)(fb) && equiv(typeEnv)(exprEnv)(ea)(eb) } - case DepApp(fa, xa) => and { case DepApp(fb, xb) => + case DepApp(_, fa, xa) => and { case DepApp(_, fb, xb) => typeEq.equiv(typeEnv)(xa)(xb) && equiv(typeEnv)(exprEnv)(fa)(fb)} case Lambda(xa, ta) => and { case Lambda(xb, tb) => - typeEq.equiv[TypeKind](typeEnv)(xa.t)(xb.t) && equiv(typeEnv)(exprEnv.add(xa.name, xb.name))(ta)(tb) } - case DepLambda(xa, ea) => and { case DepLambda(xb, eb) => - xa.getClass == xb.getClass && equiv(typeEnv.add(xa, xb))(exprEnv)(ea)(eb) } + typeEq.equiv[Type](typeEnv)(xa.t)(xb.t) && equiv(typeEnv)(exprEnv.add(xa.name, xb.name))(ta)(tb) } + case DepLambda(ka, xa, ea) => and { case DepLambda(kb, xb, eb) => + ka == kb && equiv(typeEnv.add(xa, xb))(exprEnv)(ea)(eb) } case Opaque(e1, t1) => and { case Opaque(e2, t2) => - equiv(typeEnv)(exprEnv)(e1)(e2) && typeEq.equiv[TypeKind](typeEnv)(t1)(t2) } + equiv(typeEnv)(exprEnv)(e1)(e2) && typeEq.equiv[Type](typeEnv)(t1)(t2) } case TypeAnnotation(e1, t1) => and { case TypeAnnotation(e2, t2) => - equiv(typeEnv)(exprEnv)(e1)(e2) && typeEq.equiv[TypeKind](typeEnv)(t1)(t2) } + equiv(typeEnv)(exprEnv)(e1)(e2) && typeEq.equiv[Type](typeEnv)(t1)(t2) } case TypeAssertion(e1, t1) => and { case TypeAssertion(e2, t2) => - equiv(typeEnv)(exprEnv)(e1)(e2) && typeEq.equiv[TypeKind](typeEnv)(t1)(t2) } + equiv(typeEnv)(exprEnv)(e1)(e2) && typeEq.equiv[Type](typeEnv)(t1)(t2) } // TODO: TopLevel case a: Primitive => and { case b: Primitive => a.primEq(b) } }) @@ -196,20 +196,20 @@ object equality { * All identifiers are considered equal and therefore ignored. */ override val hash: Expr => Int = { - case i: Identifier => 5 + typeEq.hash[TypeKind](i.t) - case Lambda(x, e) => 7 * hash(e) + typeEq.hash[TypeKind](x.t) + typeEq.hash[TypeKind](e.t) - case App(f, e) => 11 * hash(f) + 13 * hash(e) + typeEq.hash[TypeKind](f.t) + typeEq.hash[TypeKind](e.t) - case DepLambda(x, e) => 17 * hash(e) + typeEq.hash[TypeKind](e.t) - case DepApp(f, _) => 19 * hash(f) + typeEq.hash[TypeKind](f.t) + case i: Identifier => 5 + typeEq.hash[Type](i.t) + case Lambda(x, e) => 7 * hash(e) + typeEq.hash[Type](x.t) + typeEq.hash[Type](e.t) + case App(f, e) => 11 * hash(f) + 13 * hash(e) + typeEq.hash[Type](f.t) + typeEq.hash[Type](e.t) + case DepLambda(_, _, e) => 17 * hash(e) + typeEq.hash[Type](e.t) + case DepApp(_, f, _) => 19 * hash(f) + typeEq.hash[Type](f.t) case l@Literal(_: ScalarData | _: VectorData) => l.d.hashCode() case Literal(_: NatData) => 91 case Literal(_: IndexData) => 93 case Literal(_: ArrayData) => 95 case Literal(_: PairData) => 97 - case Opaque(e, t) => 101*hash(e) + 103*typeEq.hash[TypeKind](t) - case TypeAnnotation(e, t) => 107*hash(e) + 109*typeEq.hash[TypeKind](t) - case TypeAssertion(e, t) => 113*hash(e) + 127*typeEq.hash[TypeKind](t) - case p: Primitive => 131*p.name.hashCode() + 137*typeEq.hash[TypeKind](p.t) + case Opaque(e, t) => 101*hash(e) + 103*typeEq.hash[Type](t) + case TypeAnnotation(e, t) => 107*hash(e) + 109*typeEq.hash[Type](t) + case TypeAssertion(e, t) => 113*hash(e) + 127*typeEq.hash[Type](t) + case p: Primitive => 131*p.name.hashCode() + 137*typeEq.hash[Type](p.t) } } } \ No newline at end of file diff --git a/src/main/scala/rise/core/lifting.scala b/src/main/scala/rise/core/lifting.scala index 252ceac6e..76f853c57 100644 --- a/src/main/scala/rise/core/lifting.scala +++ b/src/main/scala/rise/core/lifting.scala @@ -37,50 +37,28 @@ object lifting { case Lambda(x, body) => Reducing((e: Expr) => substitute.exprInExpr(e, `for` = x, in = body)) case App(f, e) => chain(liftFunExpr(f).map(lf => lf(e))) - case DepApp(f, x) => - x match { - case t: DataType => - chain(liftDepFunExpr[DataKind](f).map(lf => lf(t))) - case n: Nat => chain(liftDepFunExpr[NatKind](f).map(lf => lf(n))) - case a: AddressSpace => - chain(liftDepFunExpr[AddressSpaceKind](f).map(lf => lf(a))) - case n2n: NatToNat => - chain(liftDepFunExpr[NatToNatKind](f).map(lf => lf(n2n))) - case n2d: NatToData => - chain(liftDepFunExpr[NatToDataKind](f).map(lf => lf(n2d))) - } + case DepApp(kind, f, x) => chain(liftDepFunExpr(kind, f).map(lf => lf(x))) case _ => chain(Expanding(p)) } } - def liftDepFunExpr[K <: Kind](p: Expr): Result[K#T => Expr] = { - def chain(r: Result[Expr]): Result[K#T => Expr] = - r.bind(liftDepFunExpr[K], f => Expanding((x: K#T) => depApp[K](f, x))) + def liftDepFunExpr[T](kind: Kind[T, _ <: Kind.Identifier], p: Expr): Result[T => Expr] = { + def chain(r: Result[Expr]): Result[T => Expr] = + r.bind(liftDepFunExpr[T](kind, _), f => Expanding((x: T) => depApp(kind, f, x))) p match { - case DepLambda(x, e) => - Reducing((a: K#T) => substitute.kindInExpr(a, `for` = x, in = e)) + case DepLambda(kind, x, e) => + Reducing((a: T) => substitute.kindInExpr(kind, a, `for` = x, in = e)) case App(f, e) => chain(liftFunExpr(f).map(lf => lf(e))) - case DepApp(f, x) => - x match { - case t: DataType => - chain(liftDepFunExpr[DataKind](f).map(lf => lf(t))) - case n: Nat => chain(liftDepFunExpr[NatKind](f).map(lf => lf(n))) - case a: AddressSpace => - chain(liftDepFunExpr[AddressSpaceKind](f).map(lf => lf(a))) - case n2n: NatToNat => - chain(liftDepFunExpr[NatToNatKind](f).map(lf => lf(n2n))) - case n2d: NatToData => - chain(liftDepFunExpr[NatToDataKind](f).map(lf => lf(n2d))) - } + case DepApp(kind, f, x) => chain(liftDepFunExpr(kind, f).map(lf => lf(x))) case _ => chain(Expanding(p)) } } - def liftDependentFunctionType[K <: Kind](ty: Type): K#T => Type = { + def liftDependentFunctionType[T, I <: Kind.Identifier](kind: Kind[T, I], ty: Type): T => Type = { ty match { - case DepFunType(x, t) => - (a: K#T) => substitute.kindInType(a, `for` = x, in = t) + case DepFunType(_, x, t) => + (a: T) => substitute.kindInType(kind, a, `for` = x, in = t) case _ => throw new Exception(s"did not expect $ty") } } diff --git a/src/main/scala/rise/core/makeClosed.scala b/src/main/scala/rise/core/makeClosed.scala index c5d20e199..2f6ab1ffb 100644 --- a/src/main/scala/rise/core/makeClosed.scala +++ b/src/main/scala/rise/core/makeClosed.scala @@ -15,23 +15,23 @@ object makeClosed { case (expr, (ts, ns, as, n2ds)) => ftv match { case i: TypeIdentifier => val dt = DataTypeIdentifier(freshName("dt"), isExplicit = true) - (DepLambda[DataKind](dt, expr)(DepFunType[DataKind, Type](dt, expr.t)), + (DepLambda(DataKind, dt, expr)(DepFunType(DataKind, dt, expr.t)), (ts ++ Map(i -> dt), ns, as , n2ds)) case i: DataTypeIdentifier => val dt = i.asExplicit - (DepLambda[DataKind](dt, expr)(DepFunType[DataKind, Type](dt, expr.t)), + (DepLambda(DataKind, dt, expr)(DepFunType(DataKind, dt, expr.t)), (ts ++ Map(i -> dt), ns, as , n2ds)) case i: NatIdentifier => val n = i.asExplicit - (DepLambda[NatKind](n, expr)(DepFunType[NatKind, Type](n, expr.t)), + (DepLambda(NatKind, n, expr)(DepFunType(NatKind, n, expr.t)), (ts, ns ++ Map(i -> n), as, n2ds)) case i: AddressSpaceIdentifier => val a = i.asExplicit - (DepLambda[AddressSpaceKind](a, expr)(DepFunType[AddressSpaceKind, Type](a, expr.t)), + (DepLambda(AddressSpaceKind, a, expr)(DepFunType(AddressSpaceKind, a, expr.t)), (ts, ns, as ++ Map(i -> a), n2ds)) case i: NatToDataIdentifier => val n2d = i.asExplicit - (DepLambda[NatToDataKind](n2d, expr)(DepFunType[NatToDataKind, Type](n2d, expr.t)), + (DepLambda(NatToDataKind, n2d, expr)(DepFunType(NatToDataKind, n2d, expr.t)), (ts, ns, as, n2ds ++ Map(i -> n2d))) case i => throw TypeException(s"${i.getClass} is not supported yet") } diff --git a/src/main/scala/rise/core/package.scala b/src/main/scala/rise/core/package.scala index 43cc436dc..3918ca254 100644 --- a/src/main/scala/rise/core/package.scala +++ b/src/main/scala/rise/core/package.scala @@ -23,7 +23,7 @@ package object core { s"Lambda(${toEvaluableString(x)}, ${toEvaluableString(e)})" case App(f, e) => s"Apply(${toEvaluableString(f)}, ${toEvaluableString(e)})" - case DepLambda(x, e) => + case DepLambda(_, x, e) => x match { case n: NatIdentifier => s"""DepLambda[NatKind](NatIdentifier("id$n"), @@ -31,7 +31,7 @@ package object core { case dt: DataTypeIdentifier => s"""DepLambda[DataKind]("id$dt", ${toEvaluableString(e)})""" } - case DepApp(f, x) => + case DepApp(_, f, x) => x match { case n: Nat => s"DepApply[NatKind](${toEvaluableString(f)}, $n)" case dt: DataType => diff --git a/src/main/scala/rise/core/primitives/foreignFunction.scala b/src/main/scala/rise/core/primitives/foreignFunction.scala index f6fe871b3..7dcad9145 100644 --- a/src/main/scala/rise/core/primitives/foreignFunction.scala +++ b/src/main/scala/rise/core/primitives/foreignFunction.scala @@ -27,7 +27,7 @@ object foreignFunction { lhsT ->: rhsT }) }: Type)({ case (id, t) => - DepFunType[DataKind, Type](id, t) + DepFunType(DataKind, id, t) }) } override def primEq(obj: rise.core.Primitive): Boolean = obj match { diff --git a/src/main/scala/rise/core/showRise.scala b/src/main/scala/rise/core/showRise.scala index ee87069dc..be9faac92 100644 --- a/src/main/scala/rise/core/showRise.scala +++ b/src/main/scala/rise/core/showRise.scala @@ -243,7 +243,7 @@ class ShowRiseCompact { (false, newSize, fr >@> (fd => er >@> (ed => fd :+> ed))) } - case dl @ DepLambda(x, e) => + case dl @ DepLambda(_, x, e) => val xs = s"${x.name}:${dl.kindName}" val (eInline, eSize, er) = drawAST(e) val newSize = eSize + 1 @@ -256,10 +256,10 @@ class ShowRiseCompact { (false, newSize, er >@> (ed => block(s"Λ$xs", ed))) } - case DepApp(f, x) => + case DepApp(kind, f, x) => val (fInline, fSize, fr) = f match { - case _: DepLambda[_] => drawAST(f, wrapped = true) - case _ => drawAST(f) + case _: DepLambda[_, _] => drawAST(f, wrapped = true) + case _ => drawAST(f) } val xs = x.toString val newSize = fSize + 1 diff --git a/src/main/scala/rise/core/showScala.scala b/src/main/scala/rise/core/showScala.scala index 5f4ac69d8..7eca459b4 100644 --- a/src/main/scala/rise/core/showScala.scala +++ b/src/main/scala/rise/core/showScala.scala @@ -3,7 +3,7 @@ package rise.core import rise.core.types._ object showScala { - private def kindIdent[K <: Kind](x: K#I): String = { + private def kindIdent[I <: Kind.Identifier](x: I): String = { x match { case n: NatIdentifier => s"""NatIdentifier("${n.name}", ${n.range}, ${n.isExplicit})""" @@ -48,7 +48,7 @@ object showScala { case TypeIdentifier(n) => s"""TypeIdentifier("$n")""" case FunType(inT, outT) => s"FunType(${`type`(inT)}, ${`type`(outT)})" - case DepFunType(x, t) => + case DepFunType(_, x, t) => s"DepFunType(${kindIdent(x)}, ${`type`(t)})" case DataTypeIdentifier(n, isE) => s"""DataTypeIdentifier("$n", $isE)""" case ArrayType(n, e) => @@ -74,12 +74,12 @@ object showScala { case Literal(d) => s"Literal(${data(d)})" case App(f, a) => s"App(${expr(f)}, ${expr(a)})(${`type`(e.t)})" case Lambda(x, b) => s"Lambda(${expr(x)}, ${expr(b)})(${`type`(e.t)})" - case DepApp(f, v: Nat) => - s"DepApp[NatKind](${expr(f)}, $v)(${`type`(e.t)})" - case DepApp(f, v: AddressSpace) => - s"DepApp[AddressSpaceKind](${expr(f)}, $v)(${`type`(e.t)})" - case DepApp(_, _) => ??? - case DepLambda(x, b) => s"DepLambda(${kindIdent(x)}, ${expr(b)})(${`type`(e.t)})" + case DepApp(NatKind, f, v: Nat) => + s"DepApp(NatKind, ${expr(f)}, $v)(${`type`(e.t)})" + case DepApp(AddressSpaceKind, f, v: AddressSpace) => + s"DepApp(AddressSpaceKind, ${expr(f)}, $v)(${`type`(e.t)})" + case DepApp(_, _, _) => ??? + case DepLambda(_, x, b) => s"DepLambda(${kindIdent(x)}, ${expr(b)})(${`type`(e.t)})" } } } diff --git a/src/main/scala/rise/core/substitute.scala b/src/main/scala/rise/core/substitute.scala index a7c3e5b8a..e942c82f8 100644 --- a/src/main/scala/rise/core/substitute.scala +++ b/src/main/scala/rise/core/substitute.scala @@ -9,18 +9,18 @@ object substitute { // substitute in Expr - def kindInExpr[K <: Kind](x: K#T, `for`: K#I, in: Expr): Expr = - (x, `for`) match { - case (dt: DataType, forDt: DataTypeIdentifier) => + def kindInExpr[T, I <: Kind.Identifier](kind: Kind[T, I], x: T, `for`: I, in: Expr): Expr = + (kind, x, `for`) match { + case (DataKind, dt: DataType, forDt: DataTypeIdentifier) => dataTypeInExpr(dt, forDt, in) - case (n: Nat, forN: NatIdentifier) => natInExpr(n, forN, in) - case (a: AddressSpace, forA: AddressSpaceIdentifier) => + case (NatKind, n: Nat, forN: NatIdentifier) => natInExpr(n, forN, in) + case (AddressSpaceKind, a: AddressSpace, forA: AddressSpaceIdentifier) => addressSpaceInExpr(a, forA, in) - case (n2n: NatToNat, forN2N: NatToNatIdentifier) => + case (NatToNatKind, n2n: NatToNat, forN2N: NatToNatIdentifier) => n2nInExpr(n2n, forN2N, in) - case (n2d: NatToData, forN2D: NatToDataIdentifier) => + case (NatToDataKind, n2d: NatToData, forN2D: NatToDataIdentifier) => n2dInExpr(n2d, forN2D, in) - case (_, _) => ??? + case (_, _, _) => ??? } def exprInExpr(expression : Expr, `for`: Expr, in: Expr): Expr = { @@ -53,8 +53,8 @@ object substitute { case i: Identifier => Set(i) case Lambda(x, e) => FV(e) - x case App(f, e) => FV(f) ++ FV(e) - case DepLambda(_, e) => FV(e) - case DepApp(f, _) => FV(f) + case DepLambda(_, _, e) => FV(e) + case DepApp(_, f, _) => FV(f) case Literal(_) => Set() case TypeAnnotation(e, _) => FV(e) case TypeAssertion(e, _) => FV(e) @@ -100,20 +100,16 @@ object substitute { // substitute in Type - def kindInType[K <: Kind, T <: Type](x: K#T, `for`: K#I, in: T): T = - (x, `for`) match { - case (dt: DataType, forDt: DataTypeIdentifier) => - typeInType(dt, forDt, in) - case (n: Nat, forN: NatIdentifier) => - natInType(n, forN, in) - case (a: AddressSpace, forA: AddressSpaceIdentifier) => - addressSpaceInType(a, forA, in) - case (n2n: NatToNat, forN2N: NatToNatIdentifier) => - n2nInType(n2n, forN2N, in) - case (n2d: NatToData, forN2D: NatToDataIdentifier) => - n2dInType(n2d, forN2D, in) - case (_, _) => ??? + def kindInType[T, I <: Kind.Identifier, U <: Type](kind: Kind[T, I], x: T, `for`: I, in: U): U = { + (kind, x, `for`) match { + case (DataKind, dt: DataType, forDt: DataTypeIdentifier) => typeInType(dt, forDt, in) + case (NatKind, n: Nat, forN: NatIdentifier) => natInType(n, forN, in) + case (AddressSpaceKind, a: AddressSpace, forA: AddressSpaceIdentifier) => addressSpaceInType(a, forA, in) + case (NatToNatKind, n2n: NatToNat, forN2N: NatToNatIdentifier) => n2nInType(n2n, forN2N, in) + case (NatToDataKind, n2d: NatToData, forN2D: NatToDataIdentifier) => n2dInType(n2d, forN2D, in) + case (_, _, _) => ??? } + } def typeInType[B <: Type](ty: Type, `for`: Type, in: B): B = { object Visitor extends PureTraversal { diff --git a/src/main/scala/rise/core/traverse.scala b/src/main/scala/rise/core/traverse.scala index 6f224bb60..9141ca277 100644 --- a/src/main/scala/rise/core/traverse.scala +++ b/src/main/scala/rise/core/traverse.scala @@ -65,9 +65,9 @@ object traverse { case PairType(p1, p2) => for {p11 <- `type`(p1); p21 <- `type`(p2)} yield PairType(p11, p21) - case pair@DepPairType(x, d) => + case DepPairType(kind, x, d) => for {x1 <- typeIdentifierDispatch(Binding)(x); d1 <- dataTypeDispatch(Reference)(d)} - yield DepPairType(x1, d1)(pair.kindName) + yield DepPairType(kind, x1, d1) case IndexType(n) => for {n1 <- natDispatch(Reference)(n)} yield IndexType(n1) @@ -129,23 +129,9 @@ object traverse { case FunType(a, b) => for {a1 <- `type`(a); b1 <- `type`(b)} yield FunType(a1, b1) - case DepFunType(x, t) => x match { - case n: NatIdentifier => - for { n1 <- typeIdentifierDispatch(Binding)(n); t1 <- `type`(t)} - yield DepFunType[NatKind, Type](n1, t1) - case dt: DataTypeIdentifier => - for { dt1 <- typeIdentifierDispatch(Binding)(dt); t1 <- `type`(t)} - yield DepFunType[DataKind, Type](dt1, t1) - case a: AddressSpaceIdentifier => - for { a1 <- typeIdentifierDispatch(Binding)(a); t1 <- `type`(t)} - yield DepFunType[AddressSpaceKind, Type](a1, t1) - case n2n: NatToNatIdentifier => - for { n2n1 <- typeIdentifierDispatch(Binding)(n2n); t1 <- `type`(t)} - yield DepFunType[NatToNatKind, Type](n2n1, t1) - case n2d: NatToDataIdentifier => - for { n2d1 <- typeIdentifierDispatch(Binding)(n2d); t1 <- `type`(t)} - yield DepFunType[NatToDataKind, Type](n2d1, t1) - } + case DepFunType(kind, x, t) => + for { n1 <- typeIdentifierDispatch(Binding)(x); t1 <- `type`(t)} + yield DepFunType(kind, n1, t1) }).asInstanceOf[M[T]] def expr : Expr => M[Expr] = { @@ -162,31 +148,26 @@ object traverse { e1 <- expr(e) t1 <- `type`(a.t) } yield App(f1, e1)(t1) - case dl@DepLambda(x,e) => x match { - case n: NatIdentifier => - for {n1 <- typeIdentifierDispatch(Binding)(n); e1 <- expr(e); t1 <- `type`(dl.t)} - yield DepLambda[NatKind](n1, e1)(t1) - case dt: DataTypeIdentifier => - for {dt1 <- typeIdentifierDispatch(Binding)(dt); e1 <- expr(e); t1 <- `type`(dl.t)} - yield DepLambda[DataKind](dt1, e1)(t1) - } - case da@DepApp(f, x) => x match { - case n: Nat => - for {f1 <- expr(f); n1 <- natDispatch(Reference)(n); t1 <- `type`(da.t)} - yield DepApp[NatKind](f1, n1)(t1) - case dt: DataType => - for {f1 <- expr(f); dt1 <- `type`(dt); t1 <- `type`(da.t)} - yield DepApp[DataKind](f1, dt1)(t1) - case a: AddressSpace => - for {f1 <- expr(f); a1 <- addressSpace(a); t1 <- `type`(da.t)} - yield DepApp[AddressSpaceKind](f1, a1)(t1) - case n2n: NatToNat => - for {f1 <- expr(f); n2n1 <- natToNat(n2n); t1 <- `type`(da.t)} - yield DepApp[NatToNatKind](f1, n2n1)(t1) - case n2d: NatToData => - for {f1 <- expr(f); n2d1 <- natToData(n2d); t1 <- `type`(da.t)} - yield DepApp[NatToDataKind](f1, n2d1)(t1) - } + case dl@DepLambda(kind, x,e) => + for {n1 <- typeIdentifierDispatch(Binding)(x); e1 <- expr(e); t1 <- `type`(dl.t)} + yield DepLambda(kind, n1, e1)(t1) + case da@DepApp(NatKind, f, x: Nat) => + for {f1 <- expr(f); n1 <- natDispatch(Reference)(x); t1 <- `type`(da.t)} + yield DepApp(NatKind, f1, n1)(t1) + case da@DepApp(DataKind, f, x: DataType) => + for {f1 <- expr(f); dt1 <- `type`(x); t1 <- `type`(da.t)} + yield DepApp(DataKind, f1, dt1)(t1) + case da@DepApp(AddressSpaceKind, f, x: AddressSpace) => + for {f1 <- expr(f); a1 <- addressSpace(x); t1 <- `type`(da.t)} + yield DepApp(AddressSpaceKind, f1, a1)(t1) + case da@DepApp(NatToNatKind, f, x: NatToNat) => + for {f1 <- expr(f); n2n1 <- natToNat(x); t1 <- `type`(da.t)} + yield DepApp(NatToNatKind, f1, n2n1)(t1) + case da@DepApp(NatToDataKind, f, x: NatToData) => + for {f1 <- expr(f); n2d1 <- natToData(x); t1 <- `type`(da.t)} + yield DepApp(NatToDataKind, f1, n2d1)(t1) + case DepApp(_, _, _) => + ??? case Literal(d) => for { d1 <- data(d) } yield Literal(d1) diff --git a/src/main/scala/rise/core/typedLifting.scala b/src/main/scala/rise/core/typedLifting.scala index 4bcd0ca80..ddc9b983d 100644 --- a/src/main/scala/rise/core/typedLifting.scala +++ b/src/main/scala/rise/core/typedLifting.scala @@ -22,51 +22,29 @@ object typedLifting { case Lambda(x, body) => Reducing((e: Expr) => substitute.exprInExpr(e, `for` = x, in = body)) case App(f, e) => chain(liftFunExpr(f).map(lf => lf(e))) - case DepApp(f, x) => - x match { - case t: DataType => - chain(liftDepFunExpr[DataKind](f).map(lf => lf(t))) - case n: Nat => chain(liftDepFunExpr[NatKind](f).map(lf => lf(n))) - case a: AddressSpace => - chain(liftDepFunExpr[AddressSpaceKind](f).map(lf => lf(a))) - case n2n: NatToNat => - chain(liftDepFunExpr[NatToNatKind](f).map(lf => lf(n2n))) - case n2d: NatToData => - chain(liftDepFunExpr[NatToDataKind](f).map(lf => lf(n2d))) - } + case DepApp(kind, f, x) => chain(liftDepFunExpr(kind, f).map(lf => lf(x))) case _ => chain(Expanding(p)) } } - def liftDepFunExpr[K <: Kind](p: Expr): Result[K#T => Expr] = { - def chain(r: Result[Expr]): Result[K#T => Expr] = + def liftDepFunExpr[T](kind: Kind[T, _ <: Kind.Identifier], p: Expr): Result[T => Expr] = { + def chain(r: Result[Expr]): Result[T => Expr] = r.bind( - liftDepFunExpr, + liftDepFunExpr(kind, _), f => - Expanding((x: K#T) => - DepApp(f, x)(f.t match { - case DepFunType(_, _) => lifting.liftDependentFunctionType(f.t)(x) + Expanding((x: T) => + DepApp(kind, f, x)(f.t match { + case DepFunType(_, _, _) => lifting.liftDependentFunctionType(kind, f.t)(x) case _ => throw TypeException(s"$f cannot be lifted") }) ) ) p match { - case DepLambda(x, e) => - Reducing((a: K#T) => substitute.kindInExpr(a, `for` = x, in = e)) + case DepLambda(kind, x, e) => + Reducing((a: T) => substitute.kindInExpr(kind, a, `for` = x, in = e)) case App(f, e) => chain(liftFunExpr(f).map(lf => lf(e))) - case DepApp(f, x) => - x match { - case t: DataType => - chain(liftDepFunExpr[DataKind](f).map(lf => lf(t))) - case n: Nat => chain(liftDepFunExpr[NatKind](f).map(lf => lf(n))) - case a: AddressSpace => - chain(liftDepFunExpr[AddressSpaceKind](f).map(lf => lf(a))) - case n2n: NatToNat => - chain(liftDepFunExpr[NatToNatKind](f).map(lf => lf(n2n))) - case n2d: NatToData => - chain(liftDepFunExpr[NatToDataKind](f).map(lf => lf(n2d))) - } + case DepApp(kind, f, x) => chain(liftDepFunExpr(kind, f).map(lf => lf(x))) case _ => chain(Expanding(p)) } } diff --git a/src/main/scala/rise/core/types/Constraints.scala b/src/main/scala/rise/core/types/Constraints.scala index bf8839f4c..0d1c5870a 100644 --- a/src/main/scala/rise/core/types/Constraints.scala +++ b/src/main/scala/rise/core/types/Constraints.scala @@ -34,7 +34,7 @@ case class NatToDataConstraint(a: NatToData, b: NatToData) extends Constraint { override def toString: String = s"$a ~ $b" } -case class DepConstraint[K <: Kind](df: Type, arg: K#T, t: Type) +case class DepConstraint[T](kind: Kind[T, _ <: Kind.Identifier], df: Type, arg: T, t: Type) extends Constraint { override def toString: String = s"$df ($arg) ~ $t" } @@ -107,8 +107,8 @@ object Constraint { case (FunType(ina, outa), FunType(inb, outb)) => decomposed(Seq(TypeConstraint(ina, inb), TypeConstraint(outa, outb))) case ( - DepFunType(na: NatIdentifier, ta), - DepFunType(nb: NatIdentifier, tb) + DepFunType(NatKind, na: NatIdentifier, ta), + DepFunType(NatKind, nb: NatIdentifier, tb) ) => val n = NatIdentifier(freshName("n"), isExplicit = true) decomposedPreserve(Seq( @@ -117,8 +117,8 @@ object Constraint { TypeConstraint(ta, tb), ), preserve + n - na - nb) case ( - DepFunType(dta: DataTypeIdentifier, ta), - DepFunType(dtb: DataTypeIdentifier, tb) + DepFunType(DataKind, dta: DataTypeIdentifier, ta), + DepFunType(DataKind, dtb: DataTypeIdentifier, tb) ) => val dt = DataTypeIdentifier(freshName("t"), isExplicit = true) decomposedPreserve(Seq( @@ -127,14 +127,14 @@ object Constraint { TypeConstraint(ta, tb), ), preserve + dt - dta - dtb) case ( - DepFunType(_: AddressSpaceIdentifier, _), - DepFunType(_: AddressSpaceIdentifier, _) + DepFunType(AddressSpaceKind, _: AddressSpaceIdentifier, _), + DepFunType(AddressSpaceKind, _: AddressSpaceIdentifier, _) ) => ??? case ( - DepPairType(x1: NatIdentifier, t1), - DepPairType(x2: NatIdentifier, t2) + DepPairType(NatKind, x1: NatIdentifier, t1), + DepPairType(NatKind, x2: NatIdentifier, t2) ) => val n = NatIdentifier(freshName("n"), isExplicit = true) decomposedPreserve(Seq( @@ -144,8 +144,8 @@ object Constraint { ), preserve + n - x1 - x2) case ( - DepPairType(x1: NatCollectionIdentifier, t1), - DepPairType(x2: NatCollectionIdentifier, t2) + DepPairType(NatCollectionKind, x1: NatCollectionIdentifier, t1), + DepPairType(NatCollectionKind, x2: NatCollectionIdentifier, t2) ) => val n = NatCollectionIdentifier(freshName("n"), isExplicit = true) decomposedPreserve(Seq( @@ -184,10 +184,10 @@ object Constraint { } - case DepConstraint(df, arg, t) => + case DepConstraint(kind, df, arg, t) => df match { - case _: DepFunType[_, _] => - val applied = liftDependentFunctionType(df)(arg) + case _: DepFunType[_, _, _] => + val applied = liftDependentFunctionType(kind, df)(arg) decomposed(Seq(TypeConstraint(applied, t))) case _ => error(s"expected a dependent function type, but got $df") diff --git a/src/main/scala/rise/core/types/Kinds.scala b/src/main/scala/rise/core/types/Kinds.scala index 81742e322..6df05637d 100644 --- a/src/main/scala/rise/core/types/Kinds.scala +++ b/src/main/scala/rise/core/types/Kinds.scala @@ -1,8 +1,7 @@ package rise.core.types -sealed trait Kind { - type T - type I <: Kind.Identifier with T +sealed trait Kind[+T, +I <: Kind.Identifier] { + def name :String } object Kind { @@ -16,68 +15,30 @@ object Kind { } } -sealed trait TypeKind extends Kind { - override type T = Type - override type I = TypeIdentifier +case object TypeKind extends Kind[Type, TypeIdentifier] { + override def name: String = "type" } -sealed trait DataKind extends Kind { - override type T = DataType - override type I = DataTypeIdentifier +case object DataKind extends Kind[DataType, DataTypeIdentifier] { + override def name: String = "data" } -sealed trait NatKind extends Kind { - override type T = Nat - override type I = NatIdentifier +case object NatKind extends Kind[Nat, NatIdentifier] { + override def name: String = "nat" } -sealed trait AddressSpaceKind extends Kind { - override type T = AddressSpace - override type I = AddressSpaceIdentifier +case object AddressSpaceKind extends Kind[AddressSpace, AddressSpaceIdentifier] { + override def name: String = "addressSpace" } -sealed trait NatToNatKind extends Kind { - override type T = NatToNat - override type I = NatToNatIdentifier +case object NatToNatKind extends Kind[NatToNat, NatToNatIdentifier] { + override def name: String = "nat->nat" } -sealed trait NatToDataKind extends Kind { - override type T = NatToData - override type I = NatToDataIdentifier +case object NatToDataKind extends Kind[NatToData, NatToDataIdentifier] { + override def name: String = "nat->data" } -sealed trait NatCollectionKind extends Kind { - override type T = NatCollection - override type I = NatCollectionIdentifier -} - -trait KindName[K <: Kind] { - def get: String -} - -object KindName { - implicit val typeKN: KindName[TypeKind] = new KindName[TypeKind] { - def get = "type" - } - implicit val dataKN: KindName[DataKind] = new KindName[DataKind] { - def get = "data" - } - implicit val natKN: KindName[NatKind] = new KindName[NatKind] { - def get = "nat" - } - implicit val addressSpaceKN: KindName[AddressSpaceKind] = - new KindName[AddressSpaceKind] { - def get = "addressSpace" - } - implicit val n2nKN: KindName[NatToNatKind] = new KindName[NatToNatKind] { - def get = "nat->nat" - } - implicit val n2dtKN: KindName[NatToDataKind] = new KindName[NatToDataKind] { - def get = "nat->data" - } - - implicit val natsKN: KindName[NatCollectionKind] = - new KindName[NatCollectionKind] { - override def get: String = "nats" - } +case object NatCollectionKind extends Kind[NatCollection, NatCollectionIdentifier] { + override def name: String = "nats" } diff --git a/src/main/scala/rise/core/types/Solution.scala b/src/main/scala/rise/core/types/Solution.scala index c00b5fd27..10ab2aa2b 100644 --- a/src/main/scala/rise/core/types/Solution.scala +++ b/src/main/scala/rise/core/types/Solution.scala @@ -141,19 +141,19 @@ case class Solution(ts: Map[Type, Type], case FragmentTypeConstraint(a, b) => FragmentTypeConstraint(apply(a), apply(b)) case NatToDataConstraint(a, b) => NatToDataConstraint(apply(a), apply(b)) case NatCollectionConstraint(a, b) => NatCollectionConstraint(apply(a), apply(b)) - case DepConstraint(df, arg: Nat, t) => DepConstraint[NatKind](apply(df), apply(arg), apply(t)) - case DepConstraint(df, arg: DataType, t) => - DepConstraint[DataKind](apply(df), apply(arg).asInstanceOf[DataType], apply(t)) - case DepConstraint(df, arg: Type, t) => - DepConstraint[TypeKind](apply(df), apply(arg), apply(t)) - case DepConstraint(df, arg: AddressSpace, t) => - DepConstraint[AddressSpaceKind](apply(df), apply(arg), apply(t)) - case DepConstraint(df, arg: NatToData, t) => - DepConstraint[NatToDataKind](apply(df), apply(arg), apply(t)) - case DepConstraint(df, arg: NatToNat, t) => - DepConstraint[NatToNatKind](apply(df), apply(arg), apply(t)) - case DepConstraint(df, arg: NatCollection, t) => - DepConstraint[NatCollectionKind](apply(df), apply(arg), apply(t)) - case DepConstraint(_, _, _) => throw new Exception("Impossible case") + case DepConstraint(NatKind, df, arg: Nat, t) => DepConstraint(NatKind, apply(df), apply(arg), apply(t)) + case DepConstraint(DataKind, df, arg: DataType, t) => + DepConstraint(DataKind, apply(df), apply(arg).asInstanceOf[DataType], apply(t)) + case DepConstraint(TypeKind, df, arg: Type, t) => + DepConstraint(TypeKind, apply(df), apply(arg), apply(t)) + case DepConstraint(AddressSpaceKind, df, arg: AddressSpace, t) => + DepConstraint(AddressSpaceKind, apply(df), apply(arg), apply(t)) + case DepConstraint(NatToDataKind, df, arg: NatToData, t) => + DepConstraint(NatToDataKind, apply(df), apply(arg), apply(t)) + case DepConstraint(NatToNatKind, df, arg: NatToNat, t) => + DepConstraint(NatToNatKind, apply(df), apply(arg), apply(t)) + case DepConstraint(NatCollectionKind, df, arg: NatCollection, t) => + DepConstraint(NatCollectionKind, apply(df), apply(arg), apply(t)) + case DepConstraint(_, _, _, _) => throw new Exception("Impossible case") } } diff --git a/src/main/scala/rise/core/types/Type.scala b/src/main/scala/rise/core/types/Type.scala index 1ada7f7ad..cb2414801 100644 --- a/src/main/scala/rise/core/types/Type.scala +++ b/src/main/scala/rise/core/types/Type.scala @@ -5,8 +5,8 @@ import rise.core._ import rise.core.equality._ sealed trait Type { - def =~=(b: Type): Boolean = typeAlphaEq[TypeKind](this)(b) - def =~~=(b: Type): Boolean = typePartialAlphaEq[TypeKind](this)(b) + def =~=(b: Type): Boolean = typeAlphaEq[Type](this)(b) + def =~~=(b: Type): Boolean = typePartialAlphaEq[Type](this)(b) } object TypePlaceholder extends Type { @@ -24,12 +24,11 @@ final case class FunType[T <: Type, U <: Type](inT: T, outT: U) override def toString: String = s"($inT -> $outT)" } -final case class DepFunType[K <: Kind: KindName, T <: Type]( - x: K#I with Kind.Explicitness, - t: T -) extends Type { - override def toString: String = - s"(${x.name}: ${implicitly[KindName[K]].get} -> $t)" +final case class DepFunType[T, I <: Kind.Identifier, U <: Type] + (kind: Kind[T, I], + x: I with Kind.Explicitness, + t: U) extends Type { + override def toString: String = s"(${x.name}: ${kind.name} -> $t)" } // == Data types ============================================================== @@ -153,18 +152,8 @@ final case class ManagedBufferType(dt: DataType) extends DataType { } -final case class DepPairType[K <: Kind: KindName]( - x: K#I, - t: DataType - ) extends DataType { - type Kind = K - - // Note(federico): for pattern-matching purposes, if we ever need to - // recover the kind name from a pattern-match over just DataType - val kindName: KindName[K] = implicitly[KindName[K]] - - override def toString: String = - s"(${x.name}: ${kindName.get} ** $t)" +final case class DepPairType[T, I <: Kind.Identifier](kind: Kind[T, I], x: I, t: DataType) extends DataType { + override def toString: String = s"(${x.name}: ${kind.name} ** $t)" } diff --git a/src/main/scala/rise/core/types/check.scala b/src/main/scala/rise/core/types/check.scala index 1cfafd300..20c6312d9 100644 --- a/src/main/scala/rise/core/types/check.scala +++ b/src/main/scala/rise/core/types/check.scala @@ -42,24 +42,18 @@ object check { // ----------- App expr `:` t2 - case DepLambda(x, e) => + case DepLambda(kind, x, e) => val t = ctx `|-` e // ----------- DepLambda - expr `:` (x match { - case n: NatIdentifier => DepFunType[NatKind, Type](n, t) - case dt: DataTypeIdentifier => DepFunType[DataKind, Type](dt, t) - case a: AddressSpaceIdentifier => DepFunType[AddressSpaceKind, Type](a, t) - case n2n: NatToNatIdentifier => DepFunType[NatToNatKind, Type](n2n, t) - case n2d: NatToDataIdentifier => DepFunType[NatToDataKind, Type](n2d, t) - }) - - case DepApp(f, e) => + expr `:` DepFunType(kind, x, t) + + case DepApp(kind, f, e) => val (x, t) = ctx `|-` f match { - case DepFunType(x, t) => (x, t) + case DepFunType(kind2, x, t) if kind == kind2 => (x, t) case t => throw TypeException(s"expected dependent function type and got $t") } // ----------- DepApp - expr `:` substitute.kindInType(e, `for`= x, in = t) + expr `:` substitute.kindInType(kind, e, `for`= x, in = t) case Literal(d) => // ----------- Literal diff --git a/src/main/scala/rise/core/types/package.scala b/src/main/scala/rise/core/types/package.scala index 6b360d938..50e98873c 100644 --- a/src/main/scala/rise/core/types/package.scala +++ b/src/main/scala/rise/core/types/package.scala @@ -8,14 +8,14 @@ package object types { type ->[T1 <: Type, T2 <: Type] = FunType[T1, T2] type `(dt)->`[T <: Type] = DataDepFunType[T] - type DataDepFunType[T <: Type] = DepFunType[DataKind, T] + type DataDepFunType[T <: Type] = DepFunType[DataType, DataTypeIdentifier, T] type `(nat)->`[T <: Type] = NatDepFunType[T] - type NatDepFunType[T <: Type] = DepFunType[NatKind, T] + type NatDepFunType[T <: Type] = DepFunType[Nat, NatIdentifier, T] type `(nat->nat)->`[T <: Type] = NatToNatDepFunType[T] - type NatToNatDepFunType[T <: Type] = DepFunType[NatToNatKind, T] + type NatToNatDepFunType[T <: Type] = DepFunType[NatToNat, NatToNatIdentifier, T] type `(nat->data)->`[T <: Type] = NatToDataDepFunType[T] - type NatToDataDepFunType[T <: Type] = DepFunType[NatToDataKind, T] + type NatToDataDepFunType[T <: Type] = DepFunType[NatToData, NatToDataIdentifier, T] } diff --git a/src/main/scala/rise/core/uniqueNames.scala b/src/main/scala/rise/core/uniqueNames.scala index b5c4a5a92..837b2a94c 100644 --- a/src/main/scala/rise/core/uniqueNames.scala +++ b/src/main/scala/rise/core/uniqueNames.scala @@ -79,26 +79,26 @@ object uniqueNames { t2 <- renameInTypes(l.t)(types) } yield Lambda(x2, b2)(t2) - case d@DepLambda(x: NatIdentifier, b) => + case d@DepLambda(NatKind, x: NatIdentifier, b) => val x2 = NatIdentifier(s"n$nextNatN", x.range, x.isExplicit) for { b2 <- renameInExpr(b)(values, types + (x -> x2)) t2 <- renameInTypes(d.t)(types + (x -> x2)) - } yield DepLambda[NatKind](x2, b2)(t2) + } yield DepLambda(NatKind, x2, b2)(t2) - case d@DepLambda(x: DataTypeIdentifier, b) => + case d@DepLambda(DataKind, x: DataTypeIdentifier, b) => val x2 = DataTypeIdentifier(s"dt$nextDtN", x.isExplicit) for { b2 <- renameInExpr(b)(values, types + (x -> x2)) t2 <- renameInTypes(d.t)(types) - } yield DepLambda[DataKind](x2, b2)(t2) + } yield DepLambda(DataKind, x2, b2)(t2) - case d@DepLambda(x: AddressSpaceIdentifier, b) => + case d@DepLambda(AddressSpaceKind, x: AddressSpaceIdentifier, b) => val x2 = AddressSpaceIdentifier(s"a$nextAN", x.isExplicit) for { b2 <- renameInExpr(b)(values, types + (x -> x2)) t2 <- renameInTypes(d.t)(types) - } yield DepLambda[AddressSpaceKind](x2, b2)(t2) + } yield DepLambda(AddressSpaceKind, x2, b2)(t2) case e => super.expr(e) } @@ -107,24 +107,24 @@ object uniqueNames { case i: DataTypeIdentifier => return_(types.getOrElse(i, i).asInstanceOf[T]) - case DepFunType(x: NatIdentifier, b) => + case DepFunType(NatKind, x: NatIdentifier, b) => val x2 = types.getOrElse(x, NatIdentifier(s"n$nextNatN", x.range, x.isExplicit)) .asInstanceOf[NatIdentifier with Kind.Explicitness] for { b2 <- renameInTypes(b)(types + (x -> x2)) } - yield DepFunType[NatKind, Type](x2, b2).asInstanceOf[T] + yield DepFunType(NatKind, x2, b2).asInstanceOf[T] - case DepFunType(x: DataTypeIdentifier, b) => + case DepFunType(DataKind, x: DataTypeIdentifier, b) => val x2 = types.getOrElse(x, DataTypeIdentifier(s"dt$nextDtN", x.isExplicit)).asInstanceOf[DataTypeIdentifier] for { b2 <- renameInTypes(b)(types + (x -> x2)) } - yield DepFunType[DataKind, Type](x2, b2).asInstanceOf[T] + yield DepFunType(DataKind, x2, b2).asInstanceOf[T] - case DepFunType(x: AddressSpaceIdentifier, b) => + case DepFunType(AddressSpaceKind, x: AddressSpaceIdentifier, b) => val x2 = types.getOrElse(x, AddressSpaceIdentifier(s"dt$nextAN", x.isExplicit)).asInstanceOf[AddressSpaceIdentifier] for { b2 <- renameInTypes(b)(types + (x -> x2)) } - yield DepFunType[AddressSpaceKind, Type](x2, b2).asInstanceOf[T] + yield DepFunType(AddressSpaceKind, x2, b2).asInstanceOf[T] case e => super.`type`(e) } diff --git a/src/main/scala/rise/elevate/rules/algorithmic.scala b/src/main/scala/rise/elevate/rules/algorithmic.scala index 24799b5b1..e6b13f64b 100644 --- a/src/main/scala/rise/elevate/rules/algorithmic.scala +++ b/src/main/scala/rise/elevate/rules/algorithmic.scala @@ -58,7 +58,7 @@ object algorithmic { // padEmpty n >> padEmpty m -> padEmpty n + m @rule def padEmptyFusion: Strategy[Rise] = { - case e @ App(DepApp(padEmpty(), m: Nat), App(DepApp(padEmpty(), n: Nat), in)) => + case e @ App(DepApp(NatKind, padEmpty(), m: Nat), App(DepApp(NatKind, padEmpty(), n: Nat), in)) => Success(padEmpty(n+m)(in) !: e.t) } @@ -135,7 +135,7 @@ object algorithmic { // constraint: n - m = u - v // v = u + m - n @rule def slideOverlap(u: Nat): Strategy[Rise] = { - case e @ DepApp(DepApp(slide(), n: Nat), m: Nat) => + case e @ DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), m: Nat) => val v = u + m - n Success((slide(u)(v) >> map(slide(n)(m)) >> join) !: e.t) } @@ -144,12 +144,12 @@ object algorithmic { // slide n 1 >> drop l -> slide (n+l) 1 >> map(drop l) @rule def dropInSlide: Strategy[Rise] = { - case e@App(DepApp(drop(), l: Nat), App(DepApp(DepApp(slide(), n: Nat), Cst(1)), in)) => + case e@App(DepApp(NatKind, drop(), l: Nat), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), Cst(1)), in)) => Success(app(map(drop(l)), app(slide(n + l)(1), preserveType(in))) !: e.t) } // slide n 1 >> take (N - r) -> slide (n+r) 1 >> map(take (n - r)) @rule def takeInSlide: Strategy[Rise] = { - case e@App(t@DepApp(take(), rem: Nat), App(DepApp(DepApp(slide(), n: Nat), Cst(1)), in)) => + case e@App(t@DepApp(NatKind, take(), rem: Nat), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), Cst(1)), in)) => t.t match { case FunType(ArrayType(size, _), _) => val r = size - rem @@ -159,18 +159,18 @@ object algorithmic { } @rule def dropNothing: Strategy[Rise] = { - case expr @ DepApp(drop(), Cst(0)) => Success(fun(x => x) !: expr.t) + case expr @ DepApp(NatKind, drop(), Cst(0)) => Success(fun(x => x) !: expr.t) } @rule def takeAll: Strategy[Rise] = { - case expr @ DepApp(take(), n: Nat) => expr.t match { + case expr @ DepApp(NatKind, take(), n: Nat) => expr.t match { case FunType(ArrayType(m, _), _) if n == m => Success(fun(x => x) !: expr.t) case _ => Failure(takeAll) } } @rule def padEmptyNothing: Strategy[Rise] = { - case e @ DepApp(padEmpty(), Cst(0)) => Success(fun(x => x) !: e.t) + case e @ DepApp(NatKind, padEmpty(), Cst(0)) => Success(fun(x => x) !: e.t) } @rule def mapIdentity: Strategy[Rise] = { @@ -198,7 +198,7 @@ object algorithmic { // J >> drop d -> drop (d / m) >> J >> drop (d % m) @rule def dropBeforeJoin: Strategy[Rise] = { - case e @ App(DepApp(drop(), d: Nat), App(join(), in)) => in.t match { + case e @ App(DepApp(NatKind, drop(), d: Nat), App(join(), in)) => in.t match { case ArrayType(_, ArrayType(m, _)) => Success(app(drop(d % m), join(drop(d / m)(in))) !: e.t) case _ => throw new Exception("this should not happen") @@ -209,7 +209,7 @@ object algorithmic { // -> dropLast (d / m) >> J >> dropLast (d % m) // -> take (n - d / m) >> J >> take ((n - d / m)*m - d % m) @rule def takeBeforeJoin: Strategy[Rise] = { - case e @ App(DepApp(take(), nmd: Nat), App(join(), in)) => in.t match { + case e @ App(DepApp(NatKind, take(), nmd: Nat), App(join(), in)) => in.t match { case ArrayType(n, ArrayType(m, _)) => val d = n*m - nmd val t1 = n - d / m @@ -221,7 +221,7 @@ object algorithmic { // take n >> padEmpty m -> padEmpty m' @rule def removeTakeBeforePadEmpty: Strategy[Rise] = { - case e @ App(DepApp(padEmpty(), m: Nat), App(DepApp(take(), n: Nat), in)) => + case e @ App(DepApp(NatKind, padEmpty(), m: Nat), App(DepApp(NatKind, take(), n: Nat), in)) => in.t match { case ArrayType(size, _) if ArithExpr.isSmaller(size - n, m + 1).contains(true) => @@ -370,15 +370,15 @@ object algorithmic { // zip (slide n m a) (slide n m b) -> map unzip (slide n m (zip a b)) @rule def slideOutsideZip: Strategy[Rise] = { case expr @ App(App(zip(), - App(DepApp(DepApp(slide(), n: Nat), m: Nat), a)), - App(DepApp(DepApp(slide(), n2: Nat), m2: Nat), b) + App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), m: Nat), a)), + App(DepApp(NatKind, DepApp(NatKind, slide(), n2: Nat), m2: Nat), b) ) if n == n2 && m == m2 => Success(map(unzip)(slide(n)(m)(zip(a)(b))) !: expr.t) } // slide n m (zip a b) -> map zip (zip (slide n m a) (slide n m b)) @rule def slideInsideZip: Strategy[Rise] = { - case expr @ App(DepApp(DepApp(slide(), n: Nat), m: Nat), + case expr @ App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), m: Nat), App(App(zip(), a), b) ) => Success(map(fun(p => zip(fst(p))(snd(p))))( @@ -441,8 +441,8 @@ object algorithmic { @rule def zipAsVectorUnzipSimplification: Strategy[Rise] = { case e @ App( Lambda(x, App(App(zip(), - App(DepApp(asVector(), v: Nat), App(fst(), x2))), - App(DepApp(asVector(), v2: Nat), App(snd(), x3)))), + App(DepApp(NatKind, asVector(), v: Nat), App(fst(), x2))), + App(DepApp(NatKind, asVector(), v2: Nat), App(snd(), x3)))), App(unzip(), in) ) if x =~= x2 && x =~= x3 && v == v2 => println(in.t) diff --git a/src/main/scala/rise/elevate/rules/lowering.scala b/src/main/scala/rise/elevate/rules/lowering.scala index 395a05851..24ddc027b 100644 --- a/src/main/scala/rise/elevate/rules/lowering.scala +++ b/src/main/scala/rise/elevate/rules/lowering.scala @@ -85,12 +85,12 @@ object lowering { // TODO: load identity instead, then change with other rules? @rule def circularBuffer(load: Expr): Strategy[Rise] = { - case e@DepApp(DepApp(slide(), sz: Nat), Cst(1)) => Success( + case e@DepApp(NatKind, DepApp(NatKind, slide(), sz: Nat), Cst(1)) => Success( p.circularBuffer(sz)(sz)(eraseType(load)) !: e.t) } @rule def rotateValues(write: Expr): Strategy[Rise] = { - case e@DepApp(DepApp(slide(), sz: Nat), Cst(1)) => Success( + case e@DepApp(NatKind, DepApp(NatKind, slide(), sz: Nat), Cst(1)) => Success( p.rotateValues(sz)(eraseType(write)) !: e.t) } @@ -319,7 +319,7 @@ object lowering { } @rule def circularBuffer(a: AddressSpace): Strategy[Rise] = { - case e@DepApp(DepApp(slide(), n: Nat), Cst(1)) => + case e@DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), Cst(1)) => Success( oclCircularBuffer(a)(n)(n)(fun(x => x)) !: e.t) @@ -327,14 +327,14 @@ object lowering { @rule def circularBufferLoadFusion: Strategy[Rise] = { case e@App(App( - cb @ DepApp(DepApp(DepApp(oclCircularBuffer(), _), _), _), + cb @ DepApp(NatKind, DepApp(NatKind, DepApp(AddressSpaceKind, oclCircularBuffer(), _), _), _), load), App(App(map(), f), in) ) => Success(eraseType(cb)(preserveType(f) >> load, in) !: e.t) } @rule def rotateValues(a: AddressSpace, write: Expr): Strategy[Rise] = { - case e@DepApp(DepApp(slide(), n: Nat), Cst(1)) => + case e@DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), Cst(1)) => Success( oclRotateValues(a)(n)(eraseType(write)) !: e.t) diff --git a/src/main/scala/rise/elevate/rules/movement.scala b/src/main/scala/rise/elevate/rules/movement.scala index 7aa4d475d..4034aefe0 100644 --- a/src/main/scala/rise/elevate/rules/movement.scala +++ b/src/main/scala/rise/elevate/rules/movement.scala @@ -57,9 +57,9 @@ object movement { // split/slide def isSplitOrSlide(s: Expr): Boolean = s match { - case DepApp(DepApp(slide(), _: Nat), _: Nat) => true - case DepApp(split(), _: Nat) => true - case _ => false + case DepApp(NatKind, DepApp(NatKind, slide(), _: Nat), _: Nat) => true + case DepApp(NatKind, split(), _: Nat) => true + case _ => false } def slideBeforeMapMapF: Strategy[Rise] = `S >> **f -> *f >> S` @@ -70,13 +70,13 @@ object movement { def slideBeforeMap: Strategy[Rise] = `*f >> S -> S >> **f` @rule def `*f >> S -> S >> **f`: Strategy[Rise] = { - case e@App(s @ DepApp(DepApp(slide(), _: Nat), _: Nat), App(App(map(), f), y)) => + case e@App(s @ DepApp(NatKind, DepApp(NatKind, slide(), _: Nat), _: Nat), App(App(map(), f), y)) => Success((preserveType(y) |> eraseType(s) |> map(map(f))) !: e.t) } // *f >> S -> S >> **f @rule def splitBeforeMap: Strategy[Rise] = { - case e@App(s @ DepApp(split(), _: Nat), App(App(map(), f), y)) => + case e@App(s @ DepApp(NatKind, split(), _: Nat), App(App(map(), f), y)) => Success((preserveType(y) |> eraseType(s) |> map(map(f))) !: e.t) } @@ -98,32 +98,32 @@ object movement { def dropBeforeMap: Strategy[Rise] = `*f >> drop n -> drop n >> *f` @rule def `*f >> drop n -> drop n >> *f`: Strategy[Rise] = { - case expr @ App(DepApp(drop(), n: Nat), App(App(map(), f), in)) => + case expr @ App(DepApp(NatKind, drop(), n: Nat), App(App(map(), f), in)) => Success(app(map(f), app(drop(n), preserveType(in))) !: expr.t) } def takeBeforeMap: Strategy[Rise] = `*f >> take n -> take n >> *f` @rule def `*f >> take n -> take n >> *f`: Strategy[Rise] = { - case expr @ App(DepApp(take(), n: Nat), App(App(map(), f), in)) => + case expr @ App(DepApp(NatKind, take(), n: Nat), App(App(map(), f), in)) => Success(app(map(f), app(take(n), preserveType(in))) !: expr.t) } // take n >> *f -> *f >> take n @rule def takeAfterMap: Strategy[Rise] = { - case e @ App(App(map(), f), App(DepApp(take(), n: Nat), in)) => + case e @ App(App(map(), f), App(DepApp(NatKind, take(), n: Nat), in)) => Success(take(n)(map(f)(in)) !: e.t) } def takeInZip: Strategy[Rise] = `take n (zip a b) -> zip (take n a) (take n b)` @rule def `take n (zip a b) -> zip (take n a) (take n b)`: Strategy[Rise] = { - case expr @ App(DepApp(take(), n), App(App(zip(), a), b)) => - Success(zip(depApp(take, n)(a))(depApp(take, n)(b)) !: expr.t) + case expr @ App(DepApp(NatKind, take(), n), App(App(zip(), a), b)) => + Success(zip(depApp(NatKind, take, n)(a))(depApp(NatKind, take, n)(b)) !: expr.t) } // zip (take n a) (take n b) -> take n (zip a b) @rule def takeOutisdeZip: Strategy[Rise] = { case e @ App(App(zip(), - App(DepApp(take(), n1: Nat), a)), App(DepApp(take(), n2: Nat), b) + App(DepApp(NatKind, take(), n1: Nat), a)), App(DepApp(NatKind, take(), n2: Nat), b) ) if n1 == n2 => Success(take(n1)(zip(a)(b)) !: e.t) } @@ -132,76 +132,76 @@ object movement { // TODO: can get any function out, see asScalarOutsidePair @rule def takeOutsidePair: Strategy[Rise] = { case e @ App(App(makePair(), - App(DepApp(take(), n: Nat), a)), App(DepApp(take(), m: Nat), b) + App(DepApp(NatKind, take(), n: Nat), a)), App(DepApp(NatKind, take(), m: Nat), b) ) => Success((makePair(a)(b) |> mapFst(take(n)) |> mapSnd(take(m))) !: e.t) } def dropInZip: Strategy[Rise] = `drop n (zip a b) -> zip (drop n a) (drop n b)` @rule def `drop n (zip a b) -> zip (drop n a) (drop n b)`: Strategy[Rise] = { - case expr @ App(DepApp(drop(), n), App(App(zip(), a), b)) => - Success(zip(depApp(drop, n)(a))(depApp(drop, n)(b)) !: expr.t) + case expr @ App(DepApp(NatKind, drop(), n), App(App(zip(), a), b)) => + Success(zip(depApp(NatKind, drop, n)(a))(depApp(NatKind, drop, n)(b)) !: expr.t) } def takeInSelect: Strategy[Rise] = `take n (select t a b) -> select t (take n a) (take n b)` @rule def `take n (select t a b) -> select t (take n a) (take n b)`: Strategy[Rise] = { - case expr @ App(DepApp(take(), n), App(App(App(select(), t), a), b)) => - Success(select(t)(depApp(take, n)(a), depApp(take, n)(b)) !: expr.t) + case expr @ App(DepApp(NatKind, take(), n), App(App(App(select(), t), a), b)) => + Success(select(t)(depApp(NatKind, take, n)(a), depApp(NatKind, take, n)(b)) !: expr.t) } def dropInSelect: Strategy[Rise] = `drop n (select t a b) -> select t (drop n a) (drop n b)` @rule def `drop n (select t a b) -> select t (drop n a) (drop n b)`: Strategy[Rise] = { - case expr @ App(DepApp(drop(), n), App(App(App(select(), t), a), b)) => - Success(select(t)(depApp(drop, n)(a), depApp(drop, n)(b)) !: expr.t) + case expr @ App(DepApp(NatKind, drop(), n), App(App(App(select(), t), a), b)) => + Success(select(t)(depApp(NatKind, drop, n)(a), depApp(NatKind, drop, n)(b)) !: expr.t) } def dropBeforeTake: Strategy[Rise] = `take (n+m) >> drop m -> drop m >> take n` @rule def `take (n+m) >> drop m -> drop m >> take n`: Strategy[Rise] = { - case expr @ App(DepApp(drop(), m: Nat), App(DepApp(take(), nm: Nat), in)) => + case expr @ App(DepApp(NatKind, drop(), m: Nat), App(DepApp(NatKind, take(), nm: Nat), in)) => Success(app(take(nm - m), app(drop(m), preserveType(in))) !: expr.t) } def takeBeforeDrop: Strategy[Rise] = `drop m >> take n -> take (n+m) >> drop m` @rule def `drop m >> take n -> take (n+m) >> drop m`: Strategy[Rise] = { - case expr @ App(DepApp(take(), n: Nat), App(DepApp(drop(), m: Nat), in)) => + case expr @ App(DepApp(NatKind, take(), n: Nat), App(DepApp(NatKind, drop(), m: Nat), in)) => Success(app(drop(m), app(take(n+m), preserveType(in))) !: expr.t) } def takeBeforeSlide: Strategy[Rise] = `slide n m >> take t -> take (m * (t - 1) + n) >> slide n m` @rule def `slide n m >> take t -> take (m * (t - 1) + n) >> slide n m`: Strategy[Rise] = { - case expr @ App(DepApp(take(), t: Nat), App(DepApp(DepApp(slide(), n: Nat), m: Nat), in)) => + case expr @ App(DepApp(NatKind, take(), t: Nat), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), m: Nat), in)) => Success(app(slide(n)(m), take(m * (t - 1) + n)(in)) !: expr.t) } def dropBeforeSlide: Strategy[Rise] = `slide n m >> drop d -> drop (d * m) >> slide n m` @rule def `slide n m >> drop d -> drop (d * m) >> slide n m`: Strategy[Rise] = { - case expr @ App(DepApp(drop(), d: Nat), App(DepApp(DepApp(slide(), n: Nat), m: Nat), in)) => + case expr @ App(DepApp(NatKind, drop(), d: Nat), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), m: Nat), in)) => Success(app(slide(n)(m), drop(d * m)(in)) !: expr.t) } // slide n m >> padEmpty p -> padEmpty (p * m) >> slide n m @rule def padEmptyBeforeSlide: Strategy[Rise] = { - case e @ App(DepApp(padEmpty(), p: Nat), - App(DepApp(DepApp(slide(), n: Nat), m: Nat), in) + case e @ App(DepApp(NatKind, padEmpty(), p: Nat), + App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), m: Nat), in) ) => Success(slide(n)(m)(padEmpty(p * m)(in)) !: e.t) } // map f >> padEmpty n -> padEmpty n >> map f @rule def padEmptyBeforeMap: Strategy[Rise] = { - case e @ App(DepApp(padEmpty(), n: Nat), App(App(map(), f), in)) => + case e @ App(DepApp(NatKind, padEmpty(), n: Nat), App(App(map(), f), in)) => Success(map(f)(padEmpty(n)(in)) !: e.t) } // transpose >> padEmpty n -> map (padEmpty n) >> transpose @rule def padEmptyBeforeTranspose: Strategy[Rise] = { - case e @ App(DepApp(padEmpty(), n: Nat), App(transpose(), in)) => + case e @ App(DepApp(NatKind, padEmpty(), n: Nat), App(transpose(), in)) => Success(transpose(map(padEmpty(n))(in)) !: e.t) } // padEmpty n (zip a b) -> zip (padEmpty n a) (padEmpty n b) @rule def padEmptyInsideZip: Strategy[Rise] = { - case e @ App(DepApp(padEmpty(), n: Nat), App(App(zip(), a), b)) => + case e @ App(DepApp(NatKind, padEmpty(), n: Nat), App(App(zip(), a), b)) => Success(zip(padEmpty(n)(a))(padEmpty(n)(b)) !: e.t) } @@ -209,7 +209,7 @@ object movement { // zip (fst e) (snd e) |> padEmpty n -> // (mapFst padEmpty n) (mapSnd padEmpty n) |> fun(p => zip (fst p) (snd(p)) @rule def padEmptyBeforeZip: Strategy[Rise] = { - case e @ App(DepApp(padEmpty(), n: Nat), + case e @ App(DepApp(NatKind, padEmpty(), n: Nat), App(App(zip(), App(fst(), e1)), App(snd(), e2))) if e1 =~= e2 => Success((preserveType(e1) |> @@ -282,15 +282,15 @@ object movement { def slideBeforeSplit: Strategy[Rise] = `slide(n)(s) >> split(k) -> slide(k+n-s)(k) >> map(slide(n)(s))` @rule def `slide(n)(s) >> split(k) -> slide(k+n-s)(k) >> map(slide(n)(s))`: Strategy[Rise] = { - case e@App(DepApp(split(), k: Nat), App(DepApp(DepApp(slide(), n: Nat), s: Nat), y)) => + case e@App(DepApp(NatKind, split(), k: Nat), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), s: Nat), y)) => Success((preserveType(y) |> slide(k + n - s)(k) |> map(slide(n)(s))) !: e.t) } // TODO: what if s != 1? // slide(n)(s=1) >> slide(m)(k) -> slide(m+n-1)(k) >> map(slide(n)(1)) @rule def slideBeforeSlide: Strategy[Rise] = { - case e@App(DepApp(DepApp(slide(), m: Nat), k: Nat), - App(DepApp(DepApp(slide(), n: Nat), s: Nat), in) + case e@App(DepApp(NatKind, DepApp(NatKind, slide(), m: Nat), k: Nat), + App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), s: Nat), in) ) if s == (1: Nat) => Success((preserveType(in) |> slide(m+n-s)(k) |> map(slide(n)(s))) !: e.t) } diff --git a/src/main/scala/rise/elevate/rules/package.scala b/src/main/scala/rise/elevate/rules/package.scala index ddfa1c515..eb864473b 100644 --- a/src/main/scala/rise/elevate/rules/package.scala +++ b/src/main/scala/rise/elevate/rules/package.scala @@ -15,8 +15,8 @@ package object rules { def betaReduction: Strategy[Rise] = { case App(Lambda(x, b), v) => Success(substitute.exprInExpr(v, `for` = x, in = b)) - case DepApp(DepLambda(x, b), v) => - Success(substitute.kindInExpr(v, `for` = x, in = b)) + case DepApp(k1, DepLambda(k2, x, b), v) if k1 == k2 => + Success(substitute.kindInExpr(k1, v, `for` = x, in = b)) case _ => Failure(betaReduction) } @@ -33,8 +33,8 @@ package object rules { Success(substitute.exprInExpr(v, `for` = x, in = b)) case App(Lambda(x, b), v) if !containsAtLeast(1, x)(ev)(b) => Success(substitute.exprInExpr(v, `for` = x, in = b)) - case DepApp(DepLambda(x, b), v) => - Success(substitute.kindInExpr(v, `for` = x, in = b)) + case DepApp(k1, DepLambda(k2, x, b), v) if k1 == k2 => + Success(substitute.kindInExpr(k1, v, `for` = x, in = b)) case _ => Failure(gentleBetaReduction()) } diff --git a/src/main/scala/rise/elevate/rules/traversal.scala b/src/main/scala/rise/elevate/rules/traversal.scala index f4da6675d..556b9033f 100644 --- a/src/main/scala/rise/elevate/rules/traversal.scala +++ b/src/main/scala/rise/elevate/rules/traversal.scala @@ -13,16 +13,7 @@ object traversal { case class body(s: Strategy[Rise]) extends Strategy[Rise] { def apply(e: Rise): RewriteResult[Rise] = e match { case Lambda(x, f) => s(f).mapSuccess(Lambda(x, _)(e.t)) - case DepLambda(x: NatIdentifier, f) => - s(f).mapSuccess(DepLambda[NatKind](x, _)(e.t)) - case DepLambda(x: DataTypeIdentifier, f) => - s(f).mapSuccess(DepLambda[DataKind](x, _)(e.t)) - case DepLambda(x: AddressSpaceIdentifier, f) => - s(f).mapSuccess(DepLambda[AddressSpaceKind](x, _)(e.t)) - case DepLambda(x: NatToNatIdentifier, f) => - s(f).mapSuccess(DepLambda[NatToNatKind](x, _)(e.t)) - case DepLambda(x: NatToDataIdentifier, f) => - s(f).mapSuccess(DepLambda[NatToDataKind](x, _)(e.t)) + case DepLambda(kind, x, f) => s(f).mapSuccess(DepLambda(kind, x, _)(e.t)) case _ => Failure(s) } override def toString = s"body($s)" @@ -72,7 +63,7 @@ object traversal { // To achieve a traversal that most closely corresponds to the execution order we ... case a @ App(f, e) => e.t match { // ... traverse arguments with a function type after the called function ... - case FunType(_, _) | DepFunType(_, _) => + case FunType(_, _) | DepFunType(_, _, _) => s(f) match { case Success(x: Rise) => Success(App(x, e)(a.t)) case Failure(state) => if (carryOverState) @@ -169,25 +160,8 @@ object traversal { case App(_,_) => throw new Exception("this should not happen") case Identifier(_) => None case l @ Lambda(x, e) => Some(s(e).mapSuccess(Lambda(x, _)(l.t))) - case dl @ DepLambda(x, e) => x match { - case n: NatIdentifier => - Some(s(e).mapSuccess(DepLambda[NatKind](n, _)(dl.t))) - case dt: DataTypeIdentifier => - Some(s(e).mapSuccess(DepLambda[DataKind](dt, _)(dl.t))) - case addr: AddressSpaceIdentifier => - Some(s(e).mapSuccess(DepLambda[AddressSpaceKind](addr, _)(dl.t))) - } - case da @ DepApp(f, x)=> x match { - case n: Nat => Some(s(f).mapSuccess(DepApp[NatKind](_, n)(da.t))) - case dt: DataType => - Some(s(f).mapSuccess(DepApp[DataKind](_, dt)(da.t))) - case addr: AddressSpace => - Some(s(f).mapSuccess(DepApp[AddressSpaceKind](_, addr)(da.t))) - case n2n: NatToNat => - Some(s(f).mapSuccess(DepApp[NatToNatKind](_, n2n)(da.t))) - case n2d: NatToData => - Some(s(f).mapSuccess(DepApp[NatToDataKind](_, n2d)(da.t))) - } + case dl @ DepLambda(kind, x, e) => Some(s(e).mapSuccess(DepLambda(kind, x, _)(dl.t))) + case da @ DepApp(kind, f, x) => Some(s(f).mapSuccess(DepApp(kind, _, x)(da.t))) case Literal(_) => None case _: TypeAnnotation => throw new Exception("Type annotations should be gone.") case _: TypeAssertion => throw new Exception("Type assertions should be gone.") diff --git a/src/main/scala/rise/elevate/rules/vectorize.scala b/src/main/scala/rise/elevate/rules/vectorize.scala index 81bfaddcd..f61822bb2 100644 --- a/src/main/scala/rise/elevate/rules/vectorize.scala +++ b/src/main/scala/rise/elevate/rules/vectorize.scala @@ -90,7 +90,7 @@ object vectorize { // padEmpty (p*v) (asScalar in) -> asScalar (padEmpty p in) @rule def padEmptyBeforeAsScalar: Strategy[Rise] = { - case App(DepApp(padEmpty(), pv: Nat), App(asScalar(), in)) => + case App(DepApp(NatKind, padEmpty(), pv: Nat), App(asScalar(), in)) => in.t match { case ArrayType(_, VectorType(v, _)) if (pv % v) == (0: Nat) => Success(asScalar(padEmpty(pv / v)(in))) @@ -100,7 +100,7 @@ object vectorize { // padEmpty p (asVector v in) -> asVector v (padEmpty (p*v) in) @rule def padEmptyBeforeAsVector: Strategy[Rise] = { - case e @ App(DepApp(padEmpty(), p: Nat), App(asV @ DepApp(_, v: Nat), in)) + case e @ App(DepApp(NatKind, padEmpty(), p: Nat), App(asV @ DepApp(NatKind, _, v: Nat), in)) if isAsVector(asV) => Success(eraseType(asV)(padEmpty(p*v)(in)) !: e.t) } @@ -108,10 +108,10 @@ object vectorize { // TODO: express as a combination of smaller rules @rule def alignSlide: Strategy[Rise] = { case e @ App(transpose(), - App(App(map(), DepApp(asVector(), Cst(v))), + App(App(map(), DepApp(NatKind, asVector(), Cst(v))), App(join(), App(App(map(), transpose()), - App(App(map(), DepApp(padEmpty(), Cst(p))), - App(App(map(), DepApp(DepApp(slide(), Cst(3)), Cst(1))), + App(App(map(), DepApp(NatKind, padEmpty(), Cst(p))), + App(App(map(), DepApp(NatKind, DepApp(NatKind, slide(), Cst(3)), Cst(1))), in ) ) @@ -135,9 +135,9 @@ object vectorize { Success(r !: e.t) case e @ App(transpose(), - App(App(map(), DepApp(asVector(), Cst(v))), - App(transpose(), App(DepApp(padEmpty(), Cst(p)), - App(DepApp(DepApp(slide(), Cst(3)), Cst(1)), in) + App(App(map(), DepApp(NatKind, asVector(), Cst(v))), + App(transpose(), App(DepApp(NatKind, padEmpty(), Cst(p)), + App(DepApp(NatKind, DepApp(NatKind, slide(), Cst(3)), Cst(1)), in) )) ) ) if p <= v => @@ -156,9 +156,9 @@ object vectorize { // TODO: express as a combination of smaller rules // FIXME: function f needs to be element-wise (a hidden mapVec) @rule def mapAfterShuffle: Strategy[Rise] = { - case e @ App(DepApp(asVector(), v: Nat), - App(join(), App(DepApp(DepApp(slide(), v2: Nat), Cst(1)), - App(DepApp(take(), t: Nat), App(asScalar(), + case e @ App(DepApp(NatKind, asVector(), v: Nat), + App(join(), App(DepApp(NatKind, DepApp(NatKind, slide(), v2: Nat), Cst(1)), + App(DepApp(NatKind, take(), t: Nat), App(asScalar(), App(App(map(), f), in) )) )) @@ -172,9 +172,9 @@ object vectorize { // FIXME: this is very specific @rule def padEmptyBeforeZipAsVector: Strategy[Rise] = { - case e @ App(DepApp(padEmpty(), p: Nat), App( + case e @ App(DepApp(NatKind, padEmpty(), p: Nat), App( Lambda(x, App(App(zip(), - App(asV @ DepApp(_, v: Nat), App(fst(), x2))), + App(asV @ DepApp(NatKind, _, v: Nat), App(fst(), x2))), App(asV2, App(snd(), x3)))), in )) if x =~= x2 && x =~= x3 && isAsVector(asV) && asV =~= asV2 => @@ -186,8 +186,8 @@ object vectorize { } def isAsVector: Rise => Boolean = { - case DepApp(asVector(), _: Nat) => true - case DepApp(asVectorAligned(), _: Nat) => true + case DepApp(NatKind, asVector(), _: Nat) => true + case DepApp(NatKind, asVectorAligned(), _: Nat) => true case _ => false } diff --git a/src/main/scala/rise/elevate/strategies/lowering.scala b/src/main/scala/rise/elevate/strategies/lowering.scala index c83280cce..e1d9800f6 100644 --- a/src/main/scala/rise/elevate/strategies/lowering.scala +++ b/src/main/scala/rise/elevate/strategies/lowering.scala @@ -33,10 +33,10 @@ object lowering { def insert(expr: Rise): Strategy[Rise] = _ => Success(expr) def extract(what: Strategy[Rise]): Strategy[Rise] = (expr: Rise) => { what(expr).flatMapFailure(_ => expr match { - case App(f,e) => extract(what)(f).flatMapFailure(_ => extract(what)(e)) - case Lambda(x, e) => extract(what)(x).flatMapFailure(_ => extract(what)(e)) - case DepLambda(_, e) => extract(what)(e) - case DepApp(_, _) => Failure(extract(what)) + case App(f,e) => extract(what)(f).flatMapFailure(_ => extract(what)(e)) + case Lambda(x, e) => extract(what)(x).flatMapFailure(_ => extract(what)(e)) + case DepLambda(_, _, e) => extract(what)(e) + case DepApp(_, _, _) => Failure(extract(what)) case _: Identifier => Failure(extract(what)) case _: Literal => Failure(extract(what)) case _: TypeAnnotation => throw new Exception("Type annotations should be gone.") diff --git a/src/main/scala/rise/elevate/strategies/predicate.scala b/src/main/scala/rise/elevate/strategies/predicate.scala index 97e2a711e..467803c62 100644 --- a/src/main/scala/rise/elevate/strategies/predicate.scala +++ b/src/main/scala/rise/elevate/strategies/predicate.scala @@ -26,7 +26,7 @@ object predicate { } def isLambda: is = is(_.isInstanceOf[Lambda], "Lambda") - def isDepLambda: is = is(_.isInstanceOf[DepLambda[_]], "DepLambda") + def isDepLambda: is = is(_.isInstanceOf[DepLambda[_, _]], "DepLambda") def isIdentifier: is = is(_.isInstanceOf[Identifier], "Identifier") def isApply: is = is(_.isInstanceOf[App], "Apply") diff --git a/src/main/scala/rise/eqsat/Expr.scala b/src/main/scala/rise/eqsat/Expr.scala index 0d52a1eaf..80bac44a5 100644 --- a/src/main/scala/rise/eqsat/Expr.scala +++ b/src/main/scala/rise/eqsat/Expr.scala @@ -127,16 +127,16 @@ object Expr { case i: core.Identifier => Var(bound.indexOf(i)) case core.App(f, e) => App(fromNamed(f, bound), fromNamed(e, bound)) case core.Lambda(i, e) => Lambda(fromNamed(e, bound + i)) - case core.DepApp(f, n: rct.Nat) => + case core.DepApp(rct.NatKind, f, n: rct.Nat) => NatApp(fromNamed(f, bound), Nat.fromNamed(n, bound)) - case core.DepApp(f, dt: rct.DataType) => + case core.DepApp(rct.DataKind, f, dt: rct.DataType) => DataApp(fromNamed(f, bound), DataType.fromNamed(dt, bound)) - case core.DepApp(_, _) => ??? - case core.DepLambda(n: rct.NatIdentifier, e) => + case core.DepApp(_, _, _) => ??? + case core.DepLambda(rct.NatKind, n: rct.NatIdentifier, e) => NatLambda(fromNamed(e, bound + n)) - case core.DepLambda(dt: rct.DataTypeIdentifier, e) => + case core.DepLambda(rct.DataKind, dt: rct.DataTypeIdentifier, e) => DataLambda(fromNamed(e, bound + dt)) - case core.DepLambda(_, _) => ??? + case core.DepLambda(_, _, _) => ??? case core.Literal(d) => Literal(d) // note: we set the primitive type to a place holder here, // because we do not want type information at the node level @@ -155,15 +155,15 @@ object Expr { val i = core.Identifier(s"x${bound.expr.size}")(Type.toNamed(funT.inT, bound)) core.Lambda(i, toNamed(e, bound + i)) _ case NatApp(f, x) => - core.DepApp[rct.NatKind](toNamed(f, bound), Nat.toNamed(x, bound)) _ + core.DepApp(rct.NatKind, toNamed(f, bound), Nat.toNamed(x, bound)) _ case NatLambda(e) => val i = rct.NatIdentifier(s"n${bound.nat.size}", isExplicit = true) - core.DepLambda[rct.NatKind](i, toNamed(e, bound + i)) _ + core.DepLambda(rct.NatKind, i, toNamed(e, bound + i)) _ case DataApp(f, x) => - core.DepApp[rct.DataKind](toNamed(f, bound), DataType.toNamed(x, bound)) _ + core.DepApp(rct.DataKind, toNamed(f, bound), DataType.toNamed(x, bound)) _ case DataLambda(e) => val i = rct.DataTypeIdentifier(s"dt${bound.data.size}", isExplicit = true) - core.DepLambda[rct.DataKind](i, toNamed(e, bound + i)) _ + core.DepLambda(rct.DataKind, i, toNamed(e, bound + i)) _ case Literal(d) => core.Literal(d).setType _ case Primitive(p) => p.setType _ })(Type.toNamed(expr.t, bound)) diff --git a/src/main/scala/rise/eqsat/NamedRewrite.scala b/src/main/scala/rise/eqsat/NamedRewrite.scala index 804cff60e..638f7b924 100644 --- a/src/main/scala/rise/eqsat/NamedRewrite.scala +++ b/src/main/scala/rise/eqsat/NamedRewrite.scala @@ -73,24 +73,24 @@ object NamedRewrite { // lam(x : xt, e : et) : xt -> et case rc.Lambda(x, e) => PatternNode(Lambda(makePat(e, bound + x, isRhs, matchType = false))) - case rc.DepLambda(x: rct.NatIdentifier, e) => + case rc.DepLambda(rct.NatKind, x: rct.NatIdentifier, e) => PatternNode(NatLambda(makePat(e, bound + x, isRhs, matchType = false))) - case rc.DepLambda(x: rct.DataTypeIdentifier, e) => + case rc.DepLambda(rct.DataKind, x: rct.DataTypeIdentifier, e) => PatternNode(DataLambda(makePat(e, bound + x, isRhs, matchType = false))) - case rc.DepLambda(_, _) => ??? + case rc.DepLambda(_, _, _) => ??? // note: we do not match for the type of applied functions, as we can always infer it: // app(f : et -> at, e : et) : at case rc.App(f, e) => PatternNode(App(makePat(f, bound, isRhs, matchType = false), makePat(e, bound, isRhs))) - case rc.DepApp(f, x: rct.Nat) => + case rc.DepApp(rct.NatKind, f, x: rct.Nat) => PatternNode(NatApp( makePat(f, bound, isRhs, matchType = false), makeNPat(x, bound, isRhs))) - case rc.DepApp(f, x: rct.DataType) => + case rc.DepApp(rct.DataKind, f, x: rct.DataType) => PatternNode(DataApp( makePat(f, bound, isRhs, matchType = false), makeDTPat(x, bound, isRhs))) - case rc.DepApp(_, _) => ??? + case rc.DepApp(_, _, _) => ??? case rc.Literal(d) => PatternNode(Literal(d)) // note: we set the primitive type to a place holder here, // because we do not want type information at the node level @@ -144,7 +144,7 @@ object NamedRewrite { DataTypePatternNode(PairType(makeDTPat(dt1, bound, isRhs), makeDTPat(dt2, bound, isRhs))) case rct.ArrayType(s, et) => DataTypePatternNode(ArrayType(makeNPat(s, bound, isRhs), makeDTPat(et, bound, isRhs))) - case _: rct.DepArrayType | _: rct.DepPairType[_] | + case _: rct.DepArrayType | _: rct.DepPairType[_, _] | _: rct.NatToDataApply | _: rct.FragmentType => throw new Exception(s"did not expect $dt") } @@ -154,11 +154,11 @@ object NamedRewrite { case dt: rct.DataType => makeDTPat(dt, bound, isRhs) case rct.FunType(a, b) => TypePatternNode(FunType(makeTPat(a, bound, isRhs), makeTPat(b, bound, isRhs))) - case rct.DepFunType(x: rct.NatIdentifier, t) => + case rct.DepFunType(rct.NatKind, x: rct.NatIdentifier, t) => TypePatternNode(NatFunType(makeTPat(t, bound + x, isRhs))) - case rct.DepFunType(x: rct.DataTypeIdentifier, t) => + case rct.DepFunType(rct.DataKind, x: rct.DataTypeIdentifier, t) => TypePatternNode(DataFunType(makeTPat(t, bound + x, isRhs))) - case rct.DepFunType(_, _) => ??? + case rct.DepFunType(_, _, _) => ??? case i: rct.TypeIdentifier => assert(freeT(i)) makePatVar(i.name, (bound.nat.size, bound.data.size), @@ -425,10 +425,10 @@ object NamedRewriteDSL { def lam(name: String, e: Pattern): Pattern = rc.Lambda(rc.Identifier(name)(TypePlaceholder), e)(TypePlaceholder) def nApp(f: Pattern, x: NatPattern): Pattern = - rc.DepApp[rct.NatKind](f, x)(TypePlaceholder) + rc.DepApp(rct.NatKind, f, x)(TypePlaceholder) def nLam(name: String, e: Pattern): Pattern = { val n = rct.NatIdentifier(name, isExplicit = true) - rc.DepLambda[rct.NatKind](n, e)(TypePlaceholder) + rc.DepLambda(rct.NatKind, n, e)(TypePlaceholder) } def l(d: rc.semantics.Data): Pattern = rc.Literal(d) diff --git a/src/main/scala/rise/eqsat/TypeNode.scala b/src/main/scala/rise/eqsat/TypeNode.scala index 6119a5267..ad468d7c5 100644 --- a/src/main/scala/rise/eqsat/TypeNode.scala +++ b/src/main/scala/rise/eqsat/TypeNode.scala @@ -115,9 +115,9 @@ object Type { Type(t match { case dt: rct.DataType => DataType.fromNamed(dt, bound).node case rct.FunType(a, b) => FunType(fromNamed(a, bound), fromNamed(b, bound)) - case rct.DepFunType(x: rct.NatIdentifier, t) => NatFunType(fromNamed(t, bound + x)) - case rct.DepFunType(x: rct.DataTypeIdentifier, t) => DataFunType(fromNamed(t, bound + x)) - case rct.DepFunType(_, _) => ??? + case rct.DepFunType(rct.NatKind, x: rct.NatIdentifier, t) => NatFunType(fromNamed(t, bound + x)) + case rct.DepFunType(rct.DataKind, x: rct.DataTypeIdentifier, t) => DataFunType(fromNamed(t, bound + x)) + case rct.DepFunType(_, _, _) => ??? case rct.TypePlaceholder | rct.TypeIdentifier(_) => throw new Exception(s"did not expect $t") }) @@ -129,10 +129,10 @@ object Type { case FunType(a, b) => rct.FunType(toNamed(a, bound), toNamed(b, bound)) case NatFunType(t) => val i = rct.NatIdentifier(s"n${bound.nat.size}", isExplicit = true) - rct.DepFunType[rct.NatKind, rct.Type](i, toNamed(t, bound + i)) + rct.DepFunType(rct.NatKind, i, toNamed(t, bound + i)) case DataFunType(t) => val i = rct.DataTypeIdentifier(s"n${bound.data.size}", isExplicit = true) - rct.DepFunType[rct.DataKind, rct.Type](i, toNamed(t, bound + i)) + rct.DepFunType(rct.DataKind, i, toNamed(t, bound + i)) } } @@ -150,7 +150,7 @@ object DataType { case rct.IndexType(s) => IndexType(Nat.fromNamed(s, bound)) case rct.PairType(dt1, dt2) => PairType(fromNamed(dt1, bound), fromNamed(dt2, bound)) case rct.ArrayType(s, et) => ArrayType(Nat.fromNamed(s, bound), fromNamed(et, bound)) - case _: rct.DepArrayType | _: rct.DepPairType[_] | + case _: rct.DepArrayType | _: rct.DepPairType[_, _] | _: rct.NatToDataApply | _: rct.FragmentType | _: rct.ManagedBufferType | _: rct.OpaqueType => throw new Exception(s"did not expect $dt") }) diff --git a/src/main/scala/rise/openCL/DSL.scala b/src/main/scala/rise/openCL/DSL.scala index c75f1da76..ccde849d2 100644 --- a/src/main/scala/rise/openCL/DSL.scala +++ b/src/main/scala/rise/openCL/DSL.scala @@ -2,7 +2,7 @@ package rise.openCL import rise.core.DSL._ import rise.core.{Expr, Primitive} -import rise.core.types.AddressSpaceKind +import rise.core.types.AddressSpace import shine.OpenCL.{GlobalSize, LocalSize} object DSL { @@ -32,17 +32,17 @@ object DSL { to: ToBeTyped[A], f: ToBeTyped[B] ): ToBeTyped[rise.core.Lambda] = fun(x => to(f(x))) - val toGlobal: ToBeTyped[rise.core.DepApp[AddressSpaceKind]] = toMem( + val toGlobal: ToBeTyped[rise.core.DepApp[AddressSpace]] = toMem( rise.core.types.AddressSpace.Global ) def toGlobalFun[T <: Expr](f: ToBeTyped[T]): ToBeTyped[rise.core.Lambda] = toFun(toGlobal, f) - val toLocal: ToBeTyped[rise.core.DepApp[AddressSpaceKind]] = toMem( + val toLocal: ToBeTyped[rise.core.DepApp[AddressSpace]] = toMem( rise.core.types.AddressSpace.Local ) def toLocalFun[T <: Expr](f: ToBeTyped[T]): ToBeTyped[rise.core.Lambda] = toFun(toLocal, f) - val toPrivate: ToBeTyped[rise.core.DepApp[AddressSpaceKind]] = toMem( + val toPrivate: ToBeTyped[rise.core.DepApp[AddressSpace]] = toMem( rise.core.types.AddressSpace.Private ) def toPrivateFun[T <: Expr](f: ToBeTyped[T]): ToBeTyped[rise.core.Lambda] = diff --git a/src/main/scala/shine/C/Compilation/CodeGenerator.scala b/src/main/scala/shine/C/Compilation/CodeGenerator.scala index 05eb775d0..fb6974832 100644 --- a/src/main/scala/shine/C/Compilation/CodeGenerator.scala +++ b/src/main/scala/shine/C/Compilation/CodeGenerator.scala @@ -120,7 +120,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, case f@ForNat(unroll) => f.loopBody match { - case shine.DPIA.Phrases.DepLambda(i, p) => + case shine.DPIA.Phrases.DepLambda(NatKind, i, p) => CCodeGen.codeGenForNat(f.n, i, p, unroll, env) case _ => throw new Exception("This should not happen") } @@ -134,11 +134,11 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, // on the fly beta-reduction case Apply(fun, arg) => Lifting.liftFunction(fun).reducing(arg) |> cmd(env) - case DepApply(fun, arg) => arg match { + case DepApply(kind, fun, arg) => arg match { case a: Nat => - Lifting.liftDependentFunction(fun.asInstanceOf[Phrase[NatKind `()->:` CommType]])(a) |> cmd(env) + Lifting.liftDependentFunction(fun.asInstanceOf[Phrase[NatIdentifier `()->:` CommType]])(a) |> cmd(env) case a: DataType => - Lifting.liftDependentFunction(fun.asInstanceOf[Phrase[DataKind `()->:` CommType]])(a) |> cmd(env) + Lifting.liftDependentFunction(fun.asInstanceOf[Phrase[NatIdentifier `()->:` CommType]])(a) |> cmd(env) } case DMatchI(x, inT, _, f, dPair) => @@ -171,7 +171,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, C.AST.ArraySubscript(C.AST.Cast(C.AST.PointerType(C.AST.Type.u32), expr), C.AST.Literal("0") ) , fst))) }) - case Apply(_, _) | DepApply(_, _) | + case Apply(_, _) | DepApply(_, _, _) | _: CommandPrimitive => error(s"Don't know how to generate code for $phrase") }, env) @@ -289,7 +289,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, case MkDPairSndAcc(_, _, a) => a |> acc(env, DPairSnd :: path, cont) - case phrase@(Apply(_, _) | DepApply(_, _) | + case phrase@(Apply(_, _) | DepApply(_, _, _) | Phrases.IfThenElse(_, _, _) | LetNat(_, _, _) | _: AccPrimitive) => error(s"Don't know how to generate code for $phrase") } @@ -517,7 +517,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, case Proj1(pair) => SimplifyNats.simplifyIndexAndNatExp(Lifting.liftPair(pair)._1) |> exp(env, path, cont) case Proj2(pair) => SimplifyNats.simplifyIndexAndNatExp(Lifting.liftPair(pair)._2) |> exp(env, path, cont) - case phrase@(Apply(_, _) | DepApply(_, _) | + case phrase@(Apply(_, _) | DepApply(_, _, _) | Phrases.IfThenElse(_, _, _) | LetNat(_, _, _) | _: ExpPrimitive) => error(s"Don't know how to generate code for $phrase") } @@ -638,7 +638,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, case None => error("Parameter missing") case Some(Left(param)) => generateInlinedCall(l(param), env, args.tail, cont) } - case ndl: DepLambda[NatKind, _]@unchecked => args.headOption match { + case ndl: DepLambda[Nat, NatIdentifier, _]@unchecked => args.headOption match { case Some(Right(nat)) => generateInlinedCall(ndl(nat), env, args.tail, cont) case None => error("Parameter missing") case Some(Left(_)) => error("Expression phrase argument passed but nat expected") diff --git a/src/main/scala/shine/C/Compilation/TranslationContext.scala b/src/main/scala/shine/C/Compilation/TranslationContext.scala index e7f69ed88..76e733db1 100644 --- a/src/main/scala/shine/C/Compilation/TranslationContext.scala +++ b/src/main/scala/shine/C/Compilation/TranslationContext.scala @@ -25,7 +25,7 @@ class TranslationContext() extends shine.DPIA.Compilation.TranslationContext { //TODO makes a decision. Not allowed! case DepArrayType(n, ft) => DepMapSeqI(unroll = false)(n, ft, ft, - depFun[NatKind]()(k => + depFun(NatKind)(k => λ(ExpType(ft(k), read))(x => λ(AccType( ft(k) ))(a => assign(ft(k), a, x) ))), rhs, lhs) diff --git a/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala b/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala index ed9f10d40..5063fc189 100644 --- a/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala @@ -25,13 +25,13 @@ object AcceptorTranslation { E match { // on the fly beta-reduction case Apply(fun, arg) => acc(Lifting.liftFunction(fun).reducing(arg))(A) - case DepApply(fun, arg) => arg match { + case DepApply(kind, fun, arg) => arg match { case a: Nat => - acc(Lifting.liftDependentFunction[NatKind, ExpType]( - fun.asInstanceOf[ Phrase[NatKind `()->:` ExpType]])(a))(A) + acc(Lifting.liftDependentFunction( + fun.asInstanceOf[ Phrase[NatIdentifier `()->:` ExpType]])(a))(A) case a: DataType => - acc(Lifting.liftDependentFunction[DataKind, ExpType]( - fun.asInstanceOf[Phrase[DataKind `()->:` ExpType]])(a))(A) + acc(Lifting.liftDependentFunction( + fun.asInstanceOf[Phrase[DataTypeIdentifier `()->:` ExpType]])(a))(A) } case e @@ -103,7 +103,7 @@ object AcceptorTranslation { case depMapSeq@DepMapSeq(unroll) => val (n, ft1, ft2, f, array) = depMapSeq.unwrap con(array)(λ(expT(n`.d`ft1, read))(x => - DepMapSeqI(unroll)(n, ft1, ft2, _Λ_[NatKind]()((k: NatIdentifier) => + DepMapSeqI(unroll)(n, ft1, ft2, _Λ_(NatKind)((k: NatIdentifier) => λ(expT(ft1(k), read))(x => λ(accT(ft2(k)))(o => { acc(f(k)(x))(o) }))), x, A))) @@ -115,7 +115,7 @@ object AcceptorTranslation { // Turn the f imperative by means of forwarding the acceptor translation con(input)(λ(expT(DepPairType(x, elemT), read))(pair => DMatchI(x, elemT, outT, - _Λ_[NatKind]()((fst: NatIdentifier) => + _Λ_(NatKind)((fst: NatIdentifier) => λ(expT(DataType.substitute(fst, x, elemT), read))(snd => acc(f(fst)(snd))(A) )), pair))) @@ -127,7 +127,7 @@ object AcceptorTranslation { case Iterate(n, m, k, dt, f, array) => con(array)(λ(expT((m * n.pow(k))`.`dt, read))(x => IterateIAcc(n, m, k, dt, A, - _Λ_[NatKind]()(l => λ(accT(l `.` dt))(o => + _Λ_(NatKind)(l => λ(accT(l `.` dt))(o => λ(expT((l * n)`.`dt, read))(x => acc(f(l)(x))(o)))), x))) @@ -252,7 +252,7 @@ object AcceptorTranslation { // OpenMP case omp.DepMapPar(n, ft1, ft2, f, array) => con(array)(λ(expT(n`.d`ft1, read))(x => - ompI.DepMapParI(n, ft1, ft2, _Λ_[NatKind]()((k: NatIdentifier) => + ompI.DepMapParI(n, ft1, ft2, _Λ_(NatKind)((k: NatIdentifier) => λ(expT(ft1(k), read))(x => λ(accT(ft2(k)))(o => { acc(f(k)(x))(o) }))), x, A))) @@ -271,7 +271,7 @@ object AcceptorTranslation { case depMap@ocl.DepMap(level, dim) => val (n, ft1, ft2, f, array) = depMap.unwrap con(array)(λ(expT(n`.d`ft1, read))(x => - oclI.DepMapI(level, dim)(n, ft1, ft2, _Λ_[NatKind]()((k: NatIdentifier) => + oclI.DepMapI(level, dim)(n, ft1, ft2, _Λ_(NatKind)((k: NatIdentifier) => λ(expT(ft1(k), read))(x => λ(accT(ft2(k)))(o => { acc(f(k)(x))(o) }))), x, A))) @@ -279,7 +279,7 @@ object AcceptorTranslation { case ocl.Iterate(a, n, m, k, dt, f, array) => con(array)(λ(expT({m * n.pow(k)}`.`dt, read))(x => oclI.IterateIAcc(a, n, m, k, dt, A, - _Λ_[NatKind]()(l => λ(accT(l`.`dt))(o => + _Λ_(NatKind)(l => λ(accT(l`.`dt))(o => λ(expT({l * n}`.`dt, read))(x => acc(f(l)(x))(o)))), x))) diff --git a/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala b/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala index 1153d4997..3da1212a4 100644 --- a/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala @@ -44,13 +44,13 @@ object ContinuationTranslation { // on the fly beta-reduction case Apply(fun, arg) => con(Lifting.liftFunction(fun).reducing(arg))(C) - case DepApply(fun, arg) => arg match { + case DepApply(kind, fun, arg) => arg match { case a: Nat => - con(Lifting.liftDependentFunction[NatKind, ExpType]( - fun.asInstanceOf[Phrase[NatKind `()->:` ExpType]])(a))(C) + con(Lifting.liftDependentFunction( + fun.asInstanceOf[Phrase[NatIdentifier `()->:` ExpType]])(a))(C) case a: DataType => - con(Lifting.liftDependentFunction[DataKind, ExpType]( - fun.asInstanceOf[Phrase[DataKind `()->:` ExpType]])(a))(C) + con(Lifting.liftDependentFunction( + fun.asInstanceOf[Phrase[DataTypeIdentifier `()->:` ExpType]])(a))(C) } case IfThenElse(cond, thenP, elseP) => @@ -110,7 +110,7 @@ object ContinuationTranslation { // Turn the f imperative by means of forwarding the continuation translation con(input)(λ(expT(DepPairType(x, elemT), read))(pair => DMatchI(x, elemT, outT, - _Λ_[NatKind]()((fst: NatIdentifier) => + _Λ_(NatKind)((fst: NatIdentifier) => λ(expT(DataType.substitute(fst, x, elemT), read))(snd => con(f(fst)(snd))(C) )), pair))) diff --git a/src/main/scala/shine/DPIA/Compilation/FedeTranslation.scala b/src/main/scala/shine/DPIA/Compilation/FedeTranslation.scala index 018bedb6f..90037bc77 100644 --- a/src/main/scala/shine/DPIA/Compilation/FedeTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/FedeTranslation.scala @@ -26,13 +26,13 @@ object FedeTranslation { // on the fly beta-reduction case Apply(fun, arg) => fedAcc(env)( Lifting.liftFunction(fun).reducing(arg))(C) - case DepApply(fun, arg) => arg match { + case DepApply(kind, fun, arg) => arg match { case a: Nat => fedAcc(env)( - Lifting.liftDependentFunction[NatKind, ExpType]( - fun.asInstanceOf[Phrase[NatKind `()->:` ExpType]])(a))(C) + Lifting.liftDependentFunction( + fun.asInstanceOf[Phrase[NatIdentifier `()->:` ExpType]])(a))(C) case a: DataType => fedAcc(env)( - Lifting.liftDependentFunction[DataKind, ExpType]( - fun.asInstanceOf[Phrase[DataKind `()->:` ExpType]])(a))(C) + Lifting.liftDependentFunction( + fun.asInstanceOf[Phrase[DataTypeIdentifier `()->:` ExpType]])(a))(C) } case IfThenElse(cond, thenP, elseP) => ??? diff --git a/src/main/scala/shine/DPIA/Compilation/FunDef.scala b/src/main/scala/shine/DPIA/Compilation/FunDef.scala index be302dd20..4cd5065b0 100644 --- a/src/main/scala/shine/DPIA/Compilation/FunDef.scala +++ b/src/main/scala/shine/DPIA/Compilation/FunDef.scala @@ -26,11 +26,11 @@ class FunDef(val name: String, ) = p match { case Apply(f, a) => splitBodyAndParams(Lifting.liftFunction(f).reducing(a), ps, defs) - case DepApply(f, a) => + case DepApply(_, f, a) => splitBodyAndParams(Lifting.liftDependentFunction(f)(a), ps, defs) case l: Lambda[ExpType, _]@unchecked => splitBodyAndParams(l.body, l.param +: ps, defs) - case ndl: DepLambda[_, _] => + case ndl: DepLambda[_, _, _] => splitBodyAndParams(ndl.body, Identifier(ndl.x.name, ExpType(int, read)) +: ps, defs) case ln:LetNat[ExpType, _]@unchecked => diff --git a/src/main/scala/shine/DPIA/Compilation/Passes/UnrollLoops.scala b/src/main/scala/shine/DPIA/Compilation/Passes/UnrollLoops.scala index e002382ab..4f3dbbb93 100644 --- a/src/main/scala/shine/DPIA/Compilation/Passes/UnrollLoops.scala +++ b/src/main/scala/shine/DPIA/Compilation/Passes/UnrollLoops.scala @@ -25,7 +25,7 @@ object UnrollLoops { } case f@ForNat(true) => f.loopBody match { - case shine.DPIA.Phrases.DepLambda(x, body) => + case shine.DPIA.Phrases.DepLambda(kind, x, body) => Continue(unrollLoop(f.n, init = 0, step = 1, i => PhraseType.substitute(i, `for` = x, in = body)), this) case _ => throw new Exception("This should not happen") diff --git a/src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala b/src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala index d6cbf4ab7..9666c08ad 100644 --- a/src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala @@ -23,14 +23,14 @@ object StreamTranslation { // on the fly beta-reduction case Apply(fun, arg) => str(Lifting.liftFunction(fun).reducing(arg))(C) - case DepApply(fun, arg) => arg match { + case DepApply(_, fun, arg) => arg match { case a: Nat => str( - Lifting.liftDependentFunction[NatKind, ExpType]( - fun.asInstanceOf[Phrase[NatKind `()->:` ExpType]])(a) + Lifting.liftDependentFunction( + fun.asInstanceOf[Phrase[NatIdentifier `()->:` ExpType]])(a) )(C) case a: DataType => str( - Lifting.liftDependentFunction[DataKind, ExpType]( - fun.asInstanceOf[Phrase[DataKind `()->:` ExpType]])(a) + Lifting.liftDependentFunction( + fun.asInstanceOf[Phrase[DataTypeIdentifier `()->:` ExpType]])(a) )(C) } @@ -82,9 +82,9 @@ object StreamTranslation { (expT(dt2, read) ->: (comm: CommType)) ->: (comm: CommType) )(next2 => C(nFun(i => fun(expT(dt1 x dt2, read) ->: (comm: CommType))(k => - Apply(DepApply[NatKind, (ExpType ->: CommType) ->: CommType](next1, i), + Apply(DepApply(NatKind, next1, i), fun(expT(dt1, read))(x1 => - Apply(DepApply[NatKind, (ExpType ->: CommType) ->: CommType](next2, i), + Apply(DepApply(NatKind, next2, i), fun(expT(dt2, read))(x2 => k(MakePair(dt1, dt2, read, x1, x2)) ))))), diff --git a/src/main/scala/shine/DPIA/DSL/Core.scala b/src/main/scala/shine/DPIA/DSL/Core.scala index daaae99b6..7a23c352a 100644 --- a/src/main/scala/shine/DPIA/DSL/Core.scala +++ b/src/main/scala/shine/DPIA/DSL/Core.scala @@ -26,21 +26,19 @@ object λ extends funDef object nFun { def apply[T <: PhraseType](f: NatIdentifier => Phrase[T], - range: arithexpr.arithmetic.Range): DepLambda[NatKind, T] = { + range: arithexpr.arithmetic.Range): DepLambda[Nat, NatIdentifier, T] = { val x = NatIdentifier(freshName("n"), range) - DepLambda[NatKind, T](x, f(x)) + DepLambda(NatKind, x, f(x)) } } trait depFunDef { - def apply[K <: Kind](): Object { - def apply[T <: PhraseType](f: K#I => Phrase[T]) - (implicit w: Kind.IdentifierMaker[K], kn: KindName[K]): DepLambda[K, T] + def apply[T, I <: Kind.Identifier](kind: Kind[T, I]): Object { + def apply[U <: PhraseType](f: I => Phrase[U]): DepLambda[T, I, U] } = new { - def apply[T <: PhraseType](f: K#I => Phrase[T]) - (implicit w: Kind.IdentifierMaker[K], kn: KindName[K]): DepLambda[K, T] = { - val x = w.makeIdentifier() - DepLambda(x, f(x)) + def apply[U <: PhraseType](f: I => Phrase[U]): DepLambda[T, I, U] = { + val x = kind.makeIdentifier + DepLambda(kind, x, f(x)) } } } diff --git a/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala b/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala index 2d63673f2..2b005c03f 100644 --- a/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala +++ b/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala @@ -76,7 +76,7 @@ object streamNext { f: Phrase[ExpType ->: CommType] ): Phrase[CommType] = { Phrases.Apply( - Phrases.DepApply[NatKind, (ExpType ->: CommType) ->: CommType](next, i), + Phrases.DepApply(NatKind, next, i), f ) } diff --git a/src/main/scala/shine/DPIA/DSL/package.scala b/src/main/scala/shine/DPIA/DSL/package.scala index 40e1dc323..3bf6c8bde 100644 --- a/src/main/scala/shine/DPIA/DSL/package.scala +++ b/src/main/scala/shine/DPIA/DSL/package.scala @@ -97,14 +97,14 @@ package object DSL { implicit class CallNatDependentLambda[T <: PhraseType](fun: Phrase[`(nat)->:`[T]]) { def apply(arg: Nat): Phrase[T] = - Lifting.liftDependentFunction[NatKind, T](fun)(arg) + Lifting.liftDependentFunction(fun)(arg) def $(arg: Nat): Phrase[T] = apply(arg) } implicit class CallTypeDependentLambda[T <: PhraseType](fun: Phrase[`(dt)->:`[T]]) { def apply(arg: DataType): Phrase[T] = - Lifting.liftDependentFunction[DataKind, T](fun)(arg) + Lifting.liftDependentFunction(fun)(arg) def $(arg: DataType): Phrase[T] = apply(arg) } diff --git a/src/main/scala/shine/DPIA/InferAccessAnnotation.scala b/src/main/scala/shine/DPIA/InferAccessAnnotation.scala index 20a797770..a844c3504 100644 --- a/src/main/scala/shine/DPIA/InferAccessAnnotation.scala +++ b/src/main/scala/shine/DPIA/InferAccessAnnotation.scala @@ -37,7 +37,7 @@ private class InferAccessAnnotation { @tailrec private def funOutIsWrite(ePt: PhraseType): Boolean = ePt match { - case DepFunType(_, t) => funOutIsWrite(t) + case DepFunType(_, _, t) => funOutIsWrite(t) case FunType(_, t) => funOutIsWrite(t) case expT: ExpType => expT `<=` ExpType(expT.dataType, write) case _ => throw error("This should never happen.") @@ -128,7 +128,7 @@ private class InferAccessAnnotation { ctx, isKernelParamFun) case appl: r.App => inferApp(appl, ctx, addsKernelParam(e, isKernelParamFun)) - case depL: r.DepLambda[_] => + case depL: r.DepLambda[_, _] => inferDepLambda(depL, ctx, isKernelParamFun) case depA: r.DepApp[_] => inferDepApp(depA, ctx, addsKernelParam(e, isKernelParamFun)) @@ -199,7 +199,7 @@ private class InferAccessAnnotation { } private def inferDepLambda( - depLambda: r.DepLambda[_], + depLambda: r.DepLambda[_, _], ctx: Context, kernelParamFun: Boolean ): (PhraseType, Subst) = { @@ -207,16 +207,15 @@ private class InferAccessAnnotation { val depLambdaType = depLambda.x match { case n: rt.NatIdentifier => - DepFunType[NatKind, PhraseType](natIdentifier(n), eType) + DepFunType(NatKind, natIdentifier(n), eType) case dt: rt.DataTypeIdentifier => - DepFunType[DataKind, PhraseType](dataTypeIdentifier(dt), eType) + DepFunType(DataKind, dataTypeIdentifier(dt), eType) case ad: rt.AddressSpaceIdentifier => - DepFunType[AddressSpaceKind, PhraseType]( - addressSpaceIdentifier(ad), eType) + DepFunType(AddressSpaceKind, addressSpaceIdentifier(ad), eType) case n2n: rt.NatToNatIdentifier => - DepFunType[NatToNatKind, PhraseType](natToNatIdentifier(n2n), eType) + DepFunType(NatToNatKind, natToNatIdentifier(n2n), eType) case n2d: rt.NatToDataIdentifier => - DepFunType[NatToDataKind, PhraseType](natToDataIdentifier(n2d), eType) + DepFunType(NatToDataKind, natToDataIdentifier(n2d), eType) } ptAnnotationMap.put(depLambda, depLambdaType) (depLambdaType, eSubst) @@ -230,17 +229,17 @@ private class InferAccessAnnotation { val (fType, fSubst) = inferPhraseTypes(depApp.f, ctx, kernelParamFun) val depAppType = depApp.x match { - case dt: rt.DataKind#T => - Lifting.liftDependentFunctionType[DataKind](fType)(dataType(dt)) - case addr: rt.AddressSpaceKind#T => - Lifting.liftDependentFunctionType[AddressSpaceKind](fType)( + case dt: rt.DataType => + Lifting.liftDependentFunctionType[DataType](fType)(dataType(dt)) + case addr: rt.AddressSpace => + Lifting.liftDependentFunctionType[AddressSpace](fType)( addressSpace(addr)) - case n: rt.NatKind#T => - Lifting.liftDependentFunctionType[NatKind](fType)(n) - case n2n: rt.NatToNatKind#T => - Lifting.liftDependentFunctionType[NatToNatKind](fType)(ntn(n2n)) - case n2d: rt.NatToDataKind#T => - Lifting.liftDependentFunctionType[NatToDataKind](fType)(ntd(n2d)) + case n: rt.Nat => + Lifting.liftDependentFunctionType[Nat](fType)(n) + case n2n: rt.NatToNat => + Lifting.liftDependentFunctionType[NatToNat](fType)(ntn(n2n)) + case n2d: rt.NatToData => + Lifting.liftDependentFunctionType[NatToData](fType)(ntd(n2d)) } ptAnnotationMap.put(depApp, depAppType) (depAppType, fSubst) @@ -538,7 +537,7 @@ private class InferAccessAnnotation { expT(dataType(dt), read) case rt.FunType(in: rt.DataType, out) => expT(in, read) ->: buildType(out) - case rt.DepFunType(d: rt.DataTypeIdentifier, t) => + case rt.DepFunType(rt.DataKind, d: rt.DataTypeIdentifier, t) => dFunT(d, buildType(t)) case _ => throw Exception("This should not happen") } @@ -555,9 +554,9 @@ private class InferAccessAnnotation { case rp.depMapSeq() => def buildType(t: rt.Type): PhraseType = t match { - case rt.FunType(rt.DepFunType(i, rt.FunType(elemInT:rt.DataType, elemOutT:rt.DataType)), + case rt.FunType(rt.DepFunType(rt.NatKind, i: rt.NatIdentifier, rt.FunType(elemInT:rt.DataType, elemOutT:rt.DataType)), rt.FunType(inArr@rt.DepArrayType(_, _), outArr@rt.DepArrayType(_, _))) => - val iNat = natIdentifier(i.asInstanceOf[rt.NatIdentifier]) + val iNat = natIdentifier(i) nFunT(iNat, expT(dataType(elemInT), read) ->: expT(dataType(elemOutT), write)) ->: expT(dataType(inArr), read) ->: expT(dataType(outArr), write) case _ => error("did not expect t") @@ -567,31 +566,24 @@ private class InferAccessAnnotation { case rp.dmatch() => val a = accessTypeIdentifier() def buildType(t: rt.Type): PhraseType = t match { - case rt.FunType(rt.DepPairType(x, elemT), - rt.FunType(rt.DepFunType(i, rt.FunType(app1:rt.DataType, outT:rt.DataType)), retT:rt.DataType)) => - x match { - case x:rt.NatIdentifier => - assert(i.isInstanceOf[rt.NatIdentifier]) - val i_ = natIdentifier(i.asInstanceOf[rt.NatIdentifier]) - expT(DepPairType(natIdentifier(x), dataType(elemT)), read) ->: - nFunT(i_, expT(dataType(app1), read) ->: expT(dataType(outT), a)) ->: - expT(dataType(retT), a) - case _ => ??? - } + case rt.FunType(rt.DepPairType(rt.NatKind, x: rt.NatIdentifier, elemT), + rt.FunType(rt.DepFunType(rt.NatKind, i: rt.NatIdentifier, + rt.FunType(app1:rt.DataType, outT:rt.DataType)), retT:rt.DataType)) => + + val i_ = natIdentifier(i.asInstanceOf[rt.NatIdentifier]) + expT(DepPairType(natIdentifier(x), dataType(elemT)), read) ->: + nFunT(i_, expT(dataType(app1), read) ->: expT(dataType(outT), a)) ->: + expT(dataType(retT), a) case _ => error(s"did not expect t") } buildType(p.t) case rp.makeDepPair() => def buildType(t: rt.Type): PhraseType = t match { - case rt.DepFunType(fst, rt.FunType(sndT:rt.DataType, outT:rt.DataType)) => + case rt.DepFunType(rt.NatKind, fst: rt.NatIdentifier, rt.FunType(sndT:rt.DataType, outT:rt.DataType)) => val a1 = accessTypeIdentifier() - fst match { - case fst:rt.NatIdentifier => - val fst_ = natIdentifier(fst) - nFunT(fst_, expT(dataType(sndT), a1) ->: expT(dataType(outT), a1)) - case _ => ??? - } + val fst_ = natIdentifier(fst) + nFunT(fst_, expT(dataType(sndT), a1) ->: expT(dataType(outT), a1)) case _ => error(s"did not expect $t") } @@ -641,7 +633,7 @@ private class InferAccessAnnotation { ): Boolean = if (kernelParamFun) expr.t match { - case _: rt.FunType[_, _] | _: rt.DepFunType[_, _] => true + case _: rt.FunType[_, _] | _: rt.DepFunType[_, _, _] => true case _ => false } else false @@ -663,7 +655,7 @@ private class InferAccessAnnotation { Success(outSubst(argSubst)) ) ) - case (DepFunType(lx, la), DepFunType(rx, ra)) if lx == rx => + case (DepFunType(_, lx, la), DepFunType(_, rx, ra)) if lx == rx => subUnifyPhraseType(la, ra) case _ => Try(error(s"Cannot subunify $less and $larger.")) } @@ -671,7 +663,7 @@ private class InferAccessAnnotation { def `type`(ty: rt.Type): PhraseType = ty match { case dt: rt.DataType => ExpType(dataType(dt), accessTypeIdentifier()) case rt.FunType(i, o) => `type`(i) ->: `type`(o) - case rt.DepFunType(i, t) => i match { + case rt.DepFunType(_, i, t) => i match { case dt: rt.DataTypeIdentifier => dataTypeIdentifier(dt) ->: `type`(t) case n: rt.NatIdentifier => @@ -689,7 +681,7 @@ private class InferAccessAnnotation { case (rt.FunType(inT, outT), FunType(inPT, outPT)) => checkConsistency(inT, inPT) checkConsistency(outT, outPT) - case (rt.DepFunType(x, t), DepFunType(y, pt)) => + case (rt.DepFunType(_, x, t), DepFunType(_, y, pt)) => if (x.name != y.name) error(s"Identifiers $x and $y differ") checkConsistency(t, pt) case (dt: rt.DataType, ExpType(dpt: DataType, _)) => diff --git a/src/main/scala/shine/DPIA/Lifting.scala b/src/main/scala/shine/DPIA/Lifting.scala index 0dd5c901d..78e717c0e 100644 --- a/src/main/scala/shine/DPIA/Lifting.scala +++ b/src/main/scala/shine/DPIA/Lifting.scala @@ -8,20 +8,20 @@ import scala.language.{postfixOps, reflectiveCalls} object Lifting { import rise.core.lifting.{Expanding, Reducing, Result} - def liftDependentFunction[K <: Kind, T <: PhraseType](p: Phrase[K `()->:` T]): K#T => Phrase[T] = { + def liftDependentFunction[T, I <: Kind.Identifier, U <: PhraseType](p: Phrase[I `()->:` U]): T => Phrase[U] = { p match { - case l: DepLambda[K, T] => - (arg: K#T) => PhraseType.substitute(arg, `for`=l.x, in=l.body) - case app: Apply[_, K `()->:` T] => + case l: DepLambda[T, I, U]@unchecked => + (arg: T) => PhraseType.substitute[T, I, U](l.kind, arg, `for`=l.x, in=l.body) + case app: Apply[_, I `()->:` U] => val fun = liftFunction(app.fun).reducing liftDependentFunction(fun(app.arg)) - case DepApply(f, arg) => + case DepApply(_, f, arg) => val fun = liftDependentFunction(f) liftDependentFunction(fun(arg)) - case p1: Proj1[K `()->:` T, b] => + case p1: Proj1[I `()->:` U, b] => val pair = liftPair(p1.pair) liftDependentFunction(pair._1) - case p2: Proj2[a, K `()->:` T] => + case p2: Proj2[a, I `()->:` U] => val pair = liftPair(p2.pair) liftDependentFunction(pair._2) case Identifier(_, _) | IfThenElse(_, _, _) | LetNat(_, _, _) => @@ -38,7 +38,7 @@ object Lifting { Reducing((arg: Phrase[T1]) => l.body `[` arg `/` l.param `]`) case app: Apply[_, T1 ->: T2] => chain(liftFunction(app.fun).map(lf => lf(app.arg))) - case DepApply(f, arg) => + case DepApply(_, f, arg) => val fun = liftDependentFunction(f) liftFunction(fun(arg)) case p1: Proj1[T1 ->: T2, b] => @@ -59,7 +59,7 @@ object Lifting { case app: Apply[_, ExpType ->: T] => val fun = liftFunction(app.fun).reducing liftFunctionToNatLambda(fun(app.arg)) - case DepApply(f, arg) => + case DepApply(_, f, arg) => val fun = liftDependentFunction(f) liftFunctionToNatLambda(fun(arg)) case p1: Proj1[ExpType ->: T, b] => @@ -82,7 +82,7 @@ object Lifting { case app: Apply[_, T1 x T2] => val fun = liftFunction(app.fun).reducing liftPair(fun(app.arg)) - case DepApply(f, arg) => + case DepApply(_, f, arg) => val fun = liftDependentFunction(f) liftPair(fun(arg)) case p1: Proj1[T1 x T2, b] => @@ -96,10 +96,10 @@ object Lifting { } } - def liftDependentFunctionType[K <: Kind](ty: PhraseType): K#T => PhraseType = + def liftDependentFunctionType[T](ty: PhraseType): T => PhraseType = ty match { - case DepFunType(x, t) => - (a: K#T) => PhraseType.substitute(a, x, t) + case DepFunType(kind, x, t) => + (a: T) => PhraseType.substitute(kind, a, x, t) case _ => throw new Exception(s"did not expect $ty") } } diff --git a/src/main/scala/shine/DPIA/Phrases/Phrase.scala b/src/main/scala/shine/DPIA/Phrases/Phrase.scala index 57dcfd0fb..f2156f73e 100644 --- a/src/main/scala/shine/DPIA/Phrases/Phrase.scala +++ b/src/main/scala/shine/DPIA/Phrases/Phrase.scala @@ -38,27 +38,24 @@ final case class Apply[T1 <: PhraseType, T2 <: PhraseType](fun: Phrase[T1 ->: T2 override def toString: String = s"($fun $arg)" } -final case class DepLambda[K <: Kind, T <: PhraseType](x: K#I, body: Phrase[T]) - (implicit val kn: KindName[K]) - extends Phrase[K `()->:` T] { - override val t: DepFunType[K, T] = DepFunType[K, T](x, body.t) - override def toString: String = s"Λ(${x.name} : ${kn.get}). $body" +final case class DepLambda[T, I <: Kind.Identifier, U <: PhraseType](kind: Kind[T, I], x: I, body: Phrase[U]) + extends Phrase[I `()->:` U] { + override val t: DepFunType[I, U] = DepFunType[I, U](kind, x, body.t) + override def toString: String = s"Λ(${x.name} : ${kind.name}). $body" } object DepLambda { - def apply[K <: Kind](x: K#I): Object { - def apply[T <: PhraseType](body: Phrase[T]) - (implicit kn: KindName[K]): DepLambda[K, T] + def apply[T, I <: Kind.Identifier](kind: Kind[T, I], x: I): Object { + def apply[U <: PhraseType](body: Phrase[U]): DepLambda[T, I, U] } = new { - def apply[T <: PhraseType](body: Phrase[T]) - (implicit kn: KindName[K]): DepLambda[K, T] = DepLambda(x, body) + def apply[U <: PhraseType](body: Phrase[U]): DepLambda[T, I, U] = DepLambda(kind, x, body) } } -final case class DepApply[K <: Kind, T <: PhraseType](fun: Phrase[K `()->:` T], arg: K#T) - extends Phrase[T] { +final case class DepApply[T, I <: Kind.Identifier, U <: PhraseType](kind: Kind[T, I], fun: Phrase[I `()->:` U], arg: T) + extends Phrase[U] { - override val t: T = PhraseType.substitute(arg, `for`=fun.t.x, in=fun.t.t).asInstanceOf[T] + override val t: U = PhraseType.substitute(kind, arg, `for`=fun.t.x, in=fun.t.t).asInstanceOf[U] override def toString: String = s"($fun $arg)" } @@ -144,7 +141,7 @@ object Phrase { case l @ Lambda(x, _) => val newMap = idMap + (x.name -> freshName(x.name.takeWhile(_.isLetter))) Continue(l, Renaming(newMap)) - case dl @ DepLambda(x, _) => + case dl @ DepLambda(_, x, _) => val newMap = idMap + (x.name -> freshName(x.name.takeWhile(_.isLetter))) Continue(dl, Renaming(newMap)) case _ => Continue(p, this) @@ -249,9 +246,9 @@ object Phrase { // NatData is Natural // IndexData is AsIndex } - case DepApply(fun, arg) => (fun, arg) match { + case DepApply(_, fun, arg) => (fun, arg) match { case (f, a: Nat) => - transientNatFromExpr(liftDependentFunction[NatKind, ExpType](f.asInstanceOf[Phrase[NatKind `()->:` ExpType]])(a)) + transientNatFromExpr(liftDependentFunction(f.asInstanceOf[Phrase[NatIdentifier `()->:` ExpType]])(a)) case _ => ??? } case Proj1(pair) => transientNatFromExpr(liftPair(pair)._1) diff --git a/src/main/scala/shine/DPIA/Phrases/PrettyPhrasePrinter.scala b/src/main/scala/shine/DPIA/Phrases/PrettyPhrasePrinter.scala index 1de17598d..87298c1b6 100644 --- a/src/main/scala/shine/DPIA/Phrases/PrettyPhrasePrinter.scala +++ b/src/main/scala/shine/DPIA/Phrases/PrettyPhrasePrinter.scala @@ -8,7 +8,7 @@ object PrettyPhrasePrinter { p match { case app: Apply[a, T] => s"(${apply(app.fun)})(${apply(app.arg)})" - case app: DepApply[_, T] => s"(${apply(app.fun)})(${app.arg})" + case app: DepApply[_, _, T] => s"(${apply(app.fun)})(${app.arg})" case p1: Proj1[a, b] => s"π1(${apply(p1.pair)})" @@ -25,7 +25,7 @@ object PrettyPhrasePrinter { case Lambda(param, body) => s"λ ${apply(param)}: ${param.t} -> ${apply(body)}" - case dl @ DepLambda(param, body) => s"Λ (${param.name}: ${dl.kn.get}) -> ${apply(body)}" + case DepLambda(kind, param, body) => s"Λ (${param.name}: ${kind.name}) -> ${apply(body)}" case LetNat(binder, defn, body) => s"nLet ${binder.name} = ${apply(defn)} in ${apply(body)}" diff --git a/src/main/scala/shine/DPIA/Phrases/VisitAndRebuild.scala b/src/main/scala/shine/DPIA/Phrases/VisitAndRebuild.scala index e07cc7937..003a255df 100644 --- a/src/main/scala/shine/DPIA/Phrases/VisitAndRebuild.scala +++ b/src/main/scala/shine/DPIA/Phrases/VisitAndRebuild.scala @@ -52,59 +52,59 @@ object VisitAndRebuild { case Apply(p, q) => Apply(apply(p, v), apply(q, v)) - case DepLambda(a, p) => a match { + case DepLambda(_, a, p) => a match { case n: NatIdentifier => - DepLambda[NatKind, PhraseType]( + DepLambda(NatKind, NatIdentifier( v.nat(n).asInstanceOf[arithexpr.arithmetic.NamedVar].name), apply(p, v)) case dt: DataTypeIdentifier => - DepLambda[DataKind, PhraseType]( + DepLambda(DataKind, v.data(dt).asInstanceOf[DataTypeIdentifier], apply(p, v)) case ad: AddressSpaceIdentifier => - DepLambda[AddressSpaceKind, PhraseType]( + DepLambda(AddressSpaceKind, v.addressSpace(ad).asInstanceOf[AddressSpaceIdentifier], apply(p, v)) case ac: AccessTypeIdentifier => - DepLambda[AccessKind, PhraseType]( + DepLambda(AccessKind, v.access(ac).asInstanceOf[AccessTypeIdentifier], apply(p, v)) case n2n: NatToNatIdentifier => - DepLambda[NatToNatKind, PhraseType]( + DepLambda(NatToNatKind, v.natToNat(n2n).asInstanceOf[NatToNatIdentifier], apply(p, v)) case n2d: NatToDataIdentifier => - DepLambda[NatToDataKind, PhraseType]( + DepLambda(NatToDataKind, v.natToData(n2d).asInstanceOf[NatToDataIdentifier], apply(p, v)) case _ => ??? } - case DepApply(p, a) => a match { + case DepApply(_, p, a) => a match { case n: Nat => - DepApply[NatKind, T]( - apply(p, v).asInstanceOf[Phrase[NatKind `()->:` T]], + DepApply(NatKind, + apply(p, v).asInstanceOf[Phrase[NatIdentifier `()->:` T]], v.nat(n)) case dt: DataType => - DepApply[DataKind, T]( - apply(p, v).asInstanceOf[Phrase[DataKind `()->:` T]], + DepApply(DataKind, + apply(p, v).asInstanceOf[Phrase[DataTypeIdentifier `()->:` T]], visitDataTypeAndRebuild(dt, v)) case ad: AddressSpace => - DepApply[AddressSpaceKind, T]( - apply(p, v).asInstanceOf[Phrase[AddressSpaceKind `()->:` T]], + DepApply(AddressSpaceKind, + apply(p, v).asInstanceOf[Phrase[AddressSpaceIdentifier `()->:` T]], v.addressSpace(ad)) case ac: AccessType => - DepApply[AccessKind, T]( - apply(p, v).asInstanceOf[Phrase[AccessKind `()->:` T]], + DepApply(AccessKind, + apply(p, v).asInstanceOf[Phrase[AccessTypeIdentifier `()->:` T]], v.access(ac)) case n2n: NatToNat => - DepApply[NatToNatKind, T]( - apply(p, v).asInstanceOf[Phrase[NatToNatKind `()->:` T]], + DepApply(NatToNatKind, + apply(p, v).asInstanceOf[Phrase[NatToNatIdentifier `()->:` T]], v.natToNat(n2n)) case n2d: NatToData => - DepApply[NatToDataKind, T]( - apply(p, v).asInstanceOf[Phrase[NatToDataKind `()->:` T]], + DepApply(NatToDataKind, + apply(p, v).asInstanceOf[Phrase[NatToDataIdentifier `()->:` T]], v.natToData(n2d)) case ph: PhraseType => ??? } @@ -147,30 +147,30 @@ object VisitAndRebuild { visitPhraseTypeAndRebuild(inT, v), visitPhraseTypeAndRebuild(outT, v)) case PassiveFunType(inT, outT) => PassiveFunType( visitPhraseTypeAndRebuild(inT, v), visitPhraseTypeAndRebuild(outT, v)) - case DepFunType(x, t) => x match { + case DepFunType(_, x, t) => x match { case n: NatIdentifier => - DepFunType[NatKind, PhraseType]( + DepFunType(NatKind, NatIdentifier( v.nat(n).asInstanceOf[arithexpr.arithmetic.NamedVar].name), visitPhraseTypeAndRebuild(t, v)) case dt: DataTypeIdentifier => - DepFunType[DataKind, PhraseType]( + DepFunType(DataKind, v.data(dt).asInstanceOf[DataTypeIdentifier], visitPhraseTypeAndRebuild(t, v)) case ad: AddressSpaceIdentifier => - DepFunType[AddressSpaceKind, PhraseType]( + DepFunType(AddressSpaceKind, v.addressSpace(ad).asInstanceOf[AddressSpaceIdentifier], visitPhraseTypeAndRebuild(t, v)) case ac: AccessTypeIdentifier => - DepFunType[AccessKind, PhraseType]( + DepFunType(AccessKind, v.access(ac).asInstanceOf[AccessTypeIdentifier], visitPhraseTypeAndRebuild(t, v)) case n2n: NatToNatIdentifier => - DepFunType[NatToNatKind, PhraseType]( + DepFunType(NatToNatKind, v.natToNat(n2n).asInstanceOf[NatToNatIdentifier], visitPhraseTypeAndRebuild(t, v)) case n2d: NatToDataIdentifier => - DepFunType[NatToDataKind, PhraseType]( + DepFunType(NatToDataKind, v.natToData(n2d).asInstanceOf[NatToDataIdentifier], visitPhraseTypeAndRebuild(t, v)) } diff --git a/src/main/scala/shine/DPIA/Types/Kind.scala b/src/main/scala/shine/DPIA/Types/Kind.scala index b83f67a8d..47bbf896c 100644 --- a/src/main/scala/shine/DPIA/Types/Kind.scala +++ b/src/main/scala/shine/DPIA/Types/Kind.scala @@ -3,101 +3,48 @@ package shine.DPIA.Types import shine.DPIA import shine.DPIA.NatIdentifier -sealed trait Kind { - type T - type I <: Kind.Identifier +sealed trait Kind[+T, +I <: Kind.Identifier] { + def name: String + def makeIdentifier: I } object Kind { trait Identifier { def name: String } - - trait IdentifierMaker[K <: Kind] { - def makeIdentifier(): K#I - } - - implicit object DataTypeIdentifierMaker - extends IdentifierMaker[DataKind] { - override def makeIdentifier(): DataTypeIdentifier = - DataTypeIdentifier(DPIA.freshName("dt")) - } - implicit object NatIdentifierMaker - extends IdentifierMaker[NatKind] { - override def makeIdentifier(): NatIdentifier = - NatIdentifier(DPIA.freshName("n")) - } - implicit object AddrIdentifierMaker - extends IdentifierMaker[AddressSpaceKind] { - override def makeIdentifier(): AddressSpaceIdentifier = - AddressSpaceIdentifier(DPIA.freshName("addr")) - } - implicit object AccessTypeIdentifierMaker - extends IdentifierMaker[AccessKind] { - override def makeIdentifier(): AccessTypeIdentifier = - AccessTypeIdentifier(DPIA.freshName("access")) - } -} - -sealed trait PhraseKind extends Kind { - override type T = PhraseType } -sealed trait DataKind extends Kind { - override type T = DataType - override type I = DataTypeIdentifier +case object PhraseKind extends Kind[PhraseType, Kind.Identifier] { + override def name: String = "phrase" + override def makeIdentifier: Kind.Identifier = ??? } -sealed trait NatKind extends Kind { - override type T = DPIA.Nat - override type I = DPIA.NatIdentifier +case object DataKind extends Kind[DataType, DataTypeIdentifier] { + override def name: String = "data" + override def makeIdentifier: DataTypeIdentifier = DataTypeIdentifier(DPIA.freshName("dt")) } -sealed trait AddressSpaceKind extends Kind { - override type T = AddressSpace - override type I = AddressSpaceIdentifier +case object NatKind extends Kind[DPIA.Nat, DPIA.NatIdentifier] { + override def name: String = "nat" + override def makeIdentifier: NatIdentifier = NatIdentifier(DPIA.freshName("n")) } -sealed trait AccessKind extends Kind { - override type T = AccessType - override type I = AccessTypeIdentifier +case object AddressSpaceKind extends Kind[AddressSpace, AddressSpaceIdentifier] { + override def name: String = "addressSpace" + override def makeIdentifier: AddressSpaceIdentifier = AddressSpaceIdentifier(DPIA.freshName("addr")) } -sealed trait NatToNatKind extends Kind { - override type T = NatToNat - override type I = NatToNatIdentifier +case object AccessKind extends Kind[AccessType, AccessTypeIdentifier] { + override def name: String = "access" + override def makeIdentifier: AccessTypeIdentifier = AccessTypeIdentifier(DPIA.freshName("access")) } -sealed trait NatToDataKind extends Kind { - override type T = NatToData - override type I = NatToDataIdentifier +case object NatToNatKind extends Kind[NatToNat, NatToNatIdentifier] { + override def name: String = "nat->nat" + override def makeIdentifier: NatToNatIdentifier = NatToNatIdentifier(DPIA.freshName("n2n")) } -trait KindName[K <: Kind] { - def get: String +case object NatToDataKind extends Kind[NatToData, NatToDataIdentifier] { + override def name: String = "nat->data" + override def makeIdentifier: NatToDataIdentifier = NatToDataIdentifier(DPIA.freshName("n2d")) } - -object KindName { - implicit val phraseKN: KindName[PhraseKind] = new KindName[PhraseKind] { - def get = "phrase" - } - implicit val natKN: KindName[NatKind] = new KindName[NatKind] { - def get = "nat" - } - implicit val dataKN: KindName[DataKind] = new KindName[DataKind] { - def get = "data" - } - implicit val addressSpaceKN: KindName[AddressSpaceKind] = - new KindName[AddressSpaceKind] { - def get = "addressSpace" - } - implicit val accessKN: KindName[AccessKind] = new KindName[AccessKind] { - def get = "access" - } - implicit val n2nKN: KindName[NatToNatKind] = new KindName[NatToNatKind] { - def get = "nat->nat" - } - implicit val n2dtKN: KindName[NatToDataKind] = new KindName[NatToDataKind] { - def get = "nat->data" - } -} \ No newline at end of file diff --git a/src/main/scala/shine/DPIA/Types/PhraseType.scala b/src/main/scala/shine/DPIA/Types/PhraseType.scala index 828d3a3dc..0df653798 100644 --- a/src/main/scala/shine/DPIA/Types/PhraseType.scala +++ b/src/main/scala/shine/DPIA/Types/PhraseType.scala @@ -41,33 +41,34 @@ final case class PassiveFunType[T <: PhraseType, +R <: PhraseType](inT: T, outT: override def toString = s"($inT) ->p $outT" } -final case class DepFunType[K <: Kind, +R <: PhraseType](x: K#I, t: R) - (implicit val kn: KindName[K]) +final case class DepFunType[I <: Kind.Identifier, +R <: PhraseType](kind: Kind[_, I], x: I, t: R) extends PhraseType { - override def toString = s"(${x.name}: ${kn.get}) -> $t" + override def toString = s"(${x.name}: ${kind.name}) -> $t" } object PhraseType { - def substitute[K <: Kind, T <: PhraseType](x: K#T, `for`: K#I, in: Phrase[T]): Phrase[T] =(x, `for`) match { - case (dt: DataType, forDt: DataTypeIdentifier) => substitute(dt, forDt, in) - case (n: Nat, forN: NatIdentifier) => substitute(n, forN, in) - case (a: AddressSpace, forA: AddressSpaceIdentifier) => substitute(a, forA, in) - case (a: AccessType, forA: AccessTypeIdentifier) => substitute(a, forA, in) - case (n2n: NatToNat, forN2N: NatToNatIdentifier) => substitute(n2n, forN2N, in) - case (n2d: NatToData, fotN2D: NatToDataIdentifier) => ??? //substitute(n2d, forN2D, in) - case _ => throw new Exception(s"could not substitute $x for ${`for`} in $in") - } + def substitute[T, I <: Kind.Identifier, U <: PhraseType](kind: Kind[T, I], x: T, `for`: I, in: Phrase[U]): Phrase[U] = + (x, `for`) match { + case (dt: DataType, forDt: DataTypeIdentifier) => substitute(dt, forDt, in) + case (n: Nat, forN: NatIdentifier) => substitute(n, forN, in) + case (a: AddressSpace, forA: AddressSpaceIdentifier) => substitute(a, forA, in) + case (a: AccessType, forA: AccessTypeIdentifier) => substitute(a, forA, in) + case (n2n: NatToNat, forN2N: NatToNatIdentifier) => substitute(n2n, forN2N, in) + case (n2d: NatToData, fotN2D: NatToDataIdentifier) => ??? //substitute(n2d, forN2D, in) + case _ => throw new Exception(s"could not substitute $x for ${`for`} in $in") + } - def substitute[K <: Kind](x: K#T, `for`: K#I, in: PhraseType): PhraseType = (x, `for`) match { - case (dt: DataType, forDt: DataTypeIdentifier) => substitute(dt, forDt, in) - case (n: Nat, forN: NatIdentifier) => substitute(n, forN, in) - case (a: AddressSpace, forA: AddressSpaceIdentifier) => substitute(a, forA, in) - case (a: AccessType, forA: AccessTypeIdentifier) => ??? //substitute(a, forA, in) - case (n2n: NatToNat, forN2N: NatToNatIdentifier) => substitute(n2n, forN2N, in) - case (n2d: NatToData, forN2D: NatToDataIdentifier) => ??? //substitute(n2d, forN2D, in) - case _ => throw new Exception(s"could not substitute $x for ${`for`} in $in") - } + def substitute[T, I <: Kind.Identifier](kind: Kind[T, I], x: T, `for`: I, in: PhraseType): PhraseType = + (x, `for`) match { + case (dt: DataType, forDt: DataTypeIdentifier) => substitute(dt, forDt, in) + case (n: Nat, forN: NatIdentifier) => substitute(n, forN, in) + case (a: AddressSpace, forA: AddressSpaceIdentifier) => substitute(a, forA, in) + case (a: AccessType, forA: AccessTypeIdentifier) => ??? //substitute(a, forA, in) + case (n2n: NatToNat, forN2N: NatToNatIdentifier) => substitute(n2n, forN2N, in) + case (n2d: NatToData, forN2D: NatToDataIdentifier) => ??? //substitute(n2d, forN2D, in) + case _ => throw new Exception(s"could not substitute $x for ${`for`} in $in") + } def substitute[T <: PhraseType](dt: DataType, `for`: DataTypeIdentifier, @@ -97,7 +98,7 @@ object PhraseType { case pf: PassiveFunType[_, _] => PassiveFunType(substitute(dt, `for`, pf.inT), substitute(dt, `for`, pf.outT)) case df: DepFunType[_, _] => - DepFunType(df.x, substitute(dt, `for`, df.t))(df.kn) + DepFunType(df.kind, df.x, substitute(dt, `for`, df.t)) } } @@ -149,7 +150,7 @@ object PhraseType { case pf: PassiveFunType[_, _] => PassiveFunType(substitute(n, `for`, pf.inT), substitute(n, `for`, pf.outT)) case df: DepFunType[_, _] => - DepFunType(df.x, substitute(n, `for`, df.t))(df.kn) + DepFunType(df.kind, df.x, substitute(n, `for`, df.t)) } } diff --git a/src/main/scala/shine/DPIA/Types/TypeCheck.scala b/src/main/scala/shine/DPIA/Types/TypeCheck.scala index 07f520a5f..de21e612f 100644 --- a/src/main/scala/shine/DPIA/Types/TypeCheck.scala +++ b/src/main/scala/shine/DPIA/Types/TypeCheck.scala @@ -17,9 +17,9 @@ object TypeCheck { TypeCheck(q) errorIfNotEqOrSubtype(q.t, p.t.inT) - case DepLambda(_, p) => TypeCheck(p) + case DepLambda(_, _, p) => TypeCheck(p) - case DepApply(p, _) => TypeCheck(p) + case DepApply(_, p, _) => TypeCheck(p) case LetNat(_, defn, body) => TypeCheck(defn); TypeCheck(body) @@ -116,8 +116,8 @@ object TypeCheck { accessSub == read && notContainingArrayType(bSub) case (FunType(subInT, subOutT), FunType(superInT, superOutT)) => subtypeCheck(superInT, subInT) && subtypeCheck(subOutT, superOutT) - case (DepFunType(subInT, subOutT), DepFunType(superInT, superOutT)) => - subInT == superInT && subtypeCheck(subOutT, superOutT) + case (DepFunType(kind1, subInT, subOutT), DepFunType(kind2, superInT, superOutT)) => + kind1 == kind2 && subInT == superInT && subtypeCheck(subOutT, superOutT) case _ => false } } diff --git a/src/main/scala/shine/DPIA/Types/package.scala b/src/main/scala/shine/DPIA/Types/package.scala index 0d5505364..a202fc0c9 100644 --- a/src/main/scala/shine/DPIA/Types/package.scala +++ b/src/main/scala/shine/DPIA/Types/package.scala @@ -26,42 +26,42 @@ package object Types { def `:`[T <: PhraseType](p: Phrase[T]): Unit = typeAssert(p, pt) } - type NatDependentFunctionType[T <: PhraseType] = DepFunType[NatKind, T] + type NatDependentFunctionType[T <: PhraseType] = DepFunType[NatIdentifier, T] object NatDependentFunctionType { - def apply[T <: PhraseType](n: NatIdentifier, t: T): DepFunType[NatKind, T] = - DepFunType[NatKind, T](n, t) + def apply[T <: PhraseType](n: NatIdentifier, t: T): DepFunType[NatIdentifier, T] = + DepFunType(NatKind, n, t) } - type TypeDependentFunctionType[T <: PhraseType] = DepFunType[DataKind, T] + type TypeDependentFunctionType[T <: PhraseType] = DepFunType[DataTypeIdentifier, T] object TypeDependentFunctionType { def apply[T <: PhraseType]( dt: DataTypeIdentifier, t: T - ): DepFunType[DataKind, T] = - DepFunType[DataKind, T](dt, t) + ): DepFunType[DataTypeIdentifier, T] = + DepFunType(DataKind, dt, t) } type AddrSpaceDependentFunctionType[T <: PhraseType] = - DepFunType[AddressSpaceKind, T] + DepFunType[AddressSpaceIdentifier, T] object AddrSpaceDependentFunctionType { def apply[T <: PhraseType]( addr: AddressSpaceIdentifier, t: T - ): DepFunType[AddressSpaceKind, T] = - DepFunType[AddressSpaceKind, T](addr, t) + ): DepFunType[AddressSpaceIdentifier, T] = + DepFunType(AddressSpaceKind, addr, t) } - type AccessDependentFunctionType[T <: PhraseType] = DepFunType[AccessKind, T] + type AccessDependentFunctionType[T <: PhraseType] = DepFunType[AccessTypeIdentifier, T] object AccessDependentFunctionType { def apply[T <: PhraseType]( at: AccessTypeIdentifier, t: T - ): DepFunType[AccessKind, T] = - DepFunType[AccessKind, T](at, t) + ): DepFunType[AccessTypeIdentifier, T] = + DepFunType(AccessKind, at, t) } object n2dtFun { diff --git a/src/main/scala/shine/DPIA/fromRise.scala b/src/main/scala/shine/DPIA/fromRise.scala index a362bf72f..562c24c96 100644 --- a/src/main/scala/shine/DPIA/fromRise.scala +++ b/src/main/scala/shine/DPIA/fromRise.scala @@ -40,27 +40,26 @@ object fromRise { val ee = expression(e, ptMap).asInstanceOf[Phrase[PhraseType]] Apply(ef, ee) - case r.DepLambda(x, e) => x match { + case r.DepLambda(kind, x, e) => x match { case ni: rt.NatIdentifier => - DepLambda[NatKind](natIdentifier(ni))(expression(e, ptMap)) + DepLambda(NatKind, natIdentifier(ni))(expression(e, ptMap)) case dti: rt.DataTypeIdentifier => - DepLambda[DataKind](dataTypeIdentifier(dti))(expression(e, ptMap)) + DepLambda(DataKind, dataTypeIdentifier(dti))(expression(e, ptMap)) case addri: rt.AddressSpaceIdentifier => - DepLambda[AddressSpaceKind]( - addressSpaceIdentifier(addri))(expression(e, ptMap)) + DepLambda(AddressSpaceKind, addressSpaceIdentifier(addri))(expression(e, ptMap)) } - case r.DepApp(f, x) => - def depApp[K <: Kind](f: r.Expr, arg: K#T): DepApply[K, PhraseType] = - DepApply[K, PhraseType]( - expression(f, ptMap).asInstanceOf[Phrase[DepFunType[K, PhraseType]]], + case r.DepApp(kind, f, x) => + def depApp[T, I <: Kind.Identifier](kind: Kind[T, I], f: r.Expr, arg: T): DepApply[T, I, PhraseType] = + DepApply[T, I, PhraseType](kind, + expression(f, ptMap).asInstanceOf[Phrase[DepFunType[I, PhraseType]]], arg) x match { - case n: Nat => depApp[NatKind](f, n) - case dt: rt.DataType => depApp[DataKind](f, dataType(dt)) - case a: rt.AddressSpace => depApp[AddressSpaceKind](f, addressSpace(a)) - case n2n: rt.NatToNat => depApp[NatToNatKind](f, nat2nat(n2n)) + case n: Nat => depApp(NatKind, f, n) + case dt: rt.DataType => depApp(DataKind, f, dataType(dt)) + case a: rt.AddressSpace => depApp(AddressSpaceKind, f, addressSpace(a)) + case n2n: rt.NatToNat => depApp(NatToNatKind, f, nat2nat(n2n)) } case r.Literal(d) => d match { @@ -99,12 +98,10 @@ object fromRise { } object depFun { - def apply[K <: Kind](x: K#I): Object { - def apply[T <: PhraseType](body: Phrase[T]) - (implicit kn: KindName[K]): DepLambda[K, T] + def apply[T, I <: Kind.Identifier](kind: Kind[T, I], x: I): Object { + def apply[U <: PhraseType](body: Phrase[U]): DepLambda[T, I, U] } = new { - def apply[T <: PhraseType](body: Phrase[T]) - (implicit kn: KindName[K]): DepLambda[K, T] = DepLambda(x, body) + def apply[U <: PhraseType](body: Phrase[U]): DepLambda[T, I, U] = DepLambda(kind, x, body) } } @@ -137,7 +134,7 @@ object fromRise { case nFunT(n, expT(`NatType`, `read`) ->: expT(IndexType(_), `read`)) => - depFun[NatKind](n)( + depFun(NatKind, n)( fun[ExpType](expT(NatType, read), e => NatAsIndex(n, e))) } @@ -276,7 +273,7 @@ object fromRise { expT(ArrayType(n, _), `read`) ->: expT(_, `read`)) => - depFun[AddressSpaceKind](a)( + depFun(AddressSpaceKind, a)( fun[ExpType ->: ExpType ->: ExpType]( expT(t, read) ->: expT(s, read) ->: expT(t, write), f => fun[ExpType](expT(t, write), i => @@ -291,7 +288,7 @@ object fromRise { expT(ArrayType(n, _), `read`) ->: expT(_, `read`)) => - depFun[AddressSpaceKind](a)( + depFun(AddressSpaceKind, a)( fun[ExpType ->: ExpType ->: ExpType]( expT(t, read) ->: expT(s, read) ->: expT(t, write), f => fun[ExpType](expT(t, write), i => @@ -333,7 +330,7 @@ object fromRise { case nFunT(n, expT(ArrayType(_, t), a) ->: expT(ArrayType(m, ArrayType(_, _)), _)) => - depFun[NatKind](n)( + depFun(NatKind, n)( fun[ExpType](expT({m*n}`.`t, a), e => Split(n, m, a, t, e))) } @@ -343,8 +340,8 @@ object fromRise { expT(ArrayType(insz, t), `read`) ->: expT(ArrayType(np1, ArrayType(_, _)), `read`))) => - depFun[NatKind](sz)( - depFun[NatKind](sp)( + depFun(NatKind, sz)( + depFun(NatKind, sp)( fun[ExpType](expT(insz`.`t, read), e => Slide(np1-1, sz, sp, t, e)))) } @@ -355,8 +352,8 @@ object fromRise { expT(ArrayType(insz, _), `read`) ->: expT(ArrayType(n, _), `read`))) => - depFun[NatKind](alloc)( - depFun[NatKind](sz)( + depFun(NatKind, alloc)( + depFun(NatKind, sz)( fun[ExpType ->: ExpType]( expT(s, read) ->: expT(t, write), load => fun[ExpType](expT(insz`.`s, read), e => @@ -369,7 +366,7 @@ object fromRise { (inT @ expT(ArrayType(m, s), `read`)) ->: expT(ArrayType(n, t), `write`)) => - depFun[NatKind](tile)( + depFun(NatKind, tile)( fun[ExpType ->: ExpType](fa ->: fb, f => fun[ExpType](inT, e => DepTile(n, tile, m-n, s, t, f, e)))) @@ -381,7 +378,7 @@ object fromRise { expT(ArrayType(insz, _), `read`) ->: expT(ArrayType(n, _), `read`)) => - depFun[NatKind](sz)( + depFun(NatKind, sz)( fun[ExpType ->: ExpType]( expT(s, read) ->: expT(s, write), wr => fun[ExpType](expT(insz`.`s, read), e => @@ -394,9 +391,9 @@ object fromRise { expT(ArrayType(insz, _), `read`) ->: expT(ArrayType(n, _), `read`)))) => - depFun[AddressSpaceKind](a)( - depFun[NatKind](alloc)( - depFun[NatKind](sz)( + depFun(AddressSpaceKind, a)( + depFun(NatKind, alloc)( + depFun(NatKind, sz)( fun[ExpType ->: ExpType]( expT(s, read) ->: expT(t, write), load => fun[ExpType](expT(insz`.`s, read), e => @@ -409,8 +406,8 @@ object fromRise { expT(ArrayType(insz, _), `read`) ->: expT(ArrayType(n, _), `read`))) => - depFun[AddressSpaceKind](a)( - depFun[NatKind](sz)( + depFun(AddressSpaceKind, a)( + depFun(NatKind, sz)( fun[ExpType ->: ExpType]( expT(t, read) ->: expT(t, write), write_t => fun[ExpType](expT(insz`.`t, read), e => @@ -421,9 +418,9 @@ object fromRise { case nFunT(n, n2nFunT(idxF, n2nFunT(idxFinv, expT(ArrayType(_, t), a) ->: expT(ArrayType(_, _), _)))) => - depFun[NatKind](n)( - depFun[NatToNatKind](idxF)( - depFun[NatToNatKind](idxFinv)( + depFun(NatKind, n)( + depFun(NatToNatKind, idxF)( + depFun(NatToNatKind, idxFinv)( fun[ExpType](expT(n`.`t, a), e => Reorder(n, t, a, idxF, idxFinv, e))))) } @@ -460,7 +457,7 @@ object fromRise { case nFunT(n, expT(ArrayType(nm, t), `read`) ->: expT(ArrayType(_, _), `read`)) => - depFun[NatKind](n)( + depFun(NatKind, n)( fun[ExpType](expT(nm`.`t, read), e => Take(n, nm-n, t, e))) } @@ -468,7 +465,7 @@ object fromRise { case nFunT(n, expT(ArrayType(nm, t), `read`) ->: expT(ArrayType(_, _), `read`)) => - depFun[NatKind](n)( + depFun(NatKind, n)( fun[ExpType](expT(nm`.`t, read), e => Drop(n, nm-n, t, e))) } @@ -479,8 +476,8 @@ object fromRise { expT(ArrayType(n, _), `read`) ->: expT(ArrayType(_, _), `read`))) => - depFun[NatKind](l)( - depFun[NatKind](q)( + depFun(NatKind, l)( + depFun(NatKind, q)( fun[ExpType](expT(t, read), cst => fun[ExpType](expT(n`.`t, read), e => PadCst(n, l, q, t, cst, e))))) @@ -490,7 +487,7 @@ object fromRise { case nFunT(r, expT(ArrayType(n, t), `write`) ->: _) => - depFun[NatKind](r)( + depFun(NatKind, r)( fun[ExpType](expT(n`.`t, `write`), e => PadEmpty(n, r, t, e))) } @@ -500,8 +497,8 @@ object fromRise { expT(ArrayType(n, t), `read`) ->: expT(ArrayType(_, _), `read`))) => - depFun[NatKind](l)( - depFun[NatKind](q)( + depFun(NatKind, l)( + depFun(NatKind, q)( fun[ExpType](expT(n`.`t, read), e => PadClamp(n, l, q, t, e)))) } @@ -681,7 +678,7 @@ object fromRise { case FunType(ExpType(dt: DataTypeIdentifier, `read`), out) => val (i, o) = collectTypes(out) (dt +: i, o) - case DepFunType(_, t) => collectTypes(t) + case DepFunType(_, _, t) => collectTypes(t) case _ => throw new Exception("This should not be possible") } } @@ -689,7 +686,7 @@ object fromRise { assert(inTs.length == n) inTs.foldRight[Phrase[_ <: PhraseType]]( - depFun[DataKind](outT)({ + depFun(DataKind, outT)({ val args = Seq.tabulate(n)(i => Identifier(freshName("x"), ExpType(inTs(i), read))) args.foldRight[Phrase[_ <: PhraseType]]( ForeignFunctionCall(decl, n)(inTs, outT, args) @@ -698,7 +695,7 @@ object fromRise { } }) ) { - case (t, f) => depFun[DataKind](t)(f) + case (t, f) => depFun(DataKind, t)(f) } case core.generate() => fromType { @@ -727,7 +724,7 @@ object fromRise { expT(ArrayType(insz, _), `read`) ->: expT(ArrayType(m, _), `write`) ) => - depFun[NatKind](k)( + depFun(NatKind, k)( fun[`(nat)->:`[ExpType ->: ExpType]]( l ->: (expT(ln`.`t, read) ->: expT(l`.`t, write)), f => fun[ExpType](expT(insz`.`t, read), e => @@ -741,7 +738,7 @@ object fromRise { expT(ArrayType(insz, _), `read`) ->: expT(ArrayType(m, _), `write`) )) => - depFun[AddressSpaceKind](a)(depFun[NatKind](k)( + depFun(AddressSpaceKind, a)(depFun(NatKind, k)( fun[`(nat)->:`[ExpType ->: ExpType]]( l ->: (expT(ln`.`t, read) ->: expT(l`.`t, write)), f => fun[ExpType](expT(insz`.`t, read), e => @@ -753,7 +750,7 @@ object fromRise { expT(ArrayType(mn, _), a) ->: expT(ArrayType(m, VectorType(_, t)), _)) => - depFun[NatKind](n)( + depFun(NatKind, n)( fun[ExpType](expT(mn`.`t, a), e => AsVector(n, m, t, a, e))) } @@ -763,7 +760,7 @@ object fromRise { expT(ArrayType(mn, _), a) ->: expT(ArrayType(m, VectorType(_, t)), _)) => - depFun[NatKind](n)( + depFun(NatKind, n)( fun[ExpType](expT(mn`.`t, read), e => AsVectorAligned(n, m, t, a, e))) } @@ -801,7 +798,7 @@ object fromRise { case rocl.oclToMem() => fromType { case aFunT(a, expT(t, `write`) ->: expT(_, `read`)) => - depFun[AddressSpaceKind](a)( + depFun(AddressSpaceKind, a)( fun[ExpType](expT(t, write), e => ocl.ToMem(a, t, e))) } @@ -812,8 +809,8 @@ object fromRise { expT(t, `write`) ->: _)))))) => import shine.OpenCL.{LocalSize, GlobalSize} - depFun[NatKind](ls1)(depFun[NatKind](ls2)(depFun[NatKind](ls3)( - depFun[NatKind](gs1)(depFun[NatKind](gs2)(depFun[NatKind](gs3)( + depFun(NatKind, ls1)(depFun(NatKind, ls2)(depFun(NatKind, ls3)( + depFun(NatKind, gs1)(depFun(NatKind, gs2)(depFun(NatKind, gs3)( fun[ExpType](expT(t, write), e => ocl.Run(LocalSize(ls1, ls2, ls3), GlobalSize(gs1, gs2, gs3))(t, e)))))))) } @@ -831,7 +828,7 @@ object fromRise { case core.makeDepPair() => fromType { case nFunT(fst, expT(sndT, a) ->: expT(_, _)) => - depFun[NatKind](fst)(fun[ExpType](expT(sndT, a), snd => MakeDepPair(a, fst, sndT, snd))) + depFun(NatKind, fst)(fun[ExpType](expT(sndT, a), snd => MakeDepPair(a, fst, sndT, snd))) } case rcuda.globalToShared() => fromType { @@ -965,7 +962,7 @@ object fromRise { case rt.DepArrayType(sz, f) => DepArrayType(sz, ntd(f)) case rt.PairType(a, b) => PairType(dataType(a), dataType(b)) case rt.NatToDataApply(f, n) => NatToDataApply(ntd(f), n) - case rt.DepPairType(x, t) => + case rt.DepPairType(_, x, t) => x match { case x:rt.NatIdentifier => DepPairType(natIdentifier(x), dataType(t)) case _ => ??? diff --git a/src/main/scala/shine/DPIA/package.scala b/src/main/scala/shine/DPIA/package.scala index b9a1b3d44..ca516d615 100644 --- a/src/main/scala/shine/DPIA/package.scala +++ b/src/main/scala/shine/DPIA/package.scala @@ -34,9 +34,9 @@ package object DPIA { type x[T1 <: PhraseType, T2 <: PhraseType] = PhrasePairType[T1, T2] type ->:[T <: PhraseType, R <: PhraseType] = FunType[T, R] type `->p:`[T <: PhraseType, R <: PhraseType] = PassiveFunType[T, R] - type `()->:`[K <: Kind, R <: PhraseType] = DepFunType[K, R] - type `(nat)->:`[R <: PhraseType] = DepFunType[NatKind, R] - type `(dt)->:`[R <: PhraseType] = DepFunType[DataKind, R] + type `()->:`[I <: Kind.Identifier, R <: PhraseType] = DepFunType[I, R] + type `(nat)->:`[R <: PhraseType] = DepFunType[NatIdentifier, R] + type `(dt)->:`[R <: PhraseType] = DepFunType[DataTypeIdentifier, R] type VarType = ExpType x AccType object VarType { @@ -94,10 +94,10 @@ package object DPIA { } implicit class DepFunTypeConstructor[R <: PhraseType](r: R) { - def ->:(i: DataTypeIdentifier): `()->:`[DataKind, R] = DepFunType[DataKind, R](i, r) - def ->:(n: NatIdentifier): `()->:`[NatKind, R] = DepFunType[NatKind, R](n, r) - def ->:(n: NatToNatIdentifier): `()->:`[NatToNatKind, R] = DepFunType[NatToNatKind, R](n, r) - def ->:(n: NatToDataIdentifier): `()->:`[NatToDataKind, R] = DepFunType[NatToDataKind, R](n, r) + def ->:(i: DataTypeIdentifier): `()->:`[DataTypeIdentifier, R] = DepFunType(DataKind, i, r) + def ->:(n: NatIdentifier): `()->:`[NatIdentifier, R] = DepFunType(NatKind, n, r) + def ->:(n: NatToNatIdentifier): `()->:`[NatToNatIdentifier, R] = DepFunType(NatToNatKind, n, r) + def ->:(n: NatToDataIdentifier): `()->:`[NatToDataIdentifier, R] = DepFunType(NatToDataKind, n, r) } object expT { @@ -126,15 +126,14 @@ package object DPIA { object nFunT { def apply(n: rt.NatIdentifier, t: PhraseType): PhraseType = { - DepFunType[NatKind, PhraseType](fromRise.natIdentifier(n), t) + DepFunType(NatKind, fromRise.natIdentifier(n), t) } def apply(n: NatIdentifier, t: PhraseType): PhraseType = { - DepFunType[NatKind, PhraseType](n, t) + DepFunType(NatKind, n, t) } - def unapply[K <: Kind, T <: PhraseType](funType: DepFunType[K, T] - ): Option[(NatIdentifier, T)] = { + def unapply[I <: Kind.Identifier, U <: PhraseType](funType: DepFunType[I, U]): Option[(NatIdentifier, U)] = { funType.x match { case n: NatIdentifier => Some((n, funType.t)) case _ => throw new Exception("Expected Nat DepFunType") @@ -144,23 +143,21 @@ package object DPIA { object dFunT { def apply(d: rt.DataTypeIdentifier, t: PhraseType): PhraseType = { - DepFunType[DataKind, PhraseType](fromRise.dataTypeIdentifier(d), t) + DepFunType(DataKind, fromRise.dataTypeIdentifier(d), t) } } object aFunT { def apply(a: rt.AddressSpaceIdentifier, t: PhraseType): PhraseType = { - DepFunType[AddressSpaceKind, PhraseType]( - fromRise.addressSpaceIdentifier(a), t) + DepFunType(AddressSpaceKind, fromRise.addressSpaceIdentifier(a), t) } def apply(a: AddressSpaceIdentifier, t: PhraseType): PhraseType = { - DepFunType[AddressSpaceKind, PhraseType](a, t) + DepFunType(AddressSpaceKind, a, t) } - def unapply[K <: Kind, - T <: PhraseType](funType: DepFunType[K, T] - ): Option[(AddressSpaceIdentifier, T)] = { + def unapply[I <: Kind.Identifier, T <: PhraseType](funType: DepFunType[I, T] + ): Option[(AddressSpaceIdentifier, T)] = { funType.x match { case a: AddressSpaceIdentifier => Some((a, funType.t)) case _ => throw new Exception("Expected AddressSpace DepFunType") @@ -170,15 +167,15 @@ package object DPIA { object n2nFunT { def apply(n: rt.NatToNatIdentifier, t: PhraseType): PhraseType = { - DepFunType[NatToNatKind, PhraseType](fromRise.natToNatIdentifier(n), t) + DepFunType(NatToNatKind, fromRise.natToNatIdentifier(n), t) } def apply(n: NatToNatIdentifier, t: PhraseType): PhraseType = { - DepFunType[NatToNatKind, PhraseType](n, t) + DepFunType(NatToNatKind, n, t) } - def unapply[K <: Kind, T <: PhraseType](funType: DepFunType[K, T] - ): Option[(NatToNatIdentifier, T)] = { + def unapply[I <: Kind.Identifier, T <: PhraseType](funType: DepFunType[I, T] + ): Option[(NatToNatIdentifier, T)] = { funType.x match { case n: NatToNatIdentifier => Some((n, funType.t)) case _ => throw new Exception("Expected Nat DepFunType") diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala index 9e025b2a0..52000bca8 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala @@ -7,16 +7,16 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class DepMapSeq(unroll: Boolean)(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { +final case class DepMapSeq(unroll: Boolean)(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatIdentifier, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { assert { f :: ({ val k = f.t.x - DepFunType[NatKind, PhraseType](k, FunType(expT(NatToDataApply(ft1, k), read), expT(NatToDataApply(ft2, k), write))) + DepFunType(NatKind, k, FunType(expT(NatToDataApply(ft1, k), read), expT(NatToDataApply(ft2, k), write))) }) array :: expT(DepArrayType(n, ft1), read) true } override val t: ExpType = expT(DepArrayType(n, ft2), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepMapSeq = new DepMapSeq(unroll)(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) - def unwrap: (Nat, NatToData, NatToData, Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], Phrase[ExpType]) = (n, ft1, ft2, f, array) + def unwrap: (Nat, NatToData, NatToData, Phrase[DepFunType[NatIdentifier, FunType[ExpType, ExpType]]], Phrase[ExpType]) = (n, ft1, ft2, f, array) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala b/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala index c1a153413..52a5efeb5 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala @@ -7,11 +7,11 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class Iterate(val n: Nat, val m: Nat, val k: Nat, val dt: DataType, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { +final case class Iterate(val n: Nat, val m: Nat, val k: Nat, val dt: DataType, val f: Phrase[DepFunType[NatIdentifier, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { assert { f :: ({ val l = f.t.x - DepFunType[NatKind, PhraseType](l, FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write))) + DepFunType(NatKind, l, FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write))) }) array :: expT(ArrayType(m * n.pow(k), dt), read) true diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala b/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala index b4615241c..e2f8c80ee 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala @@ -7,15 +7,15 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class ForNat(unroll: Boolean)(val n: Nat, val loopBody: Phrase[DepFunType[NatKind, CommType]]) extends CommandPrimitive { +final case class ForNat(unroll: Boolean)(val n: Nat, val loopBody: Phrase[DepFunType[NatIdentifier, CommType]]) extends CommandPrimitive { assert { loopBody :: ({ val i = loopBody.t.x - DepFunType[NatKind, PhraseType](i, comm) + DepFunType(NatKind, i, comm) }) true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForNat = new ForNat(unroll)(v.nat(n), VisitAndRebuild(loopBody, v)) - def unwrap: (Nat, Phrase[DepFunType[NatKind, CommType]]) = (n, loopBody) + def unwrap: (Nat, Phrase[DepFunType[NatIdentifier, CommType]]) = (n, loopBody) } diff --git a/src/main/scala/shine/OpenCL/AdjustArraySizesForAllocations.scala b/src/main/scala/shine/OpenCL/AdjustArraySizesForAllocations.scala index 423d1b0f2..7bfac1425 100644 --- a/src/main/scala/shine/OpenCL/AdjustArraySizesForAllocations.scala +++ b/src/main/scala/shine/OpenCL/AdjustArraySizesForAllocations.scala @@ -57,8 +57,8 @@ object AdjustArraySizesForAllocations { case Apply(f, _) => visitAndGatherInformation(f, parallInfo) case Lambda(_, p) => visitAndGatherInformation(p, parallInfo) - case DepApply(f, _) => visitAndGatherInformation(f, parallInfo) - case DepLambda(_, p) => visitAndGatherInformation(p, parallInfo) + case DepApply(_, f, _) => visitAndGatherInformation(f, parallInfo) + case DepLambda(_, _, p) => visitAndGatherInformation(p, parallInfo) case Fst(_, _, p) => visitAndGatherInformation(p, parallInfo) match { case Nil => Nil case RecordInfo(fst, _) :: Nil => fst diff --git a/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala b/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala index 0445506dc..0265e2148 100644 --- a/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala +++ b/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala @@ -46,7 +46,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, case f: ocl.ParForNat => f.body match { - case DepLambda(i: NatIdentifier, Lambda(o, p)) => + case DepLambda(NatKind, i: NatIdentifier, Lambda(o, p)) => OpenCLCodeGen.codeGenOpenCLParForNat(f, f.n, f.out, i, o, p, env) case _ => throw new Exception("This should not happen") } diff --git a/src/main/scala/shine/OpenCL/Compilation/Passes/FlagPrivateArrayLoops.scala b/src/main/scala/shine/OpenCL/Compilation/Passes/FlagPrivateArrayLoops.scala index ca31433ac..c08468895 100644 --- a/src/main/scala/shine/OpenCL/Compilation/Passes/FlagPrivateArrayLoops.scala +++ b/src/main/scala/shine/OpenCL/Compilation/Passes/FlagPrivateArrayLoops.scala @@ -1,7 +1,7 @@ package shine.OpenCL.Compilation.Passes import shine.DPIA.Phrases._ -import shine.DPIA.Types.{CommType, PhraseType} +import shine.DPIA.Types.{CommType, Kind, NatKind, PhraseType} import shine.DPIA.primitives.functional.{Idx, NatAsIndex} import shine.DPIA.primitives.imperative.{For, ForNat, IdxAcc} import shine.DPIA.{ArrayData, Nat, NatIdentifier} @@ -67,8 +67,8 @@ object FlagPrivateArrayLoops { val i = f.loopBody.asInstanceOf[Lambda[_, _]].param eliminateVars -= i.name Continue(For(unroll = true)(f.n, f.loopBody), this) - case f@ForNat(_) if (eliminateVars(f.loopBody.asInstanceOf[DepLambda[_, _]].x.name)) => - val i = f.loopBody.asInstanceOf[DepLambda[_, _]].x + case f@ForNat(_) if (eliminateVars(f.loopBody.asInstanceOf[DepLambda[_, _ <: Kind.Identifier, _]].x.name)) => + val i = f.loopBody.asInstanceOf[DepLambda[_, _ <: Kind.Identifier, _]].x eliminateVars -= i.name Continue(ForNat(unroll = true)(f.n, f.loopBody), this) case pf@ParFor(level, dim, _, name) if (eliminateVars(pf.body.asInstanceOf[Lambda[_, _]].param.name)) => @@ -79,9 +79,9 @@ object FlagPrivateArrayLoops { pf.init, pf.n, pf.step, pf.dt, pf.out, pf.body), this) case _ => throw new Exception("This should not happen") } - case pf@ParForNat(level, dim, _, name) if (eliminateVars(pf.body.asInstanceOf[DepLambda[_, _]].x.name)) => + case pf@ParForNat(level, dim, _, name) if (eliminateVars(pf.body.asInstanceOf[DepLambda[_, _ <: Kind.Identifier, _]].x.name)) => pf.body match { - case DepLambda(i: NatIdentifier, _) => + case DepLambda(NatKind, i: NatIdentifier, _) => eliminateVars -= i.name Continue(ParForNat(level, dim, unroll = true, name)( pf.init, pf.n, pf.step, pf.ft, pf.out, pf.body), this) diff --git a/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala b/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala index fbee659d0..02b2b379a 100644 --- a/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala +++ b/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala @@ -68,9 +68,9 @@ object InsertMemoryBarriers { } case f@ForNat(unroll) => f.loopBody match { - case DepLambda(x, body) => + case DepLambda(NatKind, x, body) => Stop(ForNat(unroll)(f.n, - DepLambda[NatKind, CommType](x, visitLoopBody(body, allocs, metadata)))) + DepLambda(NatKind, x, visitLoopBody(body, allocs, metadata)))) case _ => throw new Exception("This should not happen") } case pf@ocl.ParFor(Local, dim, unroll, name) => @@ -92,19 +92,19 @@ object InsertMemoryBarriers { } case pf@ocl.ParForNat(Local, dim, unroll, name) => pf.body match { - case DepLambda(i: NatIdentifier, Lambda(o, p)) => + case DepLambda(NatKind, i: NatIdentifier, Lambda(o, p)) => val outer_wg_writes = mutable.Map[Identifier[_ <: PhraseType], AddressSpace]() collectWrites(pf.out, allocs, outer_wg_writes) Stop(ocl.ParForNat(Local, dim, unroll, name)(pf.init, pf.n, pf.step, pf.ft, pf.out, - DepLambda[NatKind, AccType ->: CommType](i, Lambda(o, + DepLambda(NatKind, i, Lambda(o, visitLoopBody(p, allocs, metadata, outer_wg_writes))))) case _ => throw new Exception("This should not happen") } case pf@ocl.ParForNat(level, dim, unroll, name) => pf.body match { - case DepLambda(i: NatIdentifier, Lambda(o, p)) => + case DepLambda(NatKind, i: NatIdentifier, Lambda(o, p)) => Stop(ocl.ParForNat(level, dim, unroll, name)(pf.init, pf.n, pf.step, pf.ft, pf.out, - DepLambda[NatKind, AccType ->: CommType](i, Lambda(o, + DepLambda(NatKind, i, Lambda(o, visitLoopBody(p, allocs, metadata))))) case _ => throw new Exception("This should not happen") } diff --git a/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala b/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala index 9d2a6dcb9..5bc1a4aa2 100644 --- a/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala +++ b/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala @@ -33,14 +33,14 @@ object SeparateHostAndKernelCode { // on the fly beta-reduction case Apply(fun, arg) => Stop(VisitAndRebuild(Lifting.liftFunction(fun).reducing(arg), this)) - case DepApply(fun, arg) => arg match { + case DepApply(_, fun, arg) => arg match { case a: Nat => - Stop(VisitAndRebuild(Lifting.liftDependentFunction[NatKind, ExpType]( - fun.asInstanceOf[Phrase[NatKind `()->:` ExpType]])(a) + Stop(VisitAndRebuild(Lifting.liftDependentFunction( + fun.asInstanceOf[Phrase[NatIdentifier `()->:` ExpType]])(a) .asInstanceOf[Phrase[T]], this)) case a: DataType => - Stop(VisitAndRebuild(Lifting.liftDependentFunction[DataKind, ExpType]( - fun.asInstanceOf[Phrase[DataKind `()->:` ExpType]])(a) + Stop(VisitAndRebuild(Lifting.liftDependentFunction( + fun.asInstanceOf[Phrase[DataTypeIdentifier `()->:` ExpType]])(a) .asInstanceOf[Phrase[T]], this)) } @@ -59,7 +59,7 @@ object SeparateHostAndKernelCode { ): (Phrase[_ <: PhraseType], Seq[Phrase[ExpType]]) = { freeNats match { case v +: rest => iterNats( - DepLambda[NatKind](NatIdentifier(v.name, v.range))(definition), + DepLambda(NatKind, NatIdentifier(v.name, v.range))(definition), Literal(NatAsIntData(v)) +: args, rest) case Nil => (definition, args) } @@ -101,9 +101,9 @@ object SeparateHostAndKernelCode { Stop(p) case Lambda(x, _) => Continue(p, this.copy(boundV = boundV + x)) - case DepLambda(x: NatIdentifier, _) => + case DepLambda(NatKind, x: NatIdentifier, _) => Continue(p, this.copy(boundN = boundN + x)) - case DepLambda(x: DataTypeIdentifier, _) => + case DepLambda(DataKind, x: DataTypeIdentifier, _) => Continue(p, this.copy(boundT = boundT + x)) case _ => Continue(p, this) } diff --git a/src/main/scala/shine/OpenCL/DSL/package.scala b/src/main/scala/shine/OpenCL/DSL/package.scala index 4ef029e70..19d4b7310 100644 --- a/src/main/scala/shine/OpenCL/DSL/package.scala +++ b/src/main/scala/shine/OpenCL/DSL/package.scala @@ -26,7 +26,7 @@ package object DSL { def parForNat(level: ParallelismLevel, dim: Int, unroll: Boolean - ): (Nat, NatToData, Phrase[AccType], Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) => ParForNat = + ): (Nat, NatToData, Phrase[AccType], Phrase[DepFunType[NatIdentifier, FunType[AccType, CommType]]]) => ParForNat = level match { case Global => ParForNat(level, dim, unroll, "gl_id_")( get_global_id(dim), _, get_global_size(dim), _, _, _) @@ -39,7 +39,7 @@ package object DSL { private def parForBodyFunction(n:Nat, ft:NatToData, f:NatIdentifier => Phrase[AccType] => Phrase[CommType] - ): DepLambda[NatKind, AccType ->: CommType] = { + ): DepLambda[Nat, NatIdentifier, AccType ->: CommType] = { nFun(idx => λ(accT(ft(idx)))(o => f(idx)(o)), RangeAdd(0, n, 1)) } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala b/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala index 8fa77c246..22b0713d4 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala @@ -7,16 +7,16 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class DepMap(level: shine.OpenCL.ParallelismLevel, dim: Int)(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { +final case class DepMap(level: shine.OpenCL.ParallelismLevel, dim: Int)(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatIdentifier, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { assert { f :: ({ val m = f.t.x - DepFunType[NatKind, PhraseType](m, FunType(expT(NatToDataApply(ft1, m), read), expT(NatToDataApply(ft2, m), write))) + DepFunType(NatKind, m, FunType(expT(NatToDataApply(ft1, m), read), expT(NatToDataApply(ft2, m), write))) }) array :: expT(DepArrayType(n, ft1), read) true } override val t: ExpType = expT(DepArrayType(n, ft2), write) override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepMap = new DepMap(level, dim)(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) - def unwrap: (Nat, NatToData, NatToData, Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], Phrase[ExpType]) = (n, ft1, ft2, f, array) + def unwrap: (Nat, NatToData, NatToData, Phrase[DepFunType[NatIdentifier, FunType[ExpType, ExpType]]], Phrase[ExpType]) = (n, ft1, ft2, f, array) } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala b/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala index 3fa11ab8b..01f9943a0 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala @@ -7,11 +7,11 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class Iterate(val a: AddressSpace, val n: Nat, val m: Nat, val k: Nat, val dt: DataType, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { +final case class Iterate(val a: AddressSpace, val n: Nat, val m: Nat, val k: Nat, val dt: DataType, val f: Phrase[DepFunType[NatIdentifier, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { assert { f :: ({ val l = f.t.x - DepFunType[NatKind, PhraseType](l, FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write))) + DepFunType(NatKind, l, FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write))) }) array :: expT(ArrayType(m * n.pow(k), dt), read) true diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala b/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala index 3f0b066b4..8c0277de8 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala @@ -7,16 +7,16 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class ParForNat(level: shine.OpenCL.ParallelismLevel, dim: Int, unroll: Boolean, prefix: String)(val init: Nat, val n: Nat, val step: Nat, val ft: NatToData, val out: Phrase[AccType], val body: Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) extends CommandPrimitive { +final case class ParForNat(level: shine.OpenCL.ParallelismLevel, dim: Int, unroll: Boolean, prefix: String)(val init: Nat, val n: Nat, val step: Nat, val ft: NatToData, val out: Phrase[AccType], val body: Phrase[DepFunType[NatIdentifier, FunType[AccType, CommType]]]) extends CommandPrimitive { assert { out :: accT(DepArrayType(n, ft)) body :: ({ val i = body.t.x - DepFunType[NatKind, PhraseType](i, FunType(accT(NatToDataApply(ft, i)), comm)) + DepFunType(NatKind, i, FunType(accT(NatToDataApply(ft, i)), comm)) }) true } override val t: CommType = comm override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParForNat = new ParForNat(level, dim, unroll, prefix)(v.nat(init), v.nat(n), v.nat(step), v.natToData(ft), VisitAndRebuild(out, v), VisitAndRebuild(body, v)) - def unwrap: (Nat, Nat, Nat, NatToData, Phrase[AccType], Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) = (init, n, step, ft, out, body) + def unwrap: (Nat, Nat, Nat, NatToData, Phrase[AccType], Phrase[DepFunType[NatIdentifier, FunType[AccType, CommType]]]) = (init, n, step, ft, out, body) } diff --git a/src/main/scala/shine/OpenMP/CodeGenerator.scala b/src/main/scala/shine/OpenMP/CodeGenerator.scala index 1ccb6af92..cb7dcf603 100644 --- a/src/main/scala/shine/OpenMP/CodeGenerator.scala +++ b/src/main/scala/shine/OpenMP/CodeGenerator.scala @@ -8,7 +8,7 @@ import shine.C.Compilation.{CodeGenerator => CCodeGenerator} import shine.DPIA.DSL._ import shine.DPIA.primitives.imperative._ import shine.DPIA.Phrases._ -import shine.DPIA.Types.{AccType, CommType, DataType, ExpType, PhraseType, ScalarType, VectorType} +import shine.DPIA.Types.{AccType, CommType, DataType, ExpType, NatKind, PhraseType, ScalarType, VectorType} import shine.DPIA.primitives.functional._ import shine.DPIA.{ArrayData, Compilation, Data, Nat, NatIdentifier, Phrases, VectorData, error, freshName} import shine.OpenMP.primitives.imperative.{ParFor, ParForNat} @@ -38,7 +38,7 @@ class CodeGenerator(override val decls: CCodeGenerator.Declarations, OpenMPCodeGen.codeGenParFor(n, dt, a, i, o, p, env) case ForVec(n, dt, a, Lambda(i, Lambda(o, p))) => OpenMPCodeGen.codeGenParForVec(n, dt, a, i, o, p, env) - case ParForNat(n, _, a, DepLambda(i, Lambda(o, p))) => + case ParForNat(n, _, a, DepLambda(NatKind, i, Lambda(o, p))) => OpenMPCodeGen.codeGenParForNat(n, a, i, o, p, env) case phrase => phrase |> super.cmd(env) } diff --git a/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala b/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala index 7d5550d82..892dbdd3b 100644 --- a/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala +++ b/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala @@ -7,11 +7,11 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class DepMapPar(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { +final case class DepMapPar(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatIdentifier, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { assert { f :: ({ val m = f.t.x - DepFunType[NatKind, PhraseType](m, FunType(expT(NatToDataApply(ft1, m), read), expT(NatToDataApply(ft2, m), write))) + DepFunType(NatKind, m, FunType(expT(NatToDataApply(ft1, m), read), expT(NatToDataApply(ft2, m), write))) }) array :: expT(DepArrayType(n, ft1), read) true diff --git a/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala b/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala index 8925bb322..8b460edca 100644 --- a/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala +++ b/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala @@ -7,12 +7,12 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -final case class ParForNat(val n: Nat, val ft: NatToData, val out: Phrase[AccType], val body: Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) extends CommandPrimitive { +final case class ParForNat(val n: Nat, val ft: NatToData, val out: Phrase[AccType], val body: Phrase[DepFunType[NatIdentifier, FunType[AccType, CommType]]]) extends CommandPrimitive { assert { out :: accT(DepArrayType(n, ft)) body :: ({ val i = body.t.x - DepFunType[NatKind, PhraseType](i, FunType(accT(NatToDataApply(ft, i)), comm)) + DepFunType(NatKind, i, FunType(accT(NatToDataApply(ft, i)), comm)) }) true } diff --git a/src/test/scala/apps/asum.scala b/src/test/scala/apps/asum.scala index 3ff5602f8..2d907435d 100644 --- a/src/test/scala/apps/asum.scala +++ b/src/test/scala/apps/asum.scala @@ -32,7 +32,7 @@ class asum extends test_util.TestsWithExecutor { val typed = high_level.toExpr val N = typed.t.asInstanceOf[NatDepFunType[_ <: Type]].x - assertResult(DepFunType[NatKind, Type](N, FunType(inputT(N), f32))) { + assertResult(DepFunType(NatKind, N, FunType(inputT(N), f32))) { typed.t } } diff --git a/src/test/scala/apps/dot.scala b/src/test/scala/apps/dot.scala index cb07b260b..896cf1b7a 100644 --- a/src/test/scala/apps/dot.scala +++ b/src/test/scala/apps/dot.scala @@ -26,7 +26,7 @@ class dot extends test_util.Tests { test("Simple dot product type inference works") { val N = simpleDotProduct.t.asInstanceOf[NatDepFunType[_ <: Type]].x assertResult( - DepFunType[NatKind, Type](N, FunType(xsT(N), FunType(ysT(N), f32))) + DepFunType(NatKind, N, FunType(xsT(N), FunType(ysT(N), f32))) ) { simpleDotProduct.t } diff --git a/src/test/scala/apps/gemvCheck.scala b/src/test/scala/apps/gemvCheck.scala index 41b413c17..56f8a5501 100644 --- a/src/test/scala/apps/gemvCheck.scala +++ b/src/test/scala/apps/gemvCheck.scala @@ -19,8 +19,8 @@ class gemvCheck extends test_util.Tests { .asInstanceOf[NatDepFunType[_ <: Type]].t .asInstanceOf[NatDepFunType[_ <: Type]].x assertResult( - DepFunType(N, - DepFunType(M, + DepFunType(NatKind, N, + DepFunType(NatKind, M, ArrayType(M, ArrayType(N, f32)) ->: (ArrayType(N, f32) ->: (ArrayType(M, f32) ->: (f32 ->: (f32 ->: ArrayType(M, f32))))) diff --git a/src/test/scala/apps/separableConvolution2DNaiveEqsat.scala b/src/test/scala/apps/separableConvolution2DNaiveEqsat.scala index 0cc00d395..dcfdff9bb 100644 --- a/src/test/scala/apps/separableConvolution2DNaiveEqsat.scala +++ b/src/test/scala/apps/separableConvolution2DNaiveEqsat.scala @@ -52,15 +52,15 @@ class separableConvolution2DNaiveEqsat extends test_util.Tests { case App(f, e) => everywhere(s)(f).map(App(_, e)(p.t)) ++ everywhere(s)(e).map(App(f, _)(p.t)) case Identifier(_) => Nil case Lambda(x, e) => everywhere(s)(e).map(Lambda(x, _)(p.t)) - case DepLambda(x, e) => x match { + case DepLambda(_, x, e) => x match { case n: NatIdentifier => - everywhere(s)(e).map(DepLambda[NatKind](n, _)(p.t)) + everywhere(s)(e).map(DepLambda(NatKind, n, _)(p.t)) case n: DataTypeIdentifier => - everywhere(s)(e).map(DepLambda[DataKind](n, _)(p.t)) + everywhere(s)(e).map(DepLambda(DataKind, n, _)(p.t)) case n: AddressSpaceIdentifier => - everywhere(s)(e).map(DepLambda[AddressSpaceKind](n, _)(p.t)) + everywhere(s)(e).map(DepLambda(AddressSpaceKind, n, _)(p.t)) } - case DepApp(f, x) => everywhere(s)(f).map(DepApp(_, x)(p.t)) + case DepApp(kind, f, x) => everywhere(s)(f).map(DepApp(kind, _, x)(p.t)) case Literal(_) => Nil case _: TypeAnnotation => throw new Exception("Type annotations should be gone.") case _: TypeAssertion => throw new Exception("Type assertions should be gone.") diff --git a/src/test/scala/rise/core/showRise.scala b/src/test/scala/rise/core/showRise.scala index 416c9b98f..186ace07f 100644 --- a/src/test/scala/rise/core/showRise.scala +++ b/src/test/scala/rise/core/showRise.scala @@ -57,9 +57,9 @@ class showRise extends test_util.Tests { case i: Identifier => line(i.name) case Lambda(x, e) => block(s"λ${x.name}", drawASTSimp(e)) case App(f, e) => drawASTSimp(f) :+> drawASTSimp(e) - case dl @ DepLambda(x, e) => - block(s"Λ${x.name}:${dl.kindName}", drawASTSimp(e)) - case DepApp(f, x) => line(x.toString) <+: drawASTSimp(f) + case DepLambda(kind, x, e) => + block(s"Λ${x.name}:${kind.name}", drawASTSimp(e)) + case DepApp(_, f, x) => line(x.toString) <+: drawASTSimp(f) case Literal(d) => line(d.toString) case TypeAnnotation(e, _) => drawASTSimp(e) case TypeAssertion(e, _) => drawASTSimp(e) @@ -87,15 +87,15 @@ class showRise extends test_util.Tests { val es = lessBrackets(e, wrapped = true) if (wrapped) s"($fs $es)" else s"$fs $es" - case dl @ DepLambda(x, e) => - val xs = s"${x.name}:${dl.kindName}" + case DepLambda(kind, x, e) => + val xs = s"${x.name}:${kind.name}" val es = lessBrackets(e) if (wrapped) s"[Λ$xs. $es]" else s"Λ$xs. $es" - case DepApp(f, x) => + case DepApp(_, f, x) => val fs = f match { - case _: DepLambda[_] => lessBrackets(f, wrapped = true) - case _ => lessBrackets(f) + case _: DepLambda[_, _] => lessBrackets(f, wrapped = true) + case _ => lessBrackets(f) } if (wrapped) s"($fs $x)" else s"$fs $x" diff --git a/src/test/scala/rise/elevate/algorithmic.scala b/src/test/scala/rise/elevate/algorithmic.scala index 7715f6512..9285028ed 100644 --- a/src/test/scala/rise/elevate/algorithmic.scala +++ b/src/test/scala/rise/elevate/algorithmic.scala @@ -68,12 +68,12 @@ class algorithmic extends test_util.Tests { val addTuple = fun(x => fst(x) + snd(x)) - val mapReduce = depLambda[NatKind](M, depLambda[NatKind](N, + val mapReduce = depLambda(NatKind, M, depLambda(NatKind, N, fun(ArrayType(M, ArrayType(N, f32)))(i => map(reduce(fun(x => fun(a => x + a)))(lf32(0.0f))) $ i))) val reduceMap: Rise = - depLambda[NatKind](M, depLambda[NatKind](N, + depLambda(NatKind, M, depLambda(NatKind, N, fun(ArrayType(M, ArrayType(N, f32)))(i => reduce(fun((acc, y) => map(addTuple) $ zip(acc)(y)))(generate(fun(IndexType(M) ->: f32)(_ => lf32(0.0f)))) $ transpose(i)))) @@ -94,7 +94,7 @@ class algorithmic extends test_util.Tests { val N = NatIdentifier("N", isExplicit = true) val K = NatIdentifier("K", isExplicit = true) - val mm = depLambda[NatKind](M, depLambda[NatKind](N, depLambda[NatKind](K, + val mm = depLambda(NatKind, M, depLambda(NatKind, N, depLambda(NatKind, K, fun(ArrayType(M, ArrayType(K, f32)))(a => fun(ArrayType(K, ArrayType(N, f32)))(b => a |> map(fun(ak => @@ -104,7 +104,7 @@ class algorithmic extends test_util.Tests { lf32(0.0f))))))))))) def goldMKN(reduceFun: ToBeTyped[Rise]): ToBeTyped[Rise] = { - depLambda[NatKind](M, depLambda[NatKind](N, depLambda[NatKind](K, + depLambda(NatKind, M, depLambda(NatKind, N, depLambda(NatKind, K, fun(ArrayType(M, ArrayType(K, f32)))(a => fun(ArrayType(K, ArrayType(N, f32)))(b => a |> map(fun(ak => @@ -145,7 +145,7 @@ class algorithmic extends test_util.Tests { val K = NatIdentifier("K", isExplicit = true) val mmMKN = { - depLambda[NatKind](M, depLambda[NatKind](N, depLambda[NatKind](K, + depLambda(NatKind, M, depLambda(NatKind, N, depLambda(NatKind, K, fun(ArrayType(M, ArrayType(K, f32)))(a => fun(ArrayType(K, ArrayType(N, f32)))(b => map(fun(ak => @@ -197,7 +197,7 @@ class algorithmic extends test_util.Tests { // this one is constructed more similar to what the rewrite rules will create val goldKMNAlternative = - depLambda[NatKind](M, depLambda[NatKind](N, depLambda[NatKind](K, + depLambda(NatKind, M, depLambda(NatKind, N, depLambda(NatKind, K, fun(ArrayType(M, ArrayType(K, f32)))(a => fun(ArrayType(K, ArrayType(N, f32)))(b => reduceSeq( @@ -216,7 +216,7 @@ class algorithmic extends test_util.Tests { // unfortunately, the order of zip arguments is important val goldKMNAlternative2 = - depLambda[NatKind](M, depLambda[NatKind](N, depLambda[NatKind](K, + depLambda(NatKind, M, depLambda(NatKind, N, depLambda(NatKind, K, fun(ArrayType(M, ArrayType(K, f32)))(a => fun(ArrayType(K, ArrayType(N, f32)))(b => reduceSeq( @@ -273,7 +273,7 @@ class algorithmic extends test_util.Tests { val K = NatIdentifier("K", isExplicit = true) val mm = - DFNF(depLambda[NatKind](M, depLambda[NatKind](N, depLambda[NatKind](K, + DFNF(depLambda(NatKind, M, depLambda(NatKind, N, depLambda(NatKind, K, fun(ArrayType(M, ArrayType(K, f32)))(a => fun(ArrayType(K, ArrayType(N, f32)))(b => map(fun(ak => diff --git a/src/test/scala/rise/elevate/tiling.scala b/src/test/scala/rise/elevate/tiling.scala index cf032cdfd..3ba588e92 100644 --- a/src/test/scala/rise/elevate/tiling.scala +++ b/src/test/scala/rise/elevate/tiling.scala @@ -249,7 +249,7 @@ class tiling extends test_util.Tests { def wrapInLambda[T <: Expr](dim: Int, f: ToBeTyped[Identifier] => ToBeTyped[T], genInputType: List[Nat] => ArrayType, - natIds: List[Nat] = List()): ToBeTyped[DepLambda[NatKind]] = { + natIds: List[Nat] = List()): ToBeTyped[DepLambda[Nat, NatIdentifier]] = { dim match { case 1 => depFun((n: Nat) => fun(genInputType( natIds :+ n))(f)) case d => depFun((n: Nat) => wrapInLambda(d - 1, f, genInputType, natIds :+ n)) diff --git a/src/test/scala/rise/elevate/traversals.scala b/src/test/scala/rise/elevate/traversals.scala index ad2d80257..f43b2981c 100644 --- a/src/test/scala/rise/elevate/traversals.scala +++ b/src/test/scala/rise/elevate/traversals.scala @@ -6,7 +6,7 @@ import rise.elevate.util._ import rise.core.DSL._ import rise.core.makeClosed import rise.core.primitives._ -import rise.core.types.NatKind +import rise.core.types.{Nat, NatKind} import rise.elevate.meta.fission.bodyFission import rise.elevate.meta.traversal.inBody import rise.elevate.rules.algorithmic._ @@ -69,8 +69,8 @@ class traversals extends test_util.Tests { } test("RNF did not normalize") { - val expr2 = lambda(identifier("ee1"), lambda(identifier("ee2"), app(join, app(app(map, lambda(identifier("η125"), app(app(map, lambda(identifier("ee3"), app(join, app(app(map, lambda(identifier("η124"), app(app(map, lambda(identifier("η123"), app(identifier("ee2"), identifier("η123")))), identifier("η124")))), app(depApp[NatKind](split, 4), identifier("ee3")))))), identifier("η125")))), app(depApp[NatKind](split, 4), identifier("ee1")))))) - val expr5 = lambda(identifier("ee1"), lambda(identifier("ee2"), app(join, app(app(map, lambda(identifier("η141"), app(app(map, lambda(identifier("η140"), app(join, identifier("η140")))), identifier("η141")))), app(app(map, lambda(identifier("η145"), app(app(map, lambda(identifier("η144"), app(app(map, lambda(identifier("η143"), app(app(map, lambda(identifier("η142"), app(identifier("ee2"), identifier("η142")))), identifier("η143")))), identifier("η144")))), identifier("η145")))), app(app(map, lambda(identifier("η147"), app(app(map, lambda(identifier("η146"), app(depApp[NatKind](split, 4), identifier("η146")))), identifier("η147")))), app(depApp[NatKind](split, 4), identifier("ee1")))))))) + val expr2 = lambda(identifier("ee1"), lambda(identifier("ee2"), app(join, app(app(map, lambda(identifier("η125"), app(app(map, lambda(identifier("ee3"), app(join, app(app(map, lambda(identifier("η124"), app(app(map, lambda(identifier("η123"), app(identifier("ee2"), identifier("η123")))), identifier("η124")))), app(depApp[Nat](NatKind, split, 4), identifier("ee3")))))), identifier("η125")))), app(depApp[Nat](NatKind, split, 4), identifier("ee1")))))) + val expr5 = lambda(identifier("ee1"), lambda(identifier("ee2"), app(join, app(app(map, lambda(identifier("η141"), app(app(map, lambda(identifier("η140"), app(join, identifier("η140")))), identifier("η141")))), app(app(map, lambda(identifier("η145"), app(app(map, lambda(identifier("η144"), app(app(map, lambda(identifier("η143"), app(app(map, lambda(identifier("η142"), app(identifier("ee2"), identifier("η142")))), identifier("η143")))), identifier("η144")))), identifier("η145")))), app(app(map, lambda(identifier("η147"), app(app(map, lambda(identifier("η146"), app(depApp[Nat](NatKind, split, 4), identifier("η146")))), identifier("η147")))), app(depApp[Nat](NatKind, split, 4), identifier("ee1")))))))) assert(makeClosed(RNF(expr2).get) =~= makeClosed(toExpr(expr5))) } diff --git a/src/test/scala/rise/elevate/util/package.scala b/src/test/scala/rise/elevate/util/package.scala index 43e97f33e..d3588d668 100644 --- a/src/test/scala/rise/elevate/util/package.scala +++ b/src/test/scala/rise/elevate/util/package.scala @@ -19,7 +19,7 @@ package object util { // notation def T: ToBeTyped[Rise] = transpose - def S: ToBeTyped[DepApp[NatKind]] = split(tileSize) //slide(3)(1) + def S: ToBeTyped[DepApp[Nat]] = split(tileSize) //slide(3)(1) def J: ToBeTyped[Rise] = join def *(x: ToBeTyped[Rise]): ToBeTyped[App] = map(x) def **(x: ToBeTyped[Rise]): ToBeTyped[App] = map(map(x)) diff --git a/src/test/scala/shine/DPIA/InferAccessTypes.scala b/src/test/scala/shine/DPIA/InferAccessTypes.scala index 794bc6ff1..09feb68bf 100644 --- a/src/test/scala/shine/DPIA/InferAccessTypes.scala +++ b/src/test/scala/shine/DPIA/InferAccessTypes.scala @@ -102,7 +102,7 @@ class InferAccessTypes extends test_util.Tests { val splitArray = (depFun((n: Nat) => fun(8`.`rt.f32)(arr => arr |> split(n)))).toExpr val infPt = inferAccess(splitArray).get(splitArray).asInstanceOf[ - DepFunType[NatKind, FunType[ExpType, ExpType]] + DepFunType[NatIdentifier, FunType[ExpType, ExpType]] ] assertResult(read)(infPt.t.outT.accessType) } diff --git a/src/test/scala/shine/cuda/MMTest.scala b/src/test/scala/shine/cuda/MMTest.scala index 78bb48f74..7a720d556 100644 --- a/src/test/scala/shine/cuda/MMTest.scala +++ b/src/test/scala/shine/cuda/MMTest.scala @@ -98,7 +98,7 @@ class MMTest extends test_util.TestWithCUDA { //Kernel val simpleMatMulTile = - DepLambda[NatKind](k)( + DepLambda(NatKind, k)( Lambda[ExpType, FunType[ExpType, ExpType]](matrixATile, Lambda[ExpType, ExpType](matrixBTile, AsMatrix(mTile, nTile, kTile, f32, @@ -196,9 +196,9 @@ class MMTest extends test_util.TestWithCUDA { //Kernel val simpleMatMul = - DepLambda[NatKind](m)( - DepLambda[NatKind](n)( - DepLambda[NatKind](k)( + DepLambda(NatKind, m)( + DepLambda(NatKind, n)( + DepLambda(NatKind, k)( //Input: matrixA Lambda[ExpType, FunType[ExpType, ExpType]](matrixA, //And matrixB diff --git a/src/test/scala/shine/cuda/basic.scala b/src/test/scala/shine/cuda/basic.scala index 9ef062a41..a3a33016f 100644 --- a/src/test/scala/shine/cuda/basic.scala +++ b/src/test/scala/shine/cuda/basic.scala @@ -1,7 +1,6 @@ package shine.cuda import shine.DPIA.DSL.{depFun, λ} -import shine.DPIA.Types.Kind.NatIdentifierMaker import shine.DPIA.Types.{NatKind, _} import shine.OpenCL.{Global, Local} import util.gen @@ -9,7 +8,7 @@ import util.gen class basic extends test_util.Tests { test("id with mapThreads compiles to syntactically correct Cuda") { - val mapId = depFun[NatKind]()(n => + val mapId = depFun(NatKind)(n => λ(ExpType(ArrayType(n, f32), read))(array => shine.cuda.primitives.functional.Map(Local, 'x')(n, f32, f32, λ(ExpType(f32, read))(x => x), array)) ) @@ -19,7 +18,7 @@ class basic extends test_util.Tests { } test("id with mapGlobal compiles to syntactically correct CUDA") { - val mapId = depFun[NatKind]()(n => + val mapId = depFun(NatKind)(n => λ(ExpType(ArrayType(n, f32), read))(array => shine.cuda.primitives.functional.Map(Global, 'x')(n, f32, f32, λ(ExpType(f32, read))(x => x), array)) )