diff --git a/build.sbt b/build.sbt
index ffaf38084..4c52b490c 100644
--- a/build.sbt
+++ b/build.sbt
@@ -17,7 +17,7 @@ lazy val commonSettings = Seq(
lazy val riseAndShine = (project in file("."))
.aggregate(executor, CUexecutor)
- .dependsOn(meta, riseAndShineMacros, arithExpr, executor, CUexecutor, elevate)
+ .dependsOn(meta, arithExpr, executor, CUexecutor, elevate)
.settings(
name := "riseAndShine",
version := "1.0",
@@ -50,7 +50,16 @@ lazy val riseAndShine = (project in file("."))
lazy val generateRISEPrimitives = taskKey[Unit]("Generate RISE Primitives")
generateRISEPrimitives := {
- runner.value.run("meta.RisePrimitiveGenerator",
+ runner.value.run("meta.generator.RisePrimitives",
+ (dependencyClasspath in Compile).value.files,
+ Seq((scalaSource in Compile).value.getAbsolutePath),
+ streams.value.log).failed foreach (sys error _.getMessage)
+}
+
+lazy val generateDPIAPrimitives = taskKey[Unit]("Generate DPIA Primitives")
+
+generateDPIAPrimitives := {
+ runner.value.run("meta.generator.DPIAPrimitives",
(dependencyClasspath in Compile).value.files,
Seq((scalaSource in Compile).value.getAbsolutePath),
streams.value.log).failed foreach (sys error _.getMessage)
@@ -67,14 +76,6 @@ lazy val meta = (project in file("meta"))
libraryDependencies += "org.scalameta" %% "scalameta" % "4.4.10",
)
-lazy val riseAndShineMacros = (project in file("macros"))
- .settings(
- name := "riseAndShineMacros",
- version := "1.0",
- commonSettings,
- libraryDependencies += "org.scala-lang" % "scala-reflect" % scalaVersion.value
- )
-
lazy val arithExpr = project in file("lib/arithexpr")
lazy val executor = project in file("lib/executor")
diff --git a/macros/src/main/scala/shine/macros/Primitive.scala b/macros/src/main/scala/shine/macros/Primitive.scala
deleted file mode 100644
index 90d985a5d..000000000
--- a/macros/src/main/scala/shine/macros/Primitive.scala
+++ /dev/null
@@ -1,163 +0,0 @@
-package shine.macros
-
-import scala.annotation.{StaticAnnotation, compileTimeOnly}
-import scala.reflect.macros.blackbox
-import scala.language.experimental.macros
-
-object Primitive {
- @compileTimeOnly("ExpPrimitive macro")
- class expPrimitive extends StaticAnnotation {
- def macroTransform(annottees: Any*): Any = macro Impl.expPrimitive
- }
-
- @compileTimeOnly("AccPrimitive macro")
- class accPrimitive extends StaticAnnotation {
- def macroTransform(annottees: Any*): Any = macro Impl.accPrimitive
- }
-
- @compileTimeOnly("CommandPrimitive macro")
- class comPrimitive extends StaticAnnotation {
- def macroTransform(annottees: Any*): Any = macro Impl.comPrimitive
- }
-
- class Impl(val c: blackbox.Context) {
- import c.universe._
-
- def expPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees)
- def accPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees)
- def comPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees)
-
- def primitive(transform : ClassDef => ClassDef)(annottees: Seq[c.Expr[Any]]): c.Expr[Any] = {
- annottees.map(_.tree) match {
- case (cdef: ClassDef) :: Nil =>
- c.Expr(transform(cdef))
- case (cdef: ClassDef) :: (md: ModuleDef) :: Nil =>
- c.Expr(q"{${transform(cdef)}; $md}")
- case _ => c.abort(c.enclosingPosition, "expected a class definition")
- }
- }
-
- def makeLowerCaseName(s: String): String =
- s"${Character.toLowerCase(s.charAt(0))}${s.substring(1)}"
-
- def makeVisitAndRebuild(name: TypeName,
- additionalParams: List[ValDef],
- params: List[ValDef]): Tree = {
- val v = q"v"
- q"""
- override def visitAndRebuild(
- $v: shine.DPIA.Phrases.VisitAndRebuild.Visitor): $name
- = new ${Apply( additionalParams match {
- case List() => Ident(name)
- case _ => Apply(Ident(name), additionalParams.map {
- case ValDef(_, name, _, _) => q"$name"
- })
- }, params.map {
- case ValDef(_, name, tpt, _) => tpt match {
- case Ident(TypeName("DataType")) | Ident(TypeName("ScalarType")) |
- Ident(TypeName("BasicType")) => q"$v.data($name)"
- case Ident(TypeName("Nat")) => q"$v.nat($name)"
- case Ident(TypeName("NatIdentifier")) => q"$v.nat($name)"
- case Ident(TypeName("NatToNat")) => q"$v.natToNat($name)"
- case Ident(TypeName("NatToData")) => q"$v.natToData($name)"
- case Ident(TypeName("AccessType")) => q"$v.access($name)"
- case Ident(TypeName("AddressSpace")) => q"$v.addressSpace($name)"
- case Ident(TypeName("LocalSize")) => q"$name.visitAndRebuild($v)"
- case Ident(TypeName("GlobalSize")) => q"$name.visitAndRebuild($v)"
- // Phrase[ExpType]
- case AppliedTypeTree((Ident(TypeName("Phrase")), _)) =>
- q"shine.DPIA.Phrases.VisitAndRebuild($name, $v)"
- // Vector[Phrase[ExpType]]
- case AppliedTypeTree((Ident(TypeName("Vector")),
- List(AppliedTypeTree((Ident(TypeName("Phrase")), _)))))
- | AppliedTypeTree((Ident(TypeName("Seq")),
- List(AppliedTypeTree((Ident(TypeName("Phrase")), _)))))
- =>
- q"$name.map(shine.DPIA.Phrases.VisitAndRebuild(_, $v))"
- case _ =>
- q"$name"
- }
- })}
- """
- }
-
- case class ClassInfo(name: TypeName,
- additionalParams: List[ValDef],
- params: List[ValDef],
- body: List[Tree],
- parents: List[Tree])
-
- def primitivesFromClassDef: ClassDef => ClassInfo = {
- case q"case class $name(..$params) extends { ..$_ } with ..$parents {..$body} " =>
- ClassInfo(
- name.asInstanceOf[c.TypeName],
- List(),
- params.asInstanceOf[List[ValDef]],
- body.asInstanceOf[List[Tree]],
- parents.asInstanceOf[List[Tree]])
- case q"final case class $name(..$params) extends { ..$_ } with ..$parents {..$body} " =>
- ClassInfo(
- name.asInstanceOf[c.TypeName],
- List(),
- params.asInstanceOf[List[ValDef]],
- body.asInstanceOf[List[Tree]],
- parents.asInstanceOf[List[Tree]])
- case q"""case class $name(..$additionalParams)
- (..$params) extends { ..$_ } with ..$parents {..$body} """ =>
- ClassInfo(
- name.asInstanceOf[c.TypeName],
- additionalParams.asInstanceOf[List[ValDef]],
- params.asInstanceOf[List[ValDef]],
- body.asInstanceOf[List[Tree]],
- parents.asInstanceOf[List[Tree]])
- case q"""final case class $name(..$additionalParams)
- (..$params) extends { ..$_ } with ..$parents {..$body} """ =>
- ClassInfo(
- name.asInstanceOf[c.TypeName],
- additionalParams.asInstanceOf[List[ValDef]],
- params.asInstanceOf[List[ValDef]],
- body.asInstanceOf[List[Tree]],
- parents.asInstanceOf[List[Tree]])
- case _ =>
- c.abort(c.enclosingPosition, "expected a case class extends Primitive")
- }
-
- def makePrimitiveClass : ClassInfo => ClassDef = { case ClassInfo(name, additionalParams, params, body, parents) =>
- val visitAndRebuildMissing =
- body.collectFirst({ case DefDef(_, TermName("visitAndRebuild"), _, _, _, _) => ()}).isEmpty
-
- val generated = q"""
- ${if (visitAndRebuildMissing)
- makeVisitAndRebuild(name, additionalParams, params)
- else q""}
- """
-
- val expClass = (additionalParams match {
- case List() =>
- q"""
- final case class $name(..$params) extends {} with ..$parents {
- ..$body
- ..$generated
- }
- """
- case _ =>
- val newParams = params map {
- case ValDef(_, name, tpt, rhs) if rhs.isEmpty => q"val $name : $tpt"
- case ValDef(_, name, tpt, rhs) => q"val $name : $tpt = $rhs"
- }
- q"""
- final case class $name(..$additionalParams)
- (..$newParams) extends {} with ..$parents {
- ..$body
- ..$generated
- }
- """
- }).asInstanceOf[ClassDef]
-/*
- c.info(c.enclosingPosition,
- s"generated `${name.toString}'\n$expClass", force = false)
-*/
- expClass
- }
- }
-}
\ No newline at end of file
diff --git a/meta/src/main/scala/meta/NatParser.scala b/meta/src/main/scala/meta/NatParser.scala
deleted file mode 100644
index 9741605e2..000000000
--- a/meta/src/main/scala/meta/NatParser.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-package meta
-
-import fastparse.ScalaWhitespace._
-import fastparse._
-
-import meta.TypeParser._
-
-object NatParser {
-
- sealed trait NatAST
- object NatAST {
- case class Identifier(id: TypeAST.Identifier) extends NatAST
- case class Number(n: String) extends NatAST
- case class BinaryOp(lhs: NatAST, op: String, rhs: NatAST) extends NatAST
- case class TernaryOp(cond: BinaryOp, thenN: NatAST, elseN: NatAST) extends NatAST
- case class Nat2NatApply(f: TypeAST.Identifier, n: NatAST) extends NatAST
- case class Sum(id: Identifier, from: NatAST, upTo: NatAST, body: NatAST) extends NatAST
- }
-
- def Nat[_: P]: P[NatAST] = {
- def CompOrNat: P[NatAST] = {
- def CompOp: P[String] = P("<".! | ">".!)
- P(AddSubOrNat ~ (CompOp ~/ AddSubOrNat).rep).map(asBinaryOpOrNat)
- }
-
- def AddSubOrNat: P[NatAST] = {
- def AddSubOps: P[String] = P("+".! | "-".!)
- P(DivMulPowModOrNat ~ (AddSubOps ~ DivMulPowModOrNat).rep).map(asBinaryOpOrNat)
- }
-
- def DivMulPowModOrNat: P[NatAST] = {
- def DivMulPowModOp: P[String] = P("*".! | "/".! | "^".! | "%".!)
- P(SingleNat ~ (DivMulPowModOp ~ SingleNat).rep).map(asBinaryOpOrNat)
- }
-
- def SingleNat: P[NatAST] = {
- def Number: P[NatAST.Number] = P(CharIn("0-9").rep(1).!).map(NatAST.Number)
-
- def Sum: P[NatAST.Sum] = {
- def Assignment: P[(NatAST.Identifier, NatAST)] =
- P(NatIdentifier ~ "=" ~ Nat | "(" ~ Assignment ~ ")")
- P("sum" ~ "_" ~ Assignment ~ "^" ~ Nat ~ Nat).map(NatAST.Sum.tupled)
- }
-
- def Nat2NatApply: P[NatAST.Nat2NatApply] =
- P(TypeIdentifier ~ "(" ~ Nat ~ ")").map(NatAST.Nat2NatApply.tupled)
-
- def NatIdentifier: P[NatAST.Identifier] = P(TypeIdentifier).map(NatAST.Identifier)
-
- def Parens: P[NatAST] = P("(" ~ Nat ~ ")")
-
- P(Number | Sum | Nat2NatApply | NatIdentifier | Parens)
- }
-
- P(CompOrNat)
- }
-
- private def asBinaryOpOrNat: ((NatAST, Seq[(String, NatAST)])) => NatAST = {
- case (n, ns) => ns.foldLeft(n){
- case (lhs, (op, rhs)) => NatAST.BinaryOp(lhs, op, rhs)
- }
- }
-
-}
diff --git a/meta/src/main/scala/meta/TypeParser.scala b/meta/src/main/scala/meta/TypeParser.scala
deleted file mode 100644
index 7a31480f3..000000000
--- a/meta/src/main/scala/meta/TypeParser.scala
+++ /dev/null
@@ -1,328 +0,0 @@
-package meta
-
-import fastparse.ScalaWhitespace._
-import fastparse._
-import meta.NatParser._
-import meta.TypeParser.TypeAST.{FragmentAST, MatrixLayoutAST}
-
-object TypeParser {
-
- sealed trait TypeAST
- object TypeAST {
- case class Identifier(name: String) extends TypeAST
- case class FunType(inT: TypeAST, outT: TypeAST) extends TypeAST
- case class DepFunType(id: Identifier, kind: String, t: TypeAST) extends TypeAST
- case class ImplicitDepFunType(id: Identifier, kind: String, t: TypeAST) extends TypeAST
-
- case class ScalarType(t: String) extends TypeAST
- case object NatType extends TypeAST
- case class VectorType(size: NatAST, elemType: TypeAST) extends TypeAST
- case class IndexType(size: NatAST) extends TypeAST
- case class PairType(lhs: TypeAST, rhs: TypeAST) extends TypeAST
- case class DepPairType(id: Identifier, kind: String, t: TypeAST) extends TypeAST
- case class NatToDataApply(f: TypeAST, n: NatAST) extends TypeAST
- case class NatToDataLambda(id: Identifier, t: TypeAST) extends TypeAST
- case class ArrayType(size: NatAST, elemType: TypeAST) extends TypeAST
- case class DepArrayType(size: NatAST, fdt: TypeAST) extends TypeAST
- case class FragmentType(n: NatAST, m: NatAST, k: NatAST, elemType: TypeAST,
- fKind: FragmentAST, mLayout: MatrixLayoutAST) extends TypeAST
-
- sealed trait FragmentAST
- object FragmentAST {
- case class Identifier(id: TypeAST.Identifier) extends FragmentAST
- object ACC extends FragmentAST
- object A extends FragmentAST
- object B extends FragmentAST
- }
-
- sealed trait MatrixLayoutAST
- object MatrixLayoutAST {
- case class Identifier(id: TypeAST.Identifier) extends MatrixLayoutAST
- object ROW_MAJOR extends MatrixLayoutAST
- object COL_MAJOR extends MatrixLayoutAST
- }
-
- object Kind extends Enumeration {
- val Data, Nat, Nat2Nat, Nat2Data, Address, Fragment, MatrixLayout, Function = Value
-
- def fromString(s: String): Value = s match {
- case "data" => Data
- case "nat" => Nat
- case "nat2nat" => Nat2Nat
- case "nat2data" => Nat2Data
- case "address" => Address
- case "fragment" => Fragment
- case "matrixLayout" => MatrixLayout
- }
- }
- }
-
- def PrimitiveDeclarations[_: P]: P[Seq[(String, Option[(Int, Int)], TypeAST)]] =
- P(Start ~ PrimitiveDeclaration.rep(1) ~ End)
-
- def PrimitiveDeclaration[_: P]: P[(String, Option[(Int, Int)], TypeAST)] = {
- def ScalaFunArgs: P[(Int, Int)] = {
- import scalaparse.Scala.TrailingCommaOps
- P("(" ~ Index ~
- (scalaparse.Scala.Id ~ scalaparse.syntax.Key.O(":") ~ scalaparse.Scala.Type).repTC(1) ~
- Index ~ ")")
- }
-
- P("def" ~ Identifier ~ ScalaFunArgs.? ~ ":" ~ TypeSignature)
- }
-
- def Identifier[_: P]: P[String] = {
- def Keywords: P[Unit] =
- P(("def" | (Kind: P[Unit]) | DataType.TypeName) ~~ CharPred(_.isWhitespace))
-
- val LowerChar = scalaparse.syntax.Identifiers.NamedFunction(CharPredicates.isLower)
- val IdCharacter = scalaparse.syntax.Identifiers.NamedFunction(c =>
- CharPredicates.isLetter(c) || CharPredicates.isDigit(c))
-
- P((!Keywords ~ CharPred(LowerChar).! ~~ CharsWhile(IdCharacter).!.?).
- map(t => t._1 ++ t._2.getOrElse("")))
- }
-
- def TypeSignature[_: P]: P[TypeAST] = {
- def DepFunType: P[TypeAST.DepFunType] =
- P("(" ~ IdentifierKindPair ~ ")" ~ "->" ~/ TypeSignature).map(TypeAST.DepFunType.tupled)
-
- def ImplicitDepFunType: P[TypeAST.ImplicitDepFunType] =
- P("{" ~ IdentifierKindPair ~ "}" ~ "->" ~/ TypeSignature).
- map(TypeAST.ImplicitDepFunType.tupled)
-
- def FunType: P[TypeAST.FunType] =
- P(NoCut(LeftTypeSignature) ~ "->" ~/ TypeSignature).map(TypeAST.FunType.tupled)
-
- // Types that can appear at the left of an function arrow
- def LeftTypeSignature: P[TypeAST] = P(DataType.DataType | ("(" ~ TypeSignature ~ ")"))
-
- P(DepFunType | ImplicitDepFunType | FunType | LeftTypeSignature)
- }
-
- def TypeIdentifier[_: P]: P[TypeAST.Identifier] = P(Identifier).map(TypeAST.Identifier)
-
- def IdentifierKindPair[_: P]: P[(TypeAST.Identifier, String)] = P(TypeIdentifier ~ ":" ~ Kind)
-
- def Kind[_: P]: P[String] =
- P("data".! | "address".! | "nat2nat".! | "nat2data".! | "nat".! |
- "fragment".! | "matrixLayout".!)
-
- object DataType {
- def ScalarType[_: P]: P[TypeAST.ScalarType] =
- P("bool".! | "int".! |
- "i8".! | "i16".! | "i32".! | "i64".! |
- "u8".! | "u16".! | "u32".! | "u64".! |
- "f16".! | "f32".! | "f64".!).map(TypeAST.ScalarType)
-
- def NatType[_: P]: P[TypeAST.NatType.type] = P("natType").map(_ => TypeAST.NatType)
-
- def IndexType[_: P]: P[TypeAST.IndexType] = P("idx[" ~ Nat ~ "]").map(TypeAST.IndexType)
-
- def VectorType[_: P]: P[TypeAST.VectorType] =
- P("vec[" ~ DataType ~ "," ~ Nat ~ "]").map(t => TypeAST.VectorType(t._2, t._1))
-
- def FragmentType[_: P]: P[TypeAST.FragmentType] = {
- def FragmentKind: P[FragmentAST] =
- P(("fragment." ~~ (
- "ACC".!.map(_ => FragmentAST.ACC) |
- "A".!.map(_ => FragmentAST.A) |
- "B".!.map(_ => FragmentAST.B))
- ) | TypeIdentifier.map(FragmentAST.Identifier))
-
- def MatrixLayout: P[MatrixLayoutAST] =
- P(("matrixLayout." ~~ (
- "ROW_MAJOR".!.map(_ => MatrixLayoutAST.ROW_MAJOR) |
- "COL_MAJOR".!.map(_ => MatrixLayoutAST.COL_MAJOR))
- ) | TypeIdentifier.map(MatrixLayoutAST.Identifier))
-
- P("fragment["~ Nat ~","~ Nat ~","~ Nat ~","~ DataType ~","~ FragmentKind ~
- ","~ MatrixLayout ~"]").map(TypeAST.FragmentType.tupled)
- }
-
- def DepArrayType[_: P]: P[TypeAST.DepArrayType] =
- P(Nat ~ ".." ~/ NatToData).map(TypeAST.DepArrayType.tupled)
-
- def ArrayType[_: P]: P[TypeAST.ArrayType] =
- P(Nat ~ "." ~~ !"." ~/ DataType).map(TypeAST.ArrayType.tupled)
-
- def DepPairType[_: P]: P[TypeAST.DepPairType] =
- P("(" ~ IdentifierKindPair ~ "**" ~/ DataType ~ ")").map(TypeAST.DepPairType.tupled)
-
- def NatToDataApply[_: P]: P[TypeAST.NatToDataApply] =
- P(NatToData ~ "(" ~ Nat ~ ")").map(TypeAST.NatToDataApply.tupled)
-
- def PairType[_: P]: P[TypeAST.PairType] =
- P("(" ~ NoCut(DataType) ~ "," ~/ DataType ~ ")").map(TypeAST.PairType.tupled)
-
- def DataType[_: P]: P[TypeAST] =
- P(ScalarType | NatType | IndexType | VectorType | FragmentType | DepArrayType |
- ArrayType | DepPairType | NatToDataApply | PairType | TypeIdentifier |
- ("(" ~ DataType ~ ")"))
-
- def TypeName[_: P]: P[Unit] =
- P(ScalarType | NatType | "idx" | "vec" | "fragment" | "matrixLayout")
- }
-
- def NatToData[_: P]: P[TypeAST] = {
- def NatToDataLambda: P[TypeAST.NatToDataLambda] =
- P("(" ~ IdentifierKindPair.filter(_._2 == "nat").map(_._1) ~
- "|->" ~/ DataType.DataType ~ ")").map(TypeAST.NatToDataLambda.tupled)
-
- P(TypeIdentifier | NatToDataLambda)
- }
-
- object isWellKindedType {
- import TypeAST.Kind._
-
- def apply(typeAST: TypeAST): Boolean = {
- kindOf(typeAST, Map.empty).isDefined
- }
-
- def kindOf(typeAST: TypeAST,
- env: Map[TypeAST.Identifier, TypeAST.Kind.Value]): Option[TypeAST.Kind.Value] = {
- typeAST match {
- case id: TypeAST.Identifier =>
- env.get(id)
- case TypeAST.FunType(inT, outT) =>
- for {
- _ <- kindOf(inT, env)
- _ <- kindOf(outT, env)
- } yield Function
- case TypeAST.DepFunType(id, kind, t) =>
- if (env.isDefinedAt(id)) {
- None // we forbid shadowing
- } else {
- kindOf(t, env.updated(id, TypeAST.Kind.fromString(kind)))
- }
- case TypeAST.ImplicitDepFunType(id, kind, t) =>
- if (env.isDefinedAt(id)) {
- None // we forbid shadowing
- } else {
- kindOf(t, env.updated(id, TypeAST.Kind.fromString(kind)))
- }
- case TypeAST.VectorType(size, elemType) =>
- for {
- k1 <- kindOf(size, env)
- k2 <- kindOf(elemType, env)
- if k1 == Nat && k2 == Data
- } yield Data
- case TypeAST.IndexType(size) =>
- for {
- k <- kindOf(size, env)
- if k == Nat
- } yield Data
- case TypeAST.PairType(lhs, rhs) =>
- for {
- k1 <- kindOf(lhs, env)
- k2 <- kindOf(rhs, env)
- if k1 == Data && k2 == Data
- } yield Data
- case TypeAST.DepPairType(id, kind, t) =>
- if (env.isDefinedAt(id)) {
- None // we forbid shadowing
- } else {
- kindOf(t, env.updated(id, TypeAST.Kind.fromString(kind)))
- }
- case TypeAST.NatToDataApply(f, n) =>
- for {
- k1 <- kindOf(f, env)
- k2 <- kindOf(n, env)
- if k1 == Nat2Data && k2 == Nat
- } yield Data
- case TypeAST.NatToDataLambda(id, t) =>
- if (env.isDefinedAt(id)) {
- None // we forbid shadowing
- } else {
- for {
- k <- kindOf(t, env.updated(id, Nat))
- if k == Data
- } yield Nat2Data
- }
- case TypeAST.ArrayType(size, elemType) =>
- for {
- k1 <- kindOf(size, env)
- k2 <- kindOf(elemType, env)
- if k1 == Nat && k2 == Data
- } yield Data
- case TypeAST.DepArrayType(size, fdt) =>
- for {
- k1 <- kindOf(size, env)
- k2 <- kindOf(fdt, env)
- if k1 == Nat && k2 == Nat2Data
- } yield Data
- case TypeAST.FragmentType(n, m, k, elemType, fKind, mLayout) =>
- for {
- k1 <- kindOf(n, env)
- k2 <- kindOf(m, env)
- k3 <- kindOf(k, env)
- k4 <- kindOf(elemType, env)
- k5 <- kindOf(fKind, env)
- k6 <- kindOf(mLayout, env)
- if k1 == Nat && k2 == Nat && k3 == Nat && k4 == Data &&
- k5 == Fragment && k6 == TypeAST.Kind.MatrixLayout
- } yield Data
- case _: TypeAST.ScalarType | TypeAST.NatType =>
- Some(Data)
- }
- }
-
- def kindOf(natAST: NatAST,
- env: Map[TypeAST.Identifier, TypeAST.Kind.Value]): Option[TypeAST.Kind.Value] = {
- natAST match {
- case NatAST.Identifier(id) =>
- env.get(id)
- case NatAST.Number(_) =>
- Some(Nat)
- case NatAST.BinaryOp(lhs, _, rhs) =>
- for {
- k1 <- kindOf(lhs, env)
- k2 <- kindOf(rhs, env)
- if k1 == Nat && k2 == Nat
- } yield Nat
-
- case NatAST.TernaryOp(_, thenN, elseN) =>
- for {
- k1 <- kindOf(thenN, env)
- k2 <- kindOf(elseN, env)
- if k1 == Nat && k2 == Nat
- } yield Nat
-
- case NatAST.Nat2NatApply(f, n) =>
- for {
- k1 <- kindOf(f, env)
- k2 <- kindOf(n, env)
- if k1 == Nat2Nat && k2 == Nat
- } yield Nat
-
- case NatAST.Sum(id, from, upTo, body) =>
- val nEnv = env.updated(id.id, Nat)
- for {
- k1 <- kindOf(from, nEnv)
- k2 <- kindOf(upTo, nEnv)
- k3 <- kindOf(body, nEnv)
- if k1 == Nat && k2 == Nat && k3 == Nat
- } yield Nat
- }
- }
-
- def kindOf(fragmentAST: TypeAST.FragmentAST,
- env: Map[TypeAST.Identifier, TypeAST.Kind.Value]): Option[TypeAST.Kind.Value] = {
- fragmentAST match {
- case FragmentAST.Identifier(id) =>
- env.get(id)
- case FragmentAST.ACC | FragmentAST.A | FragmentAST.B => Some(Fragment)
- }
- }
-
- def kindOf(matrixLayout: TypeAST.MatrixLayoutAST,
- env: Map[TypeAST.Identifier, TypeAST.Kind.Value]): Option[TypeAST.Kind.Value] = {
- matrixLayout match {
- case MatrixLayoutAST.Identifier(id) =>
- env.get(id)
- case MatrixLayoutAST.ROW_MAJOR |
- MatrixLayoutAST.COL_MAJOR => Some(TypeAST.Kind.MatrixLayout)
- }
- }
- }
-}
diff --git a/meta/src/main/scala/meta/generator/DPIAPrimitives.scala b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala
new file mode 100644
index 000000000..7e6f14729
--- /dev/null
+++ b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala
@@ -0,0 +1,288 @@
+package meta.generator
+
+import fastparse.{Parsed, parse}
+import meta.parser._
+
+object DPIAPrimitives {
+ def main(args: Array[String]): Unit = {
+ val sourceDir = args.head
+ val shinePath = os.Path(sourceDir) / "shine"
+ os.walk.stream(shinePath).filter(_.ext == "dpia").foreach(path => {
+
+ import DPIA.Decl.AST._
+
+ val definition = os.read(path)
+ parse(definition, DPIA.Decl.PrimitiveDeclarations(_)) match {
+ case failure: Parsed.Failure =>
+ println(s"Failed to parse `${failure.extra.input}'")
+ println(s" $failure")
+ case Parsed.Success(seq, _) =>
+ seq.foreach {
+ case PrimitiveDeclaration(Identifier(originalName), scalaParams, params, returnType)
+ if DPIA.isWellKindedDefinition(params, returnType) =>
+ val name = originalName.capitalize
+
+ val outputPath = (path / os.up) / s"$name.scala"
+ println(s"Generate $outputPath")
+
+ import scala.meta._
+ val packageName = path.relativeTo(shinePath).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("shine")) {
+ case (t, name) => Term.Select(t, Term.Name(name))
+ }
+ val scalaParamsString = scalaParams match {
+ case Some((start, end)) => definition.substring(start, end)
+ case None => ""
+ }
+ val code = s"""// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+ |// This file is automatically generated and should not be changed manually //
+ |// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+ |${q"""
+package $packageName {
+
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
+import shine.DPIA.Types._
+import shine.DPIA._
+
+${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)}
+
+}""".toString()}
+ |""".stripMargin
+
+ os.write.over(outputPath, code)
+ case PrimitiveDeclaration(Identifier(name), _, params, returnType) =>
+ println(s"Could not generate code for `$name' as parameters `$params' and/or `$returnType' are not well kinded.")
+ }
+ }
+ })
+ }
+
+ def generateCaseClass(name: scala.meta.Type.Name,
+ scalaParamsString: String,
+ params: Seq[DPIA.Decl.AST.Param],
+ returnType: DPIA.Type.AST): scala.meta.Defn.Class = {
+ import scala.meta._
+ import meta.parser.DPIA.Type.AST
+ val (scalaReturnType, superClass) = returnType match {
+ case AST.ExpType(_, _) => (t"ExpType", init"ExpPrimitive")
+ case AST.AccType(_) => (t"AccType", init"AccPrimitive")
+ case AST.CommType => (t"CommType", init"CommandPrimitive")
+ case _ => throw new Exception(s"Expected `exp', `acc' or `comm' as return type for ${name.value}")
+ }
+ val generatedParams = generateParams(scalaParamsString, params)
+ q"""
+ final case class $name(...$generatedParams) extends $superClass {
+ {
+ ..${generateTypeChecks(params).stats}
+ }
+
+ ..${if (scalaReturnType != t"CommType") {
+ List(q"override val t: $scalaReturnType = ${generateTerm(returnType)}")
+ } else List() }
+
+ ${generateVisitAndRebuild(name, generatedParams)}
+
+ ..${if (scalaParamsString.nonEmpty && generatedParams.last.size > 1) {
+ List(generateUnwrap(generatedParams.last))
+ } else List() }
+ }
+ """
+ }
+
+ def generateParams(scalaParamsString: String,
+ params: Seq[DPIA.Decl.AST.Param]): List[List[scala.meta.Term.Param]] = {
+ import scala.meta._
+
+ val scalaParams = if (scalaParamsString.nonEmpty) {
+ s"def foo($scalaParamsString)".parse[Stat].get match {
+ case declDef: Decl.Def => declDef.paramss
+ }
+ } else {
+ List()
+ }
+
+ scalaParams ++ List(params.map(generateParam).toList)
+ }
+
+ def generateParam(param: DPIA.Decl.AST.Param): scala.meta.Term.Param = {
+ import scala.meta._
+ import _root_.meta.parser.DPIA.Kind
+ param"val ${Term.Name(param.id.name)}: ${
+ param.ty match {
+ case Left(kindAST) => generateType(kindAST)
+ case Right(typeAST) => t"Phrase[${generatePhraseType(typeAST)}]"
+ }
+ }"
+ }
+
+ def generateTypeChecks(params: Seq[DPIA.Decl.AST.Param]): scala.meta.Term.Block = {
+ import scala.meta._
+ q"""{
+ ..${params.
+ filter(param => param.ty.isRight). // only check types for parameters with phrase types
+ map(param =>
+ q"${Term.Name(param.id.name)} :: ${
+ param.ty match {
+ case Right(typeAST@DPIA.Type.AST.DepFunType(id, kind, t)) =>
+ q"""{
+ ${Defn.Val(
+ mods = Nil,
+ pats = List(Pat.Var(name = Term.Name(id.name))),
+ decltpe = None,
+ rhs = q"${Term.Name(param.id.name)}.t.x"
+ )}
+ ${generateTerm(typeAST)}
+ }"""
+ case Right(typeAST) => generateTerm(typeAST)
+ case Left(kindAST) => throw new Exception("This should not happen")
+ }}"
+ ).toList}
+ }"""
+ }
+
+ def generatePhraseType(typeAST: DPIA.Type.AST): scala.meta.Type = {
+ import scala.meta._
+ import meta.parser.DPIA.Type.AST
+ typeAST match {
+ case AST.ExpType(_, _) => t"ExpType"
+ case AST.AccType(_) => t"AccType"
+ case AST.CommType => t"CommType"
+ case AST.PairType(lhs, rhs) => t"PhrasePairType[${generatePhraseType(lhs)}, ${generatePhraseType(rhs)}]"
+ case AST.FunType(inT, outT) => t"FunType[${generatePhraseType(inT)}, ${generatePhraseType(outT)}]"
+ case AST.DepFunType(id, kind, t) => t"DepFunType[${generateKindType(kind)}, ${generatePhraseType(t)}]"
+ case AST.Identifier(name) => Type.Name(name)
+ }
+ }
+
+ def generateType(kindAST: DPIA.Kind.AST): scala.meta.Type = {
+ import scala.meta._
+ import meta.parser.DPIA.Kind.AST
+ kindAST match {
+ case AST.RiseKind(riseKind) =>
+ import meta.parser.rise.Kind.AST
+ riseKind match {
+ case AST.Data => Type.Name("DataType")
+ case AST.Address => Type.Name("AddressSpace")
+ case AST.Nat2Nat => Type.Name("NatToNat")
+ case AST.Nat2Data => Type.Name("NatToData")
+ case AST.Nat => Type.Name("Nat")
+ case AST.Fragment => Type.Name("FragmentKind")
+ case AST.MatrixLayout => Type.Name("MatrixLayout")
+ }
+ case AST.Access => Type.Name("AccessType")
+ }
+ }
+
+ def generateKindType(kindAST: DPIA.Kind.AST): scala.meta.Type = {
+ import scala.meta._
+ import meta.parser.DPIA.Kind.AST
+ kindAST match {
+ case AST.RiseKind(riseKind) =>
+ import meta.parser.rise.Kind.AST
+ riseKind match {
+ case AST.Data => Type.Name("DataKind")
+ case AST.Address => Type.Name("AddressSpaceKind")
+ case AST.Nat2Nat => Type.Name("NatToNatKind")
+ case AST.Nat2Data => Type.Name("NatToDataKind")
+ case AST.Nat => Type.Name("NatKind")
+ case AST.Fragment => ???
+ case AST.MatrixLayout => ???
+ }
+ case AST.Access => Type.Name("AccessKind")
+ }
+ }
+
+ def generateTerm(typeAST: DPIA.Type.AST): scala.meta.Term = {
+ import scala.meta._
+ import meta.parser.DPIA.Type.AST
+ typeAST match {
+ case AST.ExpType(dataType, access) =>
+ q"expT(${RisePrimitives.generateDataType(dataType)}, ${generateTerm(access)})"
+ case AST.AccType(dataType) =>
+ q"accT(${RisePrimitives.generateDataType(dataType)})"
+ case AST.CommType =>
+ q"comm"
+ case AST.PairType(lhs, rhs) =>
+ q"PhrasePairType(${generateTerm(lhs)}, ${generateTerm(rhs)})"
+ case AST.FunType(inT, outT) =>
+ q"FunType(${generateTerm(inT)}, ${generateTerm(outT)})"
+ case AST.DepFunType(id, kind, t) =>
+ q"DepFunType[${generateKindType(kind)}, PhraseType](${Term.Name(id.name)}, ${generateTerm(t)})"
+ case AST.Identifier(name) => Term.Name(name)
+ }
+ }
+
+ def generateTerm(accessAST: DPIA.Type.Access.AST): scala.meta.Term = {
+ import scala.meta._
+ import meta.parser.DPIA.Type.Access.AST
+ accessAST match {
+ case AST.Identifier(name) => Term.Name(name)
+ case AST.Read => Term.Name("read")
+ case AST.Write =>Term.Name("write")
+ }
+ }
+
+ def generateVisitAndRebuild(name: scala.meta.Type.Name,
+ paramLists: List[List[scala.meta.Term.Param]]): scala.meta.Defn.Def = {
+ import scala.meta._
+
+ object TypeIs {
+ def unapply(ty: Type): Option[String] = ty match {
+ case Type.Name(name) => Some(name)
+ case Type.Select(_, Type.Name(name)) => Some(name)
+ case _ => None
+ }
+ }
+
+ def injectVisitCall(param: Term.Param): Term = {
+ param.decltpe match {
+ case Some(ty) => ty match {
+ case TypeIs("Nat") | TypeIs("NatIdentifier") =>
+ q"v.nat(${Term.Name(param.name.value)})"
+ case TypeIs("DataType") | TypeIs("ScalarType") | TypeIs("BasicType") =>
+ q"v.data(${Term.Name(param.name.value)})"
+ case TypeIs("NatToNat") =>
+ q"v.natToNat(${Term.Name(param.name.value)})"
+ case TypeIs("NatToData") =>
+ q"v.natToData(${Term.Name(param.name.value)})"
+ case TypeIs("AccessType") =>
+ q"v.access(${Term.Name(param.name.value)})"
+ case TypeIs("AddressSpace") =>
+ q"v.addressSpace(${Term.Name(param.name.value)})"
+ case TypeIs("LocalSize") | TypeIs("GlobalSize") =>
+ q"${Term.Name(param.name.value)}.visitAndRebuild(v)"
+ case t"Phrase[$_]" => q"VisitAndRebuild(${Term.Name(param.name.value)}, v)"
+ case t"Vector[Phrase[$_]]" => q"${Term.Name(param.name.value)}.map(VisitAndRebuild(_, v))"
+ case t"Seq[Phrase[$_]]" => q"${Term.Name(param.name.value)}.map(VisitAndRebuild(_, v))"
+
+ case t"Map[Identifier[_ <: PhraseType], $_]" =>
+ q"""${Term.Name(param.name.value)}.map{ case (key, value) =>
+ VisitAndRebuild(key, v).asInstanceOf[Identifier[_ <: PhraseType]] -> value
+ }"""
+
+ case Type.Apply(Type.Name("Vector"), List(TypeIs("DataType"))) // Vector[DataType]
+ | Type.Apply(Type.Name("Seq"), List(TypeIs("DataType"))) => // Seq[DataType]
+ q"${Term.Name(param.name.value)}.map(v.data)"
+ case _ =>
+ Term.Name(param.name.value)
+ }
+ case None => throw new Exception(s"Expected type declaration")
+ }
+ }
+
+ q"""override def visitAndRebuild(v: VisitAndRebuild.Visitor): $name =
+ new $name(...${paramLists.map(_.map(injectVisitCall))})
+ """
+ }
+
+ def generateUnwrap(paramList: List[scala.meta.Term.Param]): scala.meta.Defn.Def = {
+ import scala.meta._
+ val (types, names) = paramList.map({
+ case Term.Param(_, name, Some(typ), _) => (typ, Term.Name(name.value))
+ }).unzip
+ q"""
+ def unwrap: (..$types) = (..$names)
+ """
+ }
+}
diff --git a/meta/src/main/scala/meta/RisePrimitiveGenerator.scala b/meta/src/main/scala/meta/generator/RisePrimitives.scala
similarity index 50%
rename from meta/src/main/scala/meta/RisePrimitiveGenerator.scala
rename to meta/src/main/scala/meta/generator/RisePrimitives.scala
index b61efb55c..f48f7e0a0 100644
--- a/meta/src/main/scala/meta/RisePrimitiveGenerator.scala
+++ b/meta/src/main/scala/meta/generator/RisePrimitives.scala
@@ -1,28 +1,28 @@
-package meta
+package meta.generator
import fastparse.{Parsed, parse}
-import meta.NatParser.NatAST
-import meta.TypeParser.TypeAST
-import meta.TypeParser.TypeAST.{FragmentAST, MatrixLayoutAST}
+import meta.parser._
+import meta.parser.rise.Kind
-object RisePrimitiveGenerator {
+object RisePrimitives {
def main(args: Array[String]): Unit = {
val sourceDir = args.head
- val rise = os.Path(sourceDir) / "rise"
- os.walk.stream(rise).filter(_.ext == "rise").foreach(path => {
+ val risePath = os.Path(sourceDir) / "rise"
+ os.walk.stream(risePath).filter(_.ext == "rise").foreach(path => {
val definition = os.read(path)
- parse(definition, TypeParser.PrimitiveDeclarations(_)) match {
+ parse(definition, rise.Decl.PrimitiveDeclarations(_)) match {
case failure: Parsed.Failure =>
println(s"Failed to parse `${failure.extra.input}'")
println(s" $failure")
case Parsed.Success(seq, _) =>
seq.foreach {
- case (name, args, typeSignature) if TypeParser.isWellKindedType(typeSignature) =>
+ case rise.Decl.AST.PrimitiveDeclaration(rise.Decl.AST.Identifier(name), scalaParams, typeSignature)
+ if rise.isWellKindedType(typeSignature) =>
val outputPath = (path / os.up) / s"$name.scala"
println(s"Generate $outputPath")
- val generatedDef = args match {
+ val generatedDef = scalaParams match {
case None =>
generateObject(name, typeSignature)
case Some((start, end)) =>
@@ -30,7 +30,7 @@ object RisePrimitiveGenerator {
}
import scala.meta._
- val packageName = path.relativeTo(rise).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("rise")) {
+ val packageName = path.relativeTo(risePath).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("rise")) {
case (t, name) => Term.Select(t, Term.Name(name))
}
val code = s"""// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
@@ -51,14 +51,14 @@ import arithexpr.arithmetic._
|""".stripMargin
os.write.over(outputPath, code)
- case (name, _, typeSignature) =>
+ case rise.Decl.AST.PrimitiveDeclaration(name, _, typeSignature) =>
println(s"Could not generate code for `$name' as type signature `$typeSignature' is not well kinded.")
}
}
})
}
- def generateObject(name: String, typeSignature: TypeAST): scala.meta.Term.Block = {
+ def generateObject(name: String, typeSignature: rise.Type.AST): scala.meta.Term.Block = {
import scala.meta._
val generated = q"""{
object ${Term.Name{name}} extends Builder {
@@ -68,7 +68,7 @@ import arithexpr.arithmetic._
override val name: String = ${Lit.String(name)}
override def setType(ty: Type): Primitive = Primitive()(ty)
override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass
- override def typeScheme: Type = ${generateTypeScheme(typeSignature, Map.empty)}
+ override def typeScheme: Type = ${generateTypeScheme(typeSignature)}
}
override def toString: String = ${Lit.String(name)}
@@ -83,13 +83,12 @@ import arithexpr.arithmetic._
generated
}
- def generateCaseClass(name: String, paramsString: String, typeSignature: TypeAST): scala.meta.Term.Block = {
+ def generateCaseClass(name: String, paramsString: String, typeSignature: rise.Type.AST): scala.meta.Term.Block = {
import scala.meta._
- val params = paramsString.split(",").map(param => {
- val parts = param.split(":").map(_.trim)
- param"${Term.Name(parts(0))}: ${Type.Name(parts(1))}"
- } ).toList
+ val params = s"def foo($paramsString)".parse[Stat].get match {
+ case declDef: Decl.Def => declDef.paramss.head
+ }
val args: List[Term.Name] = params.map(p => Term.Name(p.name.value))
val types: List[Type] = params.map(p => p.decltpe.get)
@@ -112,7 +111,7 @@ import arithexpr.arithmetic._
{
override val name: String = ${Lit.String(name)}
override def setType(ty: Type): Primitive = Primitive(..$args)(ty)
- override def typeScheme: Type = ${generateTypeScheme(typeSignature, Map.empty)}
+ override def typeScheme: Type = ${generateTypeScheme(typeSignature)}
override def primEq(obj: rise.core.Primitive): Boolean = obj match {
case p: Primitive => ${generateComparisonChain(args)}
@@ -130,115 +129,129 @@ import arithexpr.arithmetic._
generated
}
- def generateTypeScheme(typeAST: TypeAST, env: Map[TypeAST.Identifier, String]): scala.meta.Term = {
+ def generateTypeScheme(typeAST: rise.Type.AST): scala.meta.Term = {
import scala.meta._
typeAST match {
- case id@TypeAST.Identifier(name) =>
- assert(env.contains(id), s"$id is not in $env")
+ case rise.Type.AST.FunType(inT, outT) =>
+ q"(${generateTypeScheme(inT)}) ->: (${generateTypeScheme(outT)})"
+ case rise.Type.AST.DepFunType(id, kind, t) =>
+ q"expl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})"
+ case rise.Type.AST.ImplicitDepFunType(id, kind, t) =>
+ q"impl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})"
+ case _ => generateDataType(typeAST)
+ }
+ }
+
+ def generateDataType(typeAST: rise.Type.AST): scala.meta.Term = {
+ import scala.meta._
+ typeAST match {
+ case rise.Type.AST.Identifier(name) =>
Term.Name(name)
- case TypeAST.FunType(inT, outT) =>
- q"(${generateTypeScheme(inT, env)}) ->: (${generateTypeScheme(outT, env)})"
- case TypeAST.DepFunType(id, kind, t) =>
- q"expl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t, env.updated(id, kind))})"
- case TypeAST.ImplicitDepFunType(id, kind, t) =>
- q"impl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t, env.updated(id, kind))})"
- case TypeAST.ScalarType(t) =>
+ case rise.Type.AST.ScalarType(t) =>
t.parse[Term].get
- case TypeAST.NatType =>
- q"rise.core.types.NatType"
- case TypeAST.VectorType(size, elemType) =>
- q"rise.core.types.VectorType(${generateNat(size, env)}, ${generateTypeScheme(elemType, env)})"
- case TypeAST.IndexType(size) =>
- q"rise.core.types.IndexType(${generateNat(size, env)})"
- case TypeAST.PairType(lhs, rhs) =>
- q"rise.core.types.PairType(${generateTypeScheme(lhs, env)}, ${generateTypeScheme(rhs, env)})"
- case TypeAST.DepPairType(id, kind, t) => kind match {
- case "nat" =>
- q"Nat `**` ((${Term.Name(id.name)}: Nat) => ${generateTypeScheme(t, env.updated(id, kind))})"
+ case rise.Type.AST.NatType =>
+ q"NatType"
+ case rise.Type.AST.OpaqueType(name) =>
+ q"OpaqueType($name)"
+ case rise.Type.AST.VectorType(size, elemType) =>
+ q"VectorType(${generateNat(size)}, ${generateDataType(elemType)})"
+ case rise.Type.AST.IndexType(size) =>
+ q"IndexType(${generateNat(size)})"
+ case rise.Type.AST.PairType(lhs, rhs) =>
+ q"PairType(${generateDataType(lhs)}, ${generateDataType(rhs)})"
+ case rise.Type.AST.DepPairType(id, kind, t) => kind match {
+ case Kind.AST.Nat =>
+ q"Nat `**` ((${Term.Name(id.name)}: Nat) => ${generateDataType(t)})"
case _ => ???
}
- case TypeAST.NatToDataApply(f, n) =>
- q"rise.core.types.NatToDataApply(${generateTypeScheme(f, env)}, ${generateNat(n, env)})"
- case TypeAST.NatToDataLambda(id, t) =>
- q"n2dtFun((${Term.Name(id.name)}: NatIdentifier) => ${generateTypeScheme(t, env.updated(id, "nat"))})"
- case TypeAST.ArrayType(size, elemType) =>
- q"rise.core.types.ArrayType(${generateNat(size, env)}, ${generateTypeScheme(elemType, env)})"
- case TypeAST.DepArrayType(size, fdt) =>
- q"rise.core.types.DepArrayType(${generateNat(size, env)}, ${generateTypeScheme(fdt, env)})"
- case TypeAST.FragmentType(n, m, k, elemType, fKind, mLayout) =>
- q"rise.core.types.FragmentType(${generateNat(n, env)}, ${generateNat(m, env)}, ${generateNat(k, env)}, ${generateTypeScheme(elemType, env)}, ${generateFragment(fKind, env)}, ${generateMatrixLayout(mLayout, env)})"
+ case rise.Type.AST.NatToDataApply(f, n) =>
+ q"NatToDataApply(${generateDataType(f)}, ${generateNat(n)})"
+ case rise.Type.AST.NatToDataLambda(id, t) =>
+ q"n2dtFun((${Term.Name(id.name)}: NatIdentifier) => ${generateDataType(t)})"
+ case rise.Type.AST.ArrayType(size, elemType) =>
+ q"ArrayType(${generateNat(size)}, ${generateDataType(elemType)})"
+ case rise.Type.AST.DepArrayType(size, fdt) =>
+ q"DepArrayType(${generateNat(size)}, ${generateDataType(fdt)})"
+ case rise.Type.AST.FragmentType(n, m, k, elemType, fKind, mLayout) =>
+ q"FragmentType(${generateNat(n)}, ${generateNat(m)}, ${generateNat(k)}, ${generateDataType(elemType)}, ${generateFragment(fKind)}, ${generateMatrixLayout(mLayout)})"
+ case rise.Type.AST.ManagedBufferType(dt) =>
+ q"ManagedBufferType(${generateDataType(dt)})"
+ case rise.Type.AST.FunType(_, _) | rise.Type.AST.DepFunType(_, _, _) |
+ rise.Type.AST.ImplicitDepFunType(_, _, _) => ???
}
}
- def kindName(kind: String): String = kind match {
- case "nat" => "Nat"
- case "data" => "DataType"
- case "nat2nat" => "NatToNat"
- case "nat2data" => "NatToData"
- case "address" => "AddressSpace"
- case "fragment" => "FragmentKind"
- case "matrixLayout" => "MatrixLayout"
+ def kindName(kind: Kind.AST): String = {
+ import meta.parser.rise.Kind.AST
+ kind match {
+ case AST.Data => "DataType"
+ case AST.Address => "AddressSpace"
+ case AST.Nat2Nat => "NatToNat"
+ case AST.Nat2Data => "NatToData"
+ case AST.Nat => "Nat"
+ case AST.Fragment => "FragmentKind"
+ case AST.MatrixLayout => "MatrixLayout"
+ }
}
- def generateNat(n: NatAST, env: Map[TypeAST.Identifier, String]): scala.meta.Term = {
+ def generateNat(n: Nat.AST): scala.meta.Term = {
import scala.meta._
n match {
- case NatAST.Identifier(id) =>
- assert(env.contains(id), s"$id is not in $env")
- Term.Name(id.name)
- case NatAST.Number(n) =>
+ case Nat.AST.Identifier(id) =>
+ Term.Name(id)
+ case Nat.AST.Number(n) =>
n.parse[Term].get
- case NatAST.BinaryOp(lhs, "^", rhs) =>
- q"${generateNat(lhs, env)}.pow(${generateNat(rhs, env)})"
- case NatAST.BinaryOp(lhs, op, rhs) =>
- q"${generateNat(lhs, env)} ${Term.Name(op)} ${generateNat(rhs, env)}"
- case NatAST.TernaryOp(cond, thenN, elseN) =>
+ case Nat.AST.BinaryOp(lhs, "^", rhs) =>
+ q"${generateNat(lhs)}.pow(${generateNat(rhs)})"
+ case Nat.AST.BinaryOp(lhs, op, rhs) =>
+ q"${generateNat(lhs)} ${Term.Name(op)} ${generateNat(rhs)}"
+ case Nat.AST.TernaryOp(cond, thenN, elseN) =>
val operator: Term = cond.op match {
case "<" => q"Operator.<"
case ">" => q"Operator.>"
}
q"""
IfThenElse(
- arithPredicate(${generateNat(cond.lhs, env)}, ${generateNat(cond.rhs, env)}, $operator),
- ${generateNat(thenN, env)},
- ${generateNat(elseN, env)})
+ arithPredicate(${generateNat(cond.lhs)}, ${generateNat(cond.rhs)}, $operator),
+ ${generateNat(thenN)},
+ ${generateNat(elseN)})
"""
- case NatAST.Nat2NatApply(f, n) =>
- q"${generateTypeScheme(f, env)}(${generateNat(n, env)})"
- case NatAST.Sum(id, from, upTo, body) =>
+ case Nat.AST.Nat2NatApply(f, n) =>
+ q"${generateDataType(f)}(${generateNat(n)})"
+ case Nat.AST.Sum(id, from, upTo, body) =>
q"""BigSum(
- from = ${generateNat(from, env)},
- upTo = ${generateNat(upTo, env)},
- (${Term.Name(id.id.name)}: Nat) => ${generateNat(body, env.updated(id.id, "nat"))})
+ from = ${generateNat(from)},
+ upTo = ${generateNat(upTo)},
+ (${Term.Name(id.name)}: Nat) => ${generateNat(body)})
"""
}
}
- def generateFragment(fragmentAST: FragmentAST, env: Map[TypeAST.Identifier, String]): scala.meta.Term = {
+ def generateFragment(fragmentAST: rise.Type.Fragment.AST): scala.meta.Term = {
import scala.meta._
fragmentAST match {
- case FragmentAST.Identifier(id) =>
- assert(env.contains(id), s"$id is not in $env")
- Term.Name(id.name)
- case FragmentAST.ACC =>
- q"rise.core.types.FragmentKind.Acuumulator"
- case FragmentAST.A =>
- q"rise.core.types.FragmentKind.AMatrix"
- case FragmentAST.B =>
- q"rise.core.types.FragmentKind.BMatrix"
+ case rise.Type.Fragment.AST.Identifier(name) =>
+ Term.Name(name)
+ case rise.Type.Fragment.AST.ACC =>
+ q"FragmentKind.Accumulator"
+ case rise.Type.Fragment.AST.A =>
+ q"FragmentKind.AMatrix"
+ case rise.Type.Fragment.AST.B =>
+ q"FragmentKind.BMatrix"
}
}
- def generateMatrixLayout(matrixLayoutAST: MatrixLayoutAST, env: Map[TypeAST.Identifier, String]): scala.meta.Term = {
+ def generateMatrixLayout(matrixLayoutAST: rise.Type.MatrixLayout.AST): scala.meta.Term = {
import scala.meta._
matrixLayoutAST match {
- case MatrixLayoutAST.Identifier(id) =>
- assert(env.contains(id), s"$id is not in $env")
- Term.Name(id.name)
- case MatrixLayoutAST.ROW_MAJOR =>
- q"rise.core.types.MatrixLayout.Row_Major"
- case MatrixLayoutAST.COL_MAJOR =>
- q"rise.core.types.MatrixLayout.Col_Major"
+ case rise.Type.MatrixLayout.AST.Identifier(name) =>
+ Term.Name(name)
+ case rise.Type.MatrixLayout.AST.ROW_MAJOR =>
+ q"MatrixLayout.Row_Major"
+ case rise.Type.MatrixLayout.AST.COL_MAJOR =>
+ q"MatrixLayout.Col_Major"
+ case rise.Type.MatrixLayout.AST.NONE =>
+ q"MatrixLayout.None"
}
}
diff --git a/meta/src/main/scala/meta/parser/DPIA/Decl.scala b/meta/src/main/scala/meta/parser/DPIA/Decl.scala
new file mode 100644
index 000000000..505a508c9
--- /dev/null
+++ b/meta/src/main/scala/meta/parser/DPIA/Decl.scala
@@ -0,0 +1,41 @@
+package meta.parser.DPIA
+
+import fastparse.ScalaWhitespace._
+import fastparse._
+import meta.parser.shared.Identifier
+
+object Decl {
+ sealed trait AST
+ object AST {
+ case class Identifier(name: String) extends AST
+ case class Param(id: Identifier, ty: Either[Kind.AST, Type.AST]) extends AST
+ case class PrimitiveDeclaration(id: Identifier,
+ scalaParams: Option[(Int, Int)],
+ params: Seq[Param],
+ returnType: Type.AST) extends AST
+ }
+
+ def PrimitiveDeclarations[_: P]: P[Seq[AST.PrimitiveDeclaration]] =
+ P(Start ~ PrimitiveDeclaration.rep(1) ~ End)
+
+ // def drop(n: nat, m: nat, t: data, input: exp[n+m.t, read]): exp[m.t, read]
+ // def mapGlobal[dim: Int](n: nat, s: data, t: data, f: exp[s, read] -> exp[t, read], array: exp[n.s, read]): exp[n.t, read]
+ def PrimitiveDeclaration[_: P]: P[AST.PrimitiveDeclaration] = {
+ import scalaparse.Scala.TrailingCommaOps
+ def ScalaParams: P[(Int, Int)] = {
+ P("{" ~ Index ~
+ (scalaparse.Scala.Id ~ scalaparse.syntax.Key.O(":") ~ scalaparse.Scala.Type).repTC(1) ~
+ Index ~ "}")
+ }
+
+ def Param: P[AST.Param] = (
+ (Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind).map(pair => AST.Param(pair._1, Left(pair._2)))
+ | (Identifier.map(AST.Identifier) ~ ":" ~ Type.PhraseType).map(pair => AST.Param(pair._1, Right(pair._2)))
+ )
+
+ def Params: P[Seq[AST.Param]] = Param.repTC(0)
+
+ P("def" ~ Identifier.map(AST.Identifier) ~ ScalaParams.? ~ "(" ~ Params ~ ")" ~ ":" ~ Type.PhraseType)
+ .map(AST.PrimitiveDeclaration.tupled)
+ }
+}
diff --git a/meta/src/main/scala/meta/parser/DPIA/Kind.scala b/meta/src/main/scala/meta/parser/DPIA/Kind.scala
new file mode 100644
index 000000000..6e3c0544e
--- /dev/null
+++ b/meta/src/main/scala/meta/parser/DPIA/Kind.scala
@@ -0,0 +1,17 @@
+package meta.parser.DPIA
+
+import fastparse._
+import meta.parser.rise
+
+object Kind {
+ sealed trait AST
+ object AST {
+ case class RiseKind(riseKind: rise.Kind.AST) extends AST
+ case object Access extends AST
+ }
+
+ def Kind[_: P]: P[AST] = P(
+ rise.Kind.Kind.map(AST.RiseKind) |
+ "access".!.map(_ => AST.Access)
+ )
+}
diff --git a/meta/src/main/scala/meta/parser/DPIA/Type.scala b/meta/src/main/scala/meta/parser/DPIA/Type.scala
new file mode 100644
index 000000000..c998ed5d4
--- /dev/null
+++ b/meta/src/main/scala/meta/parser/DPIA/Type.scala
@@ -0,0 +1,65 @@
+package meta.parser.DPIA
+
+import fastparse.ScalaWhitespace._
+import fastparse._
+import meta.parser._
+import shared._
+
+object Type {
+ sealed trait AST
+ object AST {
+ case class ExpType(dataType: rise.Type.AST, access: Access.AST) extends AST
+ case class AccType(dataType: rise.Type.AST) extends AST
+ case object CommType extends AST
+
+ case class PairType(lhs: AST, rhs: AST) extends AST
+ case class FunType(inT: AST, outT: AST) extends AST
+ case class DepFunType(id: Identifier, kind: Kind.AST, t: AST) extends AST
+ case class Identifier(name: String) extends AST
+ }
+
+ object Access {
+ sealed trait AST
+ object AST {
+ case class Identifier(name: String) extends AST
+ case object Read extends AST
+ case object Write extends AST
+ }
+ }
+
+ def PhraseType[_: P]: P[AST] = {
+ def DataType: P[rise.Type.AST] = rise.Type.DataType.DataType
+
+ def AccessType: P[Access.AST] = P(
+ "read".!.map(_ => Access.AST.Read) |
+ "write".!.map(_ => Access.AST.Write) |
+ Identifier.map(Access.AST.Identifier)
+ )
+
+ def ExpType: P[AST.ExpType] = P("exp[" ~ DataType ~ "," ~ AccessType ~ "]").map(AST.ExpType.tupled)
+
+ def AccType: P[AST.AccType] = P("acc[" ~ DataType ~ "]").map(AST.AccType)
+
+ def VarType: P[AST.PairType] = P("var[" ~ DataType ~ "]").
+ map(dt => AST.PairType(AST.ExpType(dt, Access.AST.Read), AST.AccType(dt)))
+
+ def CommType: P[AST.CommType.type] = P("comm".!.map(_ => AST.CommType))
+
+ def PairType: P[AST.PairType] =
+ P("(" ~ NoCut(PhraseType) ~ "," ~/ PhraseType ~ ")").map(AST.PairType.tupled)
+
+ def FunType: P[AST.FunType] =
+ P(NoCut(NonFunPhraseType | ("(" ~ PhraseType ~ ")")) ~ "->" ~/ PhraseType).map(AST.FunType.tupled)
+
+ def DepFunType: P[AST.DepFunType] = {
+ def IdentifierKindPair: P[(AST.Identifier, Kind.AST)] =
+ P(Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind)
+
+ P("(" ~ IdentifierKindPair ~ ")" ~ "->" ~/ PhraseType).map(AST.DepFunType.tupled)
+ }
+
+ def NonFunPhraseType: P[AST] = P( ExpType | AccType | VarType | CommType | PairType | DepFunType )
+
+ P( FunType | NonFunPhraseType )
+ }
+}
diff --git a/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala b/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala
new file mode 100644
index 000000000..297978ea2
--- /dev/null
+++ b/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala
@@ -0,0 +1,63 @@
+package meta.parser.DPIA
+
+import meta.parser._
+
+object isWellKindedDefinition {
+
+ def apply(params: Seq[Decl.AST.Param], returnType: Type.AST): Boolean = {
+ import Decl.AST._
+ var isWellKindedFlag = true
+ val env = params.foldLeft(Map.empty[String, Kind.AST]) {
+ case (env, Param(Identifier(name), Left(kind))) =>
+ env.updated(name, kind)
+ case (env, Param(Identifier(_), Right(typeAST))) =>
+ if (!isWellKinded(typeAST, env)) isWellKindedFlag = false
+ env
+ }
+ isWellKindedFlag && isWellKinded(returnType, env)
+ }
+
+ def isWellKinded(typeAST: Type.AST, env: Map[String, Kind.AST]): Boolean = {
+ import Type._
+ import rise.isWellKindedType._
+ typeAST match {
+ case AST.ExpType(dataType, access) =>
+ val nenv = env.flatMap {
+ case (string, DPIA.Kind.AST.RiseKind(riseKind)) =>
+ Some((string, riseKind))
+ case _ => None
+ }
+ kindOf(dataType, nenv).isDefined && isWellKinded(access, env)
+ case AST.AccType(dataType) =>
+ val nenv = env.flatMap {
+ case (string, DPIA.Kind.AST.RiseKind(riseKind)) =>
+ Some((string, riseKind))
+ case _ => None
+ }
+ kindOf(dataType, nenv).isDefined
+ case AST.CommType => true
+ case AST.PairType(lhs, rhs) =>
+ isWellKinded(lhs, env) && isWellKinded(rhs, env)
+ case AST.FunType(inT, outT) =>
+ isWellKinded(inT, env) && isWellKinded(outT, env)
+ case AST.DepFunType(id, kind, t) =>
+ if (env.isDefinedAt(id.name)) {
+ // we forbid shadowing
+ false
+ } else {
+ isWellKinded(t, env.updated(id.name, kind))
+ }
+ case AST.Identifier(name) =>
+ env.contains(name)
+ }
+ }
+
+ def isWellKinded(accessAST: Type.Access.AST, env: Map[String, Kind.AST]): Boolean = {
+ import Type.Access._
+ accessAST match {
+ case AST.Identifier(name) => env.isDefinedAt(name)
+ case AST.Read => true
+ case AST.Write => true
+ }
+ }
+}
diff --git a/meta/src/main/scala/meta/parser/Nat.scala b/meta/src/main/scala/meta/parser/Nat.scala
new file mode 100644
index 000000000..051241818
--- /dev/null
+++ b/meta/src/main/scala/meta/parser/Nat.scala
@@ -0,0 +1,62 @@
+package meta.parser
+
+import fastparse.ScalaWhitespace._
+import fastparse._
+
+object Nat {
+
+ sealed trait AST
+ object AST {
+ case class Identifier(name: String) extends AST
+ case class Number(n: String) extends AST
+ case class BinaryOp(lhs: AST, op: String, rhs: AST) extends AST
+ case class TernaryOp(cond: BinaryOp, thenN: AST, elseN: AST) extends AST
+ case class Nat2NatApply(f: rise.Type.AST.Identifier, n: AST) extends AST
+ case class Sum(id: Identifier, from: AST, upTo: AST, body: AST) extends AST
+ }
+
+ def Nat[_: P]: P[AST] = {
+ def CompOrNat: P[AST] = {
+ def CompOp: P[String] = P("<".! | ">".!)
+ P(AddSubOrNat ~ (CompOp ~/ AddSubOrNat).rep).map(asBinaryOpOrNat)
+ }
+
+ def AddSubOrNat: P[AST] = {
+ def AddSubOps: P[String] = P("+".! | "-".!)
+ P(DivMulPowModOrNat ~ (AddSubOps ~ DivMulPowModOrNat).rep).map(asBinaryOpOrNat)
+ }
+
+ def DivMulPowModOrNat: P[AST] = {
+ def DivMulPowModOp: P[String] = P("*".! | "/".! | "^".! | "%".!)
+ P(SingleNat ~ (DivMulPowModOp ~ SingleNat).rep).map(asBinaryOpOrNat)
+ }
+
+ def SingleNat: P[AST] = {
+ def Number: P[AST.Number] = P(CharIn("0-9").rep(1).!).map(AST.Number)
+
+ def Sum: P[AST.Sum] = {
+ def Assignment: P[(AST.Identifier, AST)] =
+ P(NatIdentifier ~ "=" ~ Nat | "(" ~ Assignment ~ ")")
+ P("sum" ~ "_" ~ Assignment ~ "^" ~ Nat ~ Nat).map(AST.Sum.tupled)
+ }
+
+ def Nat2NatApply: P[AST.Nat2NatApply] =
+ P(rise.Type.TypeIdentifier ~ "(" ~ Nat ~ ")").map(AST.Nat2NatApply.tupled)
+
+ def NatIdentifier: P[AST.Identifier] = P(rise.Type.TypeIdentifier).map(i => AST.Identifier(i.name))
+
+ def Parens: P[AST] = P("(" ~ Nat ~ ")")
+
+ P(Number | Sum | Nat2NatApply | NatIdentifier | Parens)
+ }
+
+ P(CompOrNat)
+ }
+
+ private def asBinaryOpOrNat: ((AST, Seq[(String, AST)])) => AST = {
+ case (n, ns) => ns.foldLeft(n){
+ case (lhs, (op, rhs)) => AST.BinaryOp(lhs, op, rhs)
+ }
+ }
+
+}
diff --git a/meta/src/main/scala/meta/parser/rise/Decl.scala b/meta/src/main/scala/meta/parser/rise/Decl.scala
new file mode 100644
index 000000000..d4a191565
--- /dev/null
+++ b/meta/src/main/scala/meta/parser/rise/Decl.scala
@@ -0,0 +1,30 @@
+package meta.parser.rise
+
+import fastparse.ScalaWhitespace._
+import fastparse._
+import meta.parser.shared.Identifier
+
+object Decl {
+ sealed trait AST
+ object AST {
+ case class Identifier(name: String) extends AST
+ case class PrimitiveDeclaration(id: Identifier,
+ scalaParams: Option[(Int, Int)],
+ typeSignature: Type.AST) extends AST
+ }
+
+ def PrimitiveDeclarations[_: P]: P[Seq[AST.PrimitiveDeclaration]] =
+ P(Start ~ PrimitiveDeclaration.rep(1) ~ End)
+
+ def PrimitiveDeclaration[_: P]: P[AST.PrimitiveDeclaration] = {
+ def ScalaParams: P[(Int, Int)] = {
+ import scalaparse.Scala.TrailingCommaOps
+ P("(" ~ Index ~
+ (scalaparse.Scala.Id ~ scalaparse.syntax.Key.O(":") ~ scalaparse.Scala.Type).repTC(1) ~
+ Index ~ ")")
+ }
+
+ P("def" ~ Identifier.map(AST.Identifier) ~ ScalaParams.? ~ ":" ~ Type.TypeSignature)
+ .map(AST.PrimitiveDeclaration.tupled)
+ }
+}
diff --git a/meta/src/main/scala/meta/parser/rise/Kind.scala b/meta/src/main/scala/meta/parser/rise/Kind.scala
new file mode 100644
index 000000000..6333bc698
--- /dev/null
+++ b/meta/src/main/scala/meta/parser/rise/Kind.scala
@@ -0,0 +1,26 @@
+package meta.parser.rise
+
+import fastparse._
+
+object Kind {
+ sealed trait AST
+ object AST {
+ case object Data extends AST
+ case object Address extends AST
+ case object Nat2Nat extends AST
+ case object Nat2Data extends AST
+ case object Nat extends AST
+ case object Fragment extends AST
+ case object MatrixLayout extends AST
+ }
+
+ def Kind[_: P]: P[AST] = P(
+ "data".!.map(_ => AST.Data) |
+ "address".!.map(_ => AST.Address) |
+ "nat2nat".!.map(_ => AST.Nat2Nat) |
+ "nat2data".!.map(_ => AST.Nat2Data) |
+ "nat".!.map(_ => AST.Nat) |
+ "fragment".!.map(_ => AST.Fragment) |
+ "matrixLayout".!.map(_ => AST.MatrixLayout)
+ )
+}
diff --git a/meta/src/main/scala/meta/parser/rise/Type.scala b/meta/src/main/scala/meta/parser/rise/Type.scala
new file mode 100644
index 000000000..6cd066db8
--- /dev/null
+++ b/meta/src/main/scala/meta/parser/rise/Type.scala
@@ -0,0 +1,143 @@
+package meta.parser.rise
+
+import fastparse.ScalaWhitespace._
+import fastparse._
+import meta.parser.shared.Identifier
+import meta.parser._
+
+object Type {
+ sealed trait AST
+ object AST {
+ case class Identifier(name: String) extends AST
+ case class FunType(inT: AST, outT: AST) extends AST
+ case class DepFunType(id: Identifier, kind: Kind.AST, t: AST) extends AST
+ case class ImplicitDepFunType(id: Identifier, kind: Kind.AST, t: AST) extends AST
+
+ case class ScalarType(t: String) extends AST
+ case object NatType extends AST
+ case class OpaqueType(name: String) extends AST
+ case class VectorType(size: Nat.AST, elemType: AST) extends AST
+ case class IndexType(size: Nat.AST) extends AST
+ case class PairType(lhs: AST, rhs: AST) extends AST
+ case class DepPairType(id: Identifier, kind: Kind.AST, t: AST) extends AST
+ case class NatToDataApply(f: AST, n: Nat.AST) extends AST
+ case class NatToDataLambda(id: Identifier, t: AST) extends AST
+ case class ArrayType(size: Nat.AST, elemType: AST) extends AST
+ case class DepArrayType(size: Nat.AST, fdt: AST) extends AST
+ case class FragmentType(n: Nat.AST, m: Nat.AST, k: Nat.AST, elemType: AST,
+ fKind: Fragment.AST, mLayoutKind: MatrixLayout.AST) extends AST
+ case class ManagedBufferType(t: AST) extends AST
+ }
+
+ object Fragment {
+ sealed trait AST
+ object AST {
+ case class Identifier(name: String) extends AST
+ object ACC extends AST
+ object A extends AST
+ object B extends AST
+ }
+ }
+
+ object MatrixLayout {
+ sealed trait AST
+ object AST {
+ case class Identifier(name: String) extends AST
+ object ROW_MAJOR extends AST
+ object COL_MAJOR extends AST
+ object NONE extends AST
+ }
+ }
+
+ def TypeSignature[_: P]: P[AST] = {
+ def DepFunType: P[AST.DepFunType] =
+ P("(" ~ IdentifierKindPair ~ ")" ~ "->" ~/ TypeSignature).map(AST.DepFunType.tupled)
+
+ def ImplicitDepFunType: P[AST.ImplicitDepFunType] =
+ P("{" ~ IdentifierKindPair ~ "}" ~ "->" ~/ TypeSignature).
+ map(AST.ImplicitDepFunType.tupled)
+
+ def FunType: P[AST.FunType] =
+ P(NoCut(LeftTypeSignature) ~ "->" ~/ TypeSignature).map(AST.FunType.tupled)
+
+ // Types that can appear at the left of an function arrow
+ def LeftTypeSignature: P[AST] = P(DataType.DataType | ("(" ~ TypeSignature ~ ")"))
+
+ P(DepFunType | ImplicitDepFunType | FunType | LeftTypeSignature)
+ }
+
+ def TypeIdentifier[_: P]: P[AST.Identifier] = P(Identifier).map(AST.Identifier)
+
+ def IdentifierKindPair[_: P]: P[(AST.Identifier, Kind.AST)] =
+ P(Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind)
+
+ object DataType {
+ def ScalarType[_: P]: P[AST.ScalarType] =
+ P("bool".! | "int".! |
+ "i8".! | "i16".! | "i32".! | "i64".! |
+ "u8".! | "u16".! | "u32".! | "u64".! |
+ "f16".! | "f32".! | "f64".!).map(AST.ScalarType)
+
+ def NatType[_: P]: P[AST.NatType.type] = P("natType").map(_ => AST.NatType)
+
+ def OpaqueType[_: P]: P[AST.OpaqueType] = P( "\"" ~~ Identifier ~~ "\"" ).map(AST.OpaqueType)
+
+ def IndexType[_: P]: P[AST.IndexType] = P("idx[" ~ Nat.Nat ~ "]").map(AST.IndexType)
+
+ def VectorType[_: P]: P[AST.VectorType] =
+ P("vec[" ~ DataType ~ "," ~ Nat.Nat ~ "]").map(t => AST.VectorType(t._2, t._1))
+
+ def FragmentType[_: P]: P[AST.FragmentType] = {
+ def FragmentKind: P[Fragment.AST] =
+ P(("fragment." ~~ (
+ "ACC".!.map(_ => Fragment.AST.ACC) |
+ "A".!.map(_ => Fragment.AST.A) |
+ "B".!.map(_ => Fragment.AST.B))
+ ) | TypeIdentifier.map(i => Fragment.AST.Identifier(i.name)))
+
+ def MatrixLayoutKind: P[MatrixLayout.AST] =
+ P(("matrixLayout." ~~ (
+ "ROW_MAJOR".!.map(_ => MatrixLayout.AST.ROW_MAJOR) |
+ "COL_MAJOR".!.map(_ => MatrixLayout.AST.COL_MAJOR) |
+ "NONE".!.map(_ => MatrixLayout.AST.NONE))
+ ) | TypeIdentifier.map(i => MatrixLayout.AST.Identifier(i.name)))
+
+ P("fragment[" ~ Nat.Nat ~ "," ~ Nat.Nat ~ "," ~ Nat.Nat ~ "," ~ DataType ~ "," ~ FragmentKind ~
+ "," ~ MatrixLayoutKind ~ "]").map(AST.FragmentType.tupled)
+ }
+
+ def ManagedBufferType[_: P]: P[AST.ManagedBufferType] =
+ P("managed[" ~ DataType ~ "]").map(AST.ManagedBufferType)
+
+ def DepArrayType[_: P]: P[AST.DepArrayType] =
+ P(Nat.Nat ~ ".." ~/ NatToData).map(AST.DepArrayType.tupled)
+
+ def ArrayType[_: P]: P[AST.ArrayType] =
+ P(Nat.Nat ~ "." ~~ !"." ~/ DataType).map(AST.ArrayType.tupled)
+
+ def DepPairType[_: P]: P[AST.DepPairType] =
+ P("(" ~ IdentifierKindPair ~ "**" ~/ DataType ~ ")").map(AST.DepPairType.tupled)
+
+ def NatToDataApply[_: P]: P[AST.NatToDataApply] =
+ P(NatToData ~ "(" ~ Nat.Nat ~ ")").map(AST.NatToDataApply.tupled)
+
+ def PairType[_: P]: P[AST.PairType] =
+ P("(" ~ NoCut(DataType) ~ "," ~/ DataType ~ ")").map(AST.PairType.tupled)
+
+ def DataType[_: P]: P[AST] =
+ P(ScalarType | NatType | OpaqueType | IndexType | VectorType | FragmentType |
+ ManagedBufferType | DepArrayType | ArrayType | DepPairType | NatToDataApply |
+ PairType | TypeIdentifier | ("(" ~ DataType ~ ")"))
+
+ def TypeName[_: P]: P[Unit] =
+ P(ScalarType | NatType | "idx" | "vec" | "fragment" | "matrixLayout")
+ }
+
+ def NatToData[_: P]: P[AST] = {
+ def NatToDataLambda: P[AST.NatToDataLambda] =
+ P("(" ~ IdentifierKindPair.filter(_._2 == Kind.AST.Nat).map(_._1) ~
+ "|->" ~/ DataType.DataType ~ ")").map(AST.NatToDataLambda.tupled)
+
+ P(TypeIdentifier | NatToDataLambda)
+ }
+}
diff --git a/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala b/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala
new file mode 100644
index 000000000..afb868305
--- /dev/null
+++ b/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala
@@ -0,0 +1,176 @@
+package meta.parser.rise
+
+import meta.parser._
+
+object isWellKindedType {
+
+ def apply(typeAST: Type.AST): Boolean = {
+ kindOf(typeAST, Map.empty).isDefined
+ }
+
+ sealed trait DataTypeOrFunctionKind
+ case class DataTypeKind(kind: Kind.AST) extends DataTypeOrFunctionKind
+ case object FunctionKind extends DataTypeOrFunctionKind
+
+ def kindOf(typeAST: Type.AST,
+ env: Map[String, Kind.AST]): Option[DataTypeOrFunctionKind] = {
+ import Type._
+ typeAST match {
+ case AST.Identifier(name) =>
+ env.get(name).map(DataTypeKind)
+ case AST.FunType(inT, outT) =>
+ for {
+ _ <- kindOf(inT, env)
+ _ <- kindOf(outT, env)
+ } yield FunctionKind
+ case AST.DepFunType(id, kind, t) =>
+ if (env.isDefinedAt(id.name)) {
+ None // we forbid shadowing
+ } else {
+ for {
+ _ <- kindOf(t, env.updated(id.name, kind))
+ } yield FunctionKind
+ }
+ case AST.ImplicitDepFunType(id, kind, t) =>
+ if (env.isDefinedAt(id.name)) {
+ None // we forbid shadowing
+ } else {
+ for {
+ _ <- kindOf(t, env.updated(id.name, kind))
+ } yield FunctionKind
+ }
+ case AST.VectorType(size, elemType) =>
+ for {
+ k1 <- kindOf(size, env)
+ k2 <- kindOf(elemType, env)
+ if k1 == Kind.AST.Nat && k2 == DataTypeKind(Kind.AST.Data)
+ } yield DataTypeKind(Kind.AST.Data)
+ case AST.IndexType(size) =>
+ for {
+ k <- kindOf(size, env)
+ if k == Kind.AST.Nat
+ } yield DataTypeKind(Kind.AST.Data)
+ case AST.PairType(lhs, rhs) =>
+ for {
+ k1 <- kindOf(lhs, env)
+ k2 <- kindOf(rhs, env)
+ if k1 == DataTypeKind(Kind.AST.Data) && k2 == DataTypeKind(Kind.AST.Data)
+ } yield DataTypeKind(Kind.AST.Data)
+ case AST.DepPairType(id, kind, t) =>
+ if (env.isDefinedAt(id.name)) {
+ None // we forbid shadowing
+ } else {
+ kindOf(t, env.updated(id.name, kind))
+ }
+ case AST.NatToDataApply(f, n) =>
+ for {
+ k1 <- kindOf(f, env)
+ k2 <- kindOf(n, env)
+ if k1 == DataTypeKind(Kind.AST.Nat2Data) && k2 == Kind.AST.Nat
+ } yield DataTypeKind(Kind.AST.Data)
+ case AST.NatToDataLambda(id, t) =>
+ if (env.isDefinedAt(id.name)) {
+ None // we forbid shadowing
+ } else {
+ for {
+ k <- kindOf(t, env.updated(id.name, Kind.AST.Nat))
+ if k == DataTypeKind(Kind.AST.Data)
+ } yield DataTypeKind(Kind.AST.Nat2Data)
+ }
+ case AST.ArrayType(size, elemType) =>
+ for {
+ k1 <- kindOf(size, env)
+ k2 <- kindOf(elemType, env)
+ if k1 == Kind.AST.Nat && k2 == DataTypeKind(Kind.AST.Data)
+ } yield DataTypeKind(Kind.AST.Data)
+ case AST.DepArrayType(size, fdt) =>
+ for {
+ k1 <- kindOf(size, env)
+ k2 <- kindOf(fdt, env)
+ if k1 == Kind.AST.Nat && k2 == DataTypeKind(Kind.AST.Nat2Data)
+ } yield DataTypeKind(Kind.AST.Data)
+ case AST.FragmentType(n, m, k, elemType, fKind, mLayout) =>
+ for {
+ k1 <- kindOf(n, env)
+ k2 <- kindOf(m, env)
+ k3 <- kindOf(k, env)
+ k4 <- kindOf(elemType, env)
+ k5 <- kindOf(fKind, env)
+ k6 <- kindOf(mLayout, env)
+ if k1 == Kind.AST.Nat && k2 == Kind.AST.Nat && k3 == Kind.AST.Nat && k4 == DataTypeKind(Kind.AST.Data) &&
+ k5 == Kind.AST.Fragment && k6 == Kind.AST.MatrixLayout
+ } yield DataTypeKind(Kind.AST.Data)
+ case AST.ManagedBufferType(dt) =>
+ for {
+ k1 <- kindOf(dt, env)
+ if k1 == DataTypeKind(Kind.AST.Data)
+ } yield DataTypeKind(Kind.AST.Data)
+ case _: AST.ScalarType | AST.NatType | _: AST.OpaqueType =>
+ Some(DataTypeKind(Kind.AST.Data))
+ }
+ }
+
+ def kindOf(natAST: Nat.AST,
+ env: Map[String, Kind.AST]
+ ): Option[Kind.AST] = {
+ natAST match {
+ case Nat.AST.Identifier(id) =>
+ env.get(id)
+ case Nat.AST.Number(_) =>
+ Some(Kind.AST.Nat)
+ case Nat.AST.BinaryOp(lhs, _, rhs) =>
+ for {
+ k1 <- kindOf(lhs, env)
+ k2 <- kindOf(rhs, env)
+ if k1 == Kind.AST.Nat && k2 == Kind.AST.Nat
+ } yield Kind.AST.Nat
+
+ case Nat.AST.TernaryOp(_, thenN, elseN) =>
+ for {
+ k1 <- kindOf(thenN, env)
+ k2 <- kindOf(elseN, env)
+ if k1 == Kind.AST.Nat && k2 == Kind.AST.Nat
+ } yield Kind.AST.Nat
+
+ case Nat.AST.Nat2NatApply(f, n) =>
+ for {
+ k1 <- kindOf(f, env)
+ k2 <- kindOf(n, env)
+ if k1 == DataTypeKind(Kind.AST.Nat2Nat) && k2 == Kind.AST.Nat
+ } yield Kind.AST.Nat
+
+ case Nat.AST.Sum(id, from, upTo, body) =>
+ val nEnv = env.updated(id.name, Kind.AST.Nat)
+ for {
+ k1 <- kindOf(from, nEnv)
+ k2 <- kindOf(upTo, nEnv)
+ k3 <- kindOf(body, nEnv)
+ if k1 == Kind.AST.Nat && k2 == Kind.AST.Nat && k3 == Kind.AST.Nat
+ } yield Kind.AST.Nat
+ }
+ }
+
+ def kindOf(fragmentAST: Type.Fragment.AST,
+ env: Map[String, Kind.AST]
+ ): Option[Kind.AST] = {
+ import Type._
+ fragmentAST match {
+ case Fragment.AST.Identifier(id) =>
+ env.get(id)
+ case Fragment.AST.ACC | Fragment.AST.A | Fragment.AST.B => Some(Kind.AST.Fragment)
+ }
+ }
+
+ def kindOf(matrixLayout: Type.MatrixLayout.AST,
+ env: Map[String, Kind.AST]
+ ): Option[Kind.AST] = {
+ import Type._
+ matrixLayout match {
+ case MatrixLayout.AST.Identifier(id) =>
+ env.get(id)
+ case MatrixLayout.AST.ROW_MAJOR |
+ MatrixLayout.AST.COL_MAJOR |
+ MatrixLayout.AST.NONE => Some(Kind.AST.MatrixLayout)
+ }
+ }
+}
diff --git a/meta/src/main/scala/meta/parser/shared/package.scala b/meta/src/main/scala/meta/parser/shared/package.scala
new file mode 100644
index 000000000..0b59a89e1
--- /dev/null
+++ b/meta/src/main/scala/meta/parser/shared/package.scala
@@ -0,0 +1,21 @@
+package meta.parser
+
+import fastparse.ScalaWhitespace._
+import fastparse._
+
+package object shared {
+ def Identifier[_: P]: P[String] = {
+ def Keywords: P[Unit] =
+ P(( "def" |
+ (rise.Kind.Kind: P[Unit]) | rise.Type.DataType.TypeName |
+ (DPIA.Kind.Kind: P[Unit])
+ ) ~~ CharPred(_.isWhitespace))
+
+ val LowerChar = scalaparse.syntax.Identifiers.NamedFunction(CharPredicates.isLower)
+ val IdCharacter = scalaparse.syntax.Identifiers.NamedFunction(c =>
+ CharPredicates.isLetter(c) || CharPredicates.isDigit(c))
+
+ P((!Keywords ~ CharPred(LowerChar).! ~~ CharsWhile(IdCharacter).!.?).
+ map(t => t._1 ++ t._2.getOrElse("")))
+ }
+}
diff --git a/src/main/scala/rise/core/primitives/primitives.rise b/src/main/scala/rise/core/primitives/primitives.rise
index 2737c9d3c..5e7895adc 100644
--- a/src/main/scala/rise/core/primitives/primitives.rise
+++ b/src/main/scala/rise/core/primitives/primitives.rise
@@ -37,7 +37,7 @@ def generate: {n: nat} -> {t: data} -> (idx[n] -> t) -> n.t
def idx: {n: nat} -> {t: data} -> idx[n] -> n.t -> t
def take: (n: nat) -> {m: nat} -> {t: data} -> (n+m).t -> n.t
- def drop: (n: nat) -> {m: nat} -> {t: data} -> (n+m).t -> m.t
+def drop: (n: nat) -> {m: nat} -> {t: data} -> (n+m).t -> m.t
def concat: {n: nat} -> {m: nat} -> {t: data} -> n.t -> m.t -> (n+m).t
def split: (n: nat) -> {m: nat} -> {t: data} -> (m*n).t -> m.n.t
@@ -56,8 +56,8 @@ def scatter: {n: nat} -> {m: nat} -> {t: data} -> n.idx[m] -> n.t -> m.t
def reorder: {t: data} -> (n: nat) -> (idxF: nat2nat) -> (idxFinv: nat2nat) -> n.t -> n.t
def padCst: {n: nat} -> (l: nat) -> (r: nat) -> {t: data} -> t -> n.t -> (l+n+r).t
-def padClamp: {n: nat} -> (l: nat) -> (r: nat) -> {t: data} -> n.t -> (l+n+r).t
-def padEmpty: {n: nat} -> (r: nat) -> {t: data} -> n.t -> (n+r).t
+def padClamp: {n: nat} -> (l: nat) -> (r: nat) -> {t: data} -> n.t -> (l+n+r).t
+def padEmpty: {n: nat} -> (r: nat) -> {t: data} -> n.t -> (n+r).t
def zip: {n: nat} -> {s: data} -> {t: data} -> n.s -> n.t -> n.(s, t)
def unzip: {n: nat} -> {s: data} -> {t: data} -> n.(s, t) -> (n.s, n.t)
diff --git a/src/main/scala/rise/core/traverse.scala b/src/main/scala/rise/core/traverse.scala
index bbc038368..724fe1f4f 100644
--- a/src/main/scala/rise/core/traverse.scala
+++ b/src/main/scala/rise/core/traverse.scala
@@ -63,6 +63,10 @@ object traverse {
case VectorType(n, e) =>
for {n1 <- natDispatch(Reference)(n); e1 <- `type`(e)}
yield VectorType(n1, e1)
+ case ManagedBufferType(dt) =>
+ for {dt1 <- datatype(dt)}
+ yield ManagedBufferType(dt1)
+ case o: OpaqueType => return_(o: DataType)
case FragmentType(rows, columns, d3, dt, fragKind, layout) =>
for {rows1 <- nat(rows); columns1 <- nat(columns); d31 <- nat(d3); dt1 <- datatype(dt);
fragKind1 <- fragmentKind(fragKind); layout1 <- matrixLayout(layout)}
diff --git a/src/main/scala/rise/core/types/Type.scala b/src/main/scala/rise/core/types/Type.scala
index 679a7c1c0..1ada7f7ad 100644
--- a/src/main/scala/rise/core/types/Type.scala
+++ b/src/main/scala/rise/core/types/Type.scala
@@ -68,6 +68,10 @@ object f64 extends ScalarType { override def toString: String = "f64" }
object NatType extends DataType { override def toString: String = "nat" }
+final case class OpaqueType(name: String) extends DataType {
+ override def toString: String = name
+}
+
// TODO: enforce ScalarType
sealed case class VectorType(size: Nat, elemType: DataType) extends DataType {
override def toString: String = s"<$size>$elemType"
@@ -91,11 +95,11 @@ sealed trait MatrixLayout
object MatrixLayout {
object Row_Major extends MatrixLayout { override def toString = "Row_Major" }
object Col_Major extends MatrixLayout { override def toString = "Col_Major" }
+ object None extends MatrixLayout
}
-final case class MatrixLayoutIdentifier(
- name: String,
- override val isExplicit: Boolean = false
+final case class MatrixLayoutIdentifier(name: String,
+ override val isExplicit: Boolean = false
) extends MatrixLayout
with Kind.Identifier
with Kind.Explicitness {
@@ -110,7 +114,7 @@ sealed trait FragmentKind
object FragmentKind {
object AMatrix extends FragmentKind { override def toString = "AMatrix"}
object BMatrix extends FragmentKind { override def toString = "BMatrix"}
- object Acuumulator extends FragmentKind { override def toString = "Acuumulator"}
+ object Accumulator extends FragmentKind { override def toString = "Accumulator"}
}
final case class FragmentKindIdentifier(name: String,
@@ -126,7 +130,7 @@ final case class FragmentKindIdentifier(name: String,
object FragmentType {
def apply(rows: Nat, columns:Nat, d3: Nat, dataType: DataType): FragmentType = {
- FragmentType(rows, columns, d3, dataType, FragmentKind.Acuumulator, MatrixLayout.Row_Major)
+ FragmentType(rows, columns, d3, dataType, FragmentKind.Accumulator, MatrixLayout.None)
}
}
@@ -137,13 +141,17 @@ final case class FragmentType(rows: Nat,
fragmentKind: FragmentKind,
layout: MatrixLayout) extends DataType {
override def toString: String =
- if (fragmentKind == FragmentKind.Acuumulator)
+ if (fragmentKind == FragmentKind.Accumulator)
s"Fragment[$rows,$columns,$d3,$dataType,$fragmentKind]"
else
s"Fragment[$rows,$columns,$d3,$dataType,$fragmentKind,$layout]"
}
+final case class ManagedBufferType(dt: DataType) extends DataType {
+ override def toString: String = s"managed[$dt]"
+
+}
final case class DepPairType[K <: Kind: KindName](
x: K#I,
diff --git a/src/main/scala/rise/eqsat/TypeNode.scala b/src/main/scala/rise/eqsat/TypeNode.scala
index 0e8f90db5..6119a5267 100644
--- a/src/main/scala/rise/eqsat/TypeNode.scala
+++ b/src/main/scala/rise/eqsat/TypeNode.scala
@@ -151,7 +151,7 @@ object DataType {
case rct.PairType(dt1, dt2) => PairType(fromNamed(dt1, bound), fromNamed(dt2, bound))
case rct.ArrayType(s, et) => ArrayType(Nat.fromNamed(s, bound), fromNamed(et, bound))
case _: rct.DepArrayType | _: rct.DepPairType[_] |
- _: rct.NatToDataApply | _: rct.FragmentType =>
+ _: rct.NatToDataApply | _: rct.FragmentType | _: rct.ManagedBufferType | _: rct.OpaqueType =>
throw new Exception(s"did not expect $dt")
})
}
diff --git a/src/main/scala/shine/C/Compilation/CodeGenerator.scala b/src/main/scala/shine/C/Compilation/CodeGenerator.scala
index 94b037e15..e5577d20b 100644
--- a/src/main/scala/shine/C/Compilation/CodeGenerator.scala
+++ b/src/main/scala/shine/C/Compilation/CodeGenerator.scala
@@ -112,12 +112,18 @@ class CodeGenerator(val decls: CodeGenerator.Declarations,
CCodeGen.codeGenNewDoubleBuffer(ArrayType(n, dt), in, out, ps, p, env)
case f@For(unroll) =>
- val (i, p) = f.unwrapBody
- CCodeGen.codeGenFor(f.n, i, p, unroll, env)
+ f.loopBody match {
+ case Lambda(i, p) =>
+ CCodeGen.codeGenFor(f.n, i, p, unroll, env)
+ case _ => throw new Exception("This should not happen")
+ }
case f@ForNat(unroll) =>
- val (i, p) = f.unwrapBody
- CCodeGen.codeGenForNat(f.n, i, p, unroll, env)
+ f.loopBody match {
+ case shine.DPIA.Phrases.DepLambda(i, p) =>
+ CCodeGen.codeGenForNat(f.n, i, p, unroll, env)
+ case _ => throw new Exception("This should not happen")
+ }
case Proj1(pair) => Lifting.liftPair(pair)._1 |> cmd(env)
case Proj2(pair) => Lifting.liftPair(pair)._2 |> cmd(env)
@@ -438,7 +444,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations,
case _ => error(s"Expected path to be not empty")
}
- case Pad(n, l, r, _, pad, array) => path match {
+ case PadCst(n, l, r, _, pad, array) => path match {
case (i: CIntExpr) :: ps =>
pad |> exp(env, ps, padExpr =>
genPad(n, l, r, padExpr, padExpr, i, ps, array, env, cont))
@@ -490,7 +496,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations,
case _ => error(s"Expected path to be not empty")
}
- case MakeArray(_, elems) => path match {
+ case MakeArray(elems) => path match {
case (i: CIntExpr) :: ps => try {
elems(i.eval) |> exp(env, ps, cont)
} catch {
@@ -503,9 +509,9 @@ class CodeGenerator(val decls: CodeGenerator.Declarations,
case DepIdx(_, _, i, e) => e |> exp(env, CIntExpr(i) :: path, cont)
- case ForeignFunctionCall(f, inTs, outT, args) =>
- CCodeGen.codeGenForeignFunctionCall(f, inTs, outT, args, env, fe =>
- generateAccess(outT, fe, path, env, cont)
+ case ffc@ForeignFunctionCall(f, inTs, args) =>
+ CCodeGen.codeGenForeignFunctionCall(f, inTs, ffc.outT, args, env, fe =>
+ generateAccess(ffc.outT, fe, path, env, cont)
)
case Proj1(pair) => SimplifyNats.simplifyIndexAndNatExp(Lifting.liftPair(pair)._1) |> exp(env, path, cont)
@@ -543,7 +549,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations,
case shine.DPIA.Types.f32 => C.AST.Type.float
case shine.DPIA.Types.f64 => C.AST.Type.double
case _: shine.DPIA.Types.IndexType => C.AST.Type.int
- case _: shine.DPIA.Types.VectorType | _: FragmentType | _: pipeline.type =>
+ case _: shine.DPIA.Types.VectorType | _: FragmentType =>
throw new Exception(s"$b types in C are not supported")
}
case a: shine.DPIA.Types.ArrayType => C.AST.ArrayType(typ(a.elemType), Some(a.size))
@@ -856,7 +862,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations,
})
}
- def codeGenForeignFunctionCall(funDecl: ForeignFunction.Declaration,
+ def codeGenForeignFunctionCall(funDecl: rise.core.ForeignFunction.Decl,
inTs: collection.Seq[DataType],
outT: DataType,
args: collection.Seq[Phrase[ExpType]],
diff --git a/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala b/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala
index 9a7ac7c4b..32994bf80 100644
--- a/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala
+++ b/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala
@@ -90,7 +90,7 @@ object AcceptorTranslation {
case AsVector(n, m, dt, access, array) =>
acc(array)(AsVectorAcc(n, m, dt, A))
- case AsVectorAligned(n, m, w, dt, array) =>
+ case AsVectorAligned(n, m, dt, access, array) =>
acc(array)(AsVectorAcc(n, m, dt, A))
case DepIdx(n, ft, index, array) =>
@@ -225,7 +225,7 @@ object AcceptorTranslation {
y, x, A)))))
case Scatter(n, m, dt, indices, input) =>
- con(indices)(fun(expT(m`.`idx(n), read))(y =>
+ con(indices)(fun(expT(n`.`idx(m), read))(y =>
acc(input)(ScatterAcc(n, m, dt, y, A))))
case slide@Slide(n, sz, sp, dt, input) =>
@@ -283,12 +283,12 @@ object AcceptorTranslation {
λ(expT({l * n}`.`dt, read))(x => acc(f(l)(x))(o)))),
x)))
- case ocl.KernelCall(name, localSize, globalSize, inTs, outT, args) =>
+ case ocl.KernelCall(name, localSize, globalSize, _, args) =>
def rec(ts: Seq[Phrase[ExpType]],
- es: Seq[Phrase[ExpType]]): Phrase[CommType] = {
+ es: Seq[Phrase[ExpType]]): Phrase[CommType] = {
ts match {
case Nil =>
- oclImp.KernelCallCmd(name, localSize, globalSize, A, es)
+ oclImp.KernelCallCmd(name, localSize, globalSize, es)(A.t.dataType, A)
case Seq(arg, tail@_*) =>
con(arg)(λ(expT(arg.t.dataType, read))(e => rec(tail, es :+ e)))
}
@@ -303,7 +303,7 @@ object AcceptorTranslation {
λ(expT(dt1, read))(x => λ(accT(dt2))(o => acc(f(x))(o))),
x, A)))
- case ocl.OpenCLFunctionCall(name, inTs, outT, args) =>
+ case fc@ocl.OpenCLFunctionCall(name, inTs, args) =>
def rec(ts: Seq[(Phrase[ExpType], DataType)],
exps: Seq[Phrase[ExpType]],
inTs: Seq[DataType]): Phrase[CommType] = {
@@ -311,7 +311,7 @@ object AcceptorTranslation {
// with only one argument left to process return the assignment of the OpenCLFunction call
case Seq( (arg, inT) ) =>
con(arg)(λ(expT(inT, read))(e =>
- A :=|outT| ocl.OpenCLFunctionCall(name, inTs :+ inT, outT, exps :+ e) ))
+ A :=|fc.outT| ocl.OpenCLFunctionCall(name, inTs :+ inT, exps :+ e)(fc.outT) ))
// with a `tail` of arguments left, recurse
case Seq( (arg, inT), tail@_* ) =>
con(arg)(λ(expT(inT, read))(e => rec(tail, exps :+ e, inTs :+ inT) ))
@@ -321,17 +321,17 @@ object AcceptorTranslation {
rec(args zip inTs, Seq(), Seq())
// CUDA
- case cuda.AsFragment(rows, columns, d3, dataType, fragmentKind, matrix, layout) =>
+ case cuda.AsFragment(rows, columns, layers, dataType, fragmentKind, layout, matrix) =>
con(matrix)(λ(ExpType(ArrayType(rows, ArrayType(columns, dataType)), read))(matrix =>
- cudaImp.WmmaLoad(rows, columns, d3, dataType, fragmentKind, layout, matrix, A)))
+ cudaImp.WmmaLoad(rows, columns, layers, dataType, fragmentKind, layout, matrix, A)))
- case cuda.AsMatrix(rows, columns, d3, dataType, fragment) =>
+ case cuda.AsMatrix(rows, columns, layers, dataType, fragment) =>
con(fragment)(λ(ExpType(fragment.t.dataType, read))(fragment =>
- cudaImp.WmmaStore(rows, columns, d3, dataType, fragment, A)))
+ cudaImp.WmmaStore(rows, columns, layers, dataType, fragment, A)))
- case cuda.GenerateFragment(rows, columns, d3, dataType, fill, fragmentKind, layout) =>
+ case cuda.GenerateFragment(rows, columns, layers, dataType, frag, layout, fill) =>
con(fill)(λ(ExpType(dataType, read))(fill =>
- cudaImp.WmmaFill(rows, columns, d3, dataType, fill, fragmentKind, layout, A)))
+ cudaImp.WmmaFill(rows, columns, layers, dataType, frag, layout, fill, A)))
case map@cuda.Map(level, dim) =>
val (n, dt1, dt2, f, array) = map.unwrap
@@ -340,17 +340,17 @@ object AcceptorTranslation {
λ(expT(dt1, read))(x => λ(accT(dt2))(o => acc(f(x))(o))),
x, A)))
- case cuda.MapFragmentElements(fragType, fragment, fun) =>
- con(fragment)(λ(expT(fragType, read))(input =>
- shine.cuda.primitives.imperative.ForFragmentElements(fragType, input, A,
- λ(expT(fragType.dataType, read))(x =>
- λ(accT(fragType.dataType))(o =>
+ case cuda.MapFragment(rows, columns, layers, dt, frag, layout, fun, input) =>
+ con(input)(λ(expT(FragmentType(rows, columns, layers, dt, frag, layout), read))(input =>
+ shine.cuda.primitives.imperative.ForFragment(rows, columns, layers, dt, frag, layout, input, A,
+ λ(expT(dt, read))(x =>
+ λ(accT(dt))(o =>
acc(fun(x))(o))))))
case cuda.TensorMatMultAdd(m, n, k, layoutA, layoutB, dataType, dataTypeAcc, aMatrix, bMatrix, cMatrix) =>
con(aMatrix)(λ(ExpType(FragmentType(m, n, k, dataType, FragmentKind.AMatrix, layoutA), read))(aMatrix =>
con(bMatrix)(λ(ExpType(FragmentType(m, n, k, dataType, FragmentKind.BMatrix, layoutB), read))(bMatrix =>
- con(cMatrix)(λ(ExpType(FragmentType(m, n, k, dataTypeAcc), read))(cMatrix =>
+ con(cMatrix)(λ(ExpType(FragmentType(m, n, k, dataTypeAcc, FragmentKind.Accumulator, MatrixLayout.None), read))(cMatrix =>
cudaImp.WmmaMMA(m, n, k, layoutA, layoutB, dataType, dataTypeAcc, aMatrix, bMatrix, cMatrix, A)))))))
}
}
diff --git a/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala b/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala
index b9d5338d3..2617c22a7 100644
--- a/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala
+++ b/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala
@@ -119,7 +119,7 @@ object ContinuationTranslation {
con(array)(λ(expT((n + m)`.` dt, read))(x =>
C(Drop(n, m, dt, x))))
- case ForeignFunctionCall(funDecl, inTs, outT, args) =>
+ case ffc@ForeignFunctionCall(funDecl, inTs, args) =>
def rec(ts: Seq[(Phrase[ExpType], DataType)],
exps: Seq[Phrase[ExpType]],
inTs: Seq[DataType]): Phrase[CommType] = {
@@ -127,7 +127,7 @@ object ContinuationTranslation {
// with only one argument left to process return the assignment of the function call
case Seq( (arg, inT) ) =>
con(arg)(λ(expT(inT, read))(e =>
- outT match {
+ ffc.outT match {
// TODO: this is an ugly fix to avoid calling the function multiple times
// for pair assignment, see:
// https://github.com/rise-lang/shine/issues/58
@@ -140,10 +140,10 @@ object ContinuationTranslation {
case _ =>
`new`.apply
}
- backendNew(outT, tmp =>
- Assign(outT, tmp.wr, ForeignFunctionCall(funDecl, inTs :+ inT, outT, exps :+ e)) `;`
+ backendNew(ffc.outT, tmp =>
+ Assign(ffc.outT, tmp.wr, ForeignFunctionCall(funDecl, inTs :+ inT, exps :+ e)(ffc.outT)) `;`
C(tmp.rd))
- case _ => C( ForeignFunctionCall(funDecl, inTs :+ inT, outT, exps :+ e) )
+ case _ => C( ForeignFunctionCall(funDecl, inTs :+ inT, exps :+ e)(ffc.outT) )
}))
// with a `tail` of arguments left, rec
case Seq( (arg, inT), tail@_* ) =>
@@ -191,13 +191,13 @@ object ContinuationTranslation {
con(value)(fun(value.t)(x =>
con(f(x))(C)))
- case MakeArray(dt, elements) =>
+ case ma@MakeArray(elements) =>
def rec(func: Vector[Phrase[ExpType]], imp: Vector[Phrase[ExpType]]): Phrase[CommType] = {
func match {
- case xf +: func => con(xf)(fun(expT(dt, read))(xi =>
+ case xf +: func => con(xf)(fun(expT(ma.dt, read))(xi =>
rec(func, imp :+ xi)
))
- case _ => C(MakeArray(dt, imp))
+ case _ => C(MakeArray(imp)(ma.n, ma.dt))
}
}
@@ -257,10 +257,10 @@ object ContinuationTranslation {
con(e)(λ(e.t)(x =>
C(NatAsIndex(n, x))))
- case Pad(n, l, r, dt, padExp, array) =>
+ case PadCst(n, l, r, dt, padExp, array) =>
con(array)(λ(expT(n`.`dt, read))(x =>
con(padExp)(λ(expT(dt, read))(p =>
- C(Pad(n, l, r, dt, p, x))))))
+ C(PadCst(n, l, r, dt, p, x))))))
case PadClamp(n, l, r, dt, array) =>
con(array)(λ(expT(n`.`dt, read))(x =>
@@ -291,7 +291,8 @@ object ContinuationTranslation {
acc(scanSeq)(tmp.wr) `;` C(tmp.rd) ))
case slide@Slide(n, sz, sp, dt, input) =>
- con(input)(λ(expT(slide.inputSize`.`dt, read))(x =>
+ val inputSize = sp*n+sz
+ con(input)(λ(expT(inputSize`.`dt, read))(x =>
C(Slide(n, sz, sp, dt, x)) ))
case Snd(dt1, dt2, pair) =>
@@ -361,7 +362,7 @@ object ContinuationTranslation {
`new`(map.n `.` map.dt2, λ(varT(map.n `.` map.dt2))(tmp =>
acc(map)(tmp.wr) `;` C(tmp.rd)))
- case ocl.OpenCLFunctionCall(name, inTs, outT, args) =>
+ case fc@ocl.OpenCLFunctionCall(name, inTs, args) =>
def rec(ts: Seq[(Phrase[ExpType], DataType)],
es: Seq[Phrase[ExpType]],
inTs: Seq[DataType]): Phrase[CommType] = {
@@ -369,7 +370,7 @@ object ContinuationTranslation {
// with only one argument left to process continue with the OpenCLFunction call
case Seq( (arg, inT) ) =>
con(arg)(λ(expT(inT, read))(e =>
- C(ocl.OpenCLFunctionCall(name, inTs :+ inT, outT, es :+ e)) ))
+ C(ocl.OpenCLFunctionCall(name, inTs :+ inT, es :+ e)(fc.outT)) ))
// with a `tail` of arguments left, rec
case Seq( (arg, inT), tail@_* ) =>
con(arg)(λ(expT(inT, read))(e => rec(tail, es :+ e, inTs :+ inT) ))
@@ -379,12 +380,10 @@ object ContinuationTranslation {
rec(args zip inTs, Seq(), Seq())
case reduceSeq@ocl.ReduceSeq(unroll) =>
- val (n, initAddrSpace, dt1, dt2, f, init, array) =
- ( reduceSeq.n, reduceSeq.initAddrSpace, reduceSeq.dt1, reduceSeq.dt2,
- reduceSeq.f, reduceSeq.init, reduceSeq.array )
+ val (n, a, dt1, dt2, f, init, array) = reduceSeq.unwrap
con(array)(λ(expT(n`.`dt1, read))(X =>
- oclI.ReduceSeqI(n, initAddrSpace, dt1, dt2,
+ oclI.ReduceSeqI(n, a, dt1, dt2,
λ(expT(dt2, read))(x =>
λ(expT(dt1, read))(y =>
λ(accT(dt2))(o => acc( f(x)(y) )( o )))),
@@ -400,7 +399,7 @@ object ContinuationTranslation {
case cuda.GlobalToShared(dt, inputGlobal) =>
val adj = AdjustArraySizesForAllocations(inputGlobal, dt, AddressSpace.Local)
- shine.OpenCL.DSL.`new` (AddressSpace.Private) (pipeline, pipeline =>
+ shine.OpenCL.DSL.`new` (AddressSpace.Private) (OpaqueType("pipeline"), pipeline =>
shine.OpenCL.DSL.`new` (AddressSpace.Local) (adj.dt, tmp =>
acc(inputGlobal)(cudaIm.GlobalToSharedAcc(dt, pipeline.rd, tmp.wr)) `;`
cudaIm.SyncPipeline(pipeline.rd) `;`
@@ -414,19 +413,18 @@ object ContinuationTranslation {
`new`(n `.` dt2, λ(varT(n `.` dt2))(tmp =>
acc(map)(tmp.wr) `;` C(tmp.rd)))
- case mapFragmentElements@cuda.MapFragmentElements(fragType, fragment, fun) =>
- val dt = fragType.dataType
-
+ case m@cuda.MapFragment(rows, columns, layers, dt, frag, layout, fun, input) =>
+ val fragType = FragmentType(rows, columns, layers, dt, frag, layout)
shine.OpenCL.primitives.imperative.New(AddressSpace.Private, fragType,
λ(VarType(fragType))(fragmentAcc =>
- (if (fragment.t.accessType.toString == write.toString)
- acc(fragment)(fragmentAcc.wr) `;`
- cudaIm.ForFragmentElements(fragType, fragmentAcc.rd, fragmentAcc.wr,
+ (if (input.t.accessType.toString == write.toString)
+ acc(input)(fragmentAcc.wr) `;`
+ cudaIm.ForFragment(rows, columns, layers, dt, frag, layout, fragmentAcc.rd, fragmentAcc.wr,
λ(expT(dt, read))(x =>
λ(accT(dt))(o =>
acc(fun(x))(o))))
else
- acc(mapFragmentElements)(fragmentAcc.wr)) `;`
+ acc(m)(fragmentAcc.wr)) `;`
C(fragmentAcc.rd)))
}
}
diff --git a/src/main/scala/shine/DPIA/Compilation/Passes/UnrollLoops.scala b/src/main/scala/shine/DPIA/Compilation/Passes/UnrollLoops.scala
index b618f5644..e002382ab 100644
--- a/src/main/scala/shine/DPIA/Compilation/Passes/UnrollLoops.scala
+++ b/src/main/scala/shine/DPIA/Compilation/Passes/UnrollLoops.scala
@@ -16,27 +16,36 @@ object UnrollLoops {
override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] =
p match {
case f@For(true) =>
- val (ident, body) = f.unwrapBody
- Continue(unrollLoop(f.n, init = 0, step = 1, i =>
- Phrase.substitute(NatAsIndex(f.n, Natural(i)),
- `for` = ident, in = body)), this)
+ f.loopBody match {
+ case shine.DPIA.Phrases.Lambda(x, body) =>
+ Continue(unrollLoop(f.n, init = 0, step = 1, i =>
+ Phrase.substitute(NatAsIndex(f.n, Natural(i)),
+ `for` = x, in = body)), this)
+ case _ => throw new Exception("This should not happen")
+ }
case f@ForNat(true) =>
- val (ident, body) = f.unwrapBody
- Continue(unrollLoop(f.n, init = 0, step = 1, i =>
- PhraseType.substitute(i, `for` = ident, in = body)), this)
- case pf@ParFor(_, _, true) =>
- val (ident, identOut, body) = pf.unwrapBody
- pf.out.t.dataType match {
- case ArrayType(_, elemType) =>
- Continue(unrollLoop(pf.n, pf.init, pf.step, i =>
- Phrase.substitute(
- IdxAcc(pf.n, elemType,
- NatAsIndex(pf.n, Natural(i)), pf.out),
- `for` = identOut,
- Phrase.substitute(NatAsIndex(pf.n, Natural(i)),
- `for` = ident, in = body))), this)
- case _ =>
- throw new Exception("OpenCLParFor acceptor has to be of ArrayType.")
+ f.loopBody match {
+ case shine.DPIA.Phrases.DepLambda(x, body) =>
+ Continue(unrollLoop(f.n, init = 0, step = 1, i =>
+ PhraseType.substitute(i, `for` = x, in = body)), this)
+ case _ => throw new Exception("This should not happen")
+ }
+ case pf@ParFor(_, _, true, _) =>
+ pf.body match {
+ case shine.DPIA.Phrases.Lambda(ident, shine.DPIA.Phrases.Lambda(identOut, body)) =>
+ pf.out.t.dataType match {
+ case ArrayType(_, elemType) =>
+ Continue(unrollLoop(pf.n, pf.init, pf.step, i =>
+ Phrase.substitute(
+ IdxAcc(pf.n, elemType,
+ NatAsIndex(pf.n, Natural(i)), pf.out),
+ `for` = identOut,
+ Phrase.substitute(NatAsIndex(pf.n, Natural(i)),
+ `for` = ident, in = body))), this)
+ case _ =>
+ throw new Exception("OpenCLParFor acceptor has to be of ArrayType.")
+ }
+ case _ => throw new Exception("This should not happen")
}
case _ =>
Continue(p, this)
@@ -73,7 +82,7 @@ object UnrollLoops {
val numIter = ceilDiv(stopMax - startMin, incr)
val tmp = (0 until numIter).foldLeft[Phrase[CommType]](
- Comment(s"unrolling loop of $numIter"))({
+ shine.DPIA.DSL.comment(s"unrolling loop of $numIter"))({
case (prev, i) =>
val index = init + Cst(i * incr)
assert(isSmaller(index, n).contains(true)) //TODO add if-guards otherwise.
diff --git a/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala b/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala
index 0014f37b3..2d63673f2 100644
--- a/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala
+++ b/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala
@@ -7,6 +7,10 @@ import shine.DPIA.Types.DataType._
import shine.DPIA._
import shine.DPIA.primitives.functional.{Fst, Snd}
+object ImperativePrimitives {
+ def skip: Skip = Skip()
+}
+
object `new` {
def apply(dt: DataType,
f: Phrase[VarType ->: CommType]): New =
@@ -79,7 +83,7 @@ object streamNext {
}
object comment {
- def apply(comment: String): Comment = Comment(comment)
+ def apply(comment: String): Comment = Comment(comment)()
}
object fst {
@@ -109,5 +113,3 @@ object pairAcc2 {
def apply(fstT: DataType, sndT: DataType, record: Phrase[AccType]): PairAcc2 =
PairAcc2(fstT, sndT, record)
}
-
-object skip extends Skip
diff --git a/src/main/scala/shine/DPIA/DSL/package.scala b/src/main/scala/shine/DPIA/DSL/package.scala
index fe0c214f5..40e1dc323 100644
--- a/src/main/scala/shine/DPIA/DSL/package.scala
+++ b/src/main/scala/shine/DPIA/DSL/package.scala
@@ -11,15 +11,15 @@ import scala.language.implicitConversions
package object DSL {
implicit class BinOps(lhs: Phrase[ExpType]) {
- def +(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.ADD, lhs, rhs)
- def -(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.SUB, lhs, rhs)
- def *(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.MUL, lhs, rhs)
- def /(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.DIV, lhs, rhs)
- def %(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.MOD, lhs, rhs)
- def >(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.GT, lhs, rhs)
- def <(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.LT, lhs, rhs)
- def =:=(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.EQ, lhs, rhs)
- def unary_- = UnaryOp(Operators.Unary.NEG, lhs)
+ def +(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.ADD, lhs, rhs)
+ def -(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.SUB, lhs, rhs)
+ def *(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.MUL, lhs, rhs)
+ def /(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.DIV, lhs, rhs)
+ def %(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.MOD, lhs, rhs)
+ def >(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.GT, lhs, rhs)
+ def <(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.LT, lhs, rhs)
+ def =:=(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.EQ, lhs, rhs)
+ def unary_- : UnaryOp = UnaryOp(Operators.Unary.NEG, lhs)
}
implicit class ExpPhraseExtensions(e: Phrase[ExpType]) {
diff --git a/src/main/scala/shine/DPIA/Phrases/PrettyPhrasePrinter.scala b/src/main/scala/shine/DPIA/Phrases/PrettyPhrasePrinter.scala
index 3c9f53126..1de17598d 100644
--- a/src/main/scala/shine/DPIA/Phrases/PrettyPhrasePrinter.scala
+++ b/src/main/scala/shine/DPIA/Phrases/PrettyPhrasePrinter.scala
@@ -35,6 +35,18 @@ object PrettyPhrasePrinter {
case PhrasePair(fst, snd) => s"(${apply(fst)}, ${apply(snd)})"
+ case shine.DPIA.primitives.imperative.Comment(comment) => s"\n//$comment\n"
+
+ case shine.OpenCL.primitives.imperative.Barrier(local, global) =>
+ s"""barrier( ${if(local) "CLK_LOCAL_MEM_FENCE" else ""} ${if(global && local) "|" else ""}
+ ${if(global) "CLK_GLOBAL_MEM_FENCE" else ""})"""
+
+ case shine.cuda.primitives.imperative.SyncThreads() => "__syncthreads()"
+
+ case shine.cuda.primitives.imperative.SyncWarp() => "__syncwarp()"
+
+ case shine.cuda.primitives.imperative.SyncPipeline(pipe) => s"$pipe.commit_and_wait()"
+
case c: Primitive[_] => c.prettyPrint
}
}
diff --git a/src/main/scala/shine/DPIA/Phrases/VisitAndRebuild.scala b/src/main/scala/shine/DPIA/Phrases/VisitAndRebuild.scala
index 76e005bb9..e07cc7937 100644
--- a/src/main/scala/shine/DPIA/Phrases/VisitAndRebuild.scala
+++ b/src/main/scala/shine/DPIA/Phrases/VisitAndRebuild.scala
@@ -186,6 +186,11 @@ object VisitAndRebuild {
case r: PairType =>
PairType(visitDataTypeAndRebuild(r.fst, v),
visitDataTypeAndRebuild(r.snd, v))
- case d => d
+ case ManagedBufferType(dt) =>
+ ManagedBufferType(visitDataTypeAndRebuild(dt, v))
+ case d => d match {
+ case _: ComposedType | _: BasicType | _: OpaqueType |
+ _: NatToDataApply | _: DataTypeIdentifier => d
+ }
}
}
diff --git a/src/main/scala/shine/DPIA/Types/DataType.scala b/src/main/scala/shine/DPIA/Types/DataType.scala
index 311ea4897..2b18660d8 100644
--- a/src/main/scala/shine/DPIA/Types/DataType.scala
+++ b/src/main/scala/shine/DPIA/Types/DataType.scala
@@ -17,18 +17,19 @@ sealed trait MatrixLayout
object MatrixLayout {
object Row_Major extends MatrixLayout { override def toString = "Row_Major" }
object Col_Major extends MatrixLayout { override def toString = "Col_Major" }
+ object None extends MatrixLayout
}
final case class MatrixLayoutIdentifier(name: String) extends MatrixLayout with Kind.Identifier {
- var layout: Option[MatrixLayout] = None
+ var layout: MatrixLayout = MatrixLayout.None
override def toString: String = name
def setLayout(matrixLayout: MatrixLayout): Unit = {
- if (layout.isEmpty)
- layout = Some(matrixLayout)
- else if (layout.get != matrixLayout)
- throw new Exception(s"could not unify ${layout.get} and $matrixLayout")
+ if (layout == MatrixLayout.None)
+ layout = matrixLayout
+ else if (layout != matrixLayout)
+ throw new Exception(s"could not unify $layout and $matrixLayout")
}
}
@@ -40,43 +41,34 @@ object FragmentKind {
object Accumulator extends FragmentKind { override def toString = "Accumulator"}
}
-//This can be used to create fragments of kind `Accumulator` which does not need a layout
-object FragmentType {
- def apply(rows: Nat, columns:Nat, d3: Nat, dataType: DataType): FragmentType =
- FragmentType(rows, columns, d3, dataType, FragmentKind.Accumulator, null)
-}
-
/**
* Represents a CUDA-fragment which represents a tile of a matrix which is stored in registers of a warp.
* Fragments of kind `Accumulator` does not have a layout. So the `layout` of fragments of kind `Accumulator`
* can be ignored.
* @param rows number of rows
* @param columns number of columns
- * @param d3 third dimension which is used in the MMA operation
+ * @param layers third dimension which is used in the MMA operation
* @param dataType dataType of the elements
* @param fragmentKind kind of the fragment {@link FragmentKind}
* @param layout layout of the fragment {@link MatrixLayout}
*/
final case class FragmentType(rows: Nat,
columns: Nat,
- d3: Nat,
+ layers: Nat,
dataType: DataType,
fragmentKind: FragmentKind,
layout: MatrixLayout) extends BasicType {
override def toString: String =
- if (fragmentKind == FragmentKind.Accumulator)
- s"Fragment[$rows,$columns,$d3,$dataType,$fragmentKind]"
- else
- s"Fragment[$rows,$columns,$d3,$dataType,$fragmentKind,$layout]"
+ s"Fragment[$rows,$columns,$layers,$dataType,$fragmentKind,$layout]"
override def equals(o: Any): Boolean = {
o match {
case f: FragmentType =>
f.fragmentKind match {
case FragmentKind.Accumulator =>
- f.rows.equals(rows) && f.columns.equals(columns) && f.d3.equals(d3) && f.dataType.equals(dataType)
+ f.rows.equals(rows) && f.columns.equals(columns) && f.layers.equals(layers) && f.dataType.equals(dataType)
case _ =>
- f.rows.equals(rows) && f.columns.equals(columns) && f.d3.equals(d3) && f.dataType.equals(dataType) &&
+ f.rows.equals(rows) && f.columns.equals(columns) && f.layers.equals(layers) && f.dataType.equals(dataType) &&
f.fragmentKind.equals(fragmentKind) && f.layout.equals(layout)
}
case _ => false
@@ -84,8 +76,6 @@ final case class FragmentType(rows: Nat,
}
}
-object pipeline extends BasicType { override def toString = "pipeline" }
-
object bool extends ScalarType { override def toString: String = "bool" }
object int extends ScalarType { override def toString: String = "int" }
@@ -153,7 +143,7 @@ final case class PairType(fst: DataType, snd: DataType) extends ComposedType {
override def toString: String = s"($fst x $snd)"
}
-sealed case class VectorType(size: Nat, elemType: ScalarType)
+sealed case class VectorType(size: Nat, elemType: DataType)
extends BasicType
{
override def toString: String = s"<$size>$elemType"
@@ -161,7 +151,7 @@ sealed case class VectorType(size: Nat, elemType: ScalarType)
object vec {
@inline
- def apply(size: Nat, elemType: ScalarType): VectorType =
+ def apply(size: Nat, elemType: DataType): VectorType =
VectorType(size, elemType)
}
@@ -211,7 +201,7 @@ object DataType {
case f: FragmentType =>
FragmentType(ArithExpr.substitute(f.rows, Map((`for`, ae))),
ArithExpr.substitute(f.columns, Map((`for`, ae))),
- ArithExpr.substitute(f.d3, Map((`for`, ae))),
+ ArithExpr.substitute(f.layers, Map((`for`, ae))),
substitute(ae, `for`, f.dataType), f.fragmentKind, f.layout)
case a: DepArrayType =>
val subMap = Map((`for`, ae))
@@ -267,7 +257,7 @@ object DataType {
case ArrayType(size, _) => size
case DepArrayType(size, _) => size
case _: DataTypeIdentifier | _: NatToDataApply | _: OpaqueType |
- _: FragmentType | _: pipeline.type =>
+ _: FragmentType =>
throw new Exception("This should not happen")
}
diff --git a/src/main/scala/shine/DPIA/Types/package.scala b/src/main/scala/shine/DPIA/Types/package.scala
index ab1603fa9..5d1a7b630 100644
--- a/src/main/scala/shine/DPIA/Types/package.scala
+++ b/src/main/scala/shine/DPIA/Types/package.scala
@@ -1,13 +1,20 @@
package shine.DPIA
+import arithexpr.arithmetic.RangeAdd
import shine.DPIA.Phrases.Phrase
import shine.DPIA.Types.TypeCheck._
package object Types {
implicit class ReverseInferenceHelper(pt: PhraseType) {
- def ::[T <: PhraseType](p: Phrase[T]): Unit = p checkTypeEqOrSubtype pt
- def `:`[T <: PhraseType](p: Phrase[T]): Unit = p checkTypeEqOrSubtype pt
+ def ::[T <: PhraseType](p: Phrase[T]): Unit =
+ if (!(p checkTypeEqOrSubtype pt)) {
+ throw new Exception(s"Type error: found ${p.t}, expected $pt")
+ }
+ def `:`[T <: PhraseType](p: Phrase[T]): Unit =
+ if (!(p checkTypeEqOrSubtype pt)) {
+ throw new Exception(s"Type error: found ${p.t}, expected $pt")
+ }
}
type NatDependentFunctionType[T <: PhraseType] = DepFunType[NatKind, T]
@@ -47,4 +54,22 @@ package object Types {
): DepFunType[AccessKind, T] =
DepFunType[AccessKind, T](at, t)
}
+
+ object n2dtFun {
+ def apply(f: NatIdentifier => DataType): NatToDataLambda = {
+ val x = NatIdentifier(freshName("n"))
+ NatToDataLambda(x, f(x))
+ }
+
+ def apply(r: arithexpr.arithmetic.Range)
+ (f: NatIdentifier => DataType): NatToDataLambda = {
+ val x = NatIdentifier(freshName("n"), r)
+ NatToDataLambda(x, f(x))
+ }
+
+ def apply(upperBound: Nat)
+ (f: NatIdentifier => DataType): NatToDataLambda = {
+ apply(RangeAdd(0, upperBound, 1))(f)
+ }
+ }
}
diff --git a/src/main/scala/shine/DPIA/fromRise.scala b/src/main/scala/shine/DPIA/fromRise.scala
index 285c24129..a9c487b83 100644
--- a/src/main/scala/shine/DPIA/fromRise.scala
+++ b/src/main/scala/shine/DPIA/fromRise.scala
@@ -340,12 +340,12 @@ object fromRise {
case core.slide() => fromType {
case nFunT(sz, nFunT(sp,
expT(ArrayType(insz, t), `read`) ->:
- expT(ArrayType(n, ArrayType(_, _)), `read`)))
+ expT(ArrayType(np1, ArrayType(_, _)), `read`)))
=>
depFun[NatKind](sz)(
depFun[NatKind](sp)(
fun[ExpType](expT(insz`.`t, read), e =>
- Slide(n, sz, sp, t, e))))
+ Slide(np1-1, sz, sp, t, e))))
}
case core.circularBuffer() => fromType {
@@ -482,7 +482,7 @@ object fromRise {
depFun[NatKind](q)(
fun[ExpType](expT(t, read), cst =>
fun[ExpType](expT(n`.`t, read), e =>
- Pad(n, l, q, t, cst, e)))))
+ PadCst(n, l, q, t, cst, e)))))
}
case core.padEmpty() => fromType {
@@ -694,7 +694,7 @@ object fromRise {
fun[ExpType](ExpType(inTs(i), read), a =>
buildFFCall(args :+ a))
} else {
- ForeignFunctionCall(decl, inTs, outT, args)
+ ForeignFunctionCall(decl, inTs, args)(outT)
}
}
buildFFCall(Vector())
@@ -713,7 +713,7 @@ object fromRise {
): Phrase[_ <: PhraseType] = t match {
case FunType(in: ExpType, out) =>
fun[ExpType](in, e => buildArrayPrimitive(out, elements :+ e))
- case ExpType(ArrayType(_, et), _) => MakeArray(et, elements)
+ case ExpType(ArrayType(_, et), _) => MakeArray(elements)(elements.size, et)
case _ => error(s"did not expect $t")
}
buildArrayPrimitive(t, Vector())
@@ -763,7 +763,7 @@ object fromRise {
=>
depFun[NatKind](n)(
fun[ExpType](expT(mn`.`t, read), e =>
- AsVectorAligned(n, m, a, t, e)))
+ AsVectorAligned(n, m, t, a, e)))
}
case core.asScalar() => fromType {
@@ -771,7 +771,7 @@ object fromRise {
expT(ArrayType(_, _), _)
=>
fun[ExpType](expT(m`.`vec(n, t), a), e =>
- AsScalar(m, n, t, a, e))
+ AsScalar(n, m, t, a, e))
}
case core.vectorFromScalar() => fromType {
@@ -813,7 +813,7 @@ object fromRise {
depFun[NatKind](ls1)(depFun[NatKind](ls2)(depFun[NatKind](ls3)(
depFun[NatKind](gs1)(depFun[NatKind](gs2)(depFun[NatKind](gs3)(
fun[ExpType](expT(t, write), e =>
- ocl.Run(LocalSize(ls1, ls2, ls3), GlobalSize(gs1, gs2, gs3), t, e))))))))
+ ocl.Run(LocalSize(ls1, ls2, ls3), GlobalSize(gs1, gs2, gs3))(t, e))))))))
}
case core.dmatch() => fromType {
@@ -842,20 +842,20 @@ object fromRise {
case expT(ArrayType(rows, ArrayType(columns, dt)), `read`) ->:
expT(FragmentType(_, _, d3, _, fragType, layout), _) =>
fun[ExpType](expT(ArrayType(rows, ArrayType(columns, dt)), read), a =>
- cuda.AsFragment(rows, columns, d3, dt, fragType, a, layout))
+ cuda.AsFragment(rows, columns, d3, dt, fragType, layout, a))
}
case rcuda.asMatrix() => fromType {
- case expT(FragmentType(rows, columns, d3, dt, FragmentKind.Accumulator, _), `read`) ->:
+ case expT(FragmentType(rows, columns, layers, dt, FragmentKind.Accumulator, _), `read`) ->:
expT(ArrayType(_, ArrayType(_, _)), `write`) =>
- fun[ExpType](expT(FragmentType(rows, columns, d3, dt), read), dFrag =>
- cuda.AsMatrix(rows, columns, d3, dt, dFrag))
+ fun[ExpType](expT(FragmentType(rows, columns, layers, dt, FragmentKind.Accumulator, MatrixLayout.None), read), dFrag =>
+ cuda.AsMatrix(rows, columns, layers, dt, dFrag))
}
case rcuda.generateFragment() => fromType {
- case expT(dt, `read`) ->: expT(FragmentType(rows, columns, d3, _, fragType, layout), read) =>
+ case expT(dt, `read`) ->: expT(FragmentType(rows, columns, layers, _, frag, layout), read) =>
fun[ExpType](expT(dt, read), fill =>
- cuda.GenerateFragment(rows, columns, d3, dt, fill, fragType, layout))
+ cuda.GenerateFragment(rows, columns, layers, dt, frag, layout, fill))
}
case rcuda.tensorMMA() => fromType {
@@ -865,7 +865,7 @@ object fromRise {
expT(FragmentType(_, _, _, _, FragmentKind.Accumulator, _), `write`) =>
fun[ExpType](expT(FragmentType(m, k, n, dt, FragmentKind.AMatrix, layoutA), read), a =>
fun[ExpType](expT(FragmentType(k, n, m, dt, FragmentKind.BMatrix, layoutB), read), b =>
- fun[ExpType](expT(FragmentType(m, n, k, dtResult), read), c =>
+ fun[ExpType](expT(FragmentType(m, n, k, dtResult, FragmentKind.Accumulator, MatrixLayout.None), read), c =>
cuda.TensorMatMultAdd(m, n, k, layoutA, layoutB, dt, dtResult, a, b, c))))
}
@@ -873,8 +873,9 @@ object fromRise {
case (expT(dt: DataType, `read`) ->: expT(_, `write`)) ->:
expT(fragType : FragmentType, `read`) ->: expT(_, _) =>
fun[ExpType ->: ExpType](ExpType(dt, read) ->: ExpType(dt, write), f =>
- fun[ExpType](ExpType(fragType, read), fragment =>
- cuda.MapFragmentElements(fragType.asInstanceOf[FragmentType], fragment, f)))
+ fun[ExpType](ExpType(fragType, read), input =>
+ cuda.MapFragment(fragType.rows, fragType.columns, fragType.layers, fragType.dataType,
+ fragType.fragmentKind, fragType.layout, f, input)))
}
case rcuda.mapGlobal(dim) => fromType {
@@ -973,10 +974,12 @@ object fromRise {
FragmentType(f.rows, f.d3, f.columns, dataType(f.dataType), FragmentKind.AMatrix, layout(f.layout))
case rt.FragmentKind.BMatrix =>
FragmentType(f.d3, f.columns, f.rows, dataType(f.dataType), FragmentKind.BMatrix, layout(f.layout))
- case rt.FragmentKind.Acuumulator =>
+ case rt.FragmentKind.Accumulator =>
FragmentType(f.rows, f.columns, f.d3, dataType(f.dataType), FragmentKind.Accumulator, layout(f.layout))
case _ => throw new Exception("this should not happen")
}
+ case rt.OpaqueType(name) => OpaqueType(name)
+ case rt.ManagedBufferType(dt) => ManagedBufferType(dataType(dt))
}
private val layouts: mutable.HashMap[String, MatrixLayoutIdentifier] = mutable.HashMap.empty
@@ -984,6 +987,7 @@ object fromRise {
def layout(layout: rt.MatrixLayout): MatrixLayout = layout match {
case rt.MatrixLayout.Row_Major => MatrixLayout.Row_Major
case rt.MatrixLayout.Col_Major => MatrixLayout.Col_Major
+ case rt.MatrixLayout.None => MatrixLayout.None
case rt.MatrixLayoutIdentifier(name, _) => layouts.getOrElseUpdate(name, MatrixLayoutIdentifier(name))
case _ => throw new Exception("this should not happen")
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/AsScalar.scala b/src/main/scala/shine/DPIA/primitives/functional/AsScalar.scala
index 61c3e521a..9d62a3b9c 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/AsScalar.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/AsScalar.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class AsScalar(n: Nat,
- m: Nat,
- dt: ScalarType,
- access: AccessType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT(n `.` vec(m, dt), access)
- override val t: ExpType = expT((n * m) `.` dt, access)
+final case class AsScalar(val n: Nat, val m: Nat, val dt: DataType, val a: AccessType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(m, VectorType(n, dt)), a)
+ }
+ override val t: ExpType = expT(ArrayType(m * n, dt), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsScalar = new AsScalar(v.nat(n), v.nat(m), v.data(dt), v.access(a), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/AsVector.scala b/src/main/scala/shine/DPIA/primitives/functional/AsVector.scala
index 8ccb6e9cf..dbe82e1ac 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/AsVector.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/AsVector.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class AsVector(n: Nat,
- m: Nat,
- dt: ScalarType,
- access: AccessType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT((m * n) `.` dt, access)
- override val t: ExpType = expT(m `.` vec(n, dt), access)
+final case class AsVector(val n: Nat, val m: Nat, val dt: DataType, val a: AccessType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(m * n, dt), a)
+ }
+ override val t: ExpType = expT(ArrayType(m, VectorType(n, dt)), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsVector = new AsVector(v.nat(n), v.nat(m), v.data(dt), v.access(a), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/AsVectorAligned.scala b/src/main/scala/shine/DPIA/primitives/functional/AsVectorAligned.scala
index e8af77637..f84602328 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/AsVectorAligned.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/AsVectorAligned.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class AsVectorAligned(n: Nat,
- m: Nat,
- w: AccessType,
- dt: ScalarType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT((m * n)`.`dt, w)
- override val t: ExpType = expT(m`.`vec(n, dt), w)
+final case class AsVectorAligned(val n: Nat, val m: Nat, val dt: DataType, val a: AccessType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(m * n, dt), a)
+ }
+ override val t: ExpType = expT(ArrayType(m, VectorType(n, dt)), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsVectorAligned = new AsVectorAligned(v.nat(n), v.nat(m), v.data(dt), v.access(a), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Cast.scala b/src/main/scala/shine/DPIA/primitives/functional/Cast.scala
index 0d0e40c40..52fc46cb2 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Cast.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Cast.scala
@@ -1,15 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Cast(dt1: BasicType,
- dt2: BasicType,
- e: Phrase[ExpType]
- )extends ExpPrimitive {
- e :: expT(dt1, read)
+final case class Cast(val dt1: DataType, val dt2: DataType, val e: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ e :: expT(dt1, read)
+ }
override val t: ExpType = expT(dt2, read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Cast = new Cast(v.data(dt1), v.data(dt2), VisitAndRebuild(e, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/CircularBuffer.scala b/src/main/scala/shine/DPIA/primitives/functional/CircularBuffer.scala
index 3798fbfca..2154283c7 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/CircularBuffer.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/CircularBuffer.scala
@@ -1,21 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class CircularBuffer(n: Nat,
- alloc: Nat,
- sz: Nat,
- dt1: DataType,
- dt2: DataType,
- load: Phrase[ExpType ->: ExpType],
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- load :: expT(dt1, read) ->: expT(dt2, write)
- input :: expT((n - 1 + sz)`.`dt1, read)
- override val t: ExpType = expT(n`.`(sz`.`dt2), read)
+final case class CircularBuffer(val n: Nat, val alloc: Nat, val sz: Nat, val dt1: DataType, val dt2: DataType, val load: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ load :: FunType(expT(dt1, read), expT(dt2, write))
+ input :: expT(ArrayType(n - 1 + sz, dt1), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, ArrayType(sz, dt2)), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): CircularBuffer = new CircularBuffer(v.nat(n), v.nat(alloc), v.nat(sz), v.data(dt1), v.data(dt2), VisitAndRebuild(load, v), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Cycle.scala b/src/main/scala/shine/DPIA/primitives/functional/Cycle.scala
index 8a1e4525f..995185f97 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Cycle.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Cycle.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-// cycles on the m elements of an array (modulo indexing) to produce an array of n elements
-@expPrimitive
-final case class Cycle(n: Nat,
- m: Nat,
- dt: DataType,
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- input :: expT(m`.`dt, read)
- override val t: ExpType = expT(n`.`dt, read)
+final case class Cycle(val n: Nat, val m: Nat, val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ input :: expT(ArrayType(m, dt), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Cycle = new Cycle(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/DMatch.scala b/src/main/scala/shine/DPIA/primitives/functional/DMatch.scala
index 647b8857a..074a6d781 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/DMatch.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/DMatch.scala
@@ -3,9 +3,7 @@ package shine.DPIA.primitives.functional
import shine.DPIA.Phrases._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-@expPrimitive
final case class DMatch(x: NatIdentifier,
elemT: DataType,
outT: DataType,
@@ -14,4 +12,7 @@ final case class DMatch(x: NatIdentifier,
input: Phrase[ExpType]
) extends ExpPrimitive {
override val t: ExpType = expT(outT, a)
+
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[ExpType] =
+ DMatch(v.nat(x), v.data(elemT), v.data(outT), v.access(a), VisitAndRebuild(f, v), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepIdx.scala b/src/main/scala/shine/DPIA/primitives/functional/DepIdx.scala
index 4e7aee8ff..be23be676 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/DepIdx.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/DepIdx.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class DepIdx(n: Nat,
- ft: NatToData,
- index: Nat,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT(n`.d`ft, read)
- override val t: ExpType = expT(ft(index), read)
+final case class DepIdx(val n: Nat, val ft: NatToData, val index: Nat, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(DepArrayType(n, ft), read)
+ }
+ override val t: ExpType = expT(NatToDataApply(ft, index), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepIdx = new DepIdx(v.nat(n), v.natToData(ft), v.nat(index), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepJoin.scala b/src/main/scala/shine/DPIA/primitives/functional/DepJoin.scala
index 8e63454d3..a853edac8 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/DepJoin.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/DepJoin.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
-import arithexpr.arithmetic.BigSum
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class DepJoin(n: Nat,
- lenF: NatToNat,
- dt: DataType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT(n`.d`{ i => lenF(i) `.` dt }, read)
- override val t: ExpType = expT(BigSum(from = 0, upTo = n - 1, i => lenF(i))`.`dt, read)
+final case class DepJoin(val n: Nat, val lenF: NatToNat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(DepArrayType(n, n2dtFun { (i: NatIdentifier) => ArrayType(lenF(i), dt) }), read)
+ }
+ override val t: ExpType = expT(ArrayType(BigSum(from = 0, upTo = n - 1, (i: Nat) => lenF(i)), dt), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepJoin = new DepJoin(v.nat(n), v.natToNat(lenF), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala
index a9d41121e..2dede1eb0 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala
@@ -1,24 +1,21 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-//noinspection TypeAnnotation
-@expPrimitive
-final case class DepMapSeq(unroll: Boolean)
- (val n: Nat,
- val ft1: NatToData,
- val ft2: NatToData,
- val f: Phrase[`(nat)->:`[ExpType ->: ExpType]],
- val array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: f.t.x ->: expT(ft1(f.t.x), read) ->: expT(ft2(f.t.x), write)
- array :: expT(n `.d` ft1, read)
- override val t: ExpType = expT(n`.d`ft2, write)
-
- def unwrap: (Nat, NatToData, NatToData, Phrase[`(nat)->:`[ExpType ->: ExpType]], Phrase[ExpType]) =
- (n, ft1, ft2, f, array)
+final case class DepMapSeq(unroll: Boolean)(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: ({
+ val k = f.t.x
+ DepFunType[NatKind, PhraseType](k, FunType(expT(NatToDataApply(ft1, k), read), expT(NatToDataApply(ft2, k), write)))
+ })
+ array :: expT(DepArrayType(n, ft1), read)
+ }
+ override val t: ExpType = expT(DepArrayType(n, ft2), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepMapSeq = new DepMapSeq(unroll)(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
+ def unwrap: (Nat, NatToData, NatToData, Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], Phrase[ExpType]) = (n, ft1, ft2, f, array)
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepTile.scala b/src/main/scala/shine/DPIA/primitives/functional/DepTile.scala
index 8d379e32e..be95a8fab 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/DepTile.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/DepTile.scala
@@ -8,9 +8,7 @@ import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-@expPrimitive
final case class DepTile(n: Nat, tileSize: Nat, haloSize: Nat,
dt1: DataType, dt2: DataType,
processTiles: Phrase[ExpType ->: ExpType],
@@ -28,4 +26,8 @@ final case class DepTile(n: Nat, tileSize: Nat, haloSize: Nat,
expT(allTiles `.d` (i => (depSize(i) `.` dt2)), write))
array :: expT((n + haloSize)`.`dt1, read)
override val t: ExpType = expT(n`.`dt2, write)
+
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[ExpType] =
+ DepTile(v.nat(n), v.nat(tileSize), v.nat(haloSize), v.data(dt1), v.data(dt2),
+ VisitAndRebuild(processTiles, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepZip.scala b/src/main/scala/shine/DPIA/primitives/functional/DepZip.scala
index b5b84244f..8864e226e 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/DepZip.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/DepZip.scala
@@ -1,19 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class DepZip(n: Nat,
- ft1: NatToData,
- ft2: NatToData,
- e1: Phrase[ExpType],
- e2: Phrase[ExpType]
- ) extends ExpPrimitive {
- e1 :: expT(n`.d`ft1, read)
- e2 :: expT(n`.d`ft2, read)
- override val t: ExpType = expT(n`.d`{ i => PairType(ft1(i), ft2(i)) }, read)
+final case class DepZip(val n: Nat, val ft1: NatToData, val ft2: NatToData, val e1: Phrase[ExpType], val e2: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ e1 :: expT(DepArrayType(n, ft1), read)
+ e2 :: expT(DepArrayType(n, ft2), read)
+ }
+ override val t: ExpType = expT(DepArrayType(n, n2dtFun { (i: NatIdentifier) => PairType(NatToDataApply(ft1, i), NatToDataApply(ft2, i)) }), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepZip = new DepZip(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(e1, v), VisitAndRebuild(e2, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Drop.scala b/src/main/scala/shine/DPIA/primitives/functional/Drop.scala
index 81f3961cc..c90942c95 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Drop.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Drop.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-// this drops n many elements from an array of n + m elements
-@expPrimitive
-final case class Drop(n: Nat,
- m: Nat,
- dt: DataType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT((n + m)`.`dt, read)
- override val t: ExpType = expT(m`.`dt, read)
+final case class Drop(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(n + m, dt), read)
+ }
+ override val t: ExpType = expT(ArrayType(m, dt), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Drop = new Drop(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala b/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala
index a2dea107a..714cd3b23 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala
@@ -1,28 +1,14 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
-import rise.{core => lc}
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-object ForeignFunction {
- val Declaration: lc.ForeignFunction.Decl.type = lc.ForeignFunction.Decl
- val Definition: lc.ForeignFunction.Def.type = lc.ForeignFunction.Def
- type Declaration = lc.ForeignFunction.Decl
- type Definition = lc.ForeignFunction.Def
-}
-
-@expPrimitive
-final case class ForeignFunctionCall(funDecl: ForeignFunction.Declaration,
- inTs: Seq[DataType],
- outT: DataType,
- args: Seq[Phrase[ExpType]]
- ) extends ExpPrimitive {
- (inTs zip args).foreach {
- case (inT, arg) => arg :: expT(inT, read)
- }
+final case class ForeignFunctionCall(funDecl: rise.core.ForeignFunction.Decl, inTs: Seq[DataType], args: Seq[Phrase[ExpType]])(val outT: DataType) extends ExpPrimitive {
+ {}
override val t: ExpType = expT(outT, read)
-
- override def prettyPrint: String = s"${funDecl.name}(${args.map(PrettyPhrasePrinter(_)).mkString(",")})"
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForeignFunctionCall = new ForeignFunctionCall(funDecl, inTs.map(v.data), args.map(VisitAndRebuild(_, v)))(v.data(outT))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Fst.scala b/src/main/scala/shine/DPIA/primitives/functional/Fst.scala
index 41c6d331f..8230c972f 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Fst.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Fst.scala
@@ -1,15 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Fst(dt1: DataType,
- dt2: DataType,
- pair: Phrase[ExpType]
- ) extends ExpPrimitive {
- pair :: expT(dt1 x dt2, read)
+final case class Fst(val dt1: DataType, val dt2: DataType, val pair: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ pair :: expT(PairType(dt1, dt2), read)
+ }
override val t: ExpType = expT(dt1, read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Fst = new Fst(v.data(dt1), v.data(dt2), VisitAndRebuild(pair, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Gather.scala b/src/main/scala/shine/DPIA/primitives/functional/Gather.scala
index 966a34524..343491501 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Gather.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Gather.scala
@@ -1,19 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Gather(n: Nat,
- m: Nat,
- dt: DataType,
- indices: Phrase[ExpType],
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- indices :: expT(m`.`idx(n), read)
- input :: expT(n`.`dt, read)
- override val t: ExpType = expT(m`.`dt, read)
+final case class Gather(val n: Nat, val m: Nat, val dt: DataType, val indices: Phrase[ExpType], val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ indices :: expT(ArrayType(m, IndexType(n)), read)
+ input :: expT(ArrayType(n, dt), read)
+ }
+ override val t: ExpType = expT(ArrayType(m, dt), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Gather = new Gather(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(indices, v), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Generate.scala b/src/main/scala/shine/DPIA/primitives/functional/Generate.scala
index b94552cfc..7cb88033c 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Generate.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Generate.scala
@@ -1,16 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Generate(n: Nat,
- dt: DataType,
- f: Phrase[ExpType ->: ExpType]
- ) extends ExpPrimitive {
- f :: expT(idx(n), read) ->: expT(dt, read)
- override val t: ExpType = expT(n`.`dt, read)
+final case class Generate(val n: Nat, val dt: DataType, val f: Phrase[FunType[ExpType, ExpType]]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(IndexType(n), read), expT(dt, read))
+ }
+ override val t: ExpType = expT(ArrayType(n, dt), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Generate = new Generate(v.nat(n), v.data(dt), VisitAndRebuild(f, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Idx.scala b/src/main/scala/shine/DPIA/primitives/functional/Idx.scala
index acf1d6067..8ccfab1ea 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Idx.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Idx.scala
@@ -1,18 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Idx(n: Nat,
- dt: DataType,
- index: Phrase[ExpType],
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- index :: expT(idx(n), read)
- array :: expT(n`.`dt, read)
+final case class Idx(val n: Nat, val dt: DataType, val index: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ index :: expT(IndexType(n), read)
+ array :: expT(ArrayType(n, dt), read)
+ }
override val t: ExpType = expT(dt, read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Idx = new Idx(v.nat(n), v.data(dt), VisitAndRebuild(index, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/IdxVec.scala b/src/main/scala/shine/DPIA/primitives/functional/IdxVec.scala
index f1211e82b..d2e4c6229 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/IdxVec.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/IdxVec.scala
@@ -1,20 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-import scala.language.reflectiveCalls
-
-@expPrimitive
-final case class IdxVec(n: Nat,
- st: ScalarType,
- index: Phrase[ExpType],
- vector: Phrase[ExpType]
- ) extends ExpPrimitive {
- index :: expT(idx(n), read)
- vector :: expT(vec(n, st), read)
- override val t: ExpType = expT(st, read)
+final case class IdxVec(val n: Nat, val dt: DataType, val index: Phrase[ExpType], val vector: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ index :: expT(IndexType(n), read)
+ vector :: expT(VectorType(n, dt), read)
+ }
+ override val t: ExpType = expT(dt, read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxVec = new IdxVec(v.nat(n), v.data(dt), VisitAndRebuild(index, v), VisitAndRebuild(vector, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/IndexAsNat.scala b/src/main/scala/shine/DPIA/primitives/functional/IndexAsNat.scala
index 64466878d..94528a350 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/IndexAsNat.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/IndexAsNat.scala
@@ -1,15 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class IndexAsNat(n: Nat,
- e: Phrase[ExpType]
- ) extends ExpPrimitive {
- e :: expT(idx(n), read)
+final case class IndexAsNat(val n: Nat, val e: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ e :: expT(IndexType(n), read)
+ }
override val t: ExpType = expT(NatType, read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): IndexAsNat = new IndexAsNat(v.nat(n), VisitAndRebuild(e, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala b/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala
index 612e25af0..6b20d3e4a 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala
@@ -1,23 +1,20 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Iterate(n: Nat,
- m: Nat,
- k: Nat,
- dt: DataType,
- f: Phrase[`(nat)->:`[ExpType ->: ExpType]],
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
+final case class Iterate(val n: Nat, val m: Nat, val k: Nat, val dt: DataType, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive {
{
- val l = f.t.x
- f :: l ->: expT((l * n)`.`dt, read) ->: expT(l`.`dt, write)
- array :: expT((m * n.pow(k))`.`dt, read)
+ f :: ({
+ val l = f.t.x
+ DepFunType[NatKind, PhraseType](l, FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write)))
+ })
+ array :: expT(ArrayType(m * n.pow(k), dt), read)
}
- override val t: ExpType = expT(m`.`dt, write)
+ override val t: ExpType = expT(ArrayType(m, dt), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Iterate = new Iterate(v.nat(n), v.nat(m), v.nat(k), v.data(dt), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/IterateStream.scala b/src/main/scala/shine/DPIA/primitives/functional/IterateStream.scala
index 481ab7866..531bca06f 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/IterateStream.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/IterateStream.scala
@@ -1,19 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class IterateStream(n: Nat,
- dt1: DataType,
- dt2: DataType,
- f: Phrase[ExpType ->: ExpType],
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt1, read) ->: expT(dt2, write)
- array :: expT(n`.`dt1, read)
- override val t: ExpType = expT(n`.`dt2, write)
+final case class IterateStream(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt1, read), expT(dt2, write))
+ array :: expT(ArrayType(n, dt1), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt2), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): IterateStream = new IterateStream(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Join.scala b/src/main/scala/shine/DPIA/primitives/functional/Join.scala
index 593772603..c42697619 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Join.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Join.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Join(n: Nat,
- m: Nat,
- w: AccessType,
- dt: DataType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT(n`.`(m`.`dt), w)
- override val t: ExpType = expT((n * m)`.`dt, w)
+final case class Join(val n: Nat, val m: Nat, val a: AccessType, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(n, ArrayType(m, dt)), a)
+ }
+ override val t: ExpType = expT(ArrayType(n * m, dt), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Join = new Join(v.nat(n), v.nat(m), v.access(a), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Let.scala b/src/main/scala/shine/DPIA/primitives/functional/Let.scala
index ae66302cb..cf627e092 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Let.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Let.scala
@@ -1,18 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Let(dt1: DataType,
- dt2: DataType,
- access: AccessType,
- value: Phrase[ExpType],
- f: Phrase[ExpType ->: ExpType]
- ) extends ExpPrimitive {
- value :: expT(dt1, read)
- f :: expT(dt1, read) ->: expT(dt2, access)
- override val t: ExpType = expT(dt2, access)
+final case class Let(val dt1: DataType, val dt2: DataType, val a: AccessType, val value: Phrase[ExpType], val f: Phrase[FunType[ExpType, ExpType]]) extends ExpPrimitive {
+ {
+ value :: expT(dt1, read)
+ f :: FunType(expT(dt1, read), expT(dt2, a))
+ }
+ override val t: ExpType = expT(dt2, a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Let = new Let(v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(value, v), VisitAndRebuild(f, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala b/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala
index 0a2e9fe1a..528f9a149 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala
@@ -1,14 +1,15 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class MakeArray(dt: DataType,
- elements: Vector[Phrase[ExpType]]
- ) extends ExpPrimitive {
- override val t: ExpType = expT((elements.length: Nat)`.`dt, read)
+final case class MakeArray(elements: Vector[Phrase[ExpType]])(val n: Nat, val dt: DataType) extends ExpPrimitive {
+ {}
+ override val t: ExpType = expT(ArrayType(n, dt), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MakeArray = new MakeArray(elements.map(VisitAndRebuild(_, v)))(v.nat(n), v.data(dt))
+ def unwrap: (Nat, DataType) = (n, dt)
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/MakeDepPair.scala b/src/main/scala/shine/DPIA/primitives/functional/MakeDepPair.scala
index f3fa0dfb1..eae14229e 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/MakeDepPair.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/MakeDepPair.scala
@@ -3,13 +3,14 @@ package shine.DPIA.primitives.functional
import shine.DPIA.Phrases._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-@expPrimitive
final case class MakeDepPair(a: AccessType,
fst: NatIdentifier,
sndT: DataType,
snd: Phrase[ExpType]
) extends ExpPrimitive {
override val t: ExpType = expT(DepPairType(fst, sndT), a)
+
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[ExpType] =
+ MakeDepPair(v.access(a), v.nat(fst), v.data(sndT), VisitAndRebuild(snd, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/MakePair.scala b/src/main/scala/shine/DPIA/primitives/functional/MakePair.scala
index 278197c8a..38b54d25d 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/MakePair.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/MakePair.scala
@@ -1,18 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class MakePair(dt1: DataType,
- dt2: DataType,
- access: AccessType,
- fst: Phrase[ExpType],
- snd: Phrase[ExpType]
- ) extends ExpPrimitive {
- fst :: expT(dt1, access)
- snd :: expT(dt2, access)
- override val t: ExpType = expT(dt1 x dt2, access)
+final case class MakePair(val dt1: DataType, val dt2: DataType, val a: AccessType, val fst: Phrase[ExpType], val snd: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ fst :: expT(dt1, a)
+ snd :: expT(dt2, a)
+ }
+ override val t: ExpType = expT(PairType(dt1, dt2), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MakePair = new MakePair(v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(fst, v), VisitAndRebuild(snd, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Map.scala b/src/main/scala/shine/DPIA/primitives/functional/Map.scala
index b2c1b6868..64896fd85 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Map.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Map.scala
@@ -1,20 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Map(n: Nat,
- dt1: DataType,
- dt2: DataType,
- access: AccessType,
- f: Phrase[ExpType ->: ExpType],
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT(n`.`dt1, access)
- f :: expT(dt1, access) ->: expT(dt2, access)
- override val t: ExpType = expT(n`.`dt2, access)
+final case class Map(val n: Nat, val dt1: DataType, val dt2: DataType, val a: AccessType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt1, a), expT(dt2, a))
+ array :: expT(ArrayType(n, dt1), a)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt2), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Map = new Map(v.nat(n), v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapFst.scala b/src/main/scala/shine/DPIA/primitives/functional/MapFst.scala
index b809e174d..5005b6e07 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/MapFst.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/MapFst.scala
@@ -1,19 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class MapFst(w: AccessType,
- dt1: DataType,
- dt2: DataType,
- dt3: DataType,
- f: Phrase[ExpType ->: ExpType],
- record: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt1, w) ->: expT(dt3, w)
- record :: expT(dt1 x dt2, w)
- override val t: ExpType = expT(dt3 x dt2, w)
+final case class MapFst(val a: AccessType, val dt1: DataType, val dt2: DataType, val dt3: DataType, val f: Phrase[FunType[ExpType, ExpType]], val pair: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt1, a), expT(dt3, a))
+ pair :: expT(PairType(dt1, dt2), a)
+ }
+ override val t: ExpType = expT(PairType(dt3, dt2), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapFst = new MapFst(v.access(a), v.data(dt1), v.data(dt2), v.data(dt3), VisitAndRebuild(f, v), VisitAndRebuild(pair, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/MapSeq.scala
index ace4952e6..e18bf1a2b 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/MapSeq.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/MapSeq.scala
@@ -1,23 +1,18 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class MapSeq(unroll: Boolean)
- (val n: Nat,
- val dt1: DataType,
- val dt2: DataType,
- val f: Phrase[ExpType ->: ExpType],
- val array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt1, read) ->: expT(dt2, write)
- array :: expT(n`.`dt1, read)
- override val t: ExpType = expT(n`.`dt2, write)
-
- def unwrap: (Nat, DataType, DataType, Phrase[ExpType ->: ExpType], Phrase[ExpType]) =
- (n, dt1, dt2, f, array)
+final case class MapSeq(unroll: Boolean)(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt1, read), expT(dt2, write))
+ array :: expT(ArrayType(n, dt1), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt2), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapSeq = new MapSeq(unroll)(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
+ def unwrap: (Nat, DataType, DataType, Phrase[FunType[ExpType, ExpType]], Phrase[ExpType]) = (n, dt1, dt2, f, array)
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapSnd.scala b/src/main/scala/shine/DPIA/primitives/functional/MapSnd.scala
index 46e11c666..410b3dc86 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/MapSnd.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/MapSnd.scala
@@ -1,19 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class MapSnd(w: AccessType,
- dt1: DataType,
- dt2: DataType,
- dt3: DataType,
- f: Phrase[ExpType ->: ExpType],
- record: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt2, w) ->: expT(dt3, w)
- record :: expT(dt1 x dt2, w)
- override val t: ExpType = expT(dt1 x dt3, w)
+final case class MapSnd(val a: AccessType, val dt1: DataType, val dt2: DataType, val dt3: DataType, val f: Phrase[FunType[ExpType, ExpType]], val pair: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt2, a), expT(dt3, a))
+ pair :: expT(PairType(dt1, dt2), a)
+ }
+ override val t: ExpType = expT(PairType(dt1, dt3), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapSnd = new MapSnd(v.access(a), v.data(dt1), v.data(dt2), v.data(dt3), VisitAndRebuild(f, v), VisitAndRebuild(pair, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapStream.scala b/src/main/scala/shine/DPIA/primitives/functional/MapStream.scala
index 755fd7d79..c79d554c6 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/MapStream.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/MapStream.scala
@@ -1,19 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class MapStream(n: Nat,
- dt1: DataType,
- dt2: DataType,
- f: Phrase[ExpType ->: ExpType],
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt1, read) ->: expT(dt2, write)
- array :: expT(n`.`dt1, read)
- override val t: ExpType = expT(n`.`dt2, write)
+final case class MapStream(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt1, read), expT(dt2, write))
+ array :: expT(ArrayType(n, dt1), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt2), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapStream = new MapStream(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapVec.scala b/src/main/scala/shine/DPIA/primitives/functional/MapVec.scala
index d79c5c142..8583c637d 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/MapVec.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/MapVec.scala
@@ -1,18 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class MapVec(n: Nat,
- dt1: ScalarType,
- dt2: ScalarType,
- f: Phrase[ExpType ->: ExpType],
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt1, read) ->: expT(dt2, write)
- array :: expT(vec(n, dt1), read)
- override val t: ExpType = expT(vec(n, dt2), write)
+final case class MapVec(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt1, read), expT(dt2, write))
+ array :: expT(VectorType(n, dt1), read)
+ }
+ override val t: ExpType = expT(VectorType(n, dt2), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapVec = new MapVec(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/NatAsIndex.scala b/src/main/scala/shine/DPIA/primitives/functional/NatAsIndex.scala
index 2c9f03482..a7a69904d 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/NatAsIndex.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/NatAsIndex.scala
@@ -1,15 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class NatAsIndex(n: Nat,
- e: Phrase[ExpType]
- ) extends ExpPrimitive {
- e :: expT(NatType, read)
- override val t: ExpType = expT(idx(n), read)
+final case class NatAsIndex(val n: Nat, val e: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ e :: expT(NatType, read)
+ }
+ override val t: ExpType = expT(IndexType(n), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): NatAsIndex = new NatAsIndex(v.nat(n), VisitAndRebuild(e, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Pad.scala b/src/main/scala/shine/DPIA/primitives/functional/Pad.scala
deleted file mode 100644
index e5d9f69b4..000000000
--- a/src/main/scala/shine/DPIA/primitives/functional/Pad.scala
+++ /dev/null
@@ -1,20 +0,0 @@
-package shine.DPIA.primitives.functional
-
-import shine.DPIA.Phrases._
-import shine.DPIA.Types.DataType._
-import shine.DPIA.Types._
-import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Pad(n: Nat,
- l: Nat,
- r: Nat,
- dt: DataType,
- padExp: Phrase[ExpType],
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- padExp :: expT(dt, read)
- array :: expT(n `.` dt, read)
- override val t: ExpType = expT((l + n + r)`.`dt, read)
-}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/PadClamp.scala b/src/main/scala/shine/DPIA/primitives/functional/PadClamp.scala
index f211bdfd6..d8f44a1b4 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/PadClamp.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/PadClamp.scala
@@ -1,19 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-// TODO: invalid for empty array
-@expPrimitive
-final case class PadClamp(n: Nat,
- l: Nat,
- r: Nat,
- dt: DataType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT(n `.` dt, read)
- override val t: ExpType = expT((l + n + r)`.`dt, read)
+final case class PadClamp(val n: Nat, val l: Nat, val r: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(n, dt), read)
+ }
+ override val t: ExpType = expT(ArrayType(l + n + r, dt), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): PadClamp = new PadClamp(v.nat(n), v.nat(l), v.nat(r), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/PadCst.scala b/src/main/scala/shine/DPIA/primitives/functional/PadCst.scala
new file mode 100644
index 000000000..18cf95c23
--- /dev/null
+++ b/src/main/scala/shine/DPIA/primitives/functional/PadCst.scala
@@ -0,0 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+package shine.DPIA.primitives.functional
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
+import shine.DPIA.Types._
+import shine.DPIA._
+final case class PadCst(val n: Nat, val l: Nat, val r: Nat, val dt: DataType, val padExp: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ padExp :: expT(dt, read)
+ array :: expT(ArrayType(n, dt), read)
+ }
+ override val t: ExpType = expT(ArrayType(l + n + r, dt), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): PadCst = new PadCst(v.nat(n), v.nat(l), v.nat(r), v.data(dt), VisitAndRebuild(padExp, v), VisitAndRebuild(array, v))
+}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/PadEmpty.scala b/src/main/scala/shine/DPIA/primitives/functional/PadEmpty.scala
index 3b8baf59e..ed4d9d75e 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/PadEmpty.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/PadEmpty.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class PadEmpty(n: Nat,
- r: Nat,
- dt: DataType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT(n `.` dt, write)
- override val t: ExpType = expT((n + r)`.`dt, write)
+final case class PadEmpty(val n: Nat, val r: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(n, dt), write)
+ }
+ override val t: ExpType = expT(ArrayType(n + r, dt), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): PadEmpty = new PadEmpty(v.nat(n), v.nat(r), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Partition.scala b/src/main/scala/shine/DPIA/primitives/functional/Partition.scala
index 8a55e17e3..74bdf5622 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Partition.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Partition.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Partition(n: Nat,
- m: Nat,
- lenF: NatToNat,
- dt: DataType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT(n`.`dt, read)
- override val t: ExpType = expT(m`.d`{ i => lenF(i)`.`dt }, read)
+final case class Partition(val n: Nat, val m: Nat, val lenF: NatToNat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(n, dt), read)
+ }
+ override val t: ExpType = expT(DepArrayType(m, n2dtFun { (i: NatIdentifier) => ArrayType(lenF(i), dt) }), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Partition = new Partition(v.nat(n), v.nat(m), v.natToNat(lenF), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/PrintType.scala b/src/main/scala/shine/DPIA/primitives/functional/PrintType.scala
index 8f70bfe60..0fd80c754 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/PrintType.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/PrintType.scala
@@ -3,9 +3,7 @@ package shine.DPIA.primitives.functional
import shine.DPIA.Phrases._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-@expPrimitive
final case class PrintType(msg: String,
dt: DataType,
access: AccessType,
@@ -15,4 +13,7 @@ final case class PrintType(msg: String,
input :: expT(dt, access)
override val t: ExpType = expT(dt, access)
+
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[ExpType] =
+ PrintType(msg, v.data(dt), v.access(access), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/ReduceSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/ReduceSeq.scala
index cee2f2a89..bb07ef410 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/ReduceSeq.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/ReduceSeq.scala
@@ -1,25 +1,19 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class ReduceSeq(unroll: Boolean)
- (val n: Nat,
- val dt1: DataType,
- val dt2: DataType,
- val f: Phrase[ExpType ->: ExpType ->: ExpType],
- val init: Phrase[ExpType],
- val array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt2, read) ->: expT(dt1, read) ->: expT(dt2, write)
- init :: expT(dt2, write)
- array :: expT(n`.`dt1, read)
+final case class ReduceSeq(unroll: Boolean)(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], val init: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt2, read), FunType(expT(dt1, read), expT(dt2, write)))
+ init :: expT(dt2, write)
+ array :: expT(ArrayType(n, dt1), read)
+ }
override val t: ExpType = expT(dt2, read)
-
- def unwrap: (Nat, DataType, DataType, Phrase[ExpType ->: ExpType ->: ExpType], Phrase[ExpType], Phrase[ExpType]) =
- (n, dt1, dt2, f, init, array)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ReduceSeq = new ReduceSeq(unroll)(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(init, v), VisitAndRebuild(array, v))
+ def unwrap: (Nat, DataType, DataType, Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], Phrase[ExpType], Phrase[ExpType]) = (n, dt1, dt2, f, init, array)
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Reorder.scala b/src/main/scala/shine/DPIA/primitives/functional/Reorder.scala
index 2dce81633..9c9457bf1 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Reorder.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Reorder.scala
@@ -1,19 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Reorder(n: Nat,
- dt: DataType,
- access: AccessType,
- idxF: NatToNat,
- idxFinv: NatToNat,
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- input :: expT(n`.`dt, access)
- override val t: ExpType = expT(n`.`dt, access)
+final case class Reorder(val n: Nat, val dt: DataType, val a: AccessType, val idxF: NatToNat, val idxFiv: NatToNat, val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ input :: expT(ArrayType(n, dt), a)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Reorder = new Reorder(v.nat(n), v.data(dt), v.access(a), v.natToNat(idxF), v.natToNat(idxFiv), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/RotateValues.scala b/src/main/scala/shine/DPIA/primitives/functional/RotateValues.scala
index a8d2f0b07..04cbf6714 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/RotateValues.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/RotateValues.scala
@@ -1,19 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class RotateValues(n: Nat,
- sz: Nat,
- dt: DataType,
- write: Phrase[ExpType ->: ExpType],
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- write :: expT(dt, read) ->: expT(dt, shine.DPIA.Types.write)
- input :: expT((n - 1 + sz)`.`dt, read)
- override val t: ExpType = expT(n`.`(sz`.`dt), read)
+final case class RotateValues(val n: Nat, val sz: Nat, val dt: DataType, val wrt: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ wrt :: FunType(expT(dt, read), expT(dt, write))
+ input :: expT(ArrayType(n - 1 + sz, dt), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, ArrayType(sz, dt)), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): RotateValues = new RotateValues(v.nat(n), v.nat(sz), v.data(dt), VisitAndRebuild(wrt, v), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/ScanSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/ScanSeq.scala
index 48acab9ca..6be6db44b 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/ScanSeq.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/ScanSeq.scala
@@ -1,22 +1,18 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-//noinspection TypeAnnotation
-@expPrimitive
-final case class ScanSeq(n: Nat,
- dt1: DataType,
- dt2: DataType,
- f: Phrase[ExpType ->: ExpType ->: ExpType],
- init: Phrase[ExpType],
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt1, read) ->: expT(dt2, read) ->: expT(dt2, write)
- init :: expT(dt2, write)
- array :: expT(n`.`dt1, read)
- override val t: ExpType = expT(n`.`dt2, read)
+final case class ScanSeq(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], val init: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt1, read), FunType(expT(dt2, read), expT(dt2, write)))
+ init :: expT(dt2, write)
+ array :: expT(ArrayType(n, dt1), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt2), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ScanSeq = new ScanSeq(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(init, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Scatter.scala b/src/main/scala/shine/DPIA/primitives/functional/Scatter.scala
index 9f77dd186..85c2bbbf7 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Scatter.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Scatter.scala
@@ -1,17 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Scatter(n: Nat, m: Nat, dt: DataType,
- indices: Phrase[ExpType],
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- indices :: expT(n`.`idx(m), read)
- input :: expT(n`.`dt, write)
- override val t: ExpType = expT(m`.`dt, write)
+final case class Scatter(val n: Nat, val m: Nat, val dt: DataType, val indices: Phrase[ExpType], val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ indices :: expT(ArrayType(n, IndexType(m)), read)
+ input :: expT(ArrayType(n, dt), write)
+ }
+ override val t: ExpType = expT(ArrayType(m, dt), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Scatter = new Scatter(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(indices, v), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Slide.scala b/src/main/scala/shine/DPIA/primitives/functional/Slide.scala
index 4a71ba49a..92fdf0bbf 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Slide.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Slide.scala
@@ -1,23 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
-import arithexpr.arithmetic.SimplifiedExpr
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-import scala.language.reflectiveCalls
-
-@expPrimitive
-final case class Slide(n: Nat,
- sz: Nat,
- sp: Nat,
- dt: DataType,
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- val inputSize: Nat with SimplifiedExpr = sp * n + sz - sp
-
- input :: expT(inputSize`.`dt, read)
- override val t: ExpType = expT(n`.`(sz`.`dt), read)
+final case class Slide(val n: Nat, val sz: Nat, val sp: Nat, val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ input :: expT(ArrayType(sp * n + sz, dt), read)
+ }
+ override val t: ExpType = expT(ArrayType(1 + n, ArrayType(sz, dt)), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Slide = new Slide(v.nat(n), v.nat(sz), v.nat(sp), v.data(dt), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Snd.scala b/src/main/scala/shine/DPIA/primitives/functional/Snd.scala
index 214151764..77a2e9488 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Snd.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Snd.scala
@@ -1,15 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Snd(dt1: DataType,
- dt2: DataType,
- pair: Phrase[ExpType]
- ) extends ExpPrimitive {
- pair :: expT(dt1 x dt2, read)
+final case class Snd(val dt1: DataType, val dt2: DataType, val pair: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ pair :: expT(PairType(dt1, dt2), read)
+ }
override val t: ExpType = expT(dt2, read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Snd = new Snd(v.data(dt1), v.data(dt2), VisitAndRebuild(pair, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Split.scala b/src/main/scala/shine/DPIA/primitives/functional/Split.scala
index 3b5ead8fd..426e6af3a 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Split.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Split.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Split(n: Nat,
- m: Nat,
- w: AccessType,
- dt: DataType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT((m * n)`.`dt, w)
- override val t: ExpType = expT(m`.`(n`.`dt), w)
+final case class Split(val n: Nat, val m: Nat, val a: AccessType, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(m * n, dt), a)
+ }
+ override val t: ExpType = expT(ArrayType(m, ArrayType(n, dt)), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Split = new Split(v.nat(n), v.nat(m), v.access(a), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Take.scala b/src/main/scala/shine/DPIA/primitives/functional/Take.scala
index 8a9c4ba7f..ffd625320 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Take.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Take.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-// this takes n many elements from an array of n + m elements
-@expPrimitive
-final case class Take(n: Nat,
- m: Nat,
- dt: DataType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT((n + m)`.`dt, read)
- override val t: ExpType = expT(n`.`dt, read)
+final case class Take(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(n + m, dt), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Take = new Take(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/ToMem.scala b/src/main/scala/shine/DPIA/primitives/functional/ToMem.scala
index fd675f837..afd8d1bdf 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/ToMem.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/ToMem.scala
@@ -1,14 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class ToMem(dt: DataType,
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- input :: expT(dt, write)
+final case class ToMem(val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ input :: expT(dt, write)
+ }
override val t: ExpType = expT(dt, read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ToMem = new ToMem(v.data(dt), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Transpose.scala b/src/main/scala/shine/DPIA/primitives/functional/Transpose.scala
index b72b27e8b..ecdf22dc1 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Transpose.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Transpose.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Transpose(n: Nat,
- m: Nat,
- dt: DataType,
- access: AccessType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT(n`.`(m`.`dt), access)
- override val t: ExpType = expT(m`.`(n`.`dt), access)
+final case class Transpose(val n: Nat, val m: Nat, val dt: DataType, val a: AccessType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(n, ArrayType(m, dt)), a)
+ }
+ override val t: ExpType = expT(ArrayType(m, ArrayType(n, dt)), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Transpose = new Transpose(v.nat(n), v.nat(m), v.data(dt), v.access(a), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/TransposeDepArray.scala b/src/main/scala/shine/DPIA/primitives/functional/TransposeDepArray.scala
index 1891bd369..276a3f223 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/TransposeDepArray.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/TransposeDepArray.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class TransposeDepArray(n: Nat,
- m: Nat,
- f: NatToData,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT(n`.`(m`.d`f), read)
- override val t: ExpType = expT(m`.d`{ k => n`.`f(k) }, read)
+final case class TransposeDepArray(val n: Nat, val m: Nat, val ft: NatToData, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(n, DepArrayType(m, ft)), read)
+ }
+ override val t: ExpType = expT(DepArrayType(n, n2dtFun { (i: NatIdentifier) => ArrayType(n, NatToDataApply(ft, i)) }), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): TransposeDepArray = new TransposeDepArray(v.nat(n), v.nat(m), v.natToData(ft), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Unzip.scala b/src/main/scala/shine/DPIA/primitives/functional/Unzip.scala
index 2254ef12c..12d39f4bf 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Unzip.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Unzip.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Unzip(n: Nat,
- dt1: DataType,
- dt2: DataType,
- access: AccessType,
- e: Phrase[ExpType]
- ) extends ExpPrimitive {
- e :: expT(n`.`(dt1 x dt2), access)
- override val t: ExpType = expT((n`.`dt1) x (n`.`dt2), access)
+final case class Unzip(val n: Nat, val dt1: DataType, val dt2: DataType, val a: AccessType, val e: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ e :: expT(ArrayType(n, PairType(dt1, dt2)), a)
+ }
+ override val t: ExpType = expT(PairType(ArrayType(n, dt1), ArrayType(n, dt2)), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Unzip = new Unzip(v.nat(n), v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(e, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/VectorFromScalar.scala b/src/main/scala/shine/DPIA/primitives/functional/VectorFromScalar.scala
index 4f480c8f4..91b183b13 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/VectorFromScalar.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/VectorFromScalar.scala
@@ -1,15 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class VectorFromScalar(n: Nat,
- dt: ScalarType,
- arg: Phrase[ExpType]
- ) extends ExpPrimitive {
- arg :: expT(dt, read)
- override val t: ExpType = expT(vec(n, dt), read)
+final case class VectorFromScalar(val n: Nat, val dt: DataType, val arg: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ arg :: expT(dt, read)
+ }
+ override val t: ExpType = expT(VectorType(n, dt), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): VectorFromScalar = new VectorFromScalar(v.nat(n), v.data(dt), VisitAndRebuild(arg, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/Zip.scala b/src/main/scala/shine/DPIA/primitives/functional/Zip.scala
index d90cc6ece..3f0c328df 100644
--- a/src/main/scala/shine/DPIA/primitives/functional/Zip.scala
+++ b/src/main/scala/shine/DPIA/primitives/functional/Zip.scala
@@ -1,20 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Zip(n: Nat,
- dt1: DataType,
- dt2: DataType,
- access: AccessType,
- e1: Phrase[ExpType],
- e2: Phrase[ExpType]
- ) extends ExpPrimitive {
- e1 :: expT(n`.`dt1, access)
- e2 :: expT(n`.`dt2, access)
- override val t: ExpType = expT(n`.`(dt1 x dt2), access)
+final case class Zip(val n: Nat, val dt1: DataType, val dt2: DataType, val a: AccessType, val e1: Phrase[ExpType], val e2: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ e1 :: expT(ArrayType(n, dt1), a)
+ e2 :: expT(ArrayType(n, dt2), a)
+ }
+ override val t: ExpType = expT(ArrayType(n, PairType(dt1, dt2)), a)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Zip = new Zip(v.nat(n), v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(e1, v), VisitAndRebuild(e2, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/functional/primitives.dpia b/src/main/scala/shine/DPIA/primitives/functional/primitives.dpia
new file mode 100644
index 000000000..37e8ede1c
--- /dev/null
+++ b/src/main/scala/shine/DPIA/primitives/functional/primitives.dpia
@@ -0,0 +1,136 @@
+def asScalar(n: nat, m: nat, dt: data, a: access, array: exp[m.vec[dt, n], a]): exp[(m*n).dt, a]
+def asVector(n: nat, m: nat, dt: data, a: access, array: exp[(m*n).dt, a]): exp[m.vec[dt, n], a]
+def asVectorAligned(n: nat, m: nat, dt: data, a: access, array: exp[(m*n).dt, a]): exp[m.vec[dt, n], a]
+
+def cast(dt1: data, dt2: data, e: exp[dt1, read]): exp[dt2, read]
+
+def circularBuffer(n: nat, alloc: nat, sz: nat, dt1: data, dt2: data,
+ load: exp[dt1, read] -> exp[dt2, write],
+ input: exp[(n-1+sz).dt1, read]): exp[n.sz.dt2, read]
+
+def cycle(n: nat, m: nat, dt: data, input: exp[m.dt, read]): exp[n.dt, read]
+
+def depIdx(n: nat, ft: nat2data, index: nat, array: exp[n..ft, read]): exp[ft(index), read]
+
+def depJoin(n: nat, lenF: nat2nat, dt: data,
+ array: exp[n..(i: nat |-> lenF(i).dt), read]): exp[(sum_(i=0)^(n-1) lenF(i)).dt, read]
+
+def depMapSeq{unroll: Boolean}
+ (n: nat, ft1: nat2data, ft2: nat2data,
+ f: (k: nat) -> exp[ft1(k), read] -> exp[ft2(k), write],
+ array: exp[n..ft1, read]): exp[n..ft2, write]
+
+// def depTile(...)
+
+def depZip(n: nat, ft1: nat2data, ft2: nat2data,
+ e1: exp[n..ft1, read], e2: exp[n..ft2, read]): exp[n..(i: nat |-> (ft1(i), ft2(i)) ), read]
+
+// def dmatch(...)
+
+def drop(n: nat, m: nat, dt: data, array: exp[n+m.dt, read]): exp[m.dt, read]
+
+def foreignFunctionCall{funDecl: rise.core.ForeignFunction.Decl,
+ inTs: Seq[DataType],
+ args: Seq[Phrase[ExpType]]}(outT: data): exp[outT, read]
+
+def fst(dt1: data, dt2: data, pair: exp[(dt1, dt2), read]): exp[dt1, read]
+
+def gather(n: nat, m: nat, dt: data, indices: exp[m.idx[n], read], input: exp[n.dt, read]): exp[m.dt, read]
+
+def generate(n: nat, dt: data, f: exp[idx[n], read] -> exp[dt, read]): exp[n.dt, read]
+
+def idx(n: nat, dt: data, index: exp[idx[n], read], array: exp[n.dt, read]): exp[dt, read]
+
+def idxVec(n: nat, dt: data, index: exp[idx[n], read], vector: exp[vec[dt, n], read]): exp[dt, read]
+
+def indexAsNat(n: nat, e: exp[idx[n], read]): exp[natType, read]
+
+def iterate(n: nat, m: nat, k: nat, dt: data,
+ f: (l: nat) -> exp[(l*n).dt, read] -> exp[l.dt, write],
+ array: exp[(m*(n^k)).dt, read]): exp[m.dt, write]
+
+def iterateStream(n: nat, dt1: data, dt2: data,
+ f: exp[dt1, read] -> exp[dt2, write],
+ array: exp[n.dt1, read]): exp[n.dt2, write]
+
+def join(n: nat, m: nat, a: access, dt: data, array: exp[n.m.dt, a]): exp[(n*m).dt, a]
+
+def let(dt1: data, dt2: data, a: access, value: exp[dt1, read],
+ f: exp[dt1, read] -> exp[dt2, a]): exp[dt2, a]
+
+def makeArray{elements: Vector[Phrase[ExpType]]}(n: nat, dt: data): exp[n.dt, read]
+
+// def makeDepPair(...)
+
+def makePair(dt1: data, dt2: data, a: access, fst: exp[dt1, a], snd: exp[dt2, a]): exp[(dt1, dt2), a]
+
+def map(n: nat, dt1: data, dt2: data, a: access,
+ f: exp[dt1, a] -> exp[dt2, a], array: exp[n.dt1, a]): exp[n.dt2, a]
+
+def mapFst(a: access, dt1: data, dt2: data, dt3: data,
+ f: exp[dt1, a] -> exp[dt3, a],
+ pair: exp[(dt1, dt2), a]): exp[(dt3, dt2), a]
+
+def mapSeq{unroll: Boolean}
+ (n: nat, dt1: data, dt2: data,
+ f: exp[dt1, read] -> exp[dt2, write], array: exp[n.dt1, read]): exp[n.dt2, write]
+
+def mapSnd(a: access, dt1: data, dt2: data, dt3: data,
+ f: exp[dt2, a] -> exp[dt3, a],
+ pair: exp[(dt1, dt2), a]): exp[(dt1, dt3), a]
+
+def mapStream(n: nat, dt1: data, dt2: data,
+ f: exp[dt1, read] -> exp[dt2, write], array: exp[n.dt1, read]): exp[n.dt2, write]
+
+def mapVec(n: nat, dt1: data, dt2: data,
+ f: exp[dt1, read] -> exp[dt2, write], array: exp[vec[dt1, n], read]): exp[vec[dt2, n], write]
+
+def natAsIndex(n: nat, e: exp[natType, read]): exp[idx[n], read]
+
+def padClamp(n: nat, l: nat, r: nat, dt: data, array: exp[n.dt, read]): exp[(l+n+r).dt, read]
+
+def padCst(n: nat, l: nat, r: nat, dt: data, padExp: exp[dt, read], array: exp[n.dt, read]): exp[(l+n+r).dt, read]
+
+def padEmpty(n: nat, r: nat, dt: data, array: exp[n.dt, write]): exp[(n+r).dt, write]
+
+def partition(n: nat, m: nat, lenF: nat2nat, dt: data, array: exp[n.dt, read]): exp[m..(i: nat |-> lenF(i).dt), read]
+
+// def printType{msg: String}(dt: data, a: access, input: exp[dt, a]): exp[dt, a]
+
+def reduceSeq{unroll: Boolean}
+ (n: nat, dt1: data, dt2: data,
+ f: exp[dt2, read] -> exp[dt1, read] -> exp[dt2, write],
+ init: exp[dt2, write],
+ array: exp[n.dt1, read]): exp[dt2, read]
+
+def reorder(n: nat, dt: data, a: access, idxF: nat2nat, idxFiv: nat2nat, input: exp[n.dt, a]): exp[n.dt, a]
+
+def rotateValues(n: nat, sz: nat, dt: data,
+ wrt: exp[dt, read] -> exp[dt, write],
+ input: exp[(n-1+sz).dt, read]): exp[n.sz.dt, read]
+
+def scanSeq(n: nat, dt1: data, dt2: data,
+ f: exp[dt1, read] -> exp[dt2, read] -> exp[dt2, write],
+ init: exp[dt2, write], array: exp[n.dt1, read]): exp[n.dt2, read]
+
+def scatter(n: nat, m: nat, dt: data, indices: exp[n.idx[m], read], input: exp[n.dt, write]): exp[m.dt, write]
+
+def slide(n: nat, sz: nat, sp: nat, dt: data, input: exp[(sp*n+sz).dt, read]): exp[(1+n).sz.dt, read]
+
+def snd(dt1: data, dt2: data, pair: exp[(dt1, dt2), read]): exp[dt2, read]
+
+def split(n: nat, m: nat, a: access, dt: data, array: exp[(m*n).dt, a]): exp[m.n.dt, a]
+
+def take(n: nat, m: nat, dt: data, array: exp[(n+m).dt, read]): exp[n.dt, read]
+
+def toMem(dt: data, input: exp[dt, write]): exp[dt, read]
+
+def transpose(n: nat, m: nat, dt: data, a: access, array: exp[n.m.dt, a]): exp[m.n.dt, a]
+
+def transposeDepArray(n: nat, m: nat, ft: nat2data, array: exp[n.m..ft, read]): exp[n..(i: nat |-> n.ft(i)), read]
+
+def unzip(n: nat, dt1: data, dt2: data, a: access, e: exp[n.(dt1, dt2), a]): exp[(n.dt1, n.dt2), a]
+
+def vectorFromScalar(n: nat, dt: data, arg: exp[dt, read]): exp[vec[dt, n], read]
+
+def zip(n: nat, dt1: data, dt2: data, a: access, e1: exp[n.dt1, a], e2: exp[n.dt2, a]): exp[n.(dt1, dt2), a]
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/AsScalarAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/AsScalarAcc.scala
index 98f0960b6..7c064f934 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/AsScalarAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/AsScalarAcc.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class AsScalarAcc(n: Nat,
- m: Nat,
- dt: ScalarType,
- array: Phrase[AccType]
- )extends AccPrimitive {
- array :: accT((m * n)`.`dt)
- override val t: AccType = accT(n`.`vec(m, dt))
+final case class AsScalarAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(ArrayType(m * n, dt))
+ }
+ override val t: AccType = accT(ArrayType(m, VectorType(n, dt)))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsScalarAcc = new AsScalarAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/AsVectorAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/AsVectorAcc.scala
index a8307b313..d68e07768 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/AsVectorAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/AsVectorAcc.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class AsVectorAcc(n: Nat,
- m: Nat,
- dt: ScalarType,
- array: Phrase[AccType]
- ) extends AccPrimitive {
- array :: accT(n`.`vec(m, dt))
- override val t: AccType = accT((n * m)`.`dt)
+final case class AsVectorAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(ArrayType(m, VectorType(n, dt)))
+ }
+ override val t: AccType = accT(ArrayType(n * m, dt))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsVectorAcc = new AsVectorAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/Assign.scala b/src/main/scala/shine/DPIA/primitives/imperative/Assign.scala
index b2ee75675..0451dcffa 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/Assign.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/Assign.scala
@@ -1,18 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-case class Assign(dt: DataType,
- lhs: Phrase[AccType],
- rhs: Phrase[ExpType]
- ) extends CommandPrimitive {
- lhs :: accT(dt)
- rhs :: expT(dt, read)
-
- override def prettyPrint: String = s"(${PrettyPhrasePrinter(lhs)} := ${PrettyPhrasePrinter(rhs)})"
+final case class Assign(val dt: DataType, val lhs: Phrase[AccType], val rhs: Phrase[ExpType]) extends CommandPrimitive {
+ {
+ lhs :: accT(dt)
+ rhs :: expT(dt, read)
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Assign = new Assign(v.data(dt), VisitAndRebuild(lhs, v), VisitAndRebuild(rhs, v))
}
-
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/Comment.scala b/src/main/scala/shine/DPIA/primitives/imperative/Comment.scala
index 3cba9d039..c1416c259 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/Comment.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/Comment.scala
@@ -1,9 +1,14 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
-import shine.DPIA.Phrases.CommandPrimitive
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class Comment(comment : String) extends CommandPrimitive {
- override def prettyPrint: String = s"\n//$comment\n"
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
+import shine.DPIA.Types._
+import shine.DPIA._
+final case class Comment(comment: String)() extends CommandPrimitive {
+ {}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Comment = new Comment(comment)()
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/CycleAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/CycleAcc.scala
index ab55902fa..06be5aec8 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/CycleAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/CycleAcc.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class CycleAcc(n: Nat,
- m: Nat,
- dt: DataType,
- input: Phrase[AccType]
- ) extends AccPrimitive {
- input :: accT(m`.`dt)
- override val t: AccType = accT(n`.`dt)
+final case class CycleAcc(val n: Nat, val m: Nat, val dt: DataType, val input: Phrase[AccType]) extends AccPrimitive {
+ {
+ input :: accT(ArrayType(m, dt))
+ }
+ override val t: AccType = accT(ArrayType(n, dt))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): CycleAcc = new CycleAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/DMatchI.scala b/src/main/scala/shine/DPIA/primitives/imperative/DMatchI.scala
index bf43207d2..374f16fae 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/DMatchI.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/DMatchI.scala
@@ -3,13 +3,13 @@ package shine.DPIA.primitives.imperative
import shine.DPIA.Phrases._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.comPrimitive
-@comPrimitive
final case class DMatchI(x: NatIdentifier,
elemT: DataType,
outT: DataType,
f: Phrase[`(nat)->:`[ExpType ->: CommType]],
input: Phrase[ExpType]
) extends CommandPrimitive {
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[CommType] =
+ DMatchI(v.nat(x), v.data(elemT), v.data(outT), VisitAndRebuild(f, v), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/DepIdxAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/DepIdxAcc.scala
index c80913c2d..a887fef02 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/DepIdxAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/DepIdxAcc.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class DepIdxAcc(n: Nat,
- ft:NatToData,
- index: Nat,
- array: Phrase[AccType]
- ) extends AccPrimitive {
- array :: accT(n`.d`ft)
- override val t: AccType = accT(ft(index))
+final case class DepIdxAcc(val n: Nat, val ft: NatToData, val index: Nat, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(DepArrayType(n, ft))
+ }
+ override val t: AccType = accT(NatToDataApply(ft, index))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepIdxAcc = new DepIdxAcc(v.nat(n), v.natToData(ft), v.nat(index), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/DepJoinAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/DepJoinAcc.scala
index 0491484d4..011ead569 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/DepJoinAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/DepJoinAcc.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
-import arithexpr.arithmetic.BigSum
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class DepJoinAcc(n: Nat,
- lenF:NatToNat,
- dt: DataType,
- array: Phrase[AccType]
- ) extends AccPrimitive {
- array :: accT(BigSum(from=0, upTo = n-1, i => lenF(i))`.`dt)
- override val t: AccType = accT(n`.d`{ i => ArrayType(lenF(i), dt) })
+final case class DepJoinAcc(val n: Nat, val lenF: NatToNat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(ArrayType(BigSum(from = 0, upTo = n - 1, (i: Nat) => lenF(i)), dt))
+ }
+ override val t: AccType = accT(DepArrayType(n, n2dtFun { (i: NatIdentifier) => ArrayType(lenF(i), dt) }))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepJoinAcc = new DepJoinAcc(v.nat(n), v.natToNat(lenF), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/DropAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/DropAcc.scala
index c50e0555c..a09617ffc 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/DropAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/DropAcc.scala
@@ -1,18 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
-import shine.DPIA.Types._
import shine.DPIA.Types.DataType._
+import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-// this drops n many elements from an array of m elements
-@accPrimitive
-final case class DropAcc(n: Nat,
- m: Nat,
- dt: DataType,
- array: Phrase[AccType]
- ) extends AccPrimitive {
- array :: accT({n + m}`.`dt)
- override val t: AccType = accT({m - n}`.`dt)
+final case class DropAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(ArrayType(n + m, dt))
+ }
+ override val t: AccType = accT(ArrayType(m - n, dt))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): DropAcc = new DropAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/For.scala b/src/main/scala/shine/DPIA/primitives/imperative/For.scala
index a1aeb6675..3d5d4ef1c 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/For.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/For.scala
@@ -1,20 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class For(unroll: Boolean)
- (val n: Nat,
- val loopBody: Phrase[ExpType ->: CommType]
- ) extends CommandPrimitive {
- loopBody :: expT(idx(n), read) ->: comm
-
- lazy val unwrapBody: (Identifier[ExpType], Phrase[CommType]) = loopBody match {
- case Lambda(i, body) => (i, body)
- case _ => throw new Exception("This should not happen")
+final case class For(unroll: Boolean)(val n: Nat, val loopBody: Phrase[FunType[ExpType, CommType]]) extends CommandPrimitive {
+ {
+ loopBody :: FunType(expT(IndexType(n), read), comm)
}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): For = new For(unroll)(v.nat(n), VisitAndRebuild(loopBody, v))
+ def unwrap: (Nat, Phrase[FunType[ExpType, CommType]]) = (n, loopBody)
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala b/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala
index 5bace44b0..6bc08c265 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala
@@ -1,19 +1,20 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class ForNat(unroll: Boolean)
- (val n: Nat,
- val loopBody: Phrase[`(nat)->:`[CommType]]
- ) extends CommandPrimitive {
- loopBody :: loopBody.t.x ->: comm
-
- lazy val unwrapBody: (NatIdentifier, Phrase[CommType]) = loopBody match {
- case DepLambda(i, body) => (i, body)
- case _ => throw new Exception("This should not happen")
+final case class ForNat(unroll: Boolean)(val n: Nat, val loopBody: Phrase[DepFunType[NatKind, CommType]]) extends CommandPrimitive {
+ {
+ loopBody :: ({
+ val i = loopBody.t.x
+ DepFunType[NatKind, PhraseType](i, comm)
+ })
}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForNat = new ForNat(unroll)(v.nat(n), VisitAndRebuild(loopBody, v))
+ def unwrap: (Nat, Phrase[DepFunType[NatKind, CommType]]) = (n, loopBody)
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ForVec.scala b/src/main/scala/shine/DPIA/primitives/imperative/ForVec.scala
index a4e9d8cf7..2d4967f8a 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/ForVec.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/ForVec.scala
@@ -1,17 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class ForVec(n: Nat,
- dt: ScalarType,
- out: Phrase[AccType],
- loopBody: Phrase[ExpType ->: AccType ->: CommType]
- ) extends CommandPrimitive {
- out :: accT(vec(n, dt))
- loopBody :: expT(idx(n), read) ->: accT(dt) ->: comm
+final case class ForVec(val n: Nat, val dt: DataType, val out: Phrase[AccType], val loopBody: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive {
+ {
+ out :: accT(VectorType(n, dt))
+ loopBody :: FunType(expT(IndexType(n), read), FunType(accT(dt), comm))
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForVec = new ForVec(v.nat(n), v.data(dt), VisitAndRebuild(out, v), VisitAndRebuild(loopBody, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/GenerateCont.scala b/src/main/scala/shine/DPIA/primitives/imperative/GenerateCont.scala
index 6bc8952ac..029fc9810 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/GenerateCont.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/GenerateCont.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-// note: would not be necessary if generate was defined as indices + map
-@expPrimitive
-final case class GenerateCont(n: Nat,
- dt: DataType,
- f: Phrase[ExpType ->: ((ExpType ->: CommType) ->: CommType)]
- ) extends ExpPrimitive {
- f :: expT(idx(n), read) ->: (expT(dt, read) ->: comm) ->: comm
- override val t: ExpType = expT(n`.`dt, read)
+final case class GenerateCont(val n: Nat, val dt: DataType, val f: Phrase[FunType[ExpType, FunType[FunType[ExpType, CommType], CommType]]]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(IndexType(n), read), FunType(FunType(expT(dt, read), comm), comm))
+ }
+ override val t: ExpType = expT(ArrayType(n, dt), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): GenerateCont = new GenerateCont(v.nat(n), v.data(dt), VisitAndRebuild(f, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/IdxAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/IdxAcc.scala
index 3cdfb04a0..09a1942a5 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/IdxAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/IdxAcc.scala
@@ -1,18 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class IdxAcc(n: Nat,
- dt: DataType,
- index: Phrase[ExpType],
- array: Phrase[AccType]
- )extends AccPrimitive {
- index :: expT(idx(n), read)
- array :: accT(n`.`dt)
+final case class IdxAcc(val n: Nat, val dt: DataType, val index: Phrase[ExpType], val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ index :: expT(IndexType(n), read)
+ array :: accT(ArrayType(n, dt))
+ }
override val t: AccType = accT(dt)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxAcc = new IdxAcc(v.nat(n), v.data(dt), VisitAndRebuild(index, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/IdxVecAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/IdxVecAcc.scala
index eb051e3b1..0566caa11 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/IdxVecAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/IdxVecAcc.scala
@@ -1,18 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class IdxVecAcc(n: Nat,
- st: ScalarType,
- index: Phrase[ExpType],
- vector: Phrase[AccType]
- ) extends AccPrimitive {
- index :: expT(idx(n), read)
- vector :: accT(vec(n, st))
- override val t: AccType = accT(st)
+final case class IdxVecAcc(val n: Nat, val dt: DataType, val index: Phrase[ExpType], val vector: Phrase[AccType]) extends AccPrimitive {
+ {
+ index :: expT(IndexType(n), read)
+ vector :: accT(VectorType(n, dt))
+ }
+ override val t: AccType = accT(dt)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxVecAcc = new IdxVecAcc(v.nat(n), v.data(dt), VisitAndRebuild(index, v), VisitAndRebuild(vector, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/JoinAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/JoinAcc.scala
index 3ba5223b0..fe0923988 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/JoinAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/JoinAcc.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class JoinAcc(n: Nat,
- m: Nat,
- dt: DataType,
- array: Phrase[AccType]
- ) extends AccPrimitive {
- array :: accT({n * m}`.`dt)
- override val t: AccType = accT(n`.`(m`.`dt))
+final case class JoinAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(ArrayType(n * m, dt))
+ }
+ override val t: AccType = accT(ArrayType(n, ArrayType(m, dt)))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): JoinAcc = new JoinAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MapAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/MapAcc.scala
index 9ca3a5aaa..b4cd17cec 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/MapAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/MapAcc.scala
@@ -1,19 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class MapAcc(n: Nat,
- dt1: DataType,
- dt2: DataType,
- f: Phrase[AccType ->: AccType],
- array: Phrase[AccType]
- ) extends AccPrimitive {
- f :: accT(dt1) ->: accT(dt2)
- array :: accT(n`.`dt1)
- override val t: AccType = accT(n`.`dt2)
+final case class MapAcc(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[AccType, AccType]], val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ f :: FunType(accT(dt1), accT(dt2))
+ array :: accT(ArrayType(n, dt1))
+ }
+ override val t: AccType = accT(ArrayType(n, dt2))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapAcc = new MapAcc(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MapFstAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/MapFstAcc.scala
index b7ff02336..4e90c0d42 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/MapFstAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/MapFstAcc.scala
@@ -1,17 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class MapFstAcc(dt1: DataType,
- dt2: DataType,
- dt3: DataType,
- f: Phrase[AccType ->: AccType],
- record: Phrase[AccType]) extends AccPrimitive {
- f :: accT(dt3) ->: accT(dt1)
- record :: accT(dt3 x dt2)
- override val t: AccType = accT(dt1 x dt2)
+final case class MapFstAcc(val dt1: DataType, val dt2: DataType, val dt3: DataType, val f: Phrase[FunType[AccType, AccType]], val record: Phrase[AccType]) extends AccPrimitive {
+ {
+ f :: FunType(accT(dt3), accT(dt1))
+ record :: accT(PairType(dt3, dt2))
+ }
+ override val t: AccType = accT(PairType(dt1, dt2))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapFstAcc = new MapFstAcc(v.data(dt1), v.data(dt2), v.data(dt3), VisitAndRebuild(f, v), VisitAndRebuild(record, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MapRead.scala b/src/main/scala/shine/DPIA/primitives/imperative/MapRead.scala
index 5a829a2fd..67fe5ad67 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/MapRead.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/MapRead.scala
@@ -1,19 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class MapRead(n: Nat,
- dt1: DataType,
- dt2: DataType,
- f: Phrase[ExpType ->: (ExpType ->: CommType) ->: CommType],
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt1, read) ->: (expT(dt2, read) ->: comm) ->: comm
- input :: expT(n`.`dt1, read)
- override val t: ExpType = expT(n`.`dt2, read)
+final case class MapRead(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[FunType[ExpType, CommType], CommType]]], val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt1, read), FunType(FunType(expT(dt2, read), comm), comm))
+ input :: expT(ArrayType(n, dt1), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt2), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapRead = new MapRead(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MapSndAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/MapSndAcc.scala
index 3d19844ce..3c163d1cf 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/MapSndAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/MapSndAcc.scala
@@ -1,17 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class MapSndAcc(dt1: DataType,
- dt2: DataType,
- dt3: DataType,
- f: Phrase[AccType ->: AccType],
- record: Phrase[AccType]) extends AccPrimitive {
- f :: accT(dt3) ->: accT(dt2)
- record :: accT(dt1 x dt3)
- override val t: AccType = accT(dt1 x dt2)
+final case class MapSndAcc(val dt1: DataType, val dt2: DataType, val dt3: DataType, val f: Phrase[FunType[AccType, AccType]], val record: Phrase[AccType]) extends AccPrimitive {
+ {
+ f :: FunType(accT(dt3), accT(dt2))
+ record :: accT(PairType(dt1, dt3))
+ }
+ override val t: AccType = accT(PairType(dt1, dt2))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapSndAcc = new MapSndAcc(v.data(dt1), v.data(dt2), v.data(dt3), VisitAndRebuild(f, v), VisitAndRebuild(record, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MkDPairFstI.scala b/src/main/scala/shine/DPIA/primitives/imperative/MkDPairFstI.scala
index af1f88de1..c85cbce16 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/MkDPairFstI.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/MkDPairFstI.scala
@@ -3,9 +3,9 @@ package shine.DPIA.primitives.imperative
import shine.DPIA.Phrases._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.comPrimitive
-@comPrimitive
final case class MkDPairFstI(fst: Nat,
A: Phrase[AccType]) extends CommandPrimitive {
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[CommType] =
+ MkDPairFstI(v.nat(fst), VisitAndRebuild(A, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MkDPairSndAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/MkDPairSndAcc.scala
index b6c4c865f..8df8914c7 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/MkDPairSndAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/MkDPairSndAcc.scala
@@ -3,11 +3,12 @@ package shine.DPIA.primitives.imperative
import shine.DPIA.NatIdentifier
import shine.DPIA.Phrases._
import shine.DPIA.Types.{AccType, DataType}
-import shine.macros.Primitive.accPrimitive
-@accPrimitive
final case class MkDPairSndAcc(fst: NatIdentifier,
sndT: DataType,
A: Phrase[AccType]) extends AccPrimitive {
override val t = AccType(sndT)
+
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[AccType] =
+ MkDPairSndAcc(v.nat(fst), v.data(sndT), VisitAndRebuild(A, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/New.scala b/src/main/scala/shine/DPIA/primitives/imperative/New.scala
index 265549e24..b376a80a9 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/New.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/New.scala
@@ -1,13 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class New(dt: DataType,
- f: Phrase[VarType ->: CommType]
- ) extends CommandPrimitive {
- f :: varT(dt) ->: comm
+final case class New(val dt: DataType, val f: Phrase[FunType[PhrasePairType[ExpType, AccType], CommType]]) extends CommandPrimitive {
+ {
+ f :: FunType(PhrasePairType(expT(dt, read), accT(dt)), comm)
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): New = new New(v.data(dt), VisitAndRebuild(f, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/NewDoubleBuffer.scala b/src/main/scala/shine/DPIA/primitives/imperative/NewDoubleBuffer.scala
index 6c4ceec3b..7001ed75c 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/NewDoubleBuffer.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/NewDoubleBuffer.scala
@@ -1,21 +1,18 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class NewDoubleBuffer(dt1: DataType,
- dt2: DataType,
- dt3: DataType,
- n: Nat,
- in: Phrase[ExpType],
- out: Phrase[AccType],
- f: Phrase[(ExpType x AccType x CommType x CommType) ->: CommType]
- )extends CommandPrimitive {
- in :: expT(dt1, read)
- out :: accT(dt2)
- f :: (((varT(n`.`dt3) x comm) x comm) ->: comm)
+final case class NewDoubleBuffer(val dt1: DataType, val dt2: DataType, val dt3: DataType, val n: Nat, val in: Phrase[ExpType], val out: Phrase[AccType], val f: Phrase[FunType[PhrasePairType[PhrasePairType[PhrasePairType[ExpType, AccType], CommType], CommType], CommType]]) extends CommandPrimitive {
+ {
+ in :: expT(dt1, read)
+ out :: accT(dt2)
+ f :: FunType(PhrasePairType(PhrasePairType(PhrasePairType(expT(ArrayType(n, dt3), read), accT(ArrayType(n, dt3))), comm), comm), comm)
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): NewDoubleBuffer = new NewDoubleBuffer(v.data(dt1), v.data(dt2), v.data(dt3), v.nat(n), VisitAndRebuild(in, v), VisitAndRebuild(out, v), VisitAndRebuild(f, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc.scala
index 09d9b11fc..b6f03c5c7 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc.scala
@@ -1,17 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class PairAcc(dt1: DataType,
- dt2: DataType,
- fst: Phrase[AccType],
- snd: Phrase[AccType]
- ) extends AccPrimitive {
- fst :: accT(dt1)
- snd :: accT(dt2)
- override val t: AccType = accT(dt1 x dt2)
+final case class PairAcc(val dt1: DataType, val dt2: DataType, val fst: Phrase[AccType], val snd: Phrase[AccType]) extends AccPrimitive {
+ {
+ fst :: accT(dt1)
+ snd :: accT(dt2)
+ }
+ override val t: AccType = accT(PairType(dt1, dt2))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): PairAcc = new PairAcc(v.data(dt1), v.data(dt2), VisitAndRebuild(fst, v), VisitAndRebuild(snd, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc1.scala b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc1.scala
index af0e85003..924410406 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc1.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc1.scala
@@ -1,15 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class PairAcc1(dt1: DataType,
- dt2: DataType,
- pair: Phrase[AccType]
- ) extends AccPrimitive {
- pair :: accT(dt1 x dt2)
+final case class PairAcc1(val dt1: DataType, val dt2: DataType, val pair: Phrase[AccType]) extends AccPrimitive {
+ {
+ pair :: accT(PairType(dt1, dt2))
+ }
override val t: AccType = accT(dt1)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): PairAcc1 = new PairAcc1(v.data(dt1), v.data(dt2), VisitAndRebuild(pair, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc2.scala b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc2.scala
index dd8370269..f49ccd845 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc2.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc2.scala
@@ -1,15 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class PairAcc2(dt1: DataType,
- dt2: DataType,
- pair: Phrase[AccType]
- ) extends AccPrimitive {
- pair :: accT(dt1 x dt2)
+final case class PairAcc2(val dt1: DataType, val dt2: DataType, val pair: Phrase[AccType]) extends AccPrimitive {
+ {
+ pair :: accT(PairType(dt1, dt2))
+ }
override val t: AccType = accT(dt2)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): PairAcc2 = new PairAcc2(v.data(dt1), v.data(dt2), VisitAndRebuild(pair, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ReorderAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/ReorderAcc.scala
index a30bd8efc..7820bfcf0 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/ReorderAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/ReorderAcc.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class ReorderAcc(n: Nat,
- dt: DataType,
- idxF: NatToNat,
- array: Phrase[AccType]
- ) extends AccPrimitive {
- array :: accT(n`.`dt)
- override val t: AccType = accT(n`.`dt)
+final case class ReorderAcc(val n: Nat, val dt: DataType, val idxF: NatToNat, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(ArrayType(n, dt))
+ }
+ override val t: AccType = accT(ArrayType(n, dt))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ReorderAcc = new ReorderAcc(v.nat(n), v.data(dt), v.natToNat(idxF), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ScatterAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/ScatterAcc.scala
index c86146cc3..483b94e19 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/ScatterAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/ScatterAcc.scala
@@ -1,16 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class ScatterAcc(n: Nat, m: Nat, dt: DataType,
- indices: Phrase[ExpType],
- array: Phrase[AccType]) extends AccPrimitive {
- indices :: expT(n`.`idx(m), read)
- array :: accT(n`.`dt)
- override val t: AccType = accT(m`.`dt)
+final case class ScatterAcc(val n: Nat, val m: Nat, val dt: DataType, val indices: Phrase[ExpType], val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ indices :: expT(ArrayType(n, IndexType(m)), read)
+ array :: accT(ArrayType(m, dt))
+ }
+ override val t: AccType = accT(ArrayType(n, dt))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ScatterAcc = new ScatterAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(indices, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/Seq.scala b/src/main/scala/shine/DPIA/primitives/imperative/Seq.scala
index 579940f89..15500208b 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/Seq.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/Seq.scala
@@ -1,13 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class Seq(c1: Phrase[CommType],
- c2: Phrase[CommType])
- extends CommandPrimitive {
- c1 :: comm
- c2 :: comm
+import shine.DPIA._
+final case class Seq(val c1: Phrase[CommType], val c2: Phrase[CommType]) extends CommandPrimitive {
+ {
+ c1 :: comm
+ c2 :: comm
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Seq = new Seq(VisitAndRebuild(c1, v), VisitAndRebuild(c2, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/Skip.scala b/src/main/scala/shine/DPIA/primitives/imperative/Skip.scala
index 569d140b4..1d8ac8dec 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/Skip.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/Skip.scala
@@ -1,14 +1,14 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-
-// not final (and not using the macro) because of DSL.typed.skip
-case class Skip() extends CommandPrimitive {
-
+import shine.DPIA._
+final case class Skip() extends CommandPrimitive {
+ {}
override val t: CommType = comm
-
- override def visitAndRebuild(f: VisitAndRebuild.Visitor): Phrase[CommType] = this
-
- override def prettyPrint: String = "skip"
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Skip = new Skip()
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/SplitAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/SplitAcc.scala
index 07b2b0b01..4ad7b97c6 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/SplitAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/SplitAcc.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class SplitAcc(n: Nat,
- m: Nat,
- dt: DataType,
- array: Phrase[AccType]
- ) extends AccPrimitive {
- array :: accT(m`.`(n`.`dt))
- override val t: AccType = accT({n * m}`.`dt)
+final case class SplitAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(ArrayType(m, ArrayType(n, dt)))
+ }
+ override val t: AccType = accT(ArrayType(n * m, dt))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): SplitAcc = new SplitAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/TakeAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/TakeAcc.scala
index ccea96f58..923cb6ded 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/TakeAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/TakeAcc.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class TakeAcc(n: Nat,
- m: Nat,
- dt: DataType,
- array: Phrase[AccType]
- ) extends AccPrimitive {
- array :: accT({n + m}`.`dt)
- override val t: AccType = accT(n`.`dt)
+final case class TakeAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(ArrayType(n + m, dt))
+ }
+ override val t: AccType = accT(ArrayType(n, dt))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): TakeAcc = new TakeAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/TransposeAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/TransposeAcc.scala
index 31430d627..3944e3d8a 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/TransposeAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/TransposeAcc.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class TransposeAcc(n: Nat,
- m: Nat,
- dt: DataType,
- array: Phrase[AccType]
- ) extends AccPrimitive {
- array :: accT(m`.`(n`.`dt))
- override val t: AccType = accT(n`.`(m`.`dt))
+final case class TransposeAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(ArrayType(m, ArrayType(n, dt)))
+ }
+ override val t: AccType = accT(ArrayType(n, ArrayType(m, dt)))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): TransposeAcc = new TransposeAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/UnzipAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/UnzipAcc.scala
index 8e5dc8fb2..9e0069f74 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/UnzipAcc.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/UnzipAcc.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class UnzipAcc(n: Nat,
- dt1: DataType,
- dt2: DataType,
- a: Phrase[AccType]
- ) extends AccPrimitive {
- a :: accT((n`.`dt1) x (n`.`dt2))
- override val t: AccType = accT(n`.`(dt1 x dt2))
+final case class UnzipAcc(val n: Nat, val dt1: DataType, val dt2: DataType, val a: Phrase[AccType]) extends AccPrimitive {
+ {
+ a :: accT(PairType(ArrayType(n, dt1), ArrayType(n, dt2)))
+ }
+ override val t: AccType = accT(ArrayType(n, PairType(dt1, dt2)))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): UnzipAcc = new UnzipAcc(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(a, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc1.scala b/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc1.scala
index 3e4821950..28343629d 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc1.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc1.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class ZipAcc1(n: Nat,
- dt1: DataType,
- dt2: DataType,
- array: Phrase[AccType]
- ) extends AccPrimitive {
- array :: accT(n`.`(dt1 x dt2))
- override val t: AccType = accT(n`.`dt1)
+final case class ZipAcc1(val n: Nat, val dt1: DataType, val dt2: DataType, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(ArrayType(n, PairType(dt1, dt2)))
+ }
+ override val t: AccType = accT(ArrayType(n, dt1))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ZipAcc1 = new ZipAcc1(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc2.scala b/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc2.scala
index b18d7e34e..f84abc66e 100644
--- a/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc2.scala
+++ b/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc2.scala
@@ -1,17 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.DPIA.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class ZipAcc2(n: Nat,
- dt1: DataType,
- dt2: DataType,
- array: Phrase[AccType]
- ) extends AccPrimitive {
- array :: accT(n`.`(dt1 x dt2))
- override val t: AccType = accT(n`.`dt2)
+final case class ZipAcc2(val n: Nat, val dt1: DataType, val dt2: DataType, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(ArrayType(n, PairType(dt1, dt2)))
+ }
+ override val t: AccType = accT(ArrayType(n, dt2))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ZipAcc2 = new ZipAcc2(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/DPIA/primitives/imperative/primitives.dpia b/src/main/scala/shine/DPIA/primitives/imperative/primitives.dpia
new file mode 100644
index 000000000..9069428fb
--- /dev/null
+++ b/src/main/scala/shine/DPIA/primitives/imperative/primitives.dpia
@@ -0,0 +1,69 @@
+def asScalarAcc(n: nat, m: nat, dt: data, array: acc[(m*n).dt]): acc[m.vec[dt, n]]
+
+def assign(dt: data, lhs: acc[dt], rhs: exp[dt, read]): comm
+
+def asVectorAcc(n: nat, m: nat, dt: data, array: acc[m.vec[dt, n]]): acc[(n*m).dt]
+
+def comment{comment: String}(): comm
+
+def cycleAcc(n: nat, m: nat, dt: data, input: acc[m.dt]): acc[n.dt]
+
+def depIdxAcc(n: nat, ft: nat2data, index: nat, array: acc[n..ft]): acc[ft(index)]
+
+def depJoinAcc(n: nat, lenF: nat2nat, dt: data, array: acc[(sum_(i=0)^(n-1) lenF(i)).dt]): acc[n..(i: nat |-> lenF(i).dt)]
+
+// def dMatchI(...)
+
+// this drops n many elements from an array of m elements
+def dropAcc(n: nat, m: nat, dt: data, array: acc[(n+m).dt]): acc[(m-n).dt]
+
+def for{unroll: Boolean}(n: nat, loopBody: exp[idx[n], read] -> comm): comm
+def forNat{unroll: Boolean}(n: nat, loopBody: (i: nat) -> comm): comm
+def forVec(n: nat, dt: data, out: acc[vec[dt, n]], loopBody: exp[idx[n], read] -> acc[dt] -> comm): comm
+
+// note: would not be necessary if generate was defined as indices + map
+def generateCont(n: nat, dt: data, f: exp[idx[n], read] -> (exp[dt, read] -> comm) -> comm): exp[n.dt, read]
+
+def idxAcc(n: nat, dt: data, index: exp[idx[n], read], array: acc[n.dt]): acc[dt]
+def idxVecAcc(n: nat, dt: data, index: exp[idx[n], read], vector: acc[vec[dt, n]]): acc[dt]
+
+def joinAcc(n: nat, m: nat, dt: data, array: acc[(n*m).dt]): acc[n.m.dt]
+
+def mapAcc(n: nat, dt1: data, dt2: data, f: acc[dt1] -> acc[dt2], array: acc[n.dt1]): acc[n.dt2]
+
+def mapFstAcc(dt1: data, dt2: data, dt3: data, f: acc[dt3] -> acc[dt1], record: acc[(dt3, dt2)]): acc[(dt1, dt2)]
+def mapRead(n: nat, dt1: data, dt2: data,
+ f: exp[dt1, read] -> (exp[dt2, read] -> comm) -> comm, input: exp[n.dt1, read]): exp[n.dt2, read]
+def mapSndAcc(dt1: data, dt2: data, dt3: data, f: acc[dt3] -> acc[dt2], record: acc[(dt1, dt3)]): acc[(dt1, dt2)]
+
+// def mkDPairFstI(...)
+// def mkDPairSndAcc(...)
+
+def new(dt: data, f: var[dt] -> comm): comm
+
+def newDoubleBuffer(dt1: data, dt2: data, dt3: data, n: nat,
+ in: exp[dt1, read], out: acc[dt2],
+ f: ((var[n.dt3], comm), comm) -> comm): comm
+
+def pairAcc(dt1: data, dt2: data, fst: acc[dt1], snd: acc[dt2]): acc[(dt1, dt2)]
+def pairAcc1(dt1: data, dt2: data, pair: acc[(dt1, dt2)]): acc[dt1]
+def pairAcc2(dt1: data, dt2: data, pair: acc[(dt1, dt2)]): acc[dt2]
+
+def reorderAcc(n: nat, dt: data, idxF: nat2nat, array: acc[n.dt]): acc[n.dt]
+
+def scatterAcc(n: nat, m: nat, dt: data, indices: exp[n.idx[m], read], array: acc[m.dt]): acc[n.dt]
+
+def seq(c1: comm, c2: comm): comm
+
+def skip(): comm
+
+def splitAcc(n: nat, m: nat, dt: data, array: acc[m.n.dt]): acc[(n*m).dt]
+
+def takeAcc(n: nat, m: nat, dt: data, array: acc[(n+m).dt]): acc[n.dt]
+
+def transposeAcc(n: nat, m: nat, dt: data, array: acc[m.n.dt]): acc[n.m.dt]
+
+def unzipAcc(n: nat, dt1: data, dt2: data, a: acc[(n.dt1, n.dt2)]): acc[n.(dt1, dt2)]
+
+def zipAcc1(n: nat, dt1: data, dt2: data, array: acc[n.(dt1, dt2)]): acc[n.dt1]
+def zipAcc2(n: nat, dt1: data, dt2: data, array: acc[n.(dt1, dt2)]): acc[n.dt2]
diff --git a/src/main/scala/shine/DPIA/primitives/intermediate/MapVecI.scala b/src/main/scala/shine/DPIA/primitives/intermediate/MapVecI.scala
index 323afad08..25db6c81d 100644
--- a/src/main/scala/shine/DPIA/primitives/intermediate/MapVecI.scala
+++ b/src/main/scala/shine/DPIA/primitives/intermediate/MapVecI.scala
@@ -7,7 +7,7 @@ import shine.DPIA._
import shine.OpenMP.DSL.parForVec
object MapVecI {
- def apply(n: Nat, st1: ScalarType, st2: ScalarType,
+ def apply(n: Nat, st1: DataType, st2: DataType,
f: Phrase[ExpType ->: AccType ->: CommType],
in: Phrase[ExpType],
out: Phrase[AccType]): Phrase[CommType] =
diff --git a/src/main/scala/shine/OpenCL/AdjustArraySizesForAllocations.scala b/src/main/scala/shine/OpenCL/AdjustArraySizesForAllocations.scala
index 55efcd53d..423d1b0f2 100644
--- a/src/main/scala/shine/OpenCL/AdjustArraySizesForAllocations.scala
+++ b/src/main/scala/shine/OpenCL/AdjustArraySizesForAllocations.scala
@@ -8,7 +8,7 @@ import shine.DPIA._
import shine.DPIA.primitives.functional._
import shine.OpenCL.{primitives => ocl}
import shine.cuda.{primitives => cuda}
-import shine.cuda.primitives.functional.{AsMatrix, GenerateFragment, MapFragmentElements, TensorMatMultAdd, AsFragment}
+import shine.cuda.primitives.functional.{AsMatrix, GenerateFragment, MapFragment, TensorMatMultAdd, AsFragment}
import shine.cuda.warpDim
object AdjustArraySizesForAllocations {
@@ -85,7 +85,7 @@ object AdjustArraySizesForAllocations {
case _: Identifier[_] | _: Literal | _: Natural |
_: VectorFromScalar | _: Cast | _: ForeignFunctionCall |
_: BinOp | _: UnaryOp | _: GenerateFragment |
- _: AsMatrix | _: AsFragment | _: MapFragmentElements |
+ _: AsMatrix | _: AsFragment | _: MapFragment |
_: TensorMatMultAdd => parallInfo
//TODO visit value first?
@@ -113,7 +113,7 @@ object AdjustArraySizesForAllocations {
}
val stride = determineStride(parallLevel, dim, addrSpace)
- val outerDimension = ocl.imperative.IdxDistributeAcc(adjSize, oldSize, stride, parallLevel, adjElemT, A)
+ val outerDimension = ocl.imperative.IdxDistributeAcc(parallLevel)(adjSize, oldSize, stride, adjElemT, A)
val arr = identifier(freshName("x"), accT(adjElemT))
val mapFunBody = adjustedAcceptor(parallInfo.tail, adjElemT, oldElemT, addrSpace)(arr)
@@ -149,7 +149,7 @@ object AdjustArraySizesForAllocations {
}
val stride = determineStride(parallLevel, dim, addrSpace)
- val outerDimension = ocl.imperative.IdxDistribute(adjSize, oldSize, stride, parallLevel, adjElemT, E)
+ val outerDimension = ocl.imperative.IdxDistribute(parallLevel)(adjSize, oldSize, stride, adjElemT, E)
val arr = identifier(freshName("arr"), expT(adjElemT, read))
val mapFunBody = adjustedExpr(parallInfo.tail, adjElemT, oldElemT, addrSpace)(arr)
diff --git a/src/main/scala/shine/OpenCL/Compilation/HostCodeGenerator.scala b/src/main/scala/shine/OpenCL/Compilation/HostCodeGenerator.scala
index 23f5d5767..a2bb26544 100644
--- a/src/main/scala/shine/OpenCL/Compilation/HostCodeGenerator.scala
+++ b/src/main/scala/shine/OpenCL/Compilation/HostCodeGenerator.scala
@@ -25,12 +25,13 @@ case class HostCodeGenerator(override val decls: C.Compilation.CodeGenerator.Dec
override def name: String = "OpenCL Host"
override def cmd(env: Environment): Phrase[CommType] => Stmt = {
- case KernelCallCmd(name, LocalSize(ls), GlobalSize(gs), output, args) =>
- kernelCallCmd(name, ls, gs, output, args, env)
- case NewManagedBuffer(dt, access, Lambda(v, p)) =>
+ case k@KernelCallCmd(name, LocalSize(ls), GlobalSize(gs), args) =>
+ kernelCallCmd(name, ls, gs, k.output, args, env)
+ case n@NewManagedBuffer(access) =>
+ val (dt, Lambda(v, p)) = n.unwrap
newManagedBuffer(dt, access, v, p, env)
- case HostExecution(params, body) =>
- hostExecution(params, body, env)
+ case h@HostExecution(params) =>
+ hostExecution(params, h.body, env)
case phrase => phrase |> super.cmd(env)
}
@@ -202,7 +203,7 @@ case class HostCodeGenerator(override val decls: C.Compilation.CodeGenerator.Dec
C.AST.BinaryExpr(C.AST.ArithmeticExpr(a.size), BinaryOperator.*, bufferSize(a.elemType))
case a: DepArrayType => ??? // TODO
case _: DepPairType | _: NatToDataApply | _: DataTypeIdentifier | _: OpaqueType |
- _: shine.DPIA.Types.FragmentType | _: pipeline.type =>
+ _: shine.DPIA.Types.FragmentType =>
throw new Exception(s"did not expect ${dt}")
}
diff --git a/src/main/scala/shine/OpenCL/Compilation/HostManagedBuffers.scala b/src/main/scala/shine/OpenCL/Compilation/HostManagedBuffers.scala
index c9da200e2..51a2f8e8f 100644
--- a/src/main/scala/shine/OpenCL/Compilation/HostManagedBuffers.scala
+++ b/src/main/scala/shine/OpenCL/Compilation/HostManagedBuffers.scala
@@ -72,7 +72,7 @@ object HostManagedBuffers {
i -> access
}).toMap
env.foreach { case (i, a) => recordManagedAccess(managed, i, a) }
- (HostExecution(env, p2), ReadsAndWrites.empty)
+ (HostExecution(env)(p2), ReadsAndWrites.empty)
}
}
@@ -97,10 +97,10 @@ object HostManagedBuffers {
collectWrites(lhs, metadata.host_writes)
collectReads(rhs, allocs, metadata.host_reads)
Stop(p)
- case ocl.KernelCallCmd(_, _, _, out, in) =>
+ case k@ocl.KernelCallCmd(_, _, _, in) =>
in.foreach(collectReads(_, allocs, metadata.device_reads))
- collectWrites(out, metadata.device_writes)
- ((out, DEVICE_WRITE) +: in.map(_ -> DEVICE_READ)).foreach {
+ collectWrites(k.output, metadata.device_writes)
+ ((k.output, DEVICE_WRITE) +: in.map(_ -> DEVICE_READ)).foreach {
case (i: Identifier[_], a) => recordManagedAccess(managed, i, a)
case (Proj1(i: Identifier[_]), a) => recordManagedAccess(managed, i, a)
case (Proj2(i: Identifier[_]), a) => recordManagedAccess(managed, i, a)
@@ -152,11 +152,14 @@ object HostManagedBuffers {
case dpia.New(dt, Lambda(x, body)) if managed.contains(x) =>
val access = managed(x)._1
val x2 = managed(x)._2.asInstanceOf[Identifier[VarType]]
- Continue(ocl.NewManagedBuffer(dt, access, Lambda(x2, body)), this)
+ Continue(ocl.NewManagedBuffer(access)(dt, Lambda(x2, body)), this)
case _: dpia.New | _: Lambda[_, _] | _: dpia.Seq |
_: Proj2[_, _] | _: Proj1[_, _] | Natural(_) =>
Continue(p, this)
- case _: ocl.KernelCallCmd => Continue(p, this)
+ case k@ocl.KernelCallCmd(name, ls, gs, args) =>
+ val newOutput = VisitAndRebuild(k.output, this)
+ Stop(ocl.KernelCallCmd(name, ls, gs, args.map(VisitAndRebuild(_, this)))(
+ newOutput.t.dataType, newOutput))
case _: HostExecution => Stop(p)
case unexpected => throw new Exception(s"did not expect $unexpected")
}
@@ -180,7 +183,7 @@ object HostManagedBuffers {
case JoinAcc(_, _, _, a) => collectWrites(a, writes)
case SplitAcc(_, _, _, a) => collectWrites(a, writes)
case AsScalarAcc(_, _, _, a) => collectWrites(a, writes)
- case ocl.IdxDistributeAcc(_, _, _, _, _, a) => collectWrites(a, writes)
+ case idx:ocl.IdxDistributeAcc => collectWrites(idx.array, writes)
case PairAcc1(_, _, a) => collectWrites(a, writes)
case PairAcc2(_, _, a) => collectWrites(a, writes)
case TakeAcc(_, _, _, a) => collectWrites(a, writes)
@@ -212,7 +215,7 @@ object HostManagedBuffers {
collectReads(e1, allocs, reads); collectReads(e2, allocs, reads)
case Slide(_, _, _, _, e) => collectReads(e, allocs, reads)
case Map(_, _, _, _, _, e) => collectReads(e, allocs, reads)
- case ocl.IdxDistribute(_, _, _, _, _, e) => collectReads(e, allocs, reads)
+ case idx: ocl.IdxDistribute => collectReads(idx.array, allocs, reads)
case dpia.MapRead(_, _, _, _, e) => collectReads(e, allocs, reads)
case dpia.GenerateCont(_, _, _) => giveUp()
case AsScalar(_, _, _, _, e) => collectReads(e, allocs, reads)
@@ -226,12 +229,12 @@ object HostManagedBuffers {
case Split(_, _, _, _, e) => collectReads(e, allocs, reads)
case Zip(_, _, _, _, e1, e2) =>
collectReads(e1, allocs, reads); collectReads(e2, allocs, reads)
- case Pad(_, _, _, _, e1, e2) =>
+ case PadCst(_, _, _, _, e1, e2) =>
collectReads(e1, allocs, reads); collectReads(e2, allocs, reads)
case PadClamp(_, _, _, _, e) =>
collectReads(e, allocs, reads)
case Cast(_, _, e) => collectReads(e, allocs, reads)
- case ForeignFunctionCall(_, _, _, es) =>
+ case ForeignFunctionCall(_, _, es) =>
es.foreach {
collectReads(_, allocs, reads)
}
@@ -242,7 +245,7 @@ object HostManagedBuffers {
case MakePair(_, _, _, e1, e2) =>
collectReads(e1, allocs, reads); collectReads(e2, allocs, reads)
case Reorder(_, _, _, _, _, e) => collectReads(e, allocs, reads)
- case MakeArray(_, es) =>
+ case MakeArray(es) =>
es.foreach {
collectReads(_, allocs, reads)
}
diff --git a/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala b/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala
index c2dffa049..af5ede7dd 100644
--- a/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala
+++ b/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala
@@ -38,12 +38,18 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
override def cmd(env: Environment): Phrase[CommType] => Stmt = {
case f: ocl.ParFor =>
- val (i, o, p) = f.unwrapBody
- OpenCLCodeGen.codeGenOpenCLParFor(f, f.n, f.dt, f.out, i, o, p, env)
+ f.body match {
+ case Lambda(i, Lambda(o, p)) =>
+ OpenCLCodeGen.codeGenOpenCLParFor(f, f.n, f.dt, f.out, i, o, p, env)
+ case _ => throw new Exception("This should not happen")
+ }
case f: ocl.ParForNat =>
- val (i, o, p) = f.unwrapBody
- OpenCLCodeGen.codeGenOpenCLParForNat(f, f.n, f.out, i, o, p, env)
+ f.body match {
+ case DepLambda(i: NatIdentifier, Lambda(o, p)) =>
+ OpenCLCodeGen.codeGenOpenCLParForNat(f, f.n, f.out, i, o, p, env)
+ case _ => throw new Exception("This should not happen")
+ }
case phrase@Assign(dt, a, e) => dt match {
case VectorType(_, _) =>
@@ -84,20 +90,20 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
override def acc(env: Environment,
path: Path,
cont: Expr => Stmt): Phrase[AccType] => Stmt = {
- case AsVectorAcc(n, _, _, a) => path match {
- case (i: CIntExpr) :: ps => a |> acc(env, CIntExpr(i / n) :: ps, cont)
+ case AsVectorAcc(_, m, _, a) => path match {
+ case (i: CIntExpr) :: ps => a |> acc(env, CIntExpr(i / m) :: ps, cont)
case _ => error(s"Expected path to be not empty")
}
- case AsScalarAcc(_, m, dt, a) => path match {
+ case AsScalarAcc(n, _, dt, a) => path match {
case (i: CIntExpr) :: (j: CIntExpr) :: ps =>
- a |> acc(env, CIntExpr((i * m) + j) :: ps, cont)
+ a |> acc(env, CIntExpr((i * n) + j) :: ps, cont)
case (i: CIntExpr) :: Nil =>
// TODO: check alignment and use pointer with correct address space
- a |> acc(env, CIntExpr(i * m) :: Nil, array => {
+ a |> acc(env, CIntExpr(i * n) :: Nil, array => {
// the continuation has to add the value ...
val ptr = C.AST.UnaryExpr(C.AST.UnaryOperator.&, array)
- cont(C.AST.FunCall(C.AST.DeclRef(s"vstore$m"),
+ cont(C.AST.FunCall(C.AST.DeclRef(s"vstore$n"),
immutable.Seq(C.AST.Literal("0"), ptr)))
})
case _ => error(s"Expected path to be not empty")
@@ -105,10 +111,10 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
case IdxVecAcc(_, _, i, a) =>
CCodeGen.codeGenIdxAcc(i, a, env, path, cont)
- case ocl.IdxDistributeAcc(_, _, stride, _, _, a) => path match {
+ case idx: ocl.IdxDistributeAcc => path match {
// TODO: ensure that i % stride == init ?
case (i: CIntExpr) :: ps =>
- a |> acc(env, CIntExpr(i / stride) :: ps, cont)
+ idx.array |> acc(env, CIntExpr(i / idx.stride) :: ps, cont)
case _ => error(s"Expected a C-Integer-Expression on the path.")
}
@@ -161,9 +167,9 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
})
case _ => error(s"unexpected $path")
}
- case AsScalar(_, m, _, _, e) => path match {
+ case AsScalar(n, _, _, _, e) => path match {
case (i: CIntExpr) :: ps =>
- e |> exp(env, CIntExpr(i / m) :: CIntExpr(i % m) :: ps, cont)
+ e |> exp(env, CIntExpr(i / n) :: CIntExpr(i % n) :: ps, cont)
case _ => error(s"Expected path to be not empty")
}
// TODO: this has to be refactored
@@ -185,13 +191,13 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
case IdxVec(_, _, i, e) => CCodeGen.codeGenIdx(i, e, env, path, cont)
- case OpenCLFunctionCall(name, _, _, args) =>
+ case OpenCLFunctionCall(name, _, args) =>
CCodeGen.codeGenForeignCall(name, args, env, Nil, cont)
- case ocl.IdxDistribute(_, _, stride, _, _, e) => path match {
+ case idx: ocl.IdxDistribute => path match {
// TODO: ensure that i % stride == init ?
case (i: CIntExpr) :: ps
- => e |> exp(env, CIntExpr(i / stride) :: ps, cont)
+ => idx.array |> exp(env, CIntExpr(i / idx.stride) :: ps, cont)
case _ => error(s"Expected a C-Integer-Expression on the path.")
}
@@ -307,7 +313,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
p: Phrase[CommType],
env: Environment): Stmt = {
assert(!f.unroll)
- val cI = C.AST.DeclRef(f.name)
+ val cI = C.AST.DeclRef(freshName(f.prefix))
val range = RangeAdd(f.init, n, f.step)
val updatedGen = updatedRanges(cI.name, range)
@@ -363,7 +369,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
p: Phrase[CommType],
env: Environment): Stmt = {
assert(!f.unroll)
- val cI = C.AST.DeclRef(f.name)
+ val cI = C.AST.DeclRef(freshName(f.prefix))
val range = RangeAdd(f.init, n, f.step)
val updatedGen = updatedRanges(cI.name, range)
@@ -436,7 +442,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
AddressSpace.Private, Some(expr)))
}
- def codeGenVectorLiteral(n: Int, dt: ScalarType,
+ def codeGenVectorLiteral(n: Int, dt: DataType,
f: Int => Phrase[ExpType],
env: Environment,
cont: Expr => Stmt,
diff --git a/src/main/scala/shine/OpenCL/Compilation/Passes/FlagPrivateArrayLoops.scala b/src/main/scala/shine/OpenCL/Compilation/Passes/FlagPrivateArrayLoops.scala
index 808339789..ca31433ac 100644
--- a/src/main/scala/shine/OpenCL/Compilation/Passes/FlagPrivateArrayLoops.scala
+++ b/src/main/scala/shine/OpenCL/Compilation/Passes/FlagPrivateArrayLoops.scala
@@ -4,7 +4,7 @@ import shine.DPIA.Phrases._
import shine.DPIA.Types.{CommType, PhraseType}
import shine.DPIA.primitives.functional.{Idx, NatAsIndex}
import shine.DPIA.primitives.imperative.{For, ForNat, IdxAcc}
-import shine.DPIA.{ArrayData, Nat}
+import shine.DPIA.{ArrayData, Nat, NatIdentifier}
import shine.OpenCL.AddressSpace
import shine.OpenCL.primitives.imperative.{New, ParFor, ParForNat}
@@ -44,9 +44,12 @@ object FlagPrivateArrayLoops {
eliminateVars ++= indexingIdents
Stop(p)
case pf: ParFor if collectIdents(pf.out).exists(privMemIdents(_)) =>
- val (i, o, _) = pf.unwrapBody
- eliminateVars += i.name
- Continue(p, this.copy(privMemIdents = privMemIdents + o))
+ pf.body match {
+ case Lambda(i, Lambda(o, _)) =>
+ eliminateVars += i.name
+ Continue(p, this.copy(privMemIdents = privMemIdents + o))
+ case _ => throw new Exception("This should not happen")
+ }
case _ =>
Continue(p, this)
}
@@ -60,24 +63,30 @@ object FlagPrivateArrayLoops {
eliminateVars: mutable.Set[String]): Phrase[CommType] = {
VisitAndRebuild(p, new VisitAndRebuild.Visitor {
override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = p match {
- case f@For(_) if (eliminateVars(f.unwrapBody._1.name)) =>
- val (i, _) = f.unwrapBody
+ case f@For(_) if (eliminateVars(f.loopBody.asInstanceOf[Lambda[_, _]].param.name)) =>
+ val i = f.loopBody.asInstanceOf[Lambda[_, _]].param
eliminateVars -= i.name
Continue(For(unroll = true)(f.n, f.loopBody), this)
- case f@ForNat(_) if (eliminateVars(f.unwrapBody._1.name)) =>
- val (i, _) = f.unwrapBody
+ case f@ForNat(_) if (eliminateVars(f.loopBody.asInstanceOf[DepLambda[_, _]].x.name)) =>
+ val i = f.loopBody.asInstanceOf[DepLambda[_, _]].x
eliminateVars -= i.name
Continue(ForNat(unroll = true)(f.n, f.loopBody), this)
- case pf@ParFor(level, dim, _) if (eliminateVars(pf.unwrapBody._1.name)) =>
- val (i, _, _) = pf.unwrapBody
- eliminateVars -= i.name
- Continue(ParFor(level, dim, unroll = true)
- (pf.n, pf.dt, pf.out, pf.loopBody, pf.init, pf.step), this)
- case pf@ParForNat(level, dim, _) if (eliminateVars(pf.unwrapBody._1.name)) =>
- val (i, _, _) = pf.unwrapBody
- eliminateVars -= i.name
- Continue(ParForNat(level, dim, unroll = true)
- (pf.n, pf.ft, pf.out, pf.loopBody, pf.init, pf.step), this)
+ case pf@ParFor(level, dim, _, name) if (eliminateVars(pf.body.asInstanceOf[Lambda[_, _]].param.name)) =>
+ pf.body match {
+ case Lambda(i, _) =>
+ eliminateVars -= i.name
+ Continue(ParFor(level, dim, unroll = true, name)(
+ pf.init, pf.n, pf.step, pf.dt, pf.out, pf.body), this)
+ case _ => throw new Exception("This should not happen")
+ }
+ case pf@ParForNat(level, dim, _, name) if (eliminateVars(pf.body.asInstanceOf[DepLambda[_, _]].x.name)) =>
+ pf.body match {
+ case DepLambda(i: NatIdentifier, _) =>
+ eliminateVars -= i.name
+ Continue(ParForNat(level, dim, unroll = true, name)(
+ pf.init, pf.n, pf.step, pf.ft, pf.out, pf.body), this)
+ case _ => throw new Exception("This should not happen")
+ }
case _ =>
Continue(p, this)
}
diff --git a/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala b/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala
index 4ef760f11..e772d1d50 100644
--- a/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala
+++ b/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala
@@ -1,6 +1,6 @@
package shine.OpenCL.Compilation.Passes
-import shine.DPIA.->:
+import shine.DPIA.{->:, NatIdentifier}
import shine.DPIA.Phrases._
import shine.DPIA.Types._
import shine.DPIA.primitives.functional
@@ -61,35 +61,53 @@ object InsertMemoryBarriers {
override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = {
p match {
case f@For(unroll) =>
- val (x, body) = f.unwrapBody
- Stop(For(unroll)(f.n, Lambda(x, visitLoopBody(body, allocs, metadata))))
+ f.loopBody match {
+ case Lambda(x, body) =>
+ Stop(For(unroll)(f.n, Lambda(x, visitLoopBody(body, allocs, metadata))))
+ case _ => throw new Exception("This should not happen")
+ }
case f@ForNat(unroll) =>
- val (x, body) = f.unwrapBody
- Stop(ForNat(unroll)(f.n,
- DepLambda[NatKind, CommType](x, visitLoopBody(body, allocs, metadata))))
- case pf@ocl.ParFor(Local, dim, unroll) =>
- val (x, o, body) = pf.unwrapBody
- val outer_wg_writes = mutable.Map[Identifier[_ <: PhraseType], AddressSpace]()
- collectWrites(pf.out, allocs, outer_wg_writes)
- Stop(ocl.ParFor(Local, dim, unroll)(pf.n, pf.dt, pf.out,
- Lambda(x, Lambda(o,
- visitLoopBody(body, allocs, metadata, outer_wg_writes))), pf.init, pf.step))
- case pf@ocl.ParFor(level, dim, unroll) =>
- val (x, o, body) = pf.unwrapBody
- Stop(ocl.ParFor(level, dim, unroll)(pf.n, pf.dt, pf.out,
- Lambda(x, Lambda(o, visitLoopBody(body, allocs, metadata))), pf.init, pf.step))
- case pf@ocl.ParForNat(Local, dim, unroll) =>
- val (x, o, body) = pf.unwrapBody
- val outer_wg_writes = mutable.Map[Identifier[_ <: PhraseType], AddressSpace]()
- collectWrites(pf.out, allocs, outer_wg_writes)
- Stop(ocl.ParForNat(Local, dim, unroll)(pf.n, pf.ft, pf.out,
- DepLambda[NatKind, AccType ->: CommType](x, Lambda(o,
- visitLoopBody(body, allocs, metadata, outer_wg_writes))), pf.init, pf.step))
- case pf@ocl.ParForNat(level, dim, unroll) =>
- val (x, o, body) = pf.unwrapBody
- Stop(ocl.ParForNat(level, dim, unroll)(pf.n, pf.ft, pf.out,
- DepLambda[NatKind, AccType ->: CommType](x, Lambda(o,
- visitLoopBody(body, allocs, metadata))), pf.init, pf.step))
+ f.loopBody match {
+ case DepLambda(x, body) =>
+ Stop(ForNat(unroll)(f.n,
+ DepLambda[NatKind, CommType](x, visitLoopBody(body, allocs, metadata))))
+ case _ => throw new Exception("This should not happen")
+ }
+ case pf@ocl.ParFor(Local, dim, unroll, name) =>
+ pf.body match {
+ case Lambda(x, Lambda(o, body)) =>
+ val outer_wg_writes = mutable.Map[Identifier[_ <: PhraseType], AddressSpace]()
+ collectWrites(pf.out, allocs, outer_wg_writes)
+ Stop(ocl.ParFor(Local, dim, unroll, name)(pf.init, pf.n, pf.step, pf.dt, pf.out,
+ Lambda(x, Lambda(o,
+ visitLoopBody(body, allocs, metadata, outer_wg_writes)))))
+ case _ => throw new Exception("This should not happen")
+ }
+ case pf@ocl.ParFor(level, dim, unroll, name) =>
+ pf.body match {
+ case Lambda(x, Lambda(o, body)) =>
+ Stop(ocl.ParFor(level, dim, unroll, name)(pf.init, pf.n, pf.step, pf.dt, pf.out,
+ Lambda(x, Lambda(o, visitLoopBody(body, allocs, metadata)))))
+ case _ => throw new Exception("This should not happen")
+ }
+ case pf@ocl.ParForNat(Local, dim, unroll, name) =>
+ pf.body match {
+ case DepLambda(i: NatIdentifier, Lambda(o, p)) =>
+ val outer_wg_writes = mutable.Map[Identifier[_ <: PhraseType], AddressSpace]()
+ collectWrites(pf.out, allocs, outer_wg_writes)
+ Stop(ocl.ParForNat(Local, dim, unroll, name)(pf.init, pf.n, pf.step, pf.ft, pf.out,
+ DepLambda[NatKind, AccType ->: CommType](i, Lambda(o,
+ visitLoopBody(p, allocs, metadata, outer_wg_writes)))))
+ case _ => throw new Exception("This should not happen")
+ }
+ case pf@ocl.ParForNat(level, dim, unroll, name) =>
+ pf.body match {
+ case DepLambda(i: NatIdentifier, Lambda(o, p)) =>
+ Stop(ocl.ParForNat(level, dim, unroll, name)(pf.init, pf.n, pf.step, pf.ft, pf.out,
+ DepLambda[NatKind, AccType ->: CommType](i, Lambda(o,
+ visitLoopBody(p, allocs, metadata)))))
+ case _ => throw new Exception("This should not happen")
+ }
case ocl.New(addr, _, Lambda(x, _)) if addr != AddressSpace.Private =>
Continue(p, Visitor(allocs + (x -> addr), metadata))
case ocl.NewDoubleBuffer(addr, dt1, dt2, dt3, n, in, out, Lambda(x, body))
@@ -141,7 +159,7 @@ object InsertMemoryBarriers {
case JoinAcc(_, _, _, a) => collectWrites(a, allocs, writes)
case SplitAcc(_, _, _, a) => collectWrites(a, allocs, writes)
case AsScalarAcc(_, _, _, a) => collectWrites(a, allocs, writes)
- case ocl.IdxDistributeAcc(_, _, _, _, _, a) => collectWrites(a, allocs, writes)
+ case idx: ocl.IdxDistributeAcc => collectWrites(idx.array, allocs, writes)
case PairAcc1(_, _, a) => collectWrites(a, allocs, writes)
case PairAcc2(_, _, a) => collectWrites(a, allocs, writes)
case PairAcc(_, _, a, b) =>
@@ -175,7 +193,7 @@ object InsertMemoryBarriers {
collectReads(e1, allocs, reads); collectReads(e2, allocs, reads)
case Slide(_, _, _, _, e) => collectReads(e, allocs, reads)
case functional.Map(_, _, _, _, _, e) => collectReads(e, allocs, reads)
- case ocl.IdxDistribute(_, _, _, _, _, e) => collectReads(e, allocs, reads)
+ case idx: ocl.IdxDistribute => collectReads(idx.array, allocs, reads)
case MapRead(_, _, _, _, e) => collectReads(e, allocs, reads)
case GenerateCont(_, _, _) => giveUp()
case AsScalar(_, _, _, _, e) => collectReads(e, allocs, reads)
@@ -189,12 +207,12 @@ object InsertMemoryBarriers {
case Split(_, _, _, _, e) => collectReads(e, allocs, reads)
case Zip(_, _, _, _, e1, e2) =>
collectReads(e1, allocs, reads); collectReads(e2, allocs, reads)
- case Pad(_, _, _, _, e1, e2) =>
+ case PadCst(_, _, _, _, e1, e2) =>
collectReads(e1, allocs, reads); collectReads(e2, allocs, reads)
case PadClamp(_, _, _, _, e) =>
collectReads(e, allocs, reads)
case Cast(_, _, e) => collectReads(e, allocs, reads)
- case ForeignFunctionCall(_, _, _, es) =>
+ case ForeignFunctionCall(_, _, es) =>
es.foreach {
collectReads(_, allocs, reads)
}
@@ -205,7 +223,7 @@ object InsertMemoryBarriers {
case MakePair(_, _, _, e1, e2) =>
collectReads(e1, allocs, reads); collectReads(e2, allocs, reads)
case Reorder(_, _, _, _, _, e) => collectReads(e, allocs, reads)
- case MakeArray(_, es) =>
+ case MakeArray(es) =>
es.foreach {
collectReads(_, allocs, reads)
}
diff --git a/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala b/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala
index 25c57d303..d6622ff82 100644
--- a/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala
+++ b/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala
@@ -19,16 +19,15 @@ object SeparateHostAndKernelCode {
var kernelDefinitions = mutable.ArrayBuffer[KernelDef]()
val hostDefinition = VisitAndRebuild(p, new VisitAndRebuild.Visitor {
override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = p match {
- case Run(localSize, globalSize, _, value) =>
+ case r@Run(localSize, globalSize) =>
val name = s"k$kernelNum"
kernelNum += 1
- val (closedDef, args) = closeDefinition(value)
+ val (closedDef, args) = closeDefinition(r.input)
val kernelDef = KernelDef(name, closedDef, localSize, globalSize)
kernelDefinitions += kernelDef
Stop(KernelCall(name, localSize, globalSize,
kernelDef.paramTypes.map(_.dataType),
- kernelDef.returnType.dataType,
- args).asInstanceOf[Phrase[T]])
+ args)(kernelDef.returnType.dataType).asInstanceOf[Phrase[T]])
// on the fly beta-reduction
case Apply(fun, arg) =>
diff --git a/src/main/scala/shine/OpenCL/DSL/package.scala b/src/main/scala/shine/OpenCL/DSL/package.scala
index 2196f4af3..4ef029e70 100644
--- a/src/main/scala/shine/OpenCL/DSL/package.scala
+++ b/src/main/scala/shine/OpenCL/DSL/package.scala
@@ -9,32 +9,51 @@ import shine.OpenCL.primitives.imperative._
package object DSL {
+ def parFor(level: ParallelismLevel,
+ dim: Int,
+ unroll: Boolean
+ ): (Nat, DataType, Phrase[AccType], Phrase[FunType[ExpType, FunType[AccType, CommType]]]) => ParFor =
+ level match {
+ case Global => ParFor(level, dim, unroll, "gl_id_")(
+ get_global_id(dim), _, get_global_size(dim), _, _, _)
+ case Local => ParFor(level, dim, unroll, "l_id_")(
+ get_local_id(dim), _, get_local_size(dim), _, _, _)
+ case WorkGroup => ParFor(level, dim, unroll, "wg_id_")(
+ get_group_id(dim), _, get_num_groups(dim), _, _, _)
+ case Sequential | Warp | Lane => throw new Exception("This should not happen")
+ }
+
+ def parForNat(level: ParallelismLevel,
+ dim: Int,
+ unroll: Boolean
+ ): (Nat, NatToData, Phrase[AccType], Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) => ParForNat =
+ level match {
+ case Global => ParForNat(level, dim, unroll, "gl_id_")(
+ get_global_id(dim), _, get_global_size(dim), _, _, _)
+ case Local => ParForNat(level, dim, unroll, "l_id_")(
+ get_local_id(dim), _, get_local_size(dim), _, _, _)
+ case WorkGroup => ParForNat(level, dim, unroll, "wg_id_")(
+ get_group_id(dim), _, get_num_groups(dim), _, _, _)
+ case Sequential | Warp | Lane => throw new Exception("This should not happen")
+ }
+
private def parForBodyFunction(n:Nat, ft:NatToData,
f:NatIdentifier => Phrase[AccType] => Phrase[CommType]
): DepLambda[NatKind, AccType ->: CommType] = {
nFun(idx => λ(accT(ft(idx)))(o => f(idx)(o)), RangeAdd(0, n, 1))
}
- object parForNatGlobal {
- def apply(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType],
- f:NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat = {
- ParForNat(Global, dim, unroll = false)(n, ft, out, parForBodyFunction(n, ft, f))
- }
- }
+ def parForNatGlobal(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType],
+ f: NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat =
+ parForNat(Global, dim, unroll = false)(n, ft, out, parForBodyFunction(n, ft, f))
- object parForNatWorkGroup {
- def apply(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType],
- f:NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat = {
- ParForNat(WorkGroup, dim, unroll = false)(n, ft, out, parForBodyFunction(n, ft, f))
- }
- }
+ def parForNatWorkGroup(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType],
+ f:NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat =
+ parForNat(WorkGroup, dim, unroll = false)(n, ft, out, parForBodyFunction(n, ft, f))
- object parForNatLocal {
- def apply(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType],
- f:NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat = {
- ParForNat(Local, dim, unroll = false)(n, ft, out, parForBodyFunction(n, ft, f))
- }
- }
+ def parForNatLocal(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType],
+ f:NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat =
+ parForNat(Local, dim, unroll = false)(n, ft, out, parForBodyFunction(n, ft, f))
object `new` {
def apply(addrSpace: shine.DPIA.Types.AddressSpace)
@@ -60,6 +79,6 @@ package object DSL {
}
object barrier {
- def apply(local: Boolean = true, global: Boolean = true) = Barrier(local, global)
+ def apply(local: Boolean = true, global: Boolean = true) = Barrier(local, global)()
}
}
diff --git a/src/main/scala/shine/OpenCL/KernelExecutor.scala b/src/main/scala/shine/OpenCL/KernelExecutor.scala
index a00dfeb6f..9f97232b0 100644
--- a/src/main/scala/shine/OpenCL/KernelExecutor.scala
+++ b/src/main/scala/shine/OpenCL/KernelExecutor.scala
@@ -378,8 +378,7 @@ object KernelExecutor {
case DepArrayType(_, NatToDataLambda(_, elemType)) =>
getOutputType(elemType)
case DepArrayType(_, _) | _: NatToDataApply => throw new Exception("This should not happen")
- case _: DepPairType | _: ManagedBufferType | _: OpaqueType |
- _: FragmentType | _: pipeline.type =>
+ case _: DepPairType | _: ManagedBufferType | _: OpaqueType | _: FragmentType =>
throw new Exception(s"${dt} not supported as output type")
}
@@ -408,7 +407,7 @@ object KernelExecutor {
throw new Exception("This should not happen")
}
case _: DepPairType | _: NatToDataApply | _: DataTypeIdentifier |
- _: ManagedBufferType | _: OpaqueType | _: FragmentType | _: pipeline.type =>
+ _: ManagedBufferType | _: OpaqueType | _: FragmentType =>
throw new Exception(s"the byte size of ${dt} should not be requested")
}
diff --git a/src/main/scala/shine/OpenCL/primitives/functional/CircularBuffer.scala b/src/main/scala/shine/OpenCL/primitives/functional/CircularBuffer.scala
index 9cc0e4305..f4b206cbf 100644
--- a/src/main/scala/shine/OpenCL/primitives/functional/CircularBuffer.scala
+++ b/src/main/scala/shine/OpenCL/primitives/functional/CircularBuffer.scala
@@ -1,22 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class CircularBuffer(a: AddressSpace,
- n: Nat,
- alloc: Nat,
- sz: Nat,
- dt1: DataType,
- dt2: DataType,
- load: Phrase[ExpType ->: ExpType],
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- load :: expT(dt1, read) ->: expT(dt2, write)
- input :: expT((n - 1 + sz)`.`dt1, read)
- override val t: ExpType = expT(n`.`(sz`.`dt2), read)
-}
\ No newline at end of file
+final case class CircularBuffer(val a: AddressSpace, val n: Nat, val alloc: Nat, val sz: Nat, val dt1: DataType, val dt2: DataType, val load: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ load :: FunType(expT(dt1, read), expT(dt2, write))
+ input :: expT(ArrayType(n - 1 + sz, dt1), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, ArrayType(sz, dt2)), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): CircularBuffer = new CircularBuffer(v.addressSpace(a), v.nat(n), v.nat(alloc), v.nat(sz), v.data(dt1), v.data(dt2), VisitAndRebuild(load, v), VisitAndRebuild(input, v))
+}
diff --git a/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala b/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala
index 3ffe8829a..adfe41170 100644
--- a/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala
+++ b/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala
@@ -1,25 +1,21 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.OpenCL.ParallelismLevel
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class DepMap(level: ParallelismLevel,
- dim: Int)
- (val n: Nat,
- val ft1:NatToData,
- val ft2:NatToData,
- val f: Phrase[`(nat)->:`[ExpType ->: ExpType]],
- val array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: f.t.x ->: expT(ft1(f.t.x), read) ->: expT(ft2(f.t.x), write)
- array :: expT(n `.d` ft1, read)
- override val t: ExpType = expT(n`.d`ft2, write)
-
- def unwrap: (Nat, NatToData, NatToData, Phrase[`(nat)->:`[ExpType ->: ExpType]], Phrase[ExpType]) =
- (n, ft1, ft2, f, array)
+final case class DepMap(level: shine.OpenCL.ParallelismLevel, dim: Int)(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: ({
+ val m = f.t.x
+ DepFunType[NatKind, PhraseType](m, FunType(expT(NatToDataApply(ft1, m), read), expT(NatToDataApply(ft2, m), write)))
+ })
+ array :: expT(DepArrayType(n, ft1), read)
+ }
+ override val t: ExpType = expT(DepArrayType(n, ft2), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepMap = new DepMap(level, dim)(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
+ def unwrap: (Nat, NatToData, NatToData, Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], Phrase[ExpType]) = (n, ft1, ft2, f, array)
}
diff --git a/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala b/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala
index 2016c5556..0387931b5 100644
--- a/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala
+++ b/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala
@@ -1,24 +1,20 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Iterate(a: AddressSpace,
- n: Nat,
- m: Nat,
- k: Nat,
- dt: DataType,
- f: Phrase[`(nat)->:`[ExpType ->: ExpType]],
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
+final case class Iterate(val a: AddressSpace, val n: Nat, val m: Nat, val k: Nat, val dt: DataType, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive {
{
- val l = f.t.x
- f :: l ->: expT({l * n}`.`dt, read) ->: expT(l`.`dt, write)
- array :: expT({m * n.pow(k)}`.`dt, read)
+ f :: ({
+ val l = f.t.x
+ DepFunType[NatKind, PhraseType](l, FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write)))
+ })
+ array :: expT(ArrayType(m * n.pow(k), dt), read)
}
- override val t: ExpType = expT(m`.`dt, write)
+ override val t: ExpType = expT(ArrayType(m, dt), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Iterate = new Iterate(v.addressSpace(a), v.nat(n), v.nat(m), v.nat(k), v.data(dt), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala b/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala
index edff14958..d2f150c9a 100644
--- a/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala
+++ b/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala
@@ -1,20 +1,14 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.OpenCL.{GlobalSize, LocalSize}
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-case class KernelCall(name: String,
- localSize: LocalSize,
- globalSize: GlobalSize,
- inTs: Seq[DataType],
- outT: DataType,
- args: Seq[Phrase[ExpType]]) extends ExpPrimitive {
- (inTs zip args).foreach{
- case (inT, arg) => arg :: expT(inT, read)
- }
+final case class KernelCall(name: String, localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize, inTs: Seq[DataType], args: Seq[Phrase[ExpType]])(val outT: DataType) extends ExpPrimitive {
+ {}
override val t: ExpType = expT(outT, write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): KernelCall = new KernelCall(name, localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v), inTs.map(v.data), args.map(VisitAndRebuild(_, v)))(v.data(outT))
}
diff --git a/src/main/scala/shine/OpenCL/primitives/functional/Map.scala b/src/main/scala/shine/OpenCL/primitives/functional/Map.scala
index 45137ca96..a99170788 100644
--- a/src/main/scala/shine/OpenCL/primitives/functional/Map.scala
+++ b/src/main/scala/shine/OpenCL/primitives/functional/Map.scala
@@ -1,25 +1,18 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.OpenCL.ParallelismLevel
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Map(level: ParallelismLevel,
- dim: Int)
- (val n: Nat,
- val dt1: DataType,
- val dt2: DataType,
- val f: Phrase[ExpType ->: ExpType],
- val array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt1, read) ->: expT(dt2, write)
- array :: expT(n `.` dt1, read)
- override val t: ExpType = expT(n `.` dt2, write)
-
- def unwrap: (Nat, DataType, DataType, Phrase[ExpType ->: ExpType], Phrase[ExpType]) =
- (n, dt1, dt2, f, array)
+final case class Map(level: shine.OpenCL.ParallelismLevel, dim: Int)(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt1, read), expT(dt2, write))
+ array :: expT(ArrayType(n, dt1), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt2), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Map = new Map(level, dim)(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
+ def unwrap: (Nat, DataType, DataType, Phrase[FunType[ExpType, ExpType]], Phrase[ExpType]) = (n, dt1, dt2, f, array)
}
diff --git a/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala b/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala
index 2c47bfe8a..a3f2ec8a0 100644
--- a/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala
+++ b/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala
@@ -1,20 +1,14 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-import scala.language.reflectiveCalls
-
-@expPrimitive
-final case class OpenCLFunctionCall(name: String,
- inTs: Seq[DataType],
- outT: DataType,
- args: Seq[Phrase[ExpType]]
- ) extends ExpPrimitive {
- (inTs zip args).foreach{
- case (inT, arg) => arg :: expT(inT, read)
- }
- override val t: ExpType = expT(outT, read)
+final case class OpenCLFunctionCall(name: String, inTs: Seq[DataType], args: Seq[Phrase[ExpType]])(val outT: DataType) extends ExpPrimitive {
+ {}
+ override val t: ExpType = expT(outT, write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): OpenCLFunctionCall = new OpenCLFunctionCall(name, inTs.map(v.data), args.map(VisitAndRebuild(_, v)))(v.data(outT))
}
diff --git a/src/main/scala/shine/OpenCL/primitives/functional/ReduceSeq.scala b/src/main/scala/shine/OpenCL/primitives/functional/ReduceSeq.scala
index 78e416668..f00544913 100644
--- a/src/main/scala/shine/OpenCL/primitives/functional/ReduceSeq.scala
+++ b/src/main/scala/shine/OpenCL/primitives/functional/ReduceSeq.scala
@@ -1,23 +1,19 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class ReduceSeq(unroll: Boolean)
- (val n: Nat,
- val initAddrSpace: AddressSpace,
- val dt1: DataType,
- val dt2: DataType,
- val f: Phrase[ExpType ->: ExpType ->: ExpType],
- val init: Phrase[ExpType],
- val array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt2, read) ->: expT(dt1, read) ->: expT(dt2, write)
- init :: expT(dt2, write)
- array :: expT(n`.`dt1, read)
+final case class ReduceSeq(unroll: Boolean)(val n: Nat, val a: AddressSpace, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], val init: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt2, read), FunType(expT(dt1, read), expT(dt2, write)))
+ init :: expT(dt2, write)
+ array :: expT(ArrayType(n, dt1), read)
+ }
override val t: ExpType = expT(dt2, read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ReduceSeq = new ReduceSeq(unroll)(v.nat(n), v.addressSpace(a), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(init, v), VisitAndRebuild(array, v))
+ def unwrap: (Nat, AddressSpace, DataType, DataType, Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], Phrase[ExpType], Phrase[ExpType]) = (n, a, dt1, dt2, f, init, array)
}
diff --git a/src/main/scala/shine/OpenCL/primitives/functional/RotateValues.scala b/src/main/scala/shine/OpenCL/primitives/functional/RotateValues.scala
index bf8f2fe27..093afa594 100644
--- a/src/main/scala/shine/OpenCL/primitives/functional/RotateValues.scala
+++ b/src/main/scala/shine/OpenCL/primitives/functional/RotateValues.scala
@@ -1,20 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class RotateValues(a: AddressSpace,
- n: Nat,
- sz: Nat,
- dt: DataType,
- write: Phrase[ExpType ->: ExpType],
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- write :: expT(dt, read) ->: expT(dt, shine.DPIA.Types.write)
- input :: expT((n - 1 + sz)`.`dt, read)
- override val t: ExpType = expT(n`.`(sz`.`dt), read)
-}
\ No newline at end of file
+final case class RotateValues(val a: AddressSpace, val n: Nat, val sz: Nat, val dt: DataType, val wrt: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ wrt :: FunType(expT(dt, read), expT(dt, write))
+ input :: expT(ArrayType(n - 1 + sz, dt), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, ArrayType(sz, dt)), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): RotateValues = new RotateValues(v.addressSpace(a), v.nat(n), v.nat(sz), v.data(dt), VisitAndRebuild(wrt, v), VisitAndRebuild(input, v))
+}
diff --git a/src/main/scala/shine/OpenCL/primitives/functional/Run.scala b/src/main/scala/shine/OpenCL/primitives/functional/Run.scala
index 25e276c2d..6dc810e2c 100644
--- a/src/main/scala/shine/OpenCL/primitives/functional/Run.scala
+++ b/src/main/scala/shine/OpenCL/primitives/functional/Run.scala
@@ -1,17 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.functional
-
-import shine.DPIA._
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.OpenCL.{GlobalSize, LocalSize}
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Run(localSize: LocalSize,
- globalSize: GlobalSize,
- dt: DataType,
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- input :: expT(dt, write)
+import shine.DPIA._
+final case class Run(localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize)(val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ input :: expT(dt, write)
+ }
override val t: ExpType = expT(dt, write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Run = new Run(localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v))(v.data(dt), VisitAndRebuild(input, v))
+ def unwrap: (DataType, Phrase[ExpType]) = (dt, input)
}
diff --git a/src/main/scala/shine/OpenCL/primitives/functional/ToMem.scala b/src/main/scala/shine/OpenCL/primitives/functional/ToMem.scala
index e7cd552bc..26cd17bb2 100644
--- a/src/main/scala/shine/OpenCL/primitives/functional/ToMem.scala
+++ b/src/main/scala/shine/OpenCL/primitives/functional/ToMem.scala
@@ -1,15 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class ToMem(addrSpace: AddressSpace,
- dt: DataType,
- input: Phrase[ExpType]
- ) extends ExpPrimitive {
- input :: expT(dt, write)
+final case class ToMem(val a: AddressSpace, val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ input :: expT(dt, write)
+ }
override val t: ExpType = expT(dt, read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ToMem = new ToMem(v.addressSpace(a), v.data(dt), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/OpenCL/primitives/functional/primitives.dpia b/src/main/scala/shine/OpenCL/primitives/functional/primitives.dpia
new file mode 100644
index 000000000..7f24dbfd8
--- /dev/null
+++ b/src/main/scala/shine/OpenCL/primitives/functional/primitives.dpia
@@ -0,0 +1,45 @@
+def circularBuffer(a: address, n: nat, alloc: nat, sz: nat, dt1: data, dt2: data,
+ load: exp[dt1, read] -> exp[dt2, write],
+ input: exp[(n-1+sz).dt1, read]): exp[n.sz.dt2, read]
+
+def depMap{level: shine.OpenCL.ParallelismLevel, dim: Int}
+ (n: nat, ft1: nat2data, ft2: nat2data,
+ f: (m: nat) -> exp[ft1(m), read] -> exp[ft2(m), write],
+ array: exp[n..ft1, read]): exp[n..ft2, write]
+
+def iterate(a: address, n: nat, m: nat, k: nat, dt: data,
+ f: (l: nat) -> exp[(l*n).dt, read] -> exp[l.dt, write],
+ array: exp[(m*(n^k)).dt, read]): exp[m.dt, write]
+
+def kernelCall{name: String,
+ localSize: shine.OpenCL.LocalSize,
+ globalSize: shine.OpenCL.GlobalSize,
+ inTs: Seq[DataType],
+ args: Seq[Phrase[ExpType]]}
+ (outT: data): exp[outT, write]
+
+def map{level: shine.OpenCL.ParallelismLevel, dim: Int}
+ (n: nat, dt1: data, dt2: data,
+ f: exp[dt1, read] -> exp[dt2, write],
+ array: exp[n.dt1, read]): exp[n.dt2, write]
+
+def openCLFunctionCall{name: String,
+ inTs: Seq[DataType],
+ args: Seq[Phrase[ExpType]]}
+ (outT: data): exp[outT, write]
+
+def reduceSeq{unroll: Boolean}
+ (n: nat, a: address, dt1: data, dt2: data,
+ f: exp[dt2, read] -> exp[dt1, read] -> exp[dt2, write],
+ init: exp[dt2, write],
+ array: exp[n.dt1, read]): exp[dt2, read]
+
+def rotateValues(a: address, n: nat, sz: nat, dt: data,
+ wrt: exp[dt, read] -> exp[dt, write],
+ input: exp[(n-1+sz).dt, read]): exp[n.sz.dt, read]
+
+def run{localSize: shine.OpenCL.LocalSize,
+ globalSize: shine.OpenCL.GlobalSize}
+ (dt: data, input: exp[dt, write]): exp[dt, write]
+
+def toMem(a: address, dt: data, input: exp[dt, write]): exp[dt, read]
diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/Barrier.scala b/src/main/scala/shine/OpenCL/primitives/imperative/Barrier.scala
index 9f0520381..cf222f8b7 100644
--- a/src/main/scala/shine/OpenCL/primitives/imperative/Barrier.scala
+++ b/src/main/scala/shine/OpenCL/primitives/imperative/Barrier.scala
@@ -1,11 +1,14 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.imperative
-
-import shine.DPIA.Phrases.CommandPrimitive
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class Barrier(local: Boolean, global: Boolean) extends CommandPrimitive {
- override def prettyPrint: String =
- s"""barrier( ${if(local) "CLK_LOCAL_MEM_FENCE" else ""} ${if(global && local) "|" else ""}
- ${if(global) "CLK_GLOBAL_MEM_FENCE" else ""})"""
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
+import shine.DPIA.Types._
+import shine.DPIA._
+final case class Barrier(local: Boolean, global: Boolean)() extends CommandPrimitive {
+ {}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Barrier = new Barrier(local, global)()
}
diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/HostExecution.scala b/src/main/scala/shine/OpenCL/primitives/imperative/HostExecution.scala
index a3054141c..54f977635 100644
--- a/src/main/scala/shine/OpenCL/primitives/imperative/HostExecution.scala
+++ b/src/main/scala/shine/OpenCL/primitives/imperative/HostExecution.scala
@@ -1,20 +1,19 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.OpenCL.AccessFlags
-import shine.macros.Primitive.comPrimitive
-
-// In a host program which contains host managed buffers,
-// A host execution is a section where only plain arrays are used
-@comPrimitive
-final case class HostExecution(params: Map[Identifier[_ <: PhraseType], AccessFlags],
- body: Phrase[CommType]) extends CommandPrimitive {
- body :: comm
-
- override def visitAndRebuild(f: VisitAndRebuild.Visitor): Phrase[CommType] =
- HostExecution(
- params.map({ case (k, v) =>
- VisitAndRebuild(k, f).asInstanceOf[Identifier[_ <: PhraseType]] -> v }),
- VisitAndRebuild(body, f))
-}
\ No newline at end of file
+import shine.DPIA._
+final case class HostExecution(params: Map[Identifier[_ <: PhraseType], shine.OpenCL.AccessFlags])(val body: Phrase[CommType]) extends CommandPrimitive {
+ {
+ body :: comm
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): HostExecution = new HostExecution(params.map({
+ case (key, value) =>
+ VisitAndRebuild(key, v).asInstanceOf[Identifier[_ <: PhraseType]] -> value
+ }))(VisitAndRebuild(body, v))
+}
diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistribute.scala b/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistribute.scala
index 09c65e30f..c046d1597 100644
--- a/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistribute.scala
+++ b/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistribute.scala
@@ -1,20 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.OpenCL.ParallelismLevel
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class IdxDistribute(m: Nat,
- n: Nat,
- stride: Nat,
- parallelismLevel: ParallelismLevel,
- dt: DataType,
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- array :: expT(m`.`dt, read)
- override val t: ExpType = expT(n`.`dt, read)
+final case class IdxDistribute(parallelismLevel: shine.OpenCL.ParallelismLevel)(val m: Nat, val n: Nat, val stride: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ array :: expT(ArrayType(m, dt), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt), read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxDistribute = new IdxDistribute(parallelismLevel)(v.nat(m), v.nat(n), v.nat(stride), v.data(dt), VisitAndRebuild(array, v))
+ def unwrap: (Nat, Nat, Nat, DataType, Phrase[ExpType]) = (m, n, stride, dt, array)
}
diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistributeAcc.scala b/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistributeAcc.scala
index 0e7de8392..0de4a9e08 100644
--- a/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistributeAcc.scala
+++ b/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistributeAcc.scala
@@ -1,20 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.DPIA.{Nat, _}
-import shine.OpenCL.ParallelismLevel
-import shine.macros.Primitive.accPrimitive
-
-@accPrimitive
-final case class IdxDistributeAcc(m: Nat,
- n: Nat,
- stride: Nat,
- parallelismLevel: ParallelismLevel,
- dt: DataType,
- array: Phrase[AccType]
- ) extends AccPrimitive {
- array :: accT(m`.`dt)
- override val t: AccType = accT(n`.`dt)
+import shine.DPIA._
+final case class IdxDistributeAcc(parallelismLevel: shine.OpenCL.ParallelismLevel)(val m: Nat, val n: Nat, val stride: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive {
+ {
+ array :: accT(ArrayType(m, dt))
+ }
+ override val t: AccType = accT(ArrayType(n, dt))
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxDistributeAcc = new IdxDistributeAcc(parallelismLevel)(v.nat(m), v.nat(n), v.nat(stride), v.data(dt), VisitAndRebuild(array, v))
+ def unwrap: (Nat, Nat, Nat, DataType, Phrase[AccType]) = (m, n, stride, dt, array)
}
diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala b/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala
index 4864c7b20..2e420b356 100644
--- a/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala
+++ b/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala
@@ -1,14 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.OpenCL.{GlobalSize, LocalSize}
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-case class KernelCallCmd(name: String,
- localSize: LocalSize,
- globalSize: GlobalSize,
- output: Phrase[AccType],
- args: Seq[Phrase[ExpType]]) extends CommandPrimitive
-
+import shine.DPIA._
+final case class KernelCallCmd(name: String, localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize, args: Seq[Phrase[ExpType]])(val dt: DataType, val output: Phrase[AccType]) extends CommandPrimitive {
+ {
+ output :: accT(dt)
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): KernelCallCmd = new KernelCallCmd(name, localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v), args.map(VisitAndRebuild(_, v)))(v.data(dt), VisitAndRebuild(output, v))
+ def unwrap: (DataType, Phrase[AccType]) = (dt, output)
+}
diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/New.scala b/src/main/scala/shine/OpenCL/primitives/imperative/New.scala
index 4af7b56f6..4b455cb27 100644
--- a/src/main/scala/shine/OpenCL/primitives/imperative/New.scala
+++ b/src/main/scala/shine/OpenCL/primitives/imperative/New.scala
@@ -1,14 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class New(a: AddressSpace,
- dt: DataType,
- f: Phrase[VarType ->: CommType]
- ) extends CommandPrimitive {
- f :: varT(dt) ->: comm
+final case class New(val a: AddressSpace, val dt: DataType, val f: Phrase[FunType[PhrasePairType[ExpType, AccType], CommType]]) extends CommandPrimitive {
+ {
+ f :: FunType(PhrasePairType(expT(dt, read), accT(dt)), comm)
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): New = new New(v.addressSpace(a), v.data(dt), VisitAndRebuild(f, v))
}
diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/NewDoubleBuffer.scala b/src/main/scala/shine/OpenCL/primitives/imperative/NewDoubleBuffer.scala
index b96f98d74..9e50f0b6a 100644
--- a/src/main/scala/shine/OpenCL/primitives/imperative/NewDoubleBuffer.scala
+++ b/src/main/scala/shine/OpenCL/primitives/imperative/NewDoubleBuffer.scala
@@ -1,22 +1,18 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.imperative
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class NewDoubleBuffer(a: AddressSpace,
- dt1: DataType,
- dt2: DataType,
- dt3: DataType,
- n: Nat,
- in: Phrase[ExpType],
- out: Phrase[AccType],
- f: Phrase[(ExpType x AccType x CommType x CommType) ->: CommType]
- ) extends CommandPrimitive {
- in :: expT(dt1, read)
- out :: accT(dt2)
- f :: (((varT(n`.`dt3) x comm) x comm) ->: comm)
+final case class NewDoubleBuffer(val a: AddressSpace, val dt1: DataType, val dt2: DataType, val dt3: DataType, val n: Nat, val in: Phrase[ExpType], val out: Phrase[AccType], val f: Phrase[FunType[PhrasePairType[PhrasePairType[PhrasePairType[ExpType, AccType], CommType], CommType], CommType]]) extends CommandPrimitive {
+ {
+ in :: expT(dt1, read)
+ out :: accT(dt2)
+ f :: FunType(PhrasePairType(PhrasePairType(PhrasePairType(expT(ArrayType(n, dt3), read), accT(ArrayType(n, dt3))), comm), comm), comm)
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): NewDoubleBuffer = new NewDoubleBuffer(v.addressSpace(a), v.data(dt1), v.data(dt2), v.data(dt3), v.nat(n), VisitAndRebuild(in, v), VisitAndRebuild(out, v), VisitAndRebuild(f, v))
}
diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/NewManagedBuffer.scala b/src/main/scala/shine/OpenCL/primitives/imperative/NewManagedBuffer.scala
index bfdc31ede..5c439ae17 100644
--- a/src/main/scala/shine/OpenCL/primitives/imperative/NewManagedBuffer.scala
+++ b/src/main/scala/shine/OpenCL/primitives/imperative/NewManagedBuffer.scala
@@ -1,14 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.imperative
-
-import shine.DPIA._
-import shine.DPIA.Types._
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
-import shine.OpenCL.AccessFlags
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class NewManagedBuffer(dt: DataType,
- access: AccessFlags,
- k: Phrase[VarType ->: CommType]) extends CommandPrimitive {
- k :: varT(ManagedBufferType(dt)) ->: comm
+import shine.DPIA.Types.DataType._
+import shine.DPIA.Types._
+import shine.DPIA._
+final case class NewManagedBuffer(access: shine.OpenCL.AccessFlags)(val dt: DataType, val k: Phrase[FunType[PhrasePairType[ExpType, AccType], CommType]]) extends CommandPrimitive {
+ {
+ k :: FunType(PhrasePairType(expT(ManagedBufferType(dt), read), accT(ManagedBufferType(dt))), comm)
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): NewManagedBuffer = new NewManagedBuffer(access)(v.data(dt), VisitAndRebuild(k, v))
+ def unwrap: (DataType, Phrase[FunType[PhrasePairType[ExpType, AccType], CommType]]) = (dt, k)
}
diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/ParFor.scala b/src/main/scala/shine/OpenCL/primitives/imperative/ParFor.scala
index 52b1759a5..47efa67bb 100644
--- a/src/main/scala/shine/OpenCL/primitives/imperative/ParFor.scala
+++ b/src/main/scala/shine/OpenCL/primitives/imperative/ParFor.scala
@@ -1,53 +1,18 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.imperative
-
-import shine.DPIA.Phrases.{CommandPrimitive, _}
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.OpenCL._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class ParFor(level: ParallelismLevel,
- dim: Int,
- unroll: Boolean)
- (val n: Nat,
- val dt: DataType,
- val out: Phrase[AccType],
- val loopBody: Phrase[ExpType ->: AccType ->: CommType],
- val init: Nat = ParFor.initInit(level, dim),
- val step: Nat = ParFor.initStep(level, dim)
- ) extends CommandPrimitive {
- val name: String = ParFor.initName(level)
-
- out :: accT(n`.`dt)
- loopBody :: expT(idx(n), read) ->: accT(dt) ->: comm
-
- lazy val unwrapBody: (Identifier[ExpType], Identifier[AccType], Phrase[CommType]) = loopBody match {
- case Lambda(i, Lambda(o, body)) => (i, o, body)
- case _ => throw new Exception("This should not happen")
- }
-}
-
-object ParFor {
- def initInit(level: ParallelismLevel, dim: Int): Nat = level match {
- case Global => get_global_id(dim)
- case Local => get_local_id (dim)
- case WorkGroup => get_group_id (dim)
- case Sequential | Warp | Lane => throw new Exception("This should not happen")
- }
-
- def initStep(level: ParallelismLevel, dim: Int): Nat = level match {
- case Global => get_global_size(dim)
- case Local => get_local_size (dim)
- case WorkGroup => get_num_groups (dim)
- case Sequential | Warp | Lane => throw new Exception("This should not happen")
- }
-
- def initName(level: ParallelismLevel): String = level match {
- case Global => freshName("gl_id_")
- case Local => freshName( "l_id_")
- case WorkGroup => freshName("wg_id_")
- case Sequential | Warp | Lane => throw new Exception("This should not happen")
+final case class ParFor(level: shine.OpenCL.ParallelismLevel, dim: Int, unroll: Boolean, prefix: String)(val init: Nat, val n: Nat, val step: Nat, val dt: DataType, val out: Phrase[AccType], val body: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive {
+ {
+ out :: accT(ArrayType(n, dt))
+ body :: FunType(expT(IndexType(n), read), FunType(accT(dt), comm))
}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParFor = new ParFor(level, dim, unroll, prefix)(v.nat(init), v.nat(n), v.nat(step), v.data(dt), VisitAndRebuild(out, v), VisitAndRebuild(body, v))
+ def unwrap: (Nat, Nat, Nat, DataType, Phrase[AccType], Phrase[FunType[ExpType, FunType[AccType, CommType]]]) = (init, n, step, dt, out, body)
}
diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala b/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala
index a7e94c82b..bd45c3fd0 100644
--- a/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala
+++ b/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala
@@ -1,52 +1,21 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenCL.primitives.imperative
-
-import shine.DPIA.Phrases.{CommandPrimitive, _}
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.OpenCL._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class ParForNat(level: ParallelismLevel,
- dim: Int,
- unroll: Boolean)
- (val n: Nat,
- val ft: NatToData,
- val out: Phrase[AccType],
- val loopBody: Phrase[`(nat)->:`[AccType ->: CommType]],
- val init: Nat = ParForNat.initInit(level, dim),
- val step: Nat = ParForNat.initStep(level, dim),
- val name: String = ParForNat.initName(level)
- ) extends CommandPrimitive {
- out :: accT(n`.d`ft)
- loopBody :: loopBody.t.x ->: accT(ft(loopBody.t.x)) ->: comm
-
- lazy val unwrapBody: (NatIdentifier, Identifier[AccType], Phrase[CommType]) = loopBody match {
- case DepLambda(i: NatIdentifier, Lambda(o, body)) => (i, o, body)
- case _ => throw new Exception("This should not happen")
- }
-}
-
-object ParForNat {
- def initInit(level: ParallelismLevel, dim: Int): Nat = level match {
- case Global => get_global_id(dim)
- case Local => get_local_id(dim)
- case WorkGroup => get_group_id(dim)
- case Sequential | Warp | Lane => throw new Exception("This should not happen")
- }
-
- def initStep(level: ParallelismLevel, dim: Int): Nat = level match {
- case Global => get_global_size(dim)
- case Local => get_local_size(dim)
- case WorkGroup => get_num_groups(dim)
- case Sequential | Warp | Lane => throw new Exception("This should not happen")
- }
-
- def initName(level: ParallelismLevel): String = level match {
- case Global => freshName("gl_id_")
- case Local => freshName("l_id_")
- case WorkGroup => freshName("wg_id_")
- case Sequential | Warp | Lane => throw new Exception("This should not happen")
+final case class ParForNat(level: shine.OpenCL.ParallelismLevel, dim: Int, unroll: Boolean, prefix: String)(val init: Nat, val n: Nat, val step: Nat, val ft: NatToData, val out: Phrase[AccType], val body: Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) extends CommandPrimitive {
+ {
+ out :: accT(DepArrayType(n, ft))
+ body :: ({
+ val i = body.t.x
+ DepFunType[NatKind, PhraseType](i, FunType(accT(NatToDataApply(ft, i)), comm))
+ })
}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParForNat = new ParForNat(level, dim, unroll, prefix)(v.nat(init), v.nat(n), v.nat(step), v.natToData(ft), VisitAndRebuild(out, v), VisitAndRebuild(body, v))
+ def unwrap: (Nat, Nat, Nat, NatToData, Phrase[AccType], Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) = (init, n, step, ft, out, body)
}
diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/primitives.dpia b/src/main/scala/shine/OpenCL/primitives/imperative/primitives.dpia
new file mode 100644
index 000000000..5b78dbaa0
--- /dev/null
+++ b/src/main/scala/shine/OpenCL/primitives/imperative/primitives.dpia
@@ -0,0 +1,45 @@
+def barrier{local: Boolean, global: Boolean}(): comm
+
+// In a host program which contains host managed buffers,
+// A host execution is a section where only plain arrays are used
+def hostExecution{params: Map[Identifier[_ <: PhraseType], shine.OpenCL.AccessFlags]}
+ (body: comm): comm
+
+def idxDistribute{parallelismLevel: shine.OpenCL.ParallelismLevel}
+ (m: nat, n: nat, stride: nat, dt: data, array: exp[m.dt, read]): exp[n.dt, read]
+
+def idxDistributeAcc{parallelismLevel: shine.OpenCL.ParallelismLevel}
+ (m: nat, n: nat, stride: nat, dt: data, array: acc[m.dt]): acc[n.dt]
+
+def kernelCallCmd{name: String,
+ localSize: shine.OpenCL.LocalSize,
+ globalSize: shine.OpenCL.GlobalSize,
+ args: Seq[Phrase[ExpType]]}
+ (dt: data, output: acc[dt]): comm
+
+def new(a: address, dt: data, f: var[dt] -> comm): comm
+
+def newDoubleBuffer(a: address, dt1: data, dt2: data, dt3: data, n: nat,
+ in: exp[dt1, read], out: acc[dt2],
+ f: ((var[n.dt3], comm), comm) -> comm): comm
+
+def newManagedBuffer{access: shine.OpenCL.AccessFlags}
+ (dt: data, k: var[managed[dt]] -> comm): comm
+
+def parFor{level: shine.OpenCL.ParallelismLevel,
+ dim: Int,
+ unroll: Boolean,
+ prefix: String}
+ (init: nat, n: nat, step: nat,
+ dt: data,
+ out: acc[n.dt],
+ body: exp[idx[n], read] -> acc[dt] -> comm): comm
+
+def parForNat{level: shine.OpenCL.ParallelismLevel,
+ dim: Int,
+ unroll: Boolean,
+ prefix: String}
+ (init: nat, n: nat, step: nat,
+ ft: nat2data,
+ out: acc[n..ft],
+ body: (i: nat) -> acc[ft(i)] -> comm): comm
diff --git a/src/main/scala/shine/OpenCL/primitives/intermediate/MapI.scala b/src/main/scala/shine/OpenCL/primitives/intermediate/MapI.scala
index 2489221dc..0618afeec 100644
--- a/src/main/scala/shine/OpenCL/primitives/intermediate/MapI.scala
+++ b/src/main/scala/shine/OpenCL/primitives/intermediate/MapI.scala
@@ -6,7 +6,6 @@ import shine.DPIA.Types.DataType.idx
import shine.DPIA.Types._
import shine.DPIA._
import shine.OpenCL._
-import shine.OpenCL.primitives.imperative.ParFor
final case class MapI(level: ParallelismLevel, dim: Int) {
def apply(n: Nat, dt1: DataType, dt2: DataType,
@@ -14,7 +13,7 @@ final case class MapI(level: ParallelismLevel, dim: Int) {
in: Phrase[ExpType],
out: Phrase[AccType]): Phrase[CommType] = {
comment(s"map${level.toString}") `;`
- ParFor(level, dim, unroll = false)(n, dt2, out,
+ shine.OpenCL.DSL.parFor(level, dim, unroll = false)(n, dt2, out,
λ(expT(idx(n), read))(i => λ(accT(dt2))(a => f(in `@` i)(a))))
}
}
diff --git a/src/main/scala/shine/OpenMP/CodeGenerator.scala b/src/main/scala/shine/OpenMP/CodeGenerator.scala
index ff9271f18..e567531c9 100644
--- a/src/main/scala/shine/OpenMP/CodeGenerator.scala
+++ b/src/main/scala/shine/OpenMP/CodeGenerator.scala
@@ -46,19 +46,19 @@ class CodeGenerator(override val decls: CCodeGenerator.Declarations,
override def acc(env: Environment,
path: Path,
cont: Expr => Stmt): Phrase[AccType] => Stmt = {
- case AsVectorAcc(n, _, _, a) => path match {
- case (i: CIntExpr) :: ps => a |> acc(env, CIntExpr(i / n) :: ps, cont)
+ case AsVectorAcc(_, m, _, a) => path match {
+ case (i: CIntExpr) :: ps => a |> acc(env, CIntExpr(i / m) :: ps, cont)
case _ => error(s"Expected path to be not empty")
}
- case AsScalarAcc(_, m, dt, a) => path match {
+ case AsScalarAcc(n, _, dt, a) => path match {
case (i: CIntExpr) :: (j: CIntExpr) :: ps =>
- a |> acc(env, CIntExpr((i * m) + j) :: ps, cont)
+ a |> acc(env, CIntExpr((i * n) + j) :: ps, cont)
case (i: CIntExpr) :: Nil =>
- a |> acc(env, CIntExpr(i * m) :: Nil, {
+ a |> acc(env, CIntExpr(i * n) :: Nil, {
case ArraySubscript(v, idx) =>
// emit something like: ((struct float4 *)v)[idx]
- val ptrType = C.AST.PointerType(typ(VectorType(m, dt)))
+ val ptrType = C.AST.PointerType(typ(VectorType(n, dt)))
cont(C.AST.ArraySubscript(C.AST.Cast(ptrType, v), idx))
})
case _ => error(s"Expected path to be not empty")
@@ -93,9 +93,9 @@ class CodeGenerator(override val decls: CCodeGenerator.Declarations,
}
case _ => phrase |> super.exp(env, path, cont)
}
- case ForeignFunctionCall(f, inTs, outT, args) =>
- OpenMPCodeGen.codeGenForeignFunctionCall(f, inTs, outT, args, env, path, cont)
- case AsVectorAligned(n, _, _, dt, e) => path match {
+ case ffc@ForeignFunctionCall(f, inTs, args) =>
+ OpenMPCodeGen.codeGenForeignFunctionCall(f, inTs, ffc.outT, args, env, path, cont)
+ case AsVectorAligned(n, _, dt, _, e) => path match {
case (i: CIntExpr) :: (j: CIntExpr) :: ps =>
e |> exp(env, CIntExpr((i * n) + j) :: ps, cont)
@@ -277,7 +277,7 @@ class CodeGenerator(override val decls: CCodeGenerator.Declarations,
}
}
- def codeGenForeignFunctionCall(funDecl: ForeignFunction.Declaration,
+ def codeGenForeignFunctionCall(funDecl: rise.core.ForeignFunction.Decl,
inTs: collection.Seq[DataType],
outT: DataType,
args: collection.Seq[Phrase[ExpType]],
diff --git a/src/main/scala/shine/OpenMP/DSL/ImperativePrimitives.scala b/src/main/scala/shine/OpenMP/DSL/ImperativePrimitives.scala
index 9b21e9972..9fb9b594e 100644
--- a/src/main/scala/shine/OpenMP/DSL/ImperativePrimitives.scala
+++ b/src/main/scala/shine/OpenMP/DSL/ImperativePrimitives.scala
@@ -19,7 +19,7 @@ object parFor {
object `parForVec` {
def apply(n: Nat,
- st: ScalarType,
+ st: DataType,
out: Phrase[AccType],
f: Phrase[ExpType] => Phrase[AccType] => Phrase[CommType]): ForVec =
ForVec(n, st, out, λ(expT(idx(n), read))(i => λ(accT(st))(o => f(i)(o) )))
diff --git a/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala b/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala
index e6f2afea6..9035e541e 100644
--- a/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala
+++ b/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala
@@ -1,19 +1,20 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenMP.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class DepMapPar(n: Nat,
- ft1: NatToData,
- ft2: NatToData,
- f: Phrase[`(nat)->:`[ExpType ->: ExpType]],
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: f.t.x ->: expT(ft1(f.t.x), read) ->: expT(ft2(f.t.x), write)
- array :: expT(n `.d` ft1, read)
- override val t: ExpType = expT(n`.d`ft2, write)
+final case class DepMapPar(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: ({
+ val m = f.t.x
+ DepFunType[NatKind, PhraseType](m, FunType(expT(NatToDataApply(ft1, m), read), expT(NatToDataApply(ft2, m), write)))
+ })
+ array :: expT(DepArrayType(n, ft1), read)
+ }
+ override val t: ExpType = expT(DepArrayType(n, ft2), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepMapPar = new DepMapPar(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/OpenMP/primitives/functional/MapPar.scala b/src/main/scala/shine/OpenMP/primitives/functional/MapPar.scala
index c84edc469..1b622ffd4 100644
--- a/src/main/scala/shine/OpenMP/primitives/functional/MapPar.scala
+++ b/src/main/scala/shine/OpenMP/primitives/functional/MapPar.scala
@@ -1,19 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenMP.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class MapPar(n: Nat,
- dt1: DataType,
- dt2: DataType,
- f: Phrase[ExpType ->: ExpType],
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt1, read) ->: expT(dt2, write)
- array :: expT(n`.`dt1, read)
- override val t: ExpType = expT(n`.`dt2, write)
+final case class MapPar(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt1, read), expT(dt2, write))
+ array :: expT(ArrayType(n, dt1), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt2), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapPar = new MapPar(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/OpenMP/primitives/functional/ReducePar.scala b/src/main/scala/shine/OpenMP/primitives/functional/ReducePar.scala
index 37f148228..9cc4bf236 100644
--- a/src/main/scala/shine/OpenMP/primitives/functional/ReducePar.scala
+++ b/src/main/scala/shine/OpenMP/primitives/functional/ReducePar.scala
@@ -1,20 +1,18 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenMP.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class ReducePar(n: Nat,
- dt1: DataType, dt2: DataType,
- f: Phrase[ExpType ->: ExpType ->: ExpType],
- init: Phrase[ExpType],
- array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt2, read) ->: expT(dt1, read) ->: expT(dt2, write)
- init :: expT(dt2, write)
- array :: expT(n`.`dt1, read)
+final case class ReducePar(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], val init: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt2, read), FunType(expT(dt1, read), expT(dt2, write)))
+ init :: expT(dt2, write)
+ array :: expT(ArrayType(n, dt1), read)
+ }
override val t: ExpType = expT(dt2, read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ReducePar = new ReducePar(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(init, v), VisitAndRebuild(array, v))
}
diff --git a/src/main/scala/shine/OpenMP/primitives/functional/primitives.dpia b/src/main/scala/shine/OpenMP/primitives/functional/primitives.dpia
new file mode 100644
index 000000000..ffd7dde02
--- /dev/null
+++ b/src/main/scala/shine/OpenMP/primitives/functional/primitives.dpia
@@ -0,0 +1,12 @@
+def depMapPar(n: nat, ft1: nat2data, ft2: nat2data,
+ f: (m: nat) -> exp[ft1(m), read] -> exp[ft2(m), write],
+ array: exp[n..ft1, read]): exp[n..ft2, write]
+
+def mapPar(n: nat, dt1: data, dt2: data,
+ f: exp[dt1, read] -> exp[dt2, write],
+ array: exp[n.dt1, read]): exp[n.dt2, write]
+
+def reducePar(n: nat, dt1: data, dt2: data,
+ f: exp[dt2, read] -> exp[dt1, read] -> exp[dt2, write],
+ init: exp[dt2, write],
+ array: exp[n.dt1, read]): exp[dt2, read]
\ No newline at end of file
diff --git a/src/main/scala/shine/OpenMP/primitives/imperative/ParFor.scala b/src/main/scala/shine/OpenMP/primitives/imperative/ParFor.scala
index 3c1a0006b..d467a4c29 100644
--- a/src/main/scala/shine/OpenMP/primitives/imperative/ParFor.scala
+++ b/src/main/scala/shine/OpenMP/primitives/imperative/ParFor.scala
@@ -1,22 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenMP.primitives.imperative
-
-import shine.DPIA.Phrases.{CommandPrimitive, Phrase, _}
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class ParFor(n: Nat,
- dt: DataType,
- out: Phrase[AccType],
- body: Phrase[ExpType ->: AccType ->: CommType]
- ) extends CommandPrimitive {
- out :: accT(n`.`dt)
- body :: expT(idx(n), read) ->: accT(dt) ->: comm
-
- lazy val unwrapBody: (Identifier[ExpType], Identifier[AccType], Phrase[CommType]) = body match {
- case Lambda(i, Lambda(o, body)) => (i, o, body)
- case _ => throw new Exception("This should not happen")
+final case class ParFor(val n: Nat, val dt: DataType, val out: Phrase[AccType], val body: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive {
+ {
+ out :: accT(ArrayType(n, dt))
+ body :: FunType(expT(IndexType(n), read), FunType(accT(dt), comm))
}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParFor = new ParFor(v.nat(n), v.data(dt), VisitAndRebuild(out, v), VisitAndRebuild(body, v))
}
diff --git a/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala b/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala
index fc6eed716..92b4f87b3 100644
--- a/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala
+++ b/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala
@@ -1,22 +1,20 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.OpenMP.primitives.imperative
-
-import shine.DPIA.Phrases.{CommandPrimitive, Phrase, _}
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class ParForNat(n: Nat,
- ft: NatToData,
- out: Phrase[AccType],
- body: Phrase[`(nat)->:`[AccType ->: CommType]]
- ) extends CommandPrimitive {
- out :: accT(n`.d`ft)
- body :: body.t.x ->: accT(ft(body.t.x)) ->: comm
-
- lazy val unwrapBody: (NatIdentifier, Identifier[AccType], Phrase[CommType]) = body match {
- case DepLambda(n, Lambda(o, body)) => (n, o, body)
- case _ => throw new Exception("This should not happen")
+final case class ParForNat(val n: Nat, val ft: NatToData, val out: Phrase[AccType], val body: Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) extends CommandPrimitive {
+ {
+ out :: accT(DepArrayType(n, ft))
+ body :: ({
+ val i = body.t.x
+ DepFunType[NatKind, PhraseType](i, FunType(accT(NatToDataApply(ft, i)), comm))
+ })
}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParForNat = new ParForNat(v.nat(n), v.natToData(ft), VisitAndRebuild(out, v), VisitAndRebuild(body, v))
}
diff --git a/src/main/scala/shine/OpenMP/primitives/imperative/primitives.dpia b/src/main/scala/shine/OpenMP/primitives/imperative/primitives.dpia
new file mode 100644
index 000000000..4ceca55c9
--- /dev/null
+++ b/src/main/scala/shine/OpenMP/primitives/imperative/primitives.dpia
@@ -0,0 +1,7 @@
+def parFor(n: nat, dt: data,
+ out: acc[n.dt],
+ body: exp[idx[n], read] -> acc[dt] -> comm): comm
+
+def parForNat(n: nat, ft: nat2data,
+ out: acc[n..ft],
+ body: (i: nat) -> acc[ft(i)] -> comm): comm
\ No newline at end of file
diff --git a/src/main/scala/shine/cuda/AST/Types.scala b/src/main/scala/shine/cuda/AST/Types.scala
index eb683847b..34c41d417 100644
--- a/src/main/scala/shine/cuda/AST/Types.scala
+++ b/src/main/scala/shine/cuda/AST/Types.scala
@@ -10,8 +10,8 @@ object Wmma {
case MatrixLayout.Row_Major => "nvcuda::wmma::row_major"
case MatrixLayout.Col_Major => "nvcuda::wmma::col_major"
case i: MatrixLayoutIdentifier =>
- if (i.layout.isDefined)
- toString(i.layout.get)
+ if (i.layout != MatrixLayout.None)
+ toString(i.layout)
else
throw new Exception(s"layout $i not infered!")
case _ => throw new Exception("this should not happen")
diff --git a/src/main/scala/shine/cuda/Compilation/KernelCodeGenerator.scala b/src/main/scala/shine/cuda/Compilation/KernelCodeGenerator.scala
index 28b37898d..9f9b76eff 100644
--- a/src/main/scala/shine/cuda/Compilation/KernelCodeGenerator.scala
+++ b/src/main/scala/shine/cuda/Compilation/KernelCodeGenerator.scala
@@ -48,8 +48,11 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
override def cmd(env: Environment): Phrase[CommType] => Stmt = {
case f: ParFor =>
- val (i, o, p) = f.unwrapBody
- codeGenCudaParFor(f, f.n, f.dt, f.out, i, o, p, env)
+ f.body match {
+ case Lambda(i, Lambda(o, p)) =>
+ codeGenCudaParFor(f, f.n, f.dt, f.out, i, o, p, env)
+ case _ => throw new Exception("This should not happen")
+ }
case WmmaLoad(m, n, k, _, fragType, layoutIdentifier, matrix, fragmentAcc) =>
matrix |> exp(env, List(CIntExpr(0), CIntExpr(0)), matrixTile => {
@@ -92,7 +95,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
C.AST.ArithmeticExpr(ldm),
C.AST.DeclRef(toString(layout)))))}))
- case WmmaFill(_, _, _, _, fill, _, _, fragmentAcc) =>
+ case WmmaFill(_, _, _, _, _, _, fill, fragmentAcc) =>
fill |> exp(env, Nil, fill =>
fragmentAcc |> acc(env, Nil, fragment =>
C.AST.ExprStmt(C.AST.FunCall(
@@ -114,7 +117,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
bMatrix,
cMatrix)))))))
- case ForFragmentElements(fragType, inFragment, outFragmemt, Lambda(in, Lambda(out, p))) =>
+ case ForFragment(_, _, _, dt, _, _, inFragment, outFragmemt, Lambda(in, Lambda(out, p))) =>
inFragment |> exp(env, Nil, fragmentIn =>
outFragmemt |> acc(env, Nil, fragmentOut => {
val n = C.AST.StructMemberAccess(fragmentOut, C.AST.DeclRef("num_elements"))
@@ -142,8 +145,8 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
C.AST.ForLoop(C.AST.DeclStmt(init), cond, increment,
C.AST.Block(
immutable.Seq(
- C.AST.DeclStmt(C.AST.VarDecl(xIInPointer.name, PointerType(typ(fragType.dataType)), init = Some(xIInDecl))),
- C.AST.DeclStmt(C.AST.VarDecl(xIOutPointer.name, PointerType(typ(fragType.dataType)), init = Some(xIOutDecl))),
+ C.AST.DeclStmt(C.AST.VarDecl(xIInPointer.name, PointerType(typ(dt)), init = Some(xIInDecl))),
+ C.AST.DeclStmt(C.AST.VarDecl(xIOutPointer.name, PointerType(typ(dt)), init = Some(xIOutDecl))),
p |> updatedGen.cmd(env updatedIdentEnv (in -> xIIn) updatedIdentEnv (out -> xIOut))))))}))
case Assign(_, lhsAcc, rhs) =>
@@ -230,18 +233,18 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
case _ => error(s"Expected path to be not empty")
}
- case AsScalarAcc(_, m, dt, a) => path match {
+ case AsScalarAcc(n, _, dt, a) => path match {
case (i : CIntExpr) :: (j : CIntExpr) :: ps =>
- a |> acc(env, CIntExpr((i * m) + j) :: ps, cont)
+ a |> acc(env, CIntExpr((i * n) + j) :: ps, cont)
case (i : CIntExpr) :: Nil =>
- a |> acc(env,CIntExpr(i * m) :: Nil, array => {
+ a |> acc(env,CIntExpr(i * n) :: Nil, array => {
cont(
//acces first (vector-)element pointed by the pointer
C.AST.ArraySubscript(
//cast pointer to array to pointer of vectorType
C.AST.Cast(
- C.AST.PointerType(getVectorType(dt, m)),
+ C.AST.PointerType(getVectorType(dt, n)),
C.AST.UnaryExpr(C.AST.UnaryOperator.&, array)),
C.AST.ArithmeticExpr(0)))})
@@ -254,7 +257,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
override def exp(env: Environment,
path: Path,
cont: Expr => Stmt): Phrase[ExpType] => Stmt = {
- case phrase@AsVectorAligned(n, _, _, dt, e) => path match {
+ case phrase@AsVectorAligned(n, _, dt, _, e) => path match {
case (i : CIntExpr) :: (j : CIntExpr) :: ps =>
e |> exp(env, CIntExpr((i * n) + j) :: ps, cont)
@@ -288,14 +291,14 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
C.AST.FragmentType(m, n, k, typ(dataType), fragmentKind, layout)
case shine.DPIA.Types.f16 =>
cuda.AST.Type.half
- case shine.DPIA.Types.pipeline =>
+ case shine.DPIA.Types.OpaqueType("pipeline") =>
cuda.AST.Type.pipeline
case _ =>
super.typ(dt)
}
- private def getVectorType(dt: ScalarType, n: Nat): Type = {
+ private def getVectorType(dt: DataType, n: Nat): Type = {
if (n.eval > 0 && n.eval <= 4)
dt match {
case shine.DPIA.Types.u8 => BasicType(s"uchar$n")
@@ -306,13 +309,13 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
case shine.DPIA.Types.u32 => BasicType(s"uint$n")
case shine.DPIA.Types.f32 => BasicType(s"float$n")
case shine.DPIA.Types.f64 => BasicType(s"double$n")
- case _ => ???
+ case _ => throw new Exception(s"Can't create vector type from: ($dt, $n)")
}
else
dt match {
case shine.DPIA.Types.f16 if (n.eval > 0 && n.eval <= 8) => BasicType(s"float${n/2}")
case shine.DPIA.Types.f16 if (n.eval > 0 && n.eval <= 16) => BasicType(s"double${n/4}")
- case _ => ???
+ case _ => throw new Exception(s"Can't create vector type from: ($dt, $n)")
}
}
@@ -331,7 +334,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations,
p: Phrase[CommType],
env: Environment): Stmt = {
assert(!f.unroll)
- val cI = C.AST.DeclRef(f.name)
+ val cI = C.AST.DeclRef(freshName(f.prefix))
val range = RangeAdd(f.init, n, f.step)
val updatedGen = updatedRanges(cI.name, range)
diff --git a/src/main/scala/shine/cuda/Compilation/TranslationContext.scala b/src/main/scala/shine/cuda/Compilation/TranslationContext.scala
index 92ea837aa..41f8b7c5a 100644
--- a/src/main/scala/shine/cuda/Compilation/TranslationContext.scala
+++ b/src/main/scala/shine/cuda/Compilation/TranslationContext.scala
@@ -5,15 +5,15 @@ import shine.DPIA.Phrases.Phrase
import shine.DPIA.Types.{AccType, CommType, DataType, ExpType, FragmentType, read}
import shine.DPIA.primitives.imperative.Assign
import shine.DPIA.{accT, expT}
-import shine.cuda.primitives.imperative.ForFragmentElements
+import shine.cuda.primitives.imperative.ForFragment
class TranslationContext() extends shine.OpenCL.Compilation.TranslationContext {
override def assign(dt: DataType,
lhs: Phrase[AccType],
rhs: Phrase[ExpType]): Phrase[CommType] = {
dt match {
- case f: FragmentType =>
- ForFragmentElements(f, rhs, lhs,
+ case FragmentType(rows, columns, layers, dt, frag, layout) =>
+ ForFragment(rows, columns, layers, dt, frag, layout, rhs, lhs,
λ(expT(dt, read))(x =>
λ(accT(dt))(o =>
Assign(dt, o, x))))
diff --git a/src/main/scala/shine/cuda/DSL/package.scala b/src/main/scala/shine/cuda/DSL/package.scala
new file mode 100644
index 000000000..5223be1e1
--- /dev/null
+++ b/src/main/scala/shine/cuda/DSL/package.scala
@@ -0,0 +1,27 @@
+package shine.cuda
+
+import shine.DPIA.Nat
+import shine.DPIA.Phrases._
+import shine.DPIA.Types._
+import shine.OpenCL._
+import shine.cuda.primitives.imperative.ParFor
+
+package object DSL {
+ def parFor(level: ParallelismLevel,
+ dim: Int,
+ unroll: Boolean
+ ): (Nat, DataType, Phrase[AccType], Phrase[FunType[ExpType, FunType[AccType, CommType]]]) => ParFor =
+ level match {
+ case Global => ParFor(level, dim, unroll, "gl_id_")(
+ globalId(dim), _, globalDim(dim), _, _, _)
+ case Local => ParFor(level, dim, unroll, "tid_")(
+ threadId(dim), _, blockDim(dim), _, _, _)
+ case WorkGroup => ParFor(level, dim, unroll, "block_id_")(
+ blockId(dim), _, gridDim(dim), _, _, _)
+ case Warp => ParFor(level, dim, unroll, "warp_id_")(
+ warpId(dim), _, warpDim(dim), _, _, _)
+ case Lane => ParFor(level, dim, unroll, "lane_id_")(
+ laneId(dim), _, warpSize, _, _, _)
+ case Sequential => throw new Exception("This should not happen")
+ }
+}
diff --git a/src/main/scala/shine/cuda/package.scala b/src/main/scala/shine/cuda/package.scala
index 306b1d282..39cc51db1 100644
--- a/src/main/scala/shine/cuda/package.scala
+++ b/src/main/scala/shine/cuda/package.scala
@@ -1,7 +1,9 @@
package shine
-import arithexpr.arithmetic.{ArithExpr, ContinuousRange, PosInf, Range, SimplifiedExpr}
-import shine.OpenCL.BuiltInFunctionCall
+import arithexpr.arithmetic.{ArithExpr, SimplifiedExpr}
+import shine.DPIA.Nat
+import shine.DPIA.Phrases.Phrase
+import shine.DPIA.Types.{DataType, ExpType, FragmentKind, MatrixLayoutIdentifier}
package object cuda {
@@ -39,4 +41,11 @@ package object cuda {
def apply(param: Int): ArithExpr with SimplifiedExpr =
blockDim(param) * gridDim(param)
}
+
+ object AsFragment{
+ def apply(rows: Nat, columns: Nat, layers: Nat, dataType: DataType,
+ frag: FragmentKind, matrix: Phrase[ExpType]): shine.cuda.primitives.functional.AsFragment =
+ shine.cuda.primitives.functional.AsFragment(rows, columns, layers, dataType,
+ frag, MatrixLayoutIdentifier("ml"), matrix)
+ }
}
diff --git a/src/main/scala/shine/cuda/primitives/functional/AsFragment.scala b/src/main/scala/shine/cuda/primitives/functional/AsFragment.scala
index 20000579c..31d151cf2 100644
--- a/src/main/scala/shine/cuda/primitives/functional/AsFragment.scala
+++ b/src/main/scala/shine/cuda/primitives/functional/AsFragment.scala
@@ -1,36 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.functional
-
-import shine.DPIA.Nat
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.macros.Primitive.expPrimitive
-
-object AsFragment{
- def apply(rows: Nat, columns: Nat, d3: Nat, dataType: DataType, fragmentType: FragmentKind, matrix: Phrase[ExpType]):
- AsFragment = AsFragment(rows, columns, d3, dataType, fragmentType, matrix, MatrixLayoutIdentifier("ml"))
-}
-
-/**
- * Returns a fragment from a matrix tile which must resides in shared or global memory ({@link WmmaLoad}).
- * This primitive needs to be executed by a full warp!
- * @param rows number of rows of the fragment ({@link FragmentType#rows})
- * @param columns number of columns of the fragment ({@link FragmentType#columns})
- * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3})
- * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype})
- * @param fragmentKind kind of the fragment ({@link FragmentType#fragmentKind})
- * @param layout layout of the fragment ({@link FragmentType#layout}).
- * The layout will be infered in Codegeneration. Hence a `MatrixLayoutIdentifier` can be
- * used as layout.
- */
-@expPrimitive
-case class AsFragment(rows: Nat,
- columns: Nat,
- d3: Nat,
- dataType: DataType,
- fragmentKind: FragmentKind,
- matrix: Phrase[ExpType],
- layout: MatrixLayout) extends ExpPrimitive{
-
- matrix :: ExpType(ArrayType(rows, ArrayType(columns, dataType)), write)
- override val t: ExpType = ExpType(FragmentType(rows, columns, d3, dataType, fragmentKind, layout), write)
+import shine.DPIA._
+final case class AsFragment(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ input :: expT(ArrayType(rows, ArrayType(columns, dt)), read)
+ }
+ override val t: ExpType = expT(FragmentType(rows, columns, layers, dt, frag, layout), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsFragment = new AsFragment(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/cuda/primitives/functional/AsMatrix.scala b/src/main/scala/shine/cuda/primitives/functional/AsMatrix.scala
index fbf2c3c96..c3ea35b06 100644
--- a/src/main/scala/shine/cuda/primitives/functional/AsMatrix.scala
+++ b/src/main/scala/shine/cuda/primitives/functional/AsMatrix.scala
@@ -1,27 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.functional
-
-import shine.DPIA.Nat
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.macros.Primitive.expPrimitive
-
-/**
- * Returns a matrix tile with the elements from the `Accumulator`-fragment. The matrix tile must resides
- * in shared or global memory ({@link WmmaStore}).
- * This primitive needs to be executed by a full warp!
- * @param rows number of rows of the fragment ({@link FragmentType#rows})
- * @param columns number of columns of the fragment ({@link FragmentType#columns})
- * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3})
- * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype})
- * @param fragment fragment from which the elements should be stored
- */
-@expPrimitive
-case class AsMatrix(rows: Nat,
- columns: Nat,
- d3: Nat,
- dataType: DataType,
- fragment: Phrase[ExpType]) extends ExpPrimitive {
-
- fragment :: ExpType(FragmentType(rows, columns, d3, dataType), read)
- override val t: ExpType = ExpType(ArrayType(rows, ArrayType(columns, dataType)), write)
+import shine.DPIA._
+final case class AsMatrix(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ input :: expT(FragmentType(rows, columns, layers, dt, FragmentKind.Accumulator, MatrixLayout.None), read)
+ }
+ override val t: ExpType = expT(ArrayType(rows, ArrayType(columns, dt)), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsMatrix = new AsMatrix(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), VisitAndRebuild(input, v))
}
diff --git a/src/main/scala/shine/cuda/primitives/functional/GenerateFragment.scala b/src/main/scala/shine/cuda/primitives/functional/GenerateFragment.scala
index 7bb57180f..fa0144660 100644
--- a/src/main/scala/shine/cuda/primitives/functional/GenerateFragment.scala
+++ b/src/main/scala/shine/cuda/primitives/functional/GenerateFragment.scala
@@ -1,31 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.functional
-
-import shine.DPIA.Nat
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.macros.Primitive.expPrimitive
-
-/**
- * Returns a fragment in which all elements have a specific value ({@link WmmaFill}).
- * This primitive needs to be executed by a full warp!
- * @param rows number of rows of the fragment ({@link FragmentType#rows})
- * @param columns number of columns of the fragment ({@link FragmentType#columns})
- * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3})
- * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype})
- * @param fill new value of all elements in the fragment (type of fill: `dataType`)
- * @param fragmentKind kind of the fragment ({@link FragmentType#fragmentKind})
- * @param layout layout of the fragment ({@link FragmentType#layout})
- */
-@expPrimitive
-case class GenerateFragment(rows: Nat,
- columns: Nat,
- d3: Nat,
- dataType: DataType,
- fill: Phrase[ExpType],
- fragmentKind: FragmentKind,
- layout: MatrixLayout) extends ExpPrimitive {
-
- fill :: ExpType(dataType, read)
-
- override val t: ExpType = ExpType(FragmentType(rows, columns, d3, dataType), write)
+import shine.DPIA._
+final case class GenerateFragment(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val fill: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ fill :: expT(dt, read)
+ }
+ override val t: ExpType = expT(FragmentType(rows, columns, layers, dt, frag, layout), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): GenerateFragment = new GenerateFragment(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(fill, v))
}
diff --git a/src/main/scala/shine/cuda/primitives/functional/GlobalToShared.scala b/src/main/scala/shine/cuda/primitives/functional/GlobalToShared.scala
index caecc088b..cc4f06188 100644
--- a/src/main/scala/shine/cuda/primitives/functional/GlobalToShared.scala
+++ b/src/main/scala/shine/cuda/primitives/functional/GlobalToShared.scala
@@ -1,21 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.functional
-
-import shine.DPIA.DSL.{`new` => _}
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.expPrimitive
-
-/**
- * Returns a copy in shared memory of data in global memory ({@link GlobalToSharedAcc}).
- * @param dt datatype of data which should be copied
- * @param inputGlobal data in global memory which should be copied to shared memory
- */
-@expPrimitive
-final case class GlobalToShared(dt: DataType,
- inputGlobal: Phrase[ExpType]) extends ExpPrimitive {
-
- inputGlobal :: expT(dt, write)
+final case class GlobalToShared(val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ input :: expT(dt, write)
+ }
override val t: ExpType = expT(dt, read)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): GlobalToShared = new GlobalToShared(v.data(dt), VisitAndRebuild(input, v))
}
-
diff --git a/src/main/scala/shine/cuda/primitives/functional/Map.scala b/src/main/scala/shine/cuda/primitives/functional/Map.scala
index 8805e60a7..7695393cd 100644
--- a/src/main/scala/shine/cuda/primitives/functional/Map.scala
+++ b/src/main/scala/shine/cuda/primitives/functional/Map.scala
@@ -1,25 +1,18 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.functional
-
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.OpenCL.ParallelismLevel
-import shine.macros.Primitive.expPrimitive
-
-@expPrimitive
-final case class Map(level: ParallelismLevel,
- dim: Int)
- (val n: Nat,
- val dt1: DataType,
- val dt2: DataType,
- val f: Phrase[ExpType ->: ExpType],
- val array: Phrase[ExpType]
- ) extends ExpPrimitive {
- f :: expT(dt1, read) ->: expT(dt2, write)
- array :: expT(n `.` dt1, read)
- override val t: ExpType = expT(n `.` dt2, write)
-
- def unwrap: (Nat, DataType, DataType, Phrase[ExpType ->: ExpType], Phrase[ExpType]) =
- (n, dt1, dt2, f, array)
+final case class Map(level: shine.OpenCL.ParallelismLevel, dim: Int)(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt1, read), expT(dt2, write))
+ array :: expT(ArrayType(n, dt1), read)
+ }
+ override val t: ExpType = expT(ArrayType(n, dt2), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): Map = new Map(level, dim)(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v))
+ def unwrap: (Nat, DataType, DataType, Phrase[FunType[ExpType, ExpType]], Phrase[ExpType]) = (n, dt1, dt2, f, array)
}
diff --git a/src/main/scala/shine/cuda/primitives/functional/MapFragment.scala b/src/main/scala/shine/cuda/primitives/functional/MapFragment.scala
new file mode 100644
index 000000000..9757cf271
--- /dev/null
+++ b/src/main/scala/shine/cuda/primitives/functional/MapFragment.scala
@@ -0,0 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+package shine.cuda.primitives.functional
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
+import shine.DPIA.Types._
+import shine.DPIA._
+final case class MapFragment(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val f: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ f :: FunType(expT(dt, read), expT(dt, write))
+ input :: expT(FragmentType(rows, columns, layers, dt, frag, layout), read)
+ }
+ override val t: ExpType = expT(FragmentType(rows, columns, layers, dt, frag, layout), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapFragment = new MapFragment(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(f, v), VisitAndRebuild(input, v))
+}
diff --git a/src/main/scala/shine/cuda/primitives/functional/MapFragmentElements.scala b/src/main/scala/shine/cuda/primitives/functional/MapFragmentElements.scala
deleted file mode 100644
index 0ce195bec..000000000
--- a/src/main/scala/shine/cuda/primitives/functional/MapFragmentElements.scala
+++ /dev/null
@@ -1,26 +0,0 @@
-package shine.cuda.primitives.functional
-
-import shine.DPIA.->:
-import shine.DPIA.Phrases._
-import shine.DPIA.Types._
-import shine.macros.Primitive.expPrimitive
-
-/**
- * Returns a fragment with the values of an applied function to elements of a fragment.
- * This primitive needs to be executed by a full warp!
- * @param fragType type of the fragment
- * @param fragment fragment of type `fragType` on whose elements the function should be applied
- * @param fun function which takes an element of type `fragType.dataType` and
- * returns an element of type `fragType.dataType`
- */
-@expPrimitive
-case class MapFragmentElements(fragType: FragmentType,
- fragment: Phrase[ExpType],
- fun: Phrase[ExpType ->: ExpType],
- ) extends ExpPrimitive {
-
- fragment :: ExpType(fragType, read)
- fun :: ExpType(fragType.dataType, read) ->: ExpType(fragType.dataType, write)
-
- override val t: ExpType = ExpType(fragType, write)
-}
diff --git a/src/main/scala/shine/cuda/primitives/functional/TensorMatMultAdd.scala b/src/main/scala/shine/cuda/primitives/functional/TensorMatMultAdd.scala
index 88a61459e..c35cf09a7 100644
--- a/src/main/scala/shine/cuda/primitives/functional/TensorMatMultAdd.scala
+++ b/src/main/scala/shine/cuda/primitives/functional/TensorMatMultAdd.scala
@@ -1,41 +1,18 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.functional
-
-import shine.DPIA.Nat
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.macros.Primitive.expPrimitive
-
-/**
- * Executes an MMA instruction using (multiple) Tensor Cores ({@link WmmaMMA}).
- * Returns a `Accumulator`-fragment as result of: aMatrix * bMatrix + cMatrix
- * (inplace operations using the same variable as `Acceptor` in `acceptorTranslation` and
- * as `cMatrix` are possible).
- * This primitive needs to be executed by a full warp!
- * @param m number of rows of the `aMatrix`
- * @param n number of columns of the `bMatrix` and the `cMatrix`
- * @param k number of columns of the `aMatrix` and number of rows of the `bMatrix`
- * @param layoutA layout of the `aMatrix`
- * @param layoutB layout of the `bMatrix`
- * @param dataType datatype of elements of `aMatrix` and `bMatrix` ({@link FragmentType#datatype})
- * @param dataTypeAcc datatype of elements of `cMatrix` and the resultMatrix ({@link FragmentType#datatype})
- * @param aMatrix first factor of type fragment
- * @param bMatrix second factor of type fragment
- * @param cMatrix accumulator of type fragment which is added to the product of `aMatrix` * `bMatrix`
- */
-@expPrimitive
-final case class TensorMatMultAdd(m: Nat,
- n: Nat,
- k: Nat,
- layoutA: MatrixLayout,
- layoutB: MatrixLayout,
- dataType: DataType,
- dataTypeAcc: DataType,
- aMatrix: Phrase[ExpType],
- bMatrix: Phrase[ExpType],
- cMatrix: Phrase[ExpType]) extends ExpPrimitive {
- aMatrix :: ExpType(FragmentType(m, k, n, dataType, FragmentKind.AMatrix, layoutA), read)
- bMatrix :: ExpType(FragmentType(k, n, m, dataType, FragmentKind.BMatrix, layoutB), read)
- cMatrix :: ExpType(FragmentType(m, n, k, dataTypeAcc), read)
-
- override val t: ExpType = ExpType(FragmentType(m, n, k, dataTypeAcc), write)
+import shine.DPIA._
+final case class TensorMatMultAdd(val m: Nat, val n: Nat, val k: Nat, val layoutA: MatrixLayout, val layoutB: MatrixLayout, val dt1: DataType, val dt2: DataType, val aMatrix: Phrase[ExpType], val bMatrix: Phrase[ExpType], val cMatrix: Phrase[ExpType]) extends ExpPrimitive {
+ {
+ aMatrix :: expT(FragmentType(m, k, n, dt1, FragmentKind.AMatrix, layoutA), read)
+ bMatrix :: expT(FragmentType(k, n, m, dt1, FragmentKind.BMatrix, layoutB), read)
+ cMatrix :: expT(FragmentType(m, n, k, dt2, FragmentKind.Accumulator, MatrixLayout.None), read)
+ }
+ override val t: ExpType = expT(FragmentType(m, n, k, dt2, FragmentKind.Accumulator, MatrixLayout.None), write)
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): TensorMatMultAdd = new TensorMatMultAdd(v.nat(m), v.nat(n), v.nat(k), layoutA, layoutB, v.data(dt1), v.data(dt2), VisitAndRebuild(aMatrix, v), VisitAndRebuild(bMatrix, v), VisitAndRebuild(cMatrix, v))
}
diff --git a/src/main/scala/shine/cuda/primitives/functional/primitives.dpia b/src/main/scala/shine/cuda/primitives/functional/primitives.dpia
new file mode 100644
index 000000000..c89d64a7c
--- /dev/null
+++ b/src/main/scala/shine/cuda/primitives/functional/primitives.dpia
@@ -0,0 +1,96 @@
+/**
+ * Returns a fragment from a matrix tile which must resides in shared or global memory ({@link WmmaLoad}).
+ * This primitive needs to be executed by a full warp!
+ * @param rows number of rows of the fragment ({@link FragmentType#rows})
+ * @param columns number of columns of the fragment ({@link FragmentType#columns})
+ * @param layers third dimension which is used in the MMA operation ({@link FragmentType#d3})
+ * @param dt dataType of the elements in the fragment ({@link FragmentType#datatype})
+ * @param frag kind of the fragment ({@link FragmentType#fragmentKind})
+ * @param layout layout of the fragment ({@link FragmentType#layout}).
+ * The layout will be infered in Codegeneration. Hence a `MatrixLayoutIdentifier` can be
+ * used as layout.
+ */
+def asFragment(rows: nat, columns: nat, layers: nat, dt: data,
+ frag: fragment, layout: matrixLayout,
+ input: exp[rows.columns.dt, read]): exp[fragment[rows, columns, layers, dt, frag, layout], write]
+
+/**
+ * Returns a matrix tile with the elements from the `Accumulator`-fragment. The matrix tile must resides
+ * in shared or global memory ({@link WmmaStore}).
+ * This primitive needs to be executed by a full warp!
+ * @param rows number of rows of the fragment ({@link FragmentType#rows})
+ * @param columns number of columns of the fragment ({@link FragmentType#columns})
+ * @param layers third dimension which is used in the MMA operation ({@link FragmentType#d3})
+ * @param dt dataType of the elements in the fragment ({@link FragmentType#datatype})
+ * @param input fragment from which the elements should be stored
+ */
+def asMatrix(rows: nat, columns: nat, layers: nat, dt: data,
+ input: exp[fragment[rows, columns, layers, dt, fragment.ACC, matrixLayout.NONE], read]
+ ): exp[rows.columns.dt, write]
+
+/**
+ * Returns a fragment in which all elements have a specific value ({@link WmmaFill}).
+ * This primitive needs to be executed by a full warp!
+ * @param rows number of rows of the fragment ({@link FragmentType#rows})
+ * @param columns number of columns of the fragment ({@link FragmentType#columns})
+ * @param layers third dimension which is used in the MMA operation ({@link FragmentType#d3})
+ * @param dt dataType of the elements in the fragment ({@link FragmentType#datatype})
+ * @param frag kind of the fragment ({@link FragmentType#fragmentKind})
+ * @param layout layout of the fragment ({@link FragmentType#layout})
+ * @param fill new value of all elements in the fragment (type of fill: `dataType`)
+ */
+def generateFragment(rows: nat, columns: nat, layers: nat, dt: data,
+ frag: fragment, layout: matrixLayout,
+ fill: exp[dt, read]
+ ): exp[fragment[rows, columns, layers, dt, frag, layout], write]
+
+/**
+ * Returns a copy in shared memory of data in global memory ({@link GlobalToSharedAcc}).
+ * @param dt datatype of data which should be copied
+ * @param input data in global memory which should be copied to shared memory
+ */
+def globalToShared(dt: data, input: exp[dt, write]): exp[dt, read]
+
+def map{level: shine.OpenCL.ParallelismLevel, dim: Int}
+ (n: nat, dt1: data, dt2: data,
+ f: exp[dt1, read] -> exp[dt2, write],
+ array: exp[n.dt1, read]): exp[n.dt2, write]
+
+/**
+ * Returns a fragment with the values of an applied function to elements of a fragment.
+ * This primitive needs to be executed by a full warp!
+ * @param fragType type of the fragment
+ * @param fragment fragment of type `fragType` on whose elements the function should be applied
+ * @param fun function which takes an element of type `fragType.dataType` and
+ * returns an element of type `fragType.dataType`
+ */
+def mapFragment(rows: nat, columns: nat, layers: nat, dt: data,
+ frag: fragment, layout: matrixLayout,
+ f: exp[dt, read] -> exp[dt, write],
+ input: exp[fragment[rows, columns, layers, dt, frag, layout], read]
+ ): exp[fragment[rows, columns, layers, dt, frag, layout], write]
+
+/**
+ * Executes an MMA instruction using (multiple) Tensor Cores ({@link WmmaMMA}).
+ * Returns a `Accumulator`-fragment as result of: aMatrix * bMatrix + cMatrix
+ * (inplace operations using the same variable as `Acceptor` in `acceptorTranslation` and
+ * as `cMatrix` are possible).
+ * This primitive needs to be executed by a full warp!
+ * @param m number of rows of the `aMatrix`
+ * @param n number of columns of the `bMatrix` and the `cMatrix`
+ * @param k number of columns of the `aMatrix` and number of rows of the `bMatrix`
+ * @param layoutA layout of the `aMatrix`
+ * @param layoutB layout of the `bMatrix`
+ * @param dataType datatype of elements of `aMatrix` and `bMatrix` ({@link FragmentType#datatype})
+ * @param dataTypeAcc datatype of elements of `cMatrix` and the resultMatrix ({@link FragmentType#datatype})
+ * @param aMatrix first factor of type fragment
+ * @param bMatrix second factor of type fragment
+ * @param cMatrix accumulator of type fragment which is added to the product of `aMatrix` * `bMatrix`
+ */
+def tensorMatMultAdd(m: nat, n: nat, k: nat,
+ layoutA: matrixLayout, layoutB: matrixLayout,
+ dt1: data, dt2: data,
+ aMatrix: exp[fragment[m, k, n, dt1, fragment.A, layoutA], read],
+ bMatrix: exp[fragment[k, n, m, dt1, fragment.B, layoutB], read],
+ cMatrix: exp[fragment[m, n, k, dt2, fragment.ACC, matrixLayout.NONE], read]
+ ): exp[fragment[m, n, k, dt2, fragment.ACC, matrixLayout.NONE], write]
diff --git a/src/main/scala/shine/cuda/primitives/imperative/ForFragment.scala b/src/main/scala/shine/cuda/primitives/imperative/ForFragment.scala
new file mode 100644
index 000000000..283476433
--- /dev/null
+++ b/src/main/scala/shine/cuda/primitives/imperative/ForFragment.scala
@@ -0,0 +1,18 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+package shine.cuda.primitives.imperative
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
+import shine.DPIA.Types._
+import shine.DPIA._
+final case class ForFragment(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val in: Phrase[ExpType], val out: Phrase[AccType], val fun: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive {
+ {
+ in :: expT(FragmentType(rows, columns, layers, dt, frag, layout), read)
+ out :: accT(FragmentType(rows, columns, layers, dt, frag, layout))
+ fun :: FunType(expT(dt, read), FunType(accT(dt), comm))
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForFragment = new ForFragment(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(in, v), VisitAndRebuild(out, v), VisitAndRebuild(fun, v))
+}
diff --git a/src/main/scala/shine/cuda/primitives/imperative/ForFragmentElements.scala b/src/main/scala/shine/cuda/primitives/imperative/ForFragmentElements.scala
deleted file mode 100644
index 5a6a1cfbf..000000000
--- a/src/main/scala/shine/cuda/primitives/imperative/ForFragmentElements.scala
+++ /dev/null
@@ -1,28 +0,0 @@
-package shine.cuda.primitives.imperative
-
-import shine.DPIA.->:
-import shine.DPIA.Phrases._
-import shine.DPIA.Types._
-import shine.macros.Primitive.comPrimitive
-
-/**
- * Applies a function on every element of a fragment.
- * This primitive needs to be executed by a full warp!
- * @param fragType type of the fragment
- * @param in fragment of type `fragType` whose elements should be iterated
- * @param out fragment-Acceptor of type `fragType` which is used to store the result
- * @param fun function which takes an element of type `fragType.dataType` from `in` and
- * an element-Acceptor of type `fragType.dataType` from `out` and returns a command
- */
-@comPrimitive
-final case class ForFragmentElements(fragType: FragmentType,
- in: Phrase[ExpType],
- out: Phrase[AccType],
- fun: Phrase[ExpType ->: AccType ->: CommType],
- ) extends CommandPrimitive {
- in :: ExpType(fragType, read)
- out :: AccType(fragType)
- fun :: FunType(ExpType(fragType.dataType, read), FunType(AccType(fragType.dataType), comm))
-
- override def prettyPrint: String = s"ForFragmentElements(${PrettyPhrasePrinter(in)}, ${PrettyPhrasePrinter(out)}, ${PrettyPhrasePrinter(fun)})"
-}
diff --git a/src/main/scala/shine/cuda/primitives/imperative/GlobalToSharedAcc.scala b/src/main/scala/shine/cuda/primitives/imperative/GlobalToSharedAcc.scala
index a0ba8cd68..dba030d80 100644
--- a/src/main/scala/shine/cuda/primitives/imperative/GlobalToSharedAcc.scala
+++ b/src/main/scala/shine/cuda/primitives/imperative/GlobalToSharedAcc.scala
@@ -1,30 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.imperative
-
-import shine.DPIA.DSL.{`new` => _}
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.macros.Primitive.accPrimitive
-
-/**
- * Copy a element from global memory to shared memory without using registers
- * (faster than normal copy operations using the =-operator).
- * This requires CUDA 11 and compute capability >= 8. For devices with ompute capability
- * smaller than 8 this will be compiled by the CUDA-Compiler to the same as normal copy
- * operations using the =-operator).
- * @param dt datatype of element which should be copied
- * @param pipe pipeline which should be used to execute this copy instruction
- * @param outputShared output-Acceptor in shared memory of type `dt`
- */
-@accPrimitive
-final case class GlobalToSharedAcc(dt: DataType,
- pipe: Phrase[ExpType],
- outputShared: Phrase[AccType]
- ) extends AccPrimitive {
- pipe :: expT(pipeline, read)
- outputShared :: accT(dt)
+final case class GlobalToSharedAcc(val dt: DataType, val pipe: Phrase[ExpType], val outputShared: Phrase[AccType]) extends AccPrimitive {
+ {
+ pipe :: expT(OpaqueType("pipeline"), read)
+ outputShared :: accT(dt)
+ }
override val t: AccType = accT(dt)
-
- override def prettyPrint: String =
- s"(GlobalToSharedAcc $pipe, ${PrettyPhrasePrinter(outputShared)})"
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): GlobalToSharedAcc = new GlobalToSharedAcc(v.data(dt), VisitAndRebuild(pipe, v), VisitAndRebuild(outputShared, v))
}
diff --git a/src/main/scala/shine/cuda/primitives/imperative/ParFor.scala b/src/main/scala/shine/cuda/primitives/imperative/ParFor.scala
index 8d05be355..8a471f68d 100644
--- a/src/main/scala/shine/cuda/primitives/imperative/ParFor.scala
+++ b/src/main/scala/shine/cuda/primitives/imperative/ParFor.scala
@@ -1,60 +1,18 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.imperative
-
-import shine.DPIA.Phrases.{CommandPrimitive, _}
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA._
-import shine.OpenCL._
-import shine.cuda.{blockDim, blockId, globalDim, globalId, gridDim, laneId, threadId, warpDim, warpId, warpSize}
-import shine.macros.Primitive.comPrimitive
-
-@comPrimitive
-final case class ParFor(level: ParallelismLevel,
- dim: Int,
- unroll: Boolean)
- (val n: Nat,
- val dt: DataType,
- val out: Phrase[AccType],
- val loopBody: Phrase[ExpType ->: AccType ->: CommType],
- val init: Nat = ParFor.initInit(level, dim),
- val step: Nat = ParFor.initStep(level, dim)
- ) extends CommandPrimitive {
- val name: String = ParFor.initName(level)
-
- out :: accT(n`.`dt)
- loopBody :: expT(idx(n), read) ->: accT(dt) ->: comm
-
- lazy val unwrapBody: (Identifier[ExpType], Identifier[AccType], Phrase[CommType]) = loopBody match {
- case Lambda(i, Lambda(o, body)) => (i, o, body)
- case _ => throw new Exception("This should not happen")
- }
-}
-
-object ParFor {
- def initInit(level: ParallelismLevel, dim: Int): Nat = level match {
- case Global => globalId(dim)
- case Local => threadId(dim)
- case WorkGroup => blockId(dim)
- case Warp => warpId(dim)
- case Lane => laneId(dim)
- case Sequential => throw new Exception("This should not happen")
- }
-
- def initStep(level: ParallelismLevel, dim: Int): Nat = level match {
- case Global => globalDim(dim)
- case Local => blockDim(dim)
- case WorkGroup => gridDim(dim)
- case Warp => warpDim(dim)
- case Lane => warpSize
- case Sequential => throw new Exception("This should not happen")
- }
-
- def initName(level: ParallelismLevel): String = level match {
- case Global => freshName("gl_id_")
- case Local => freshName("tid_")
- case WorkGroup => freshName("block_id_")
- case Warp => freshName("warp_id_")
- case Lane => freshName(s"lane_id_")
- case Sequential => throw new Exception("This should not happen")
+final case class ParFor(level: shine.OpenCL.ParallelismLevel, dim: Int, unroll: Boolean, prefix: String)(val init: Nat, val n: Nat, val step: Nat, val dt: DataType, val out: Phrase[AccType], val body: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive {
+ {
+ out :: accT(ArrayType(n, dt))
+ body :: FunType(expT(IndexType(n), read), FunType(accT(dt), comm))
}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParFor = new ParFor(level, dim, unroll, prefix)(v.nat(init), v.nat(n), v.nat(step), v.data(dt), VisitAndRebuild(out, v), VisitAndRebuild(body, v))
+ def unwrap: (Nat, Nat, Nat, DataType, Phrase[AccType], Phrase[FunType[ExpType, FunType[AccType, CommType]]]) = (init, n, step, dt, out, body)
}
diff --git a/src/main/scala/shine/cuda/primitives/imperative/SyncPipeline.scala b/src/main/scala/shine/cuda/primitives/imperative/SyncPipeline.scala
index d22b59bf2..c1f06a010 100644
--- a/src/main/scala/shine/cuda/primitives/imperative/SyncPipeline.scala
+++ b/src/main/scala/shine/cuda/primitives/imperative/SyncPipeline.scala
@@ -1,16 +1,16 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.imperative
-
-import shine.DPIA.Phrases.{CommandPrimitive, Phrase}
-import shine.DPIA.Types.{ExpType, pipeline, read}
-import shine.DPIA.expT
-import shine.macros.Primitive.comPrimitive
-
-/**
- * Execute and wait for all asynchronous memory transactions (used by the {@link GlobalToSharedAcc})
- */
-@comPrimitive
-final case class SyncPipeline(pipe: Phrase[ExpType]) extends CommandPrimitive {
- pipe :: expT(pipeline, read)
-
- override def prettyPrint: String = s"$pipe.commit_and_wait()"
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
+import shine.DPIA.Types._
+import shine.DPIA._
+final case class SyncPipeline(val pipe: Phrase[ExpType]) extends CommandPrimitive {
+ {
+ pipe :: expT(OpaqueType("pipeline"), read)
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): SyncPipeline = new SyncPipeline(VisitAndRebuild(pipe, v))
}
diff --git a/src/main/scala/shine/cuda/primitives/imperative/SyncThreads.scala b/src/main/scala/shine/cuda/primitives/imperative/SyncThreads.scala
index 523217d10..d7d7acafc 100644
--- a/src/main/scala/shine/cuda/primitives/imperative/SyncThreads.scala
+++ b/src/main/scala/shine/cuda/primitives/imperative/SyncThreads.scala
@@ -1,12 +1,14 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.imperative
-
-import shine.DPIA.Phrases.CommandPrimitive
-import shine.macros.Primitive.comPrimitive
-
-/**
- * Synchronize all thread in a single thread block.
- */
-@comPrimitive
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
+import shine.DPIA.Types._
+import shine.DPIA._
final case class SyncThreads() extends CommandPrimitive {
- override def prettyPrint: String = "__syncthreads()"
+ {}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): SyncThreads = new SyncThreads()
}
diff --git a/src/main/scala/shine/cuda/primitives/imperative/SyncWarp.scala b/src/main/scala/shine/cuda/primitives/imperative/SyncWarp.scala
index b058c1a8b..31718237a 100644
--- a/src/main/scala/shine/cuda/primitives/imperative/SyncWarp.scala
+++ b/src/main/scala/shine/cuda/primitives/imperative/SyncWarp.scala
@@ -1,12 +1,14 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.imperative
-
-import shine.DPIA.Phrases.CommandPrimitive
-import shine.macros.Primitive.comPrimitive
-
-/**
- * Synchronize all elements in a single warp.
- */
-@comPrimitive
+import arithexpr.arithmetic._
+import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
+import shine.DPIA.Types._
+import shine.DPIA._
final case class SyncWarp() extends CommandPrimitive {
- override def prettyPrint: String = "__syncwarp()"
+ {}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): SyncWarp = new SyncWarp()
}
diff --git a/src/main/scala/shine/cuda/primitives/imperative/WmmaFill.scala b/src/main/scala/shine/cuda/primitives/imperative/WmmaFill.scala
index cfbea2664..cf250936a 100644
--- a/src/main/scala/shine/cuda/primitives/imperative/WmmaFill.scala
+++ b/src/main/scala/shine/cuda/primitives/imperative/WmmaFill.scala
@@ -1,35 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.imperative
-
-import shine.DPIA.Nat
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.macros.Primitive.comPrimitive
-
-/**
- * Fills a fragment with a specific value.
- * This primitive needs to be executed by a full warp!
- * @param rows number of rows of the fragment ({@link FragmentType#rows})
- * @param columns number of columns of the fragment ({@link FragmentType#columns})
- * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3})
- * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype})
- * @param fill new value of all elements in the fragment (type of fill: `dataType`)
- * @param fragmentKind kind of the fragment ({@link FragmentType#fragmentKind})
- * @param layout layout of the fragment ({@link FragmentType#layout})
- * @param fragment fragment-Acceptor whose elements should be changed to `fill`
- */
-@comPrimitive
-case class WmmaFill(rows: Nat,
- columns: Nat,
- d3: Nat,
- dataType: DataType,
- fill: Phrase[ExpType],
- fragmentKind: FragmentKind,
- layout: MatrixLayout,
- fragment: Phrase[AccType]
- ) extends CommandPrimitive {
- fill :: ExpType(dataType, read)
- fragment :: AccType(FragmentType(rows, columns, d3, dataType, fragmentKind, layout))
-
- override def prettyPrint: String =
- s"WmmaFill(${PrettyPhrasePrinter(fill)}, ${PrettyPhrasePrinter(fragment)})"
+import shine.DPIA._
+final case class WmmaFill(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val fill: Phrase[ExpType], val target: Phrase[AccType]) extends CommandPrimitive {
+ {
+ fill :: expT(dt, read)
+ target :: accT(FragmentType(rows, columns, layers, dt, frag, layout))
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): WmmaFill = new WmmaFill(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(fill, v), VisitAndRebuild(target, v))
}
diff --git a/src/main/scala/shine/cuda/primitives/imperative/WmmaLoad.scala b/src/main/scala/shine/cuda/primitives/imperative/WmmaLoad.scala
index 00cb5479a..61b33bcba 100644
--- a/src/main/scala/shine/cuda/primitives/imperative/WmmaLoad.scala
+++ b/src/main/scala/shine/cuda/primitives/imperative/WmmaLoad.scala
@@ -1,36 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.imperative
-
-import shine.DPIA.Nat
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.macros.Primitive.comPrimitive
-
-/**
- * Loads a tile of a matrix into a fragment.
- * This primitive needs to be executed by a full warp!
- * @param rows number of rows of the fragment ({@link FragmentType#rows})
- * @param columns number of columns of the fragment ({@link FragmentType#columns})
- * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3})
- * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype})
- * @param fragmentKind kind of the fragment ({@link FragmentType#fragmentKind})
- * @param layout layout of the fragment ({@link FragmentType#layout})
- * @param matrixTile matrix tile which should be loaded into the fragment
- * @param fragment fragment-Acceptor into which the `matrixTile` should be loaded
- */
-@comPrimitive
-final case class WmmaLoad(rows: Nat,
- columns: Nat,
- d3: Nat,
- dataType: DataType,
- fragmentKind: FragmentKind,
- layout: MatrixLayout,
- matrixTile: Phrase[ExpType],
- fragment: Phrase[AccType]
- ) extends CommandPrimitive {
- fragment :: ExpType(FragmentType(rows, columns, d3, dataType, fragmentKind, layout), write)
- matrixTile :: ExpType(ArrayType(rows, ArrayType(columns, dataType)), read)
-
- override def prettyPrint: String = {
- s"wmmaLoad(${PrettyPhrasePrinter(matrixTile)}, ${PrettyPhrasePrinter(fragment)})"
+import shine.DPIA._
+final case class WmmaLoad(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val matrixTile: Phrase[ExpType], val target: Phrase[AccType]) extends CommandPrimitive {
+ {
+ matrixTile :: expT(ArrayType(rows, ArrayType(columns, dt)), read)
+ target :: accT(FragmentType(rows, columns, layers, dt, frag, layout))
}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): WmmaLoad = new WmmaLoad(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(matrixTile, v), VisitAndRebuild(target, v))
}
diff --git a/src/main/scala/shine/cuda/primitives/imperative/WmmaMMA.scala b/src/main/scala/shine/cuda/primitives/imperative/WmmaMMA.scala
index 4f4bb8c9d..793a562f5 100644
--- a/src/main/scala/shine/cuda/primitives/imperative/WmmaMMA.scala
+++ b/src/main/scala/shine/cuda/primitives/imperative/WmmaMMA.scala
@@ -1,46 +1,19 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.imperative
-
-import shine.DPIA.Nat
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.macros.Primitive.comPrimitive
-
-/**
- * Executes an MMA instruction using (multiple) Tensor Cores.
- * Calculates: aMatrix * bMatrix + cMatrix
- * This primitive needs to be executed by a full warp!
- * @param m number of rows of the `aMatrix`
- * @param n number of columns of the `bMatrix` and the `cMatrix`
- * @param k number of columns of the `aMatrix` and number of rows of the `bMatrix`
- * @param layoutA layout of the `aMatrix`
- * @param layoutB layout of the `bMatrix`
- * @param dataType datatype of elements of `aMatrix` and `bMatrix` ({@link FragmentType#datatype})
- * @param dataTypeAcc datatype of elements of `cMatrix` and the resultMatrix ({@link FragmentType#datatype})
- * @param aMatrix first factor of type fragment
- * @param bMatrix second factor of type fragment
- * @param cMatrix accumulator of type fragment which is added to the product of `aMatrix` * `bMatrix`
- * @param resultMatrix fragment-Accumulator in which the result is stored
- * (inplace operations using the `cMatrix` as resultMatrix is possible)
- */
-@comPrimitive
-case class WmmaMMA(m: Nat,
- n: Nat,
- k: Nat,
- layoutA : MatrixLayout,
- layoutB : MatrixLayout,
- dataType: DataType,
- dataTypeAcc: DataType,
- aMatrix: Phrase[ExpType],
- bMatrix: Phrase[ExpType],
- cMatrix: Phrase[ExpType],
- resultMatrix: Phrase[AccType]
- ) extends CommandPrimitive {
- aMatrix :: ExpType(FragmentType(m, k, n, dataType, FragmentKind.AMatrix, layoutA), read)
- bMatrix :: ExpType(FragmentType(k, n, m, dataType, FragmentKind.BMatrix, layoutB), read)
- cMatrix :: ExpType(FragmentType(m, n, k, dataTypeAcc), read)
- resultMatrix :: AccType(FragmentType(m, n, k, dataTypeAcc))
-
- override def prettyPrint: String =
- s"WmmaMMA(${PrettyPhrasePrinter(aMatrix)}, ${PrettyPhrasePrinter(bMatrix)}," +
- s"${PrettyPhrasePrinter(cMatrix)}, ${PrettyPhrasePrinter(resultMatrix)})"
+import shine.DPIA._
+final case class WmmaMMA(val m: Nat, val n: Nat, val k: Nat, val layoutA: MatrixLayout, val layoutB: MatrixLayout, val dt1: DataType, val dt2: DataType, val aMatrix: Phrase[ExpType], val bMatrix: Phrase[ExpType], val cMatrix: Phrase[ExpType], val resultMatrix: Phrase[AccType]) extends CommandPrimitive {
+ {
+ aMatrix :: expT(FragmentType(m, k, n, dt1, FragmentKind.AMatrix, layoutA), read)
+ bMatrix :: expT(FragmentType(k, n, m, dt1, FragmentKind.BMatrix, layoutB), read)
+ cMatrix :: expT(FragmentType(m, n, k, dt2, FragmentKind.Accumulator, MatrixLayout.None), read)
+ resultMatrix :: accT(FragmentType(m, n, k, dt2, FragmentKind.Accumulator, MatrixLayout.None))
+ }
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): WmmaMMA = new WmmaMMA(v.nat(m), v.nat(n), v.nat(k), layoutA, layoutB, v.data(dt1), v.data(dt2), VisitAndRebuild(aMatrix, v), VisitAndRebuild(bMatrix, v), VisitAndRebuild(cMatrix, v), VisitAndRebuild(resultMatrix, v))
}
diff --git a/src/main/scala/shine/cuda/primitives/imperative/WmmaStore.scala b/src/main/scala/shine/cuda/primitives/imperative/WmmaStore.scala
index a463fc53b..fb5e7ffd6 100644
--- a/src/main/scala/shine/cuda/primitives/imperative/WmmaStore.scala
+++ b/src/main/scala/shine/cuda/primitives/imperative/WmmaStore.scala
@@ -1,34 +1,17 @@
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
+// This file is automatically generated and should not be changed manually //
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //
package shine.cuda.primitives.imperative
-
-import shine.DPIA.Nat
+import arithexpr.arithmetic._
import shine.DPIA.Phrases._
+import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
-import shine.macros.Primitive.comPrimitive
-
-/**
- * Stores the elements from a fragment with fragmentKind `Accumulator` into a
- * matrix tile which resides in shared or global memory.
- * This primitive needs to be executed by a full warp!
- * @param rows number of rows of the fragment ({@link FragmentType#rows})
- * @param columns number of columns of the fragment ({@link FragmentType#columns})
- * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3})
- * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype})
- * @param fragment fragment from which the elements should be stored
- * @param matrixTile matrixTile-Acceptor in which the elements should be stored
- */
-@comPrimitive
-final case class WmmaStore(rows: Nat,
- columns: Nat,
- d3: Nat,
- dataType: DataType,
- fragment: Phrase[ExpType],
- matrixTile: Phrase[AccType]
- ) extends CommandPrimitive {
- fragment :: ExpType(FragmentType(rows, columns, d3, dataType), read)
- matrixTile :: AccType(ArrayType(rows, ArrayType(columns, dataType)))
-
- override def prettyPrint: String = {
- s"wmmaStore(${PrettyPhrasePrinter(fragment)}, ${PrettyPhrasePrinter(matrixTile)})"
+import shine.DPIA._
+final case class WmmaStore(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val value: Phrase[ExpType], val matrixTile: Phrase[AccType]) extends CommandPrimitive {
+ {
+ value :: expT(FragmentType(rows, columns, layers, dt, FragmentKind.Accumulator, MatrixLayout.None), read)
+ matrixTile :: accT(ArrayType(rows, ArrayType(columns, dt)))
}
+ override val t: CommType = comm
+ override def visitAndRebuild(v: VisitAndRebuild.Visitor): WmmaStore = new WmmaStore(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), VisitAndRebuild(value, v), VisitAndRebuild(matrixTile, v))
}
-
diff --git a/src/main/scala/shine/cuda/primitives/imperative/primitives.dpia b/src/main/scala/shine/cuda/primitives/imperative/primitives.dpia
new file mode 100644
index 000000000..4922f06ef
--- /dev/null
+++ b/src/main/scala/shine/cuda/primitives/imperative/primitives.dpia
@@ -0,0 +1,123 @@
+/**
+ * Applies a function on every element of a fragment.
+ * This primitive needs to be executed by a full warp!
+ * @param fragType type of the fragment
+ * @param in fragment of type `fragType` whose elements should be iterated
+ * @param out fragment-Acceptor of type `fragType` which is used to store the result
+ * @param fun function which takes an element of type `fragType.dataType` from `in` and
+ * an element-Acceptor of type `fragType.dataType` from `out` and returns a command
+ */
+def forFragment(rows: nat, columns: nat, layers: nat, dt: data,
+ frag: fragment, layout: matrixLayout,
+ in: exp[fragment[rows, columns, layers, dt, frag, layout], read],
+ out: acc[fragment[rows, columns, layers, dt, frag, layout]],
+ fun: exp[dt, read] -> acc[dt] -> comm): comm
+
+/**
+ * Copy a element from global memory to shared memory without using registers
+ * (faster than normal copy operations using the =-operator).
+ * This requires CUDA 11 and compute capability >= 8. For devices with ompute capability
+ * smaller than 8 this will be compiled by the CUDA-Compiler to the same as normal copy
+ * operations using the =-operator).
+ * @param dt datatype of element which should be copied
+ * @param pipe pipeline which should be used to execute this copy instruction
+ * @param outputShared output-Acceptor in shared memory of type `dt`
+ */
+def globalToSharedAcc(dt: data, pipe: exp["pipeline", read], outputShared: acc[dt]): acc[dt]
+
+def parFor{level: shine.OpenCL.ParallelismLevel,
+ dim: Int,
+ unroll: Boolean,
+ prefix: String}
+ (init: nat, n: nat, step: nat,
+ dt: data,
+ out: acc[n.dt],
+ body: exp[idx[n], read] -> acc[dt] -> comm): comm
+
+/**
+ * Execute and wait for all asynchronous memory transactions (used by the {@link GlobalToSharedAcc})
+ */
+def syncPipeline(pipe: exp["pipeline", read]): comm
+
+/**
+ * Synchronize all thread in a single thread block.
+ */
+def syncThreads(): comm
+
+/**
+ * Synchronize all elements in a single warp.
+ */
+def syncWarp(): comm
+
+/**
+ * Fills a fragment with a specific value.
+ * This primitive needs to be executed by a full warp!
+ * @param rows number of rows of the fragment ({@link FragmentType#rows})
+ * @param columns number of columns of the fragment ({@link FragmentType#columns})
+ * @param layers third dimension which is used in the MMA operation ({@link FragmentType#d3})
+ * @param dt dataType of the elements in the fragment ({@link FragmentType#datatype})
+ * @param fill new value of all elements in the fragment (type of fill: `dataType`)
+ * @param frag kind of the fragment ({@link FragmentType#fragmentKind})
+ * @param layout layout of the fragment ({@link FragmentType#layout})
+ * @param target fragment-Acceptor whose elements should be changed to `fill`
+ */
+def wmmaFill(rows: nat, columns: nat, layers: nat, dt: data, frag: fragment, layout: matrixLayout,
+ fill: exp[dt, read],
+ target: acc[fragment[rows, columns, layers, dt, frag, layout]]): comm
+
+/**
+ * Loads a tile of a matrix into a fragment.
+ * This primitive needs to be executed by a full warp!
+ * @param rows number of rows of the fragment ({@link FragmentType#rows})
+ * @param columns number of columns of the fragment ({@link FragmentType#columns})
+ * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3})
+ * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype})
+ * @param fragmentKind kind of the fragment ({@link FragmentType#fragmentKind})
+ * @param layout layout of the fragment ({@link FragmentType#layout})
+ * @param matrixTile matrix tile which should be loaded into the fragment
+ * @param fragment fragment-Acceptor into which the `matrixTile` should be loaded
+ */
+def wmmaLoad(rows: nat, columns: nat, layers: nat, dt: data, frag: fragment, layout: matrixLayout,
+ matrixTile: exp[rows.columns.dt, read],
+ target: acc[fragment[rows, columns, layers, dt, frag, layout]]): comm
+
+/**
+ * Executes an MMA instruction using (multiple) Tensor Cores.
+ * Calculates: aMatrix * bMatrix + cMatrix
+ * This primitive needs to be executed by a full warp!
+ * @param m number of rows of the `aMatrix`
+ * @param n number of columns of the `bMatrix` and the `cMatrix`
+ * @param k number of columns of the `aMatrix` and number of rows of the `bMatrix`
+ * @param layoutA layout of the `aMatrix`
+ * @param layoutB layout of the `bMatrix`
+ * @param dt1 datatype of elements of `aMatrix` and `bMatrix` ({@link FragmentType#datatype})
+ * @param dt2 datatype of elements of `cMatrix` and the resultMatrix ({@link FragmentType#datatype})
+ * @param aMatrix first factor of type fragment
+ * @param bMatrix second factor of type fragment
+ * @param cMatrix accumulator of type fragment which is added to the product of `aMatrix` * `bMatrix`
+ * @param resultMatrix fragment-Accumulator in which the result is stored
+ * (inplace operations using the `cMatrix` as resultMatrix is possible)
+ */
+def wmmaMMA(m: nat, n: nat, k: nat,
+ layoutA: matrixLayout,
+ layoutB: matrixLayout,
+ dt1: data, dt2: data,
+ aMatrix: exp[fragment[m, k, n, dt1, fragment.A, layoutA], read],
+ bMatrix: exp[fragment[k, n, m, dt1, fragment.B, layoutB], read],
+ cMatrix: exp[fragment[m, n, k, dt2, fragment.ACC, matrixLayout.NONE], read],
+ resultMatrix: acc[fragment[m, n, k, dt2, fragment.ACC, matrixLayout.NONE]]): comm
+
+/**
+ * Stores the elements from a fragment with fragmentKind `Accumulator` into a
+ * matrix tile which resides in shared or global memory.
+ * This primitive needs to be executed by a full warp!
+ * @param rows number of rows of the fragment ({@link FragmentType#rows})
+ * @param columns number of columns of the fragment ({@link FragmentType#columns})
+ * @param layers third dimension which is used in the MMA operation ({@link FragmentType#d3})
+ * @param dt dataType of the elements in the fragment ({@link FragmentType#datatype})
+ * @param value fragment from which the elements should be stored
+ * @param matrixTile matrixTile-Acceptor in which the elements should be stored
+ */
+def wmmaStore(rows: nat, columns: nat, layers: nat, dt: data,
+ value: exp[fragment[rows, columns, layers, dt, fragment.ACC, matrixLayout.NONE], read],
+ matrixTile: acc[rows.columns.dt]): comm
diff --git a/src/main/scala/shine/cuda/primitives/intermediate/MapI.scala b/src/main/scala/shine/cuda/primitives/intermediate/MapI.scala
index c2a1f9401..32cdbea62 100644
--- a/src/main/scala/shine/cuda/primitives/intermediate/MapI.scala
+++ b/src/main/scala/shine/cuda/primitives/intermediate/MapI.scala
@@ -14,7 +14,7 @@ final case class MapI(level: ParallelismLevel, dim: Int) {
in: Phrase[ExpType],
out: Phrase[AccType]): Phrase[CommType] = {
val imperativ = comment(s"map${level.toString}") `;`
- ParFor(level, dim, unroll = false)(n, dt2, out,
+ shine.cuda.DSL.parFor(level, dim, unroll = false)(n, dt2, out,
λ(expT(idx(n), read))(i => λ(accT(dt2))(a => f(in `@` i)(a))))
//TODO use other InsertMemoryBarrieres-mechanism
level match {
diff --git a/src/test/scala/shine/cuda/MMTest.scala b/src/test/scala/shine/cuda/MMTest.scala
index 40ef9e0bc..78bb48f74 100644
--- a/src/test/scala/shine/cuda/MMTest.scala
+++ b/src/test/scala/shine/cuda/MMTest.scala
@@ -37,22 +37,20 @@ class MMTest extends test_util.TestWithCUDA {
AsMatrix(mTile, nTile, kTile, f32,
//do matrix multiplication
- ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, nTile, kTile, f16),
- TensorMatMultAdd(mTile, nTile, kTile, Row_Major, Row_Major, f16, f32,
+ ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None),
+ TensorMatMultAdd(mTile, nTile, kTile, MatrixLayoutIdentifier("ml"), MatrixLayoutIdentifier("ml"), f16, f32,
//load aMatrix into a fragment
- ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, kTile, nTile, f16, FragmentKind.AMatrix, Row_Major),
- AsFragment(mTile, kTile, nTile, f16, FragmentKind.AMatrix,
- matrixATile)),
+ ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, kTile, nTile, f16, FragmentKind.AMatrix, MatrixLayoutIdentifier("ml")),
+ shine.cuda.AsFragment(mTile, kTile, nTile, f16, FragmentKind.AMatrix, matrixATile)),
//load bMatrix into a fragment
- ToMem(shine.cuda.AddressSpace.Private, FragmentType(kTile, nTile, mTile, f16, FragmentKind.BMatrix, Row_Major),
- AsFragment(kTile, nTile, mTile, f16, FragmentKind.BMatrix,
- matrixBTile)),
+ ToMem(shine.cuda.AddressSpace.Private, FragmentType(kTile, nTile, mTile, f16, FragmentKind.BMatrix, MatrixLayoutIdentifier("ml")),
+ shine.cuda.AsFragment(kTile, nTile, mTile, f16, FragmentKind.BMatrix, matrixBTile)),
//add fragment with zeros
- ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, nTile, kTile, f16),
- GenerateFragment(mTile, nTile, kTile, f32, Literal(FloatData(0.0f)), FragmentKind.Accumulator, Row_Major))))))
+ ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None),
+ GenerateFragment(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None, Literal(FloatData(0.0f))))))))
)
val kernel = gen.cuda.kernel("matrixMult").fromPhrase(simpleMatMulTile)
@@ -93,7 +91,7 @@ class MMTest extends test_util.TestWithCUDA {
val matrixATile = Identifier(freshName("MatrixATile"), ExpType(ArrayType(mTile, ArrayType(k, f16)), read))
val matrixBTile = Identifier(freshName("MatrixBTile"), ExpType(ArrayType(k, ArrayType(nTile, f16)), read))
- val matrixCFrag = Identifier(freshName("MatrixCFrag"), ExpType(FragmentType(mTile, nTile, kTile, f32), read))
+ val matrixCFrag = Identifier(freshName("MatrixCFrag"), ExpType(FragmentType(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None), read))
val matrixABTiles = Identifier(freshName("MatrixABTiles"), ExpType(PairType(
ArrayType(mTile, ArrayType(kTile, f16)),
ArrayType(kTile, ArrayType(nTile, f16))), read))
@@ -108,21 +106,21 @@ class MMTest extends test_util.TestWithCUDA {
PairType(
ArrayType(mTile, ArrayType(kTile, f16)),
ArrayType(kTile, ArrayType(nTile, f16))),
- FragmentType(mTile, nTile, kTile, f32),
+ FragmentType(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None),
Lambda[ExpType, FunType[ExpType, ExpType]](matrixCFrag,
Lambda[ExpType, ExpType](matrixABTiles,
- TensorMatMultAdd(mTile, nTile, kTile, Row_Major, Row_Major, f16, f32,
- ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, kTile, nTile, f16, FragmentKind.AMatrix, Row_Major),
- AsFragment(mTile, kTile, nTile, f16, FragmentKind.AMatrix,
+ TensorMatMultAdd(mTile, nTile, kTile, MatrixLayoutIdentifier("ml"), MatrixLayoutIdentifier("ml"), f16, f32,
+ ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, kTile, nTile, f16, FragmentKind.AMatrix, MatrixLayoutIdentifier("ml")),
+ shine.cuda.AsFragment(mTile, kTile, nTile, f16, FragmentKind.AMatrix,
Transpose(kTile, mTile, f16, read,
Fst(
ArrayType(kTile, ArrayType(mTile, f16)),
ArrayType(kTile, ArrayType(nTile, f16)),
matrixABTiles)))),
- ToMem(shine.cuda.AddressSpace.Private, FragmentType(kTile, nTile, mTile, f16, FragmentKind.BMatrix, Row_Major),
- AsFragment(kTile, nTile, mTile, f16, FragmentKind.BMatrix,
+ ToMem(shine.cuda.AddressSpace.Private, FragmentType(kTile, nTile, mTile, f16, FragmentKind.BMatrix, MatrixLayoutIdentifier("ml")),
+ shine.cuda.AsFragment(kTile, nTile, mTile, f16, FragmentKind.BMatrix,
Snd(
ArrayType(mTile, ArrayType(kTile, f16)),
ArrayType(kTile, ArrayType(nTile, f16)),
@@ -130,7 +128,7 @@ class MMTest extends test_util.TestWithCUDA {
matrixCFrag))),
- GenerateFragment(mTile, nTile, kTile, f32, Literal(FloatData(0.0f)), FragmentKind.Accumulator, Row_Major),
+ GenerateFragment(mTile, nTile, kTile, f32, FragmentKind.Accumulator, Row_Major, Literal(FloatData(0.0f))),
Zip(k /^ kTile,
ArrayType(mTile, ArrayType(kTile, f16)),
@@ -190,7 +188,7 @@ class MMTest extends test_util.TestWithCUDA {
val matrixARow = Identifier(freshName("MatrixARow"), ExpType(ArrayType(mTile, ArrayType(k, f16)), read))
val matrixBColumnT = Identifier(freshName("MatrixBColumn"), ExpType(ArrayType(nTile, ArrayType(k, f16)), read))
- val matrixCFrag = Identifier(freshName("MatrixCFrag"), ExpType(FragmentType(mTile, nTile, kTile, f32), read))
+ val matrixCFrag = Identifier(freshName("MatrixCFrag"), ExpType(FragmentType(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None), read))
val matrixABTiles = Identifier(freshName("MatrixABTiles"), ExpType(PairType(
ArrayType(kTile, ArrayType(mTile, f16)),
@@ -205,8 +203,8 @@ class MMTest extends test_util.TestWithCUDA {
Lambda[ExpType, FunType[ExpType, ExpType]](matrixA,
//And matrixB
Lambda[ExpType, ExpType](matrixB,
- Transpose(n, m, f32, read,
- Join(n /^ nTile, nTile, read, ArrayType(m, f32),
+ Transpose(n, m, f32, write,
+ Join(n /^ nTile, nTile, write, ArrayType(m, f32),
//Map over nTile-column-block of matrixB
Map(Local, 1)(n /^ nTile,
//A transposed column of matrixB
@@ -237,17 +235,17 @@ class MMTest extends test_util.TestWithCUDA {
ArrayType(kTile, ArrayType(nTile, f16))),
//Result: tile of cMatrix as fragment
- FragmentType(mTile, nTile, kTile, f32),
+ FragmentType(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None),
//Multiply matrixATile and matrixBTile
Lambda[ExpType, FunType[ExpType, ExpType]](matrixCFrag,
Lambda[ExpType, ExpType](matrixABTiles,
//matrix multiply and accumulate
- TensorMatMultAdd(mTile, nTile, kTile, Row_Major, Row_Major, f16, f32,
+ TensorMatMultAdd(mTile, nTile, kTile, MatrixLayoutIdentifier("ml"), MatrixLayoutIdentifier("ml"), f16, f32,
//matrixATile as fragment
- ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, kTile, nTile, f16, FragmentKind.AMatrix, Row_Major),
- AsFragment(mTile, kTile, nTile, f16, FragmentKind.AMatrix,
+ ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, kTile, nTile, f16, FragmentKind.AMatrix, MatrixLayoutIdentifier("ml")),
+ shine.cuda.AsFragment(mTile, kTile, nTile, f16, FragmentKind.AMatrix,
Transpose(kTile, mTile, f16, read,
Fst(
ArrayType(mTile, ArrayType(kTile, f16)),
@@ -255,8 +253,8 @@ class MMTest extends test_util.TestWithCUDA {
matrixABTiles)))),
//matrixBTile as fragment
- ToMem(shine.cuda.AddressSpace.Private, FragmentType(kTile, nTile, mTile, f16, FragmentKind.BMatrix, Row_Major),
- AsFragment(kTile, nTile, mTile, f16, FragmentKind.BMatrix,
+ ToMem(shine.cuda.AddressSpace.Private, FragmentType(kTile, nTile, mTile, f16, FragmentKind.BMatrix, MatrixLayoutIdentifier("ml")),
+ shine.cuda.AsFragment(kTile, nTile, mTile, f16, FragmentKind.BMatrix,
Snd(
ArrayType(mTile, ArrayType(kTile, f16)),
ArrayType(kTile, ArrayType(nTile, f16)),
@@ -265,7 +263,7 @@ class MMTest extends test_util.TestWithCUDA {
matrixCFrag))),
//Neutral Element for Reduce: fragment initialized with zeros
- GenerateFragment(mTile, nTile, kTile, f32, Literal(FloatData(0.0f)), FragmentKind.Accumulator, Row_Major),
+ GenerateFragment(mTile, nTile, kTile, f32, FragmentKind.Accumulator, Row_Major, Literal(FloatData(0.0f))),
//Zip transposed, splited row of matrixA and splited column of matrixB
Zip(k /^ kTile,