diff --git a/build.sbt b/build.sbt index ffaf38084..4c52b490c 100644 --- a/build.sbt +++ b/build.sbt @@ -17,7 +17,7 @@ lazy val commonSettings = Seq( lazy val riseAndShine = (project in file(".")) .aggregate(executor, CUexecutor) - .dependsOn(meta, riseAndShineMacros, arithExpr, executor, CUexecutor, elevate) + .dependsOn(meta, arithExpr, executor, CUexecutor, elevate) .settings( name := "riseAndShine", version := "1.0", @@ -50,7 +50,16 @@ lazy val riseAndShine = (project in file(".")) lazy val generateRISEPrimitives = taskKey[Unit]("Generate RISE Primitives") generateRISEPrimitives := { - runner.value.run("meta.RisePrimitiveGenerator", + runner.value.run("meta.generator.RisePrimitives", + (dependencyClasspath in Compile).value.files, + Seq((scalaSource in Compile).value.getAbsolutePath), + streams.value.log).failed foreach (sys error _.getMessage) +} + +lazy val generateDPIAPrimitives = taskKey[Unit]("Generate DPIA Primitives") + +generateDPIAPrimitives := { + runner.value.run("meta.generator.DPIAPrimitives", (dependencyClasspath in Compile).value.files, Seq((scalaSource in Compile).value.getAbsolutePath), streams.value.log).failed foreach (sys error _.getMessage) @@ -67,14 +76,6 @@ lazy val meta = (project in file("meta")) libraryDependencies += "org.scalameta" %% "scalameta" % "4.4.10", ) -lazy val riseAndShineMacros = (project in file("macros")) - .settings( - name := "riseAndShineMacros", - version := "1.0", - commonSettings, - libraryDependencies += "org.scala-lang" % "scala-reflect" % scalaVersion.value - ) - lazy val arithExpr = project in file("lib/arithexpr") lazy val executor = project in file("lib/executor") diff --git a/macros/src/main/scala/shine/macros/Primitive.scala b/macros/src/main/scala/shine/macros/Primitive.scala deleted file mode 100644 index 90d985a5d..000000000 --- a/macros/src/main/scala/shine/macros/Primitive.scala +++ /dev/null @@ -1,163 +0,0 @@ -package shine.macros - -import scala.annotation.{StaticAnnotation, compileTimeOnly} -import scala.reflect.macros.blackbox -import scala.language.experimental.macros - -object Primitive { - @compileTimeOnly("ExpPrimitive macro") - class expPrimitive extends StaticAnnotation { - def macroTransform(annottees: Any*): Any = macro Impl.expPrimitive - } - - @compileTimeOnly("AccPrimitive macro") - class accPrimitive extends StaticAnnotation { - def macroTransform(annottees: Any*): Any = macro Impl.accPrimitive - } - - @compileTimeOnly("CommandPrimitive macro") - class comPrimitive extends StaticAnnotation { - def macroTransform(annottees: Any*): Any = macro Impl.comPrimitive - } - - class Impl(val c: blackbox.Context) { - import c.universe._ - - def expPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees) - def accPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees) - def comPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees) - - def primitive(transform : ClassDef => ClassDef)(annottees: Seq[c.Expr[Any]]): c.Expr[Any] = { - annottees.map(_.tree) match { - case (cdef: ClassDef) :: Nil => - c.Expr(transform(cdef)) - case (cdef: ClassDef) :: (md: ModuleDef) :: Nil => - c.Expr(q"{${transform(cdef)}; $md}") - case _ => c.abort(c.enclosingPosition, "expected a class definition") - } - } - - def makeLowerCaseName(s: String): String = - s"${Character.toLowerCase(s.charAt(0))}${s.substring(1)}" - - def makeVisitAndRebuild(name: TypeName, - additionalParams: List[ValDef], - params: List[ValDef]): Tree = { - val v = q"v" - q""" - override def visitAndRebuild( - $v: shine.DPIA.Phrases.VisitAndRebuild.Visitor): $name - = new ${Apply( additionalParams match { - case List() => Ident(name) - case _ => Apply(Ident(name), additionalParams.map { - case ValDef(_, name, _, _) => q"$name" - }) - }, params.map { - case ValDef(_, name, tpt, _) => tpt match { - case Ident(TypeName("DataType")) | Ident(TypeName("ScalarType")) | - Ident(TypeName("BasicType")) => q"$v.data($name)" - case Ident(TypeName("Nat")) => q"$v.nat($name)" - case Ident(TypeName("NatIdentifier")) => q"$v.nat($name)" - case Ident(TypeName("NatToNat")) => q"$v.natToNat($name)" - case Ident(TypeName("NatToData")) => q"$v.natToData($name)" - case Ident(TypeName("AccessType")) => q"$v.access($name)" - case Ident(TypeName("AddressSpace")) => q"$v.addressSpace($name)" - case Ident(TypeName("LocalSize")) => q"$name.visitAndRebuild($v)" - case Ident(TypeName("GlobalSize")) => q"$name.visitAndRebuild($v)" - // Phrase[ExpType] - case AppliedTypeTree((Ident(TypeName("Phrase")), _)) => - q"shine.DPIA.Phrases.VisitAndRebuild($name, $v)" - // Vector[Phrase[ExpType]] - case AppliedTypeTree((Ident(TypeName("Vector")), - List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) - | AppliedTypeTree((Ident(TypeName("Seq")), - List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) - => - q"$name.map(shine.DPIA.Phrases.VisitAndRebuild(_, $v))" - case _ => - q"$name" - } - })} - """ - } - - case class ClassInfo(name: TypeName, - additionalParams: List[ValDef], - params: List[ValDef], - body: List[Tree], - parents: List[Tree]) - - def primitivesFromClassDef: ClassDef => ClassInfo = { - case q"case class $name(..$params) extends { ..$_ } with ..$parents {..$body} " => - ClassInfo( - name.asInstanceOf[c.TypeName], - List(), - params.asInstanceOf[List[ValDef]], - body.asInstanceOf[List[Tree]], - parents.asInstanceOf[List[Tree]]) - case q"final case class $name(..$params) extends { ..$_ } with ..$parents {..$body} " => - ClassInfo( - name.asInstanceOf[c.TypeName], - List(), - params.asInstanceOf[List[ValDef]], - body.asInstanceOf[List[Tree]], - parents.asInstanceOf[List[Tree]]) - case q"""case class $name(..$additionalParams) - (..$params) extends { ..$_ } with ..$parents {..$body} """ => - ClassInfo( - name.asInstanceOf[c.TypeName], - additionalParams.asInstanceOf[List[ValDef]], - params.asInstanceOf[List[ValDef]], - body.asInstanceOf[List[Tree]], - parents.asInstanceOf[List[Tree]]) - case q"""final case class $name(..$additionalParams) - (..$params) extends { ..$_ } with ..$parents {..$body} """ => - ClassInfo( - name.asInstanceOf[c.TypeName], - additionalParams.asInstanceOf[List[ValDef]], - params.asInstanceOf[List[ValDef]], - body.asInstanceOf[List[Tree]], - parents.asInstanceOf[List[Tree]]) - case _ => - c.abort(c.enclosingPosition, "expected a case class extends Primitive") - } - - def makePrimitiveClass : ClassInfo => ClassDef = { case ClassInfo(name, additionalParams, params, body, parents) => - val visitAndRebuildMissing = - body.collectFirst({ case DefDef(_, TermName("visitAndRebuild"), _, _, _, _) => ()}).isEmpty - - val generated = q""" - ${if (visitAndRebuildMissing) - makeVisitAndRebuild(name, additionalParams, params) - else q""} - """ - - val expClass = (additionalParams match { - case List() => - q""" - final case class $name(..$params) extends {} with ..$parents { - ..$body - ..$generated - } - """ - case _ => - val newParams = params map { - case ValDef(_, name, tpt, rhs) if rhs.isEmpty => q"val $name : $tpt" - case ValDef(_, name, tpt, rhs) => q"val $name : $tpt = $rhs" - } - q""" - final case class $name(..$additionalParams) - (..$newParams) extends {} with ..$parents { - ..$body - ..$generated - } - """ - }).asInstanceOf[ClassDef] -/* - c.info(c.enclosingPosition, - s"generated `${name.toString}'\n$expClass", force = false) -*/ - expClass - } - } -} \ No newline at end of file diff --git a/meta/src/main/scala/meta/NatParser.scala b/meta/src/main/scala/meta/NatParser.scala deleted file mode 100644 index 9741605e2..000000000 --- a/meta/src/main/scala/meta/NatParser.scala +++ /dev/null @@ -1,64 +0,0 @@ -package meta - -import fastparse.ScalaWhitespace._ -import fastparse._ - -import meta.TypeParser._ - -object NatParser { - - sealed trait NatAST - object NatAST { - case class Identifier(id: TypeAST.Identifier) extends NatAST - case class Number(n: String) extends NatAST - case class BinaryOp(lhs: NatAST, op: String, rhs: NatAST) extends NatAST - case class TernaryOp(cond: BinaryOp, thenN: NatAST, elseN: NatAST) extends NatAST - case class Nat2NatApply(f: TypeAST.Identifier, n: NatAST) extends NatAST - case class Sum(id: Identifier, from: NatAST, upTo: NatAST, body: NatAST) extends NatAST - } - - def Nat[_: P]: P[NatAST] = { - def CompOrNat: P[NatAST] = { - def CompOp: P[String] = P("<".! | ">".!) - P(AddSubOrNat ~ (CompOp ~/ AddSubOrNat).rep).map(asBinaryOpOrNat) - } - - def AddSubOrNat: P[NatAST] = { - def AddSubOps: P[String] = P("+".! | "-".!) - P(DivMulPowModOrNat ~ (AddSubOps ~ DivMulPowModOrNat).rep).map(asBinaryOpOrNat) - } - - def DivMulPowModOrNat: P[NatAST] = { - def DivMulPowModOp: P[String] = P("*".! | "/".! | "^".! | "%".!) - P(SingleNat ~ (DivMulPowModOp ~ SingleNat).rep).map(asBinaryOpOrNat) - } - - def SingleNat: P[NatAST] = { - def Number: P[NatAST.Number] = P(CharIn("0-9").rep(1).!).map(NatAST.Number) - - def Sum: P[NatAST.Sum] = { - def Assignment: P[(NatAST.Identifier, NatAST)] = - P(NatIdentifier ~ "=" ~ Nat | "(" ~ Assignment ~ ")") - P("sum" ~ "_" ~ Assignment ~ "^" ~ Nat ~ Nat).map(NatAST.Sum.tupled) - } - - def Nat2NatApply: P[NatAST.Nat2NatApply] = - P(TypeIdentifier ~ "(" ~ Nat ~ ")").map(NatAST.Nat2NatApply.tupled) - - def NatIdentifier: P[NatAST.Identifier] = P(TypeIdentifier).map(NatAST.Identifier) - - def Parens: P[NatAST] = P("(" ~ Nat ~ ")") - - P(Number | Sum | Nat2NatApply | NatIdentifier | Parens) - } - - P(CompOrNat) - } - - private def asBinaryOpOrNat: ((NatAST, Seq[(String, NatAST)])) => NatAST = { - case (n, ns) => ns.foldLeft(n){ - case (lhs, (op, rhs)) => NatAST.BinaryOp(lhs, op, rhs) - } - } - -} diff --git a/meta/src/main/scala/meta/TypeParser.scala b/meta/src/main/scala/meta/TypeParser.scala deleted file mode 100644 index 7a31480f3..000000000 --- a/meta/src/main/scala/meta/TypeParser.scala +++ /dev/null @@ -1,328 +0,0 @@ -package meta - -import fastparse.ScalaWhitespace._ -import fastparse._ -import meta.NatParser._ -import meta.TypeParser.TypeAST.{FragmentAST, MatrixLayoutAST} - -object TypeParser { - - sealed trait TypeAST - object TypeAST { - case class Identifier(name: String) extends TypeAST - case class FunType(inT: TypeAST, outT: TypeAST) extends TypeAST - case class DepFunType(id: Identifier, kind: String, t: TypeAST) extends TypeAST - case class ImplicitDepFunType(id: Identifier, kind: String, t: TypeAST) extends TypeAST - - case class ScalarType(t: String) extends TypeAST - case object NatType extends TypeAST - case class VectorType(size: NatAST, elemType: TypeAST) extends TypeAST - case class IndexType(size: NatAST) extends TypeAST - case class PairType(lhs: TypeAST, rhs: TypeAST) extends TypeAST - case class DepPairType(id: Identifier, kind: String, t: TypeAST) extends TypeAST - case class NatToDataApply(f: TypeAST, n: NatAST) extends TypeAST - case class NatToDataLambda(id: Identifier, t: TypeAST) extends TypeAST - case class ArrayType(size: NatAST, elemType: TypeAST) extends TypeAST - case class DepArrayType(size: NatAST, fdt: TypeAST) extends TypeAST - case class FragmentType(n: NatAST, m: NatAST, k: NatAST, elemType: TypeAST, - fKind: FragmentAST, mLayout: MatrixLayoutAST) extends TypeAST - - sealed trait FragmentAST - object FragmentAST { - case class Identifier(id: TypeAST.Identifier) extends FragmentAST - object ACC extends FragmentAST - object A extends FragmentAST - object B extends FragmentAST - } - - sealed trait MatrixLayoutAST - object MatrixLayoutAST { - case class Identifier(id: TypeAST.Identifier) extends MatrixLayoutAST - object ROW_MAJOR extends MatrixLayoutAST - object COL_MAJOR extends MatrixLayoutAST - } - - object Kind extends Enumeration { - val Data, Nat, Nat2Nat, Nat2Data, Address, Fragment, MatrixLayout, Function = Value - - def fromString(s: String): Value = s match { - case "data" => Data - case "nat" => Nat - case "nat2nat" => Nat2Nat - case "nat2data" => Nat2Data - case "address" => Address - case "fragment" => Fragment - case "matrixLayout" => MatrixLayout - } - } - } - - def PrimitiveDeclarations[_: P]: P[Seq[(String, Option[(Int, Int)], TypeAST)]] = - P(Start ~ PrimitiveDeclaration.rep(1) ~ End) - - def PrimitiveDeclaration[_: P]: P[(String, Option[(Int, Int)], TypeAST)] = { - def ScalaFunArgs: P[(Int, Int)] = { - import scalaparse.Scala.TrailingCommaOps - P("(" ~ Index ~ - (scalaparse.Scala.Id ~ scalaparse.syntax.Key.O(":") ~ scalaparse.Scala.Type).repTC(1) ~ - Index ~ ")") - } - - P("def" ~ Identifier ~ ScalaFunArgs.? ~ ":" ~ TypeSignature) - } - - def Identifier[_: P]: P[String] = { - def Keywords: P[Unit] = - P(("def" | (Kind: P[Unit]) | DataType.TypeName) ~~ CharPred(_.isWhitespace)) - - val LowerChar = scalaparse.syntax.Identifiers.NamedFunction(CharPredicates.isLower) - val IdCharacter = scalaparse.syntax.Identifiers.NamedFunction(c => - CharPredicates.isLetter(c) || CharPredicates.isDigit(c)) - - P((!Keywords ~ CharPred(LowerChar).! ~~ CharsWhile(IdCharacter).!.?). - map(t => t._1 ++ t._2.getOrElse(""))) - } - - def TypeSignature[_: P]: P[TypeAST] = { - def DepFunType: P[TypeAST.DepFunType] = - P("(" ~ IdentifierKindPair ~ ")" ~ "->" ~/ TypeSignature).map(TypeAST.DepFunType.tupled) - - def ImplicitDepFunType: P[TypeAST.ImplicitDepFunType] = - P("{" ~ IdentifierKindPair ~ "}" ~ "->" ~/ TypeSignature). - map(TypeAST.ImplicitDepFunType.tupled) - - def FunType: P[TypeAST.FunType] = - P(NoCut(LeftTypeSignature) ~ "->" ~/ TypeSignature).map(TypeAST.FunType.tupled) - - // Types that can appear at the left of an function arrow - def LeftTypeSignature: P[TypeAST] = P(DataType.DataType | ("(" ~ TypeSignature ~ ")")) - - P(DepFunType | ImplicitDepFunType | FunType | LeftTypeSignature) - } - - def TypeIdentifier[_: P]: P[TypeAST.Identifier] = P(Identifier).map(TypeAST.Identifier) - - def IdentifierKindPair[_: P]: P[(TypeAST.Identifier, String)] = P(TypeIdentifier ~ ":" ~ Kind) - - def Kind[_: P]: P[String] = - P("data".! | "address".! | "nat2nat".! | "nat2data".! | "nat".! | - "fragment".! | "matrixLayout".!) - - object DataType { - def ScalarType[_: P]: P[TypeAST.ScalarType] = - P("bool".! | "int".! | - "i8".! | "i16".! | "i32".! | "i64".! | - "u8".! | "u16".! | "u32".! | "u64".! | - "f16".! | "f32".! | "f64".!).map(TypeAST.ScalarType) - - def NatType[_: P]: P[TypeAST.NatType.type] = P("natType").map(_ => TypeAST.NatType) - - def IndexType[_: P]: P[TypeAST.IndexType] = P("idx[" ~ Nat ~ "]").map(TypeAST.IndexType) - - def VectorType[_: P]: P[TypeAST.VectorType] = - P("vec[" ~ DataType ~ "," ~ Nat ~ "]").map(t => TypeAST.VectorType(t._2, t._1)) - - def FragmentType[_: P]: P[TypeAST.FragmentType] = { - def FragmentKind: P[FragmentAST] = - P(("fragment." ~~ ( - "ACC".!.map(_ => FragmentAST.ACC) | - "A".!.map(_ => FragmentAST.A) | - "B".!.map(_ => FragmentAST.B)) - ) | TypeIdentifier.map(FragmentAST.Identifier)) - - def MatrixLayout: P[MatrixLayoutAST] = - P(("matrixLayout." ~~ ( - "ROW_MAJOR".!.map(_ => MatrixLayoutAST.ROW_MAJOR) | - "COL_MAJOR".!.map(_ => MatrixLayoutAST.COL_MAJOR)) - ) | TypeIdentifier.map(MatrixLayoutAST.Identifier)) - - P("fragment["~ Nat ~","~ Nat ~","~ Nat ~","~ DataType ~","~ FragmentKind ~ - ","~ MatrixLayout ~"]").map(TypeAST.FragmentType.tupled) - } - - def DepArrayType[_: P]: P[TypeAST.DepArrayType] = - P(Nat ~ ".." ~/ NatToData).map(TypeAST.DepArrayType.tupled) - - def ArrayType[_: P]: P[TypeAST.ArrayType] = - P(Nat ~ "." ~~ !"." ~/ DataType).map(TypeAST.ArrayType.tupled) - - def DepPairType[_: P]: P[TypeAST.DepPairType] = - P("(" ~ IdentifierKindPair ~ "**" ~/ DataType ~ ")").map(TypeAST.DepPairType.tupled) - - def NatToDataApply[_: P]: P[TypeAST.NatToDataApply] = - P(NatToData ~ "(" ~ Nat ~ ")").map(TypeAST.NatToDataApply.tupled) - - def PairType[_: P]: P[TypeAST.PairType] = - P("(" ~ NoCut(DataType) ~ "," ~/ DataType ~ ")").map(TypeAST.PairType.tupled) - - def DataType[_: P]: P[TypeAST] = - P(ScalarType | NatType | IndexType | VectorType | FragmentType | DepArrayType | - ArrayType | DepPairType | NatToDataApply | PairType | TypeIdentifier | - ("(" ~ DataType ~ ")")) - - def TypeName[_: P]: P[Unit] = - P(ScalarType | NatType | "idx" | "vec" | "fragment" | "matrixLayout") - } - - def NatToData[_: P]: P[TypeAST] = { - def NatToDataLambda: P[TypeAST.NatToDataLambda] = - P("(" ~ IdentifierKindPair.filter(_._2 == "nat").map(_._1) ~ - "|->" ~/ DataType.DataType ~ ")").map(TypeAST.NatToDataLambda.tupled) - - P(TypeIdentifier | NatToDataLambda) - } - - object isWellKindedType { - import TypeAST.Kind._ - - def apply(typeAST: TypeAST): Boolean = { - kindOf(typeAST, Map.empty).isDefined - } - - def kindOf(typeAST: TypeAST, - env: Map[TypeAST.Identifier, TypeAST.Kind.Value]): Option[TypeAST.Kind.Value] = { - typeAST match { - case id: TypeAST.Identifier => - env.get(id) - case TypeAST.FunType(inT, outT) => - for { - _ <- kindOf(inT, env) - _ <- kindOf(outT, env) - } yield Function - case TypeAST.DepFunType(id, kind, t) => - if (env.isDefinedAt(id)) { - None // we forbid shadowing - } else { - kindOf(t, env.updated(id, TypeAST.Kind.fromString(kind))) - } - case TypeAST.ImplicitDepFunType(id, kind, t) => - if (env.isDefinedAt(id)) { - None // we forbid shadowing - } else { - kindOf(t, env.updated(id, TypeAST.Kind.fromString(kind))) - } - case TypeAST.VectorType(size, elemType) => - for { - k1 <- kindOf(size, env) - k2 <- kindOf(elemType, env) - if k1 == Nat && k2 == Data - } yield Data - case TypeAST.IndexType(size) => - for { - k <- kindOf(size, env) - if k == Nat - } yield Data - case TypeAST.PairType(lhs, rhs) => - for { - k1 <- kindOf(lhs, env) - k2 <- kindOf(rhs, env) - if k1 == Data && k2 == Data - } yield Data - case TypeAST.DepPairType(id, kind, t) => - if (env.isDefinedAt(id)) { - None // we forbid shadowing - } else { - kindOf(t, env.updated(id, TypeAST.Kind.fromString(kind))) - } - case TypeAST.NatToDataApply(f, n) => - for { - k1 <- kindOf(f, env) - k2 <- kindOf(n, env) - if k1 == Nat2Data && k2 == Nat - } yield Data - case TypeAST.NatToDataLambda(id, t) => - if (env.isDefinedAt(id)) { - None // we forbid shadowing - } else { - for { - k <- kindOf(t, env.updated(id, Nat)) - if k == Data - } yield Nat2Data - } - case TypeAST.ArrayType(size, elemType) => - for { - k1 <- kindOf(size, env) - k2 <- kindOf(elemType, env) - if k1 == Nat && k2 == Data - } yield Data - case TypeAST.DepArrayType(size, fdt) => - for { - k1 <- kindOf(size, env) - k2 <- kindOf(fdt, env) - if k1 == Nat && k2 == Nat2Data - } yield Data - case TypeAST.FragmentType(n, m, k, elemType, fKind, mLayout) => - for { - k1 <- kindOf(n, env) - k2 <- kindOf(m, env) - k3 <- kindOf(k, env) - k4 <- kindOf(elemType, env) - k5 <- kindOf(fKind, env) - k6 <- kindOf(mLayout, env) - if k1 == Nat && k2 == Nat && k3 == Nat && k4 == Data && - k5 == Fragment && k6 == TypeAST.Kind.MatrixLayout - } yield Data - case _: TypeAST.ScalarType | TypeAST.NatType => - Some(Data) - } - } - - def kindOf(natAST: NatAST, - env: Map[TypeAST.Identifier, TypeAST.Kind.Value]): Option[TypeAST.Kind.Value] = { - natAST match { - case NatAST.Identifier(id) => - env.get(id) - case NatAST.Number(_) => - Some(Nat) - case NatAST.BinaryOp(lhs, _, rhs) => - for { - k1 <- kindOf(lhs, env) - k2 <- kindOf(rhs, env) - if k1 == Nat && k2 == Nat - } yield Nat - - case NatAST.TernaryOp(_, thenN, elseN) => - for { - k1 <- kindOf(thenN, env) - k2 <- kindOf(elseN, env) - if k1 == Nat && k2 == Nat - } yield Nat - - case NatAST.Nat2NatApply(f, n) => - for { - k1 <- kindOf(f, env) - k2 <- kindOf(n, env) - if k1 == Nat2Nat && k2 == Nat - } yield Nat - - case NatAST.Sum(id, from, upTo, body) => - val nEnv = env.updated(id.id, Nat) - for { - k1 <- kindOf(from, nEnv) - k2 <- kindOf(upTo, nEnv) - k3 <- kindOf(body, nEnv) - if k1 == Nat && k2 == Nat && k3 == Nat - } yield Nat - } - } - - def kindOf(fragmentAST: TypeAST.FragmentAST, - env: Map[TypeAST.Identifier, TypeAST.Kind.Value]): Option[TypeAST.Kind.Value] = { - fragmentAST match { - case FragmentAST.Identifier(id) => - env.get(id) - case FragmentAST.ACC | FragmentAST.A | FragmentAST.B => Some(Fragment) - } - } - - def kindOf(matrixLayout: TypeAST.MatrixLayoutAST, - env: Map[TypeAST.Identifier, TypeAST.Kind.Value]): Option[TypeAST.Kind.Value] = { - matrixLayout match { - case MatrixLayoutAST.Identifier(id) => - env.get(id) - case MatrixLayoutAST.ROW_MAJOR | - MatrixLayoutAST.COL_MAJOR => Some(TypeAST.Kind.MatrixLayout) - } - } - } -} diff --git a/meta/src/main/scala/meta/generator/DPIAPrimitives.scala b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala new file mode 100644 index 000000000..7e6f14729 --- /dev/null +++ b/meta/src/main/scala/meta/generator/DPIAPrimitives.scala @@ -0,0 +1,288 @@ +package meta.generator + +import fastparse.{Parsed, parse} +import meta.parser._ + +object DPIAPrimitives { + def main(args: Array[String]): Unit = { + val sourceDir = args.head + val shinePath = os.Path(sourceDir) / "shine" + os.walk.stream(shinePath).filter(_.ext == "dpia").foreach(path => { + + import DPIA.Decl.AST._ + + val definition = os.read(path) + parse(definition, DPIA.Decl.PrimitiveDeclarations(_)) match { + case failure: Parsed.Failure => + println(s"Failed to parse `${failure.extra.input}'") + println(s" $failure") + case Parsed.Success(seq, _) => + seq.foreach { + case PrimitiveDeclaration(Identifier(originalName), scalaParams, params, returnType) + if DPIA.isWellKindedDefinition(params, returnType) => + val name = originalName.capitalize + + val outputPath = (path / os.up) / s"$name.scala" + println(s"Generate $outputPath") + + import scala.meta._ + val packageName = path.relativeTo(shinePath).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("shine")) { + case (t, name) => Term.Select(t, Term.Name(name)) + } + val scalaParamsString = scalaParams match { + case Some((start, end)) => definition.substring(start, end) + case None => "" + } + val code = s"""// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // + |// This file is automatically generated and should not be changed manually // + |// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // + |${q""" +package $packageName { + +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ +import shine.DPIA.Types._ +import shine.DPIA._ + +${generateCaseClass(Type.Name(name), scalaParamsString, params, returnType)} + +}""".toString()} + |""".stripMargin + + os.write.over(outputPath, code) + case PrimitiveDeclaration(Identifier(name), _, params, returnType) => + println(s"Could not generate code for `$name' as parameters `$params' and/or `$returnType' are not well kinded.") + } + } + }) + } + + def generateCaseClass(name: scala.meta.Type.Name, + scalaParamsString: String, + params: Seq[DPIA.Decl.AST.Param], + returnType: DPIA.Type.AST): scala.meta.Defn.Class = { + import scala.meta._ + import meta.parser.DPIA.Type.AST + val (scalaReturnType, superClass) = returnType match { + case AST.ExpType(_, _) => (t"ExpType", init"ExpPrimitive") + case AST.AccType(_) => (t"AccType", init"AccPrimitive") + case AST.CommType => (t"CommType", init"CommandPrimitive") + case _ => throw new Exception(s"Expected `exp', `acc' or `comm' as return type for ${name.value}") + } + val generatedParams = generateParams(scalaParamsString, params) + q""" + final case class $name(...$generatedParams) extends $superClass { + { + ..${generateTypeChecks(params).stats} + } + + ..${if (scalaReturnType != t"CommType") { + List(q"override val t: $scalaReturnType = ${generateTerm(returnType)}") + } else List() } + + ${generateVisitAndRebuild(name, generatedParams)} + + ..${if (scalaParamsString.nonEmpty && generatedParams.last.size > 1) { + List(generateUnwrap(generatedParams.last)) + } else List() } + } + """ + } + + def generateParams(scalaParamsString: String, + params: Seq[DPIA.Decl.AST.Param]): List[List[scala.meta.Term.Param]] = { + import scala.meta._ + + val scalaParams = if (scalaParamsString.nonEmpty) { + s"def foo($scalaParamsString)".parse[Stat].get match { + case declDef: Decl.Def => declDef.paramss + } + } else { + List() + } + + scalaParams ++ List(params.map(generateParam).toList) + } + + def generateParam(param: DPIA.Decl.AST.Param): scala.meta.Term.Param = { + import scala.meta._ + import _root_.meta.parser.DPIA.Kind + param"val ${Term.Name(param.id.name)}: ${ + param.ty match { + case Left(kindAST) => generateType(kindAST) + case Right(typeAST) => t"Phrase[${generatePhraseType(typeAST)}]" + } + }" + } + + def generateTypeChecks(params: Seq[DPIA.Decl.AST.Param]): scala.meta.Term.Block = { + import scala.meta._ + q"""{ + ..${params. + filter(param => param.ty.isRight). // only check types for parameters with phrase types + map(param => + q"${Term.Name(param.id.name)} :: ${ + param.ty match { + case Right(typeAST@DPIA.Type.AST.DepFunType(id, kind, t)) => + q"""{ + ${Defn.Val( + mods = Nil, + pats = List(Pat.Var(name = Term.Name(id.name))), + decltpe = None, + rhs = q"${Term.Name(param.id.name)}.t.x" + )} + ${generateTerm(typeAST)} + }""" + case Right(typeAST) => generateTerm(typeAST) + case Left(kindAST) => throw new Exception("This should not happen") + }}" + ).toList} + }""" + } + + def generatePhraseType(typeAST: DPIA.Type.AST): scala.meta.Type = { + import scala.meta._ + import meta.parser.DPIA.Type.AST + typeAST match { + case AST.ExpType(_, _) => t"ExpType" + case AST.AccType(_) => t"AccType" + case AST.CommType => t"CommType" + case AST.PairType(lhs, rhs) => t"PhrasePairType[${generatePhraseType(lhs)}, ${generatePhraseType(rhs)}]" + case AST.FunType(inT, outT) => t"FunType[${generatePhraseType(inT)}, ${generatePhraseType(outT)}]" + case AST.DepFunType(id, kind, t) => t"DepFunType[${generateKindType(kind)}, ${generatePhraseType(t)}]" + case AST.Identifier(name) => Type.Name(name) + } + } + + def generateType(kindAST: DPIA.Kind.AST): scala.meta.Type = { + import scala.meta._ + import meta.parser.DPIA.Kind.AST + kindAST match { + case AST.RiseKind(riseKind) => + import meta.parser.rise.Kind.AST + riseKind match { + case AST.Data => Type.Name("DataType") + case AST.Address => Type.Name("AddressSpace") + case AST.Nat2Nat => Type.Name("NatToNat") + case AST.Nat2Data => Type.Name("NatToData") + case AST.Nat => Type.Name("Nat") + case AST.Fragment => Type.Name("FragmentKind") + case AST.MatrixLayout => Type.Name("MatrixLayout") + } + case AST.Access => Type.Name("AccessType") + } + } + + def generateKindType(kindAST: DPIA.Kind.AST): scala.meta.Type = { + import scala.meta._ + import meta.parser.DPIA.Kind.AST + kindAST match { + case AST.RiseKind(riseKind) => + import meta.parser.rise.Kind.AST + riseKind match { + case AST.Data => Type.Name("DataKind") + case AST.Address => Type.Name("AddressSpaceKind") + case AST.Nat2Nat => Type.Name("NatToNatKind") + case AST.Nat2Data => Type.Name("NatToDataKind") + case AST.Nat => Type.Name("NatKind") + case AST.Fragment => ??? + case AST.MatrixLayout => ??? + } + case AST.Access => Type.Name("AccessKind") + } + } + + def generateTerm(typeAST: DPIA.Type.AST): scala.meta.Term = { + import scala.meta._ + import meta.parser.DPIA.Type.AST + typeAST match { + case AST.ExpType(dataType, access) => + q"expT(${RisePrimitives.generateDataType(dataType)}, ${generateTerm(access)})" + case AST.AccType(dataType) => + q"accT(${RisePrimitives.generateDataType(dataType)})" + case AST.CommType => + q"comm" + case AST.PairType(lhs, rhs) => + q"PhrasePairType(${generateTerm(lhs)}, ${generateTerm(rhs)})" + case AST.FunType(inT, outT) => + q"FunType(${generateTerm(inT)}, ${generateTerm(outT)})" + case AST.DepFunType(id, kind, t) => + q"DepFunType[${generateKindType(kind)}, PhraseType](${Term.Name(id.name)}, ${generateTerm(t)})" + case AST.Identifier(name) => Term.Name(name) + } + } + + def generateTerm(accessAST: DPIA.Type.Access.AST): scala.meta.Term = { + import scala.meta._ + import meta.parser.DPIA.Type.Access.AST + accessAST match { + case AST.Identifier(name) => Term.Name(name) + case AST.Read => Term.Name("read") + case AST.Write =>Term.Name("write") + } + } + + def generateVisitAndRebuild(name: scala.meta.Type.Name, + paramLists: List[List[scala.meta.Term.Param]]): scala.meta.Defn.Def = { + import scala.meta._ + + object TypeIs { + def unapply(ty: Type): Option[String] = ty match { + case Type.Name(name) => Some(name) + case Type.Select(_, Type.Name(name)) => Some(name) + case _ => None + } + } + + def injectVisitCall(param: Term.Param): Term = { + param.decltpe match { + case Some(ty) => ty match { + case TypeIs("Nat") | TypeIs("NatIdentifier") => + q"v.nat(${Term.Name(param.name.value)})" + case TypeIs("DataType") | TypeIs("ScalarType") | TypeIs("BasicType") => + q"v.data(${Term.Name(param.name.value)})" + case TypeIs("NatToNat") => + q"v.natToNat(${Term.Name(param.name.value)})" + case TypeIs("NatToData") => + q"v.natToData(${Term.Name(param.name.value)})" + case TypeIs("AccessType") => + q"v.access(${Term.Name(param.name.value)})" + case TypeIs("AddressSpace") => + q"v.addressSpace(${Term.Name(param.name.value)})" + case TypeIs("LocalSize") | TypeIs("GlobalSize") => + q"${Term.Name(param.name.value)}.visitAndRebuild(v)" + case t"Phrase[$_]" => q"VisitAndRebuild(${Term.Name(param.name.value)}, v)" + case t"Vector[Phrase[$_]]" => q"${Term.Name(param.name.value)}.map(VisitAndRebuild(_, v))" + case t"Seq[Phrase[$_]]" => q"${Term.Name(param.name.value)}.map(VisitAndRebuild(_, v))" + + case t"Map[Identifier[_ <: PhraseType], $_]" => + q"""${Term.Name(param.name.value)}.map{ case (key, value) => + VisitAndRebuild(key, v).asInstanceOf[Identifier[_ <: PhraseType]] -> value + }""" + + case Type.Apply(Type.Name("Vector"), List(TypeIs("DataType"))) // Vector[DataType] + | Type.Apply(Type.Name("Seq"), List(TypeIs("DataType"))) => // Seq[DataType] + q"${Term.Name(param.name.value)}.map(v.data)" + case _ => + Term.Name(param.name.value) + } + case None => throw new Exception(s"Expected type declaration") + } + } + + q"""override def visitAndRebuild(v: VisitAndRebuild.Visitor): $name = + new $name(...${paramLists.map(_.map(injectVisitCall))}) + """ + } + + def generateUnwrap(paramList: List[scala.meta.Term.Param]): scala.meta.Defn.Def = { + import scala.meta._ + val (types, names) = paramList.map({ + case Term.Param(_, name, Some(typ), _) => (typ, Term.Name(name.value)) + }).unzip + q""" + def unwrap: (..$types) = (..$names) + """ + } +} diff --git a/meta/src/main/scala/meta/RisePrimitiveGenerator.scala b/meta/src/main/scala/meta/generator/RisePrimitives.scala similarity index 50% rename from meta/src/main/scala/meta/RisePrimitiveGenerator.scala rename to meta/src/main/scala/meta/generator/RisePrimitives.scala index b61efb55c..f48f7e0a0 100644 --- a/meta/src/main/scala/meta/RisePrimitiveGenerator.scala +++ b/meta/src/main/scala/meta/generator/RisePrimitives.scala @@ -1,28 +1,28 @@ -package meta +package meta.generator import fastparse.{Parsed, parse} -import meta.NatParser.NatAST -import meta.TypeParser.TypeAST -import meta.TypeParser.TypeAST.{FragmentAST, MatrixLayoutAST} +import meta.parser._ +import meta.parser.rise.Kind -object RisePrimitiveGenerator { +object RisePrimitives { def main(args: Array[String]): Unit = { val sourceDir = args.head - val rise = os.Path(sourceDir) / "rise" - os.walk.stream(rise).filter(_.ext == "rise").foreach(path => { + val risePath = os.Path(sourceDir) / "rise" + os.walk.stream(risePath).filter(_.ext == "rise").foreach(path => { val definition = os.read(path) - parse(definition, TypeParser.PrimitiveDeclarations(_)) match { + parse(definition, rise.Decl.PrimitiveDeclarations(_)) match { case failure: Parsed.Failure => println(s"Failed to parse `${failure.extra.input}'") println(s" $failure") case Parsed.Success(seq, _) => seq.foreach { - case (name, args, typeSignature) if TypeParser.isWellKindedType(typeSignature) => + case rise.Decl.AST.PrimitiveDeclaration(rise.Decl.AST.Identifier(name), scalaParams, typeSignature) + if rise.isWellKindedType(typeSignature) => val outputPath = (path / os.up) / s"$name.scala" println(s"Generate $outputPath") - val generatedDef = args match { + val generatedDef = scalaParams match { case None => generateObject(name, typeSignature) case Some((start, end)) => @@ -30,7 +30,7 @@ object RisePrimitiveGenerator { } import scala.meta._ - val packageName = path.relativeTo(rise).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("rise")) { + val packageName = path.relativeTo(risePath).segments.dropRight(1).foldLeft[Term.Ref](Term.Name("rise")) { case (t, name) => Term.Select(t, Term.Name(name)) } val code = s"""// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // @@ -51,14 +51,14 @@ import arithexpr.arithmetic._ |""".stripMargin os.write.over(outputPath, code) - case (name, _, typeSignature) => + case rise.Decl.AST.PrimitiveDeclaration(name, _, typeSignature) => println(s"Could not generate code for `$name' as type signature `$typeSignature' is not well kinded.") } } }) } - def generateObject(name: String, typeSignature: TypeAST): scala.meta.Term.Block = { + def generateObject(name: String, typeSignature: rise.Type.AST): scala.meta.Term.Block = { import scala.meta._ val generated = q"""{ object ${Term.Name{name}} extends Builder { @@ -68,7 +68,7 @@ import arithexpr.arithmetic._ override val name: String = ${Lit.String(name)} override def setType(ty: Type): Primitive = Primitive()(ty) override def primEq(obj: rise.core.Primitive): Boolean = obj.getClass == getClass - override def typeScheme: Type = ${generateTypeScheme(typeSignature, Map.empty)} + override def typeScheme: Type = ${generateTypeScheme(typeSignature)} } override def toString: String = ${Lit.String(name)} @@ -83,13 +83,12 @@ import arithexpr.arithmetic._ generated } - def generateCaseClass(name: String, paramsString: String, typeSignature: TypeAST): scala.meta.Term.Block = { + def generateCaseClass(name: String, paramsString: String, typeSignature: rise.Type.AST): scala.meta.Term.Block = { import scala.meta._ - val params = paramsString.split(",").map(param => { - val parts = param.split(":").map(_.trim) - param"${Term.Name(parts(0))}: ${Type.Name(parts(1))}" - } ).toList + val params = s"def foo($paramsString)".parse[Stat].get match { + case declDef: Decl.Def => declDef.paramss.head + } val args: List[Term.Name] = params.map(p => Term.Name(p.name.value)) val types: List[Type] = params.map(p => p.decltpe.get) @@ -112,7 +111,7 @@ import arithexpr.arithmetic._ { override val name: String = ${Lit.String(name)} override def setType(ty: Type): Primitive = Primitive(..$args)(ty) - override def typeScheme: Type = ${generateTypeScheme(typeSignature, Map.empty)} + override def typeScheme: Type = ${generateTypeScheme(typeSignature)} override def primEq(obj: rise.core.Primitive): Boolean = obj match { case p: Primitive => ${generateComparisonChain(args)} @@ -130,115 +129,129 @@ import arithexpr.arithmetic._ generated } - def generateTypeScheme(typeAST: TypeAST, env: Map[TypeAST.Identifier, String]): scala.meta.Term = { + def generateTypeScheme(typeAST: rise.Type.AST): scala.meta.Term = { import scala.meta._ typeAST match { - case id@TypeAST.Identifier(name) => - assert(env.contains(id), s"$id is not in $env") + case rise.Type.AST.FunType(inT, outT) => + q"(${generateTypeScheme(inT)}) ->: (${generateTypeScheme(outT)})" + case rise.Type.AST.DepFunType(id, kind, t) => + q"expl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})" + case rise.Type.AST.ImplicitDepFunType(id, kind, t) => + q"impl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t)})" + case _ => generateDataType(typeAST) + } + } + + def generateDataType(typeAST: rise.Type.AST): scala.meta.Term = { + import scala.meta._ + typeAST match { + case rise.Type.AST.Identifier(name) => Term.Name(name) - case TypeAST.FunType(inT, outT) => - q"(${generateTypeScheme(inT, env)}) ->: (${generateTypeScheme(outT, env)})" - case TypeAST.DepFunType(id, kind, t) => - q"expl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t, env.updated(id, kind))})" - case TypeAST.ImplicitDepFunType(id, kind, t) => - q"impl((${Term.Name(id.name)}: ${Type.Name(kindName(kind))}) => ${generateTypeScheme(t, env.updated(id, kind))})" - case TypeAST.ScalarType(t) => + case rise.Type.AST.ScalarType(t) => t.parse[Term].get - case TypeAST.NatType => - q"rise.core.types.NatType" - case TypeAST.VectorType(size, elemType) => - q"rise.core.types.VectorType(${generateNat(size, env)}, ${generateTypeScheme(elemType, env)})" - case TypeAST.IndexType(size) => - q"rise.core.types.IndexType(${generateNat(size, env)})" - case TypeAST.PairType(lhs, rhs) => - q"rise.core.types.PairType(${generateTypeScheme(lhs, env)}, ${generateTypeScheme(rhs, env)})" - case TypeAST.DepPairType(id, kind, t) => kind match { - case "nat" => - q"Nat `**` ((${Term.Name(id.name)}: Nat) => ${generateTypeScheme(t, env.updated(id, kind))})" + case rise.Type.AST.NatType => + q"NatType" + case rise.Type.AST.OpaqueType(name) => + q"OpaqueType($name)" + case rise.Type.AST.VectorType(size, elemType) => + q"VectorType(${generateNat(size)}, ${generateDataType(elemType)})" + case rise.Type.AST.IndexType(size) => + q"IndexType(${generateNat(size)})" + case rise.Type.AST.PairType(lhs, rhs) => + q"PairType(${generateDataType(lhs)}, ${generateDataType(rhs)})" + case rise.Type.AST.DepPairType(id, kind, t) => kind match { + case Kind.AST.Nat => + q"Nat `**` ((${Term.Name(id.name)}: Nat) => ${generateDataType(t)})" case _ => ??? } - case TypeAST.NatToDataApply(f, n) => - q"rise.core.types.NatToDataApply(${generateTypeScheme(f, env)}, ${generateNat(n, env)})" - case TypeAST.NatToDataLambda(id, t) => - q"n2dtFun((${Term.Name(id.name)}: NatIdentifier) => ${generateTypeScheme(t, env.updated(id, "nat"))})" - case TypeAST.ArrayType(size, elemType) => - q"rise.core.types.ArrayType(${generateNat(size, env)}, ${generateTypeScheme(elemType, env)})" - case TypeAST.DepArrayType(size, fdt) => - q"rise.core.types.DepArrayType(${generateNat(size, env)}, ${generateTypeScheme(fdt, env)})" - case TypeAST.FragmentType(n, m, k, elemType, fKind, mLayout) => - q"rise.core.types.FragmentType(${generateNat(n, env)}, ${generateNat(m, env)}, ${generateNat(k, env)}, ${generateTypeScheme(elemType, env)}, ${generateFragment(fKind, env)}, ${generateMatrixLayout(mLayout, env)})" + case rise.Type.AST.NatToDataApply(f, n) => + q"NatToDataApply(${generateDataType(f)}, ${generateNat(n)})" + case rise.Type.AST.NatToDataLambda(id, t) => + q"n2dtFun((${Term.Name(id.name)}: NatIdentifier) => ${generateDataType(t)})" + case rise.Type.AST.ArrayType(size, elemType) => + q"ArrayType(${generateNat(size)}, ${generateDataType(elemType)})" + case rise.Type.AST.DepArrayType(size, fdt) => + q"DepArrayType(${generateNat(size)}, ${generateDataType(fdt)})" + case rise.Type.AST.FragmentType(n, m, k, elemType, fKind, mLayout) => + q"FragmentType(${generateNat(n)}, ${generateNat(m)}, ${generateNat(k)}, ${generateDataType(elemType)}, ${generateFragment(fKind)}, ${generateMatrixLayout(mLayout)})" + case rise.Type.AST.ManagedBufferType(dt) => + q"ManagedBufferType(${generateDataType(dt)})" + case rise.Type.AST.FunType(_, _) | rise.Type.AST.DepFunType(_, _, _) | + rise.Type.AST.ImplicitDepFunType(_, _, _) => ??? } } - def kindName(kind: String): String = kind match { - case "nat" => "Nat" - case "data" => "DataType" - case "nat2nat" => "NatToNat" - case "nat2data" => "NatToData" - case "address" => "AddressSpace" - case "fragment" => "FragmentKind" - case "matrixLayout" => "MatrixLayout" + def kindName(kind: Kind.AST): String = { + import meta.parser.rise.Kind.AST + kind match { + case AST.Data => "DataType" + case AST.Address => "AddressSpace" + case AST.Nat2Nat => "NatToNat" + case AST.Nat2Data => "NatToData" + case AST.Nat => "Nat" + case AST.Fragment => "FragmentKind" + case AST.MatrixLayout => "MatrixLayout" + } } - def generateNat(n: NatAST, env: Map[TypeAST.Identifier, String]): scala.meta.Term = { + def generateNat(n: Nat.AST): scala.meta.Term = { import scala.meta._ n match { - case NatAST.Identifier(id) => - assert(env.contains(id), s"$id is not in $env") - Term.Name(id.name) - case NatAST.Number(n) => + case Nat.AST.Identifier(id) => + Term.Name(id) + case Nat.AST.Number(n) => n.parse[Term].get - case NatAST.BinaryOp(lhs, "^", rhs) => - q"${generateNat(lhs, env)}.pow(${generateNat(rhs, env)})" - case NatAST.BinaryOp(lhs, op, rhs) => - q"${generateNat(lhs, env)} ${Term.Name(op)} ${generateNat(rhs, env)}" - case NatAST.TernaryOp(cond, thenN, elseN) => + case Nat.AST.BinaryOp(lhs, "^", rhs) => + q"${generateNat(lhs)}.pow(${generateNat(rhs)})" + case Nat.AST.BinaryOp(lhs, op, rhs) => + q"${generateNat(lhs)} ${Term.Name(op)} ${generateNat(rhs)}" + case Nat.AST.TernaryOp(cond, thenN, elseN) => val operator: Term = cond.op match { case "<" => q"Operator.<" case ">" => q"Operator.>" } q""" IfThenElse( - arithPredicate(${generateNat(cond.lhs, env)}, ${generateNat(cond.rhs, env)}, $operator), - ${generateNat(thenN, env)}, - ${generateNat(elseN, env)}) + arithPredicate(${generateNat(cond.lhs)}, ${generateNat(cond.rhs)}, $operator), + ${generateNat(thenN)}, + ${generateNat(elseN)}) """ - case NatAST.Nat2NatApply(f, n) => - q"${generateTypeScheme(f, env)}(${generateNat(n, env)})" - case NatAST.Sum(id, from, upTo, body) => + case Nat.AST.Nat2NatApply(f, n) => + q"${generateDataType(f)}(${generateNat(n)})" + case Nat.AST.Sum(id, from, upTo, body) => q"""BigSum( - from = ${generateNat(from, env)}, - upTo = ${generateNat(upTo, env)}, - (${Term.Name(id.id.name)}: Nat) => ${generateNat(body, env.updated(id.id, "nat"))}) + from = ${generateNat(from)}, + upTo = ${generateNat(upTo)}, + (${Term.Name(id.name)}: Nat) => ${generateNat(body)}) """ } } - def generateFragment(fragmentAST: FragmentAST, env: Map[TypeAST.Identifier, String]): scala.meta.Term = { + def generateFragment(fragmentAST: rise.Type.Fragment.AST): scala.meta.Term = { import scala.meta._ fragmentAST match { - case FragmentAST.Identifier(id) => - assert(env.contains(id), s"$id is not in $env") - Term.Name(id.name) - case FragmentAST.ACC => - q"rise.core.types.FragmentKind.Acuumulator" - case FragmentAST.A => - q"rise.core.types.FragmentKind.AMatrix" - case FragmentAST.B => - q"rise.core.types.FragmentKind.BMatrix" + case rise.Type.Fragment.AST.Identifier(name) => + Term.Name(name) + case rise.Type.Fragment.AST.ACC => + q"FragmentKind.Accumulator" + case rise.Type.Fragment.AST.A => + q"FragmentKind.AMatrix" + case rise.Type.Fragment.AST.B => + q"FragmentKind.BMatrix" } } - def generateMatrixLayout(matrixLayoutAST: MatrixLayoutAST, env: Map[TypeAST.Identifier, String]): scala.meta.Term = { + def generateMatrixLayout(matrixLayoutAST: rise.Type.MatrixLayout.AST): scala.meta.Term = { import scala.meta._ matrixLayoutAST match { - case MatrixLayoutAST.Identifier(id) => - assert(env.contains(id), s"$id is not in $env") - Term.Name(id.name) - case MatrixLayoutAST.ROW_MAJOR => - q"rise.core.types.MatrixLayout.Row_Major" - case MatrixLayoutAST.COL_MAJOR => - q"rise.core.types.MatrixLayout.Col_Major" + case rise.Type.MatrixLayout.AST.Identifier(name) => + Term.Name(name) + case rise.Type.MatrixLayout.AST.ROW_MAJOR => + q"MatrixLayout.Row_Major" + case rise.Type.MatrixLayout.AST.COL_MAJOR => + q"MatrixLayout.Col_Major" + case rise.Type.MatrixLayout.AST.NONE => + q"MatrixLayout.None" } } diff --git a/meta/src/main/scala/meta/parser/DPIA/Decl.scala b/meta/src/main/scala/meta/parser/DPIA/Decl.scala new file mode 100644 index 000000000..505a508c9 --- /dev/null +++ b/meta/src/main/scala/meta/parser/DPIA/Decl.scala @@ -0,0 +1,41 @@ +package meta.parser.DPIA + +import fastparse.ScalaWhitespace._ +import fastparse._ +import meta.parser.shared.Identifier + +object Decl { + sealed trait AST + object AST { + case class Identifier(name: String) extends AST + case class Param(id: Identifier, ty: Either[Kind.AST, Type.AST]) extends AST + case class PrimitiveDeclaration(id: Identifier, + scalaParams: Option[(Int, Int)], + params: Seq[Param], + returnType: Type.AST) extends AST + } + + def PrimitiveDeclarations[_: P]: P[Seq[AST.PrimitiveDeclaration]] = + P(Start ~ PrimitiveDeclaration.rep(1) ~ End) + + // def drop(n: nat, m: nat, t: data, input: exp[n+m.t, read]): exp[m.t, read] + // def mapGlobal[dim: Int](n: nat, s: data, t: data, f: exp[s, read] -> exp[t, read], array: exp[n.s, read]): exp[n.t, read] + def PrimitiveDeclaration[_: P]: P[AST.PrimitiveDeclaration] = { + import scalaparse.Scala.TrailingCommaOps + def ScalaParams: P[(Int, Int)] = { + P("{" ~ Index ~ + (scalaparse.Scala.Id ~ scalaparse.syntax.Key.O(":") ~ scalaparse.Scala.Type).repTC(1) ~ + Index ~ "}") + } + + def Param: P[AST.Param] = ( + (Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind).map(pair => AST.Param(pair._1, Left(pair._2))) + | (Identifier.map(AST.Identifier) ~ ":" ~ Type.PhraseType).map(pair => AST.Param(pair._1, Right(pair._2))) + ) + + def Params: P[Seq[AST.Param]] = Param.repTC(0) + + P("def" ~ Identifier.map(AST.Identifier) ~ ScalaParams.? ~ "(" ~ Params ~ ")" ~ ":" ~ Type.PhraseType) + .map(AST.PrimitiveDeclaration.tupled) + } +} diff --git a/meta/src/main/scala/meta/parser/DPIA/Kind.scala b/meta/src/main/scala/meta/parser/DPIA/Kind.scala new file mode 100644 index 000000000..6e3c0544e --- /dev/null +++ b/meta/src/main/scala/meta/parser/DPIA/Kind.scala @@ -0,0 +1,17 @@ +package meta.parser.DPIA + +import fastparse._ +import meta.parser.rise + +object Kind { + sealed trait AST + object AST { + case class RiseKind(riseKind: rise.Kind.AST) extends AST + case object Access extends AST + } + + def Kind[_: P]: P[AST] = P( + rise.Kind.Kind.map(AST.RiseKind) | + "access".!.map(_ => AST.Access) + ) +} diff --git a/meta/src/main/scala/meta/parser/DPIA/Type.scala b/meta/src/main/scala/meta/parser/DPIA/Type.scala new file mode 100644 index 000000000..c998ed5d4 --- /dev/null +++ b/meta/src/main/scala/meta/parser/DPIA/Type.scala @@ -0,0 +1,65 @@ +package meta.parser.DPIA + +import fastparse.ScalaWhitespace._ +import fastparse._ +import meta.parser._ +import shared._ + +object Type { + sealed trait AST + object AST { + case class ExpType(dataType: rise.Type.AST, access: Access.AST) extends AST + case class AccType(dataType: rise.Type.AST) extends AST + case object CommType extends AST + + case class PairType(lhs: AST, rhs: AST) extends AST + case class FunType(inT: AST, outT: AST) extends AST + case class DepFunType(id: Identifier, kind: Kind.AST, t: AST) extends AST + case class Identifier(name: String) extends AST + } + + object Access { + sealed trait AST + object AST { + case class Identifier(name: String) extends AST + case object Read extends AST + case object Write extends AST + } + } + + def PhraseType[_: P]: P[AST] = { + def DataType: P[rise.Type.AST] = rise.Type.DataType.DataType + + def AccessType: P[Access.AST] = P( + "read".!.map(_ => Access.AST.Read) | + "write".!.map(_ => Access.AST.Write) | + Identifier.map(Access.AST.Identifier) + ) + + def ExpType: P[AST.ExpType] = P("exp[" ~ DataType ~ "," ~ AccessType ~ "]").map(AST.ExpType.tupled) + + def AccType: P[AST.AccType] = P("acc[" ~ DataType ~ "]").map(AST.AccType) + + def VarType: P[AST.PairType] = P("var[" ~ DataType ~ "]"). + map(dt => AST.PairType(AST.ExpType(dt, Access.AST.Read), AST.AccType(dt))) + + def CommType: P[AST.CommType.type] = P("comm".!.map(_ => AST.CommType)) + + def PairType: P[AST.PairType] = + P("(" ~ NoCut(PhraseType) ~ "," ~/ PhraseType ~ ")").map(AST.PairType.tupled) + + def FunType: P[AST.FunType] = + P(NoCut(NonFunPhraseType | ("(" ~ PhraseType ~ ")")) ~ "->" ~/ PhraseType).map(AST.FunType.tupled) + + def DepFunType: P[AST.DepFunType] = { + def IdentifierKindPair: P[(AST.Identifier, Kind.AST)] = + P(Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind) + + P("(" ~ IdentifierKindPair ~ ")" ~ "->" ~/ PhraseType).map(AST.DepFunType.tupled) + } + + def NonFunPhraseType: P[AST] = P( ExpType | AccType | VarType | CommType | PairType | DepFunType ) + + P( FunType | NonFunPhraseType ) + } +} diff --git a/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala b/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala new file mode 100644 index 000000000..297978ea2 --- /dev/null +++ b/meta/src/main/scala/meta/parser/DPIA/isWellKindedDefinition.scala @@ -0,0 +1,63 @@ +package meta.parser.DPIA + +import meta.parser._ + +object isWellKindedDefinition { + + def apply(params: Seq[Decl.AST.Param], returnType: Type.AST): Boolean = { + import Decl.AST._ + var isWellKindedFlag = true + val env = params.foldLeft(Map.empty[String, Kind.AST]) { + case (env, Param(Identifier(name), Left(kind))) => + env.updated(name, kind) + case (env, Param(Identifier(_), Right(typeAST))) => + if (!isWellKinded(typeAST, env)) isWellKindedFlag = false + env + } + isWellKindedFlag && isWellKinded(returnType, env) + } + + def isWellKinded(typeAST: Type.AST, env: Map[String, Kind.AST]): Boolean = { + import Type._ + import rise.isWellKindedType._ + typeAST match { + case AST.ExpType(dataType, access) => + val nenv = env.flatMap { + case (string, DPIA.Kind.AST.RiseKind(riseKind)) => + Some((string, riseKind)) + case _ => None + } + kindOf(dataType, nenv).isDefined && isWellKinded(access, env) + case AST.AccType(dataType) => + val nenv = env.flatMap { + case (string, DPIA.Kind.AST.RiseKind(riseKind)) => + Some((string, riseKind)) + case _ => None + } + kindOf(dataType, nenv).isDefined + case AST.CommType => true + case AST.PairType(lhs, rhs) => + isWellKinded(lhs, env) && isWellKinded(rhs, env) + case AST.FunType(inT, outT) => + isWellKinded(inT, env) && isWellKinded(outT, env) + case AST.DepFunType(id, kind, t) => + if (env.isDefinedAt(id.name)) { + // we forbid shadowing + false + } else { + isWellKinded(t, env.updated(id.name, kind)) + } + case AST.Identifier(name) => + env.contains(name) + } + } + + def isWellKinded(accessAST: Type.Access.AST, env: Map[String, Kind.AST]): Boolean = { + import Type.Access._ + accessAST match { + case AST.Identifier(name) => env.isDefinedAt(name) + case AST.Read => true + case AST.Write => true + } + } +} diff --git a/meta/src/main/scala/meta/parser/Nat.scala b/meta/src/main/scala/meta/parser/Nat.scala new file mode 100644 index 000000000..051241818 --- /dev/null +++ b/meta/src/main/scala/meta/parser/Nat.scala @@ -0,0 +1,62 @@ +package meta.parser + +import fastparse.ScalaWhitespace._ +import fastparse._ + +object Nat { + + sealed trait AST + object AST { + case class Identifier(name: String) extends AST + case class Number(n: String) extends AST + case class BinaryOp(lhs: AST, op: String, rhs: AST) extends AST + case class TernaryOp(cond: BinaryOp, thenN: AST, elseN: AST) extends AST + case class Nat2NatApply(f: rise.Type.AST.Identifier, n: AST) extends AST + case class Sum(id: Identifier, from: AST, upTo: AST, body: AST) extends AST + } + + def Nat[_: P]: P[AST] = { + def CompOrNat: P[AST] = { + def CompOp: P[String] = P("<".! | ">".!) + P(AddSubOrNat ~ (CompOp ~/ AddSubOrNat).rep).map(asBinaryOpOrNat) + } + + def AddSubOrNat: P[AST] = { + def AddSubOps: P[String] = P("+".! | "-".!) + P(DivMulPowModOrNat ~ (AddSubOps ~ DivMulPowModOrNat).rep).map(asBinaryOpOrNat) + } + + def DivMulPowModOrNat: P[AST] = { + def DivMulPowModOp: P[String] = P("*".! | "/".! | "^".! | "%".!) + P(SingleNat ~ (DivMulPowModOp ~ SingleNat).rep).map(asBinaryOpOrNat) + } + + def SingleNat: P[AST] = { + def Number: P[AST.Number] = P(CharIn("0-9").rep(1).!).map(AST.Number) + + def Sum: P[AST.Sum] = { + def Assignment: P[(AST.Identifier, AST)] = + P(NatIdentifier ~ "=" ~ Nat | "(" ~ Assignment ~ ")") + P("sum" ~ "_" ~ Assignment ~ "^" ~ Nat ~ Nat).map(AST.Sum.tupled) + } + + def Nat2NatApply: P[AST.Nat2NatApply] = + P(rise.Type.TypeIdentifier ~ "(" ~ Nat ~ ")").map(AST.Nat2NatApply.tupled) + + def NatIdentifier: P[AST.Identifier] = P(rise.Type.TypeIdentifier).map(i => AST.Identifier(i.name)) + + def Parens: P[AST] = P("(" ~ Nat ~ ")") + + P(Number | Sum | Nat2NatApply | NatIdentifier | Parens) + } + + P(CompOrNat) + } + + private def asBinaryOpOrNat: ((AST, Seq[(String, AST)])) => AST = { + case (n, ns) => ns.foldLeft(n){ + case (lhs, (op, rhs)) => AST.BinaryOp(lhs, op, rhs) + } + } + +} diff --git a/meta/src/main/scala/meta/parser/rise/Decl.scala b/meta/src/main/scala/meta/parser/rise/Decl.scala new file mode 100644 index 000000000..d4a191565 --- /dev/null +++ b/meta/src/main/scala/meta/parser/rise/Decl.scala @@ -0,0 +1,30 @@ +package meta.parser.rise + +import fastparse.ScalaWhitespace._ +import fastparse._ +import meta.parser.shared.Identifier + +object Decl { + sealed trait AST + object AST { + case class Identifier(name: String) extends AST + case class PrimitiveDeclaration(id: Identifier, + scalaParams: Option[(Int, Int)], + typeSignature: Type.AST) extends AST + } + + def PrimitiveDeclarations[_: P]: P[Seq[AST.PrimitiveDeclaration]] = + P(Start ~ PrimitiveDeclaration.rep(1) ~ End) + + def PrimitiveDeclaration[_: P]: P[AST.PrimitiveDeclaration] = { + def ScalaParams: P[(Int, Int)] = { + import scalaparse.Scala.TrailingCommaOps + P("(" ~ Index ~ + (scalaparse.Scala.Id ~ scalaparse.syntax.Key.O(":") ~ scalaparse.Scala.Type).repTC(1) ~ + Index ~ ")") + } + + P("def" ~ Identifier.map(AST.Identifier) ~ ScalaParams.? ~ ":" ~ Type.TypeSignature) + .map(AST.PrimitiveDeclaration.tupled) + } +} diff --git a/meta/src/main/scala/meta/parser/rise/Kind.scala b/meta/src/main/scala/meta/parser/rise/Kind.scala new file mode 100644 index 000000000..6333bc698 --- /dev/null +++ b/meta/src/main/scala/meta/parser/rise/Kind.scala @@ -0,0 +1,26 @@ +package meta.parser.rise + +import fastparse._ + +object Kind { + sealed trait AST + object AST { + case object Data extends AST + case object Address extends AST + case object Nat2Nat extends AST + case object Nat2Data extends AST + case object Nat extends AST + case object Fragment extends AST + case object MatrixLayout extends AST + } + + def Kind[_: P]: P[AST] = P( + "data".!.map(_ => AST.Data) | + "address".!.map(_ => AST.Address) | + "nat2nat".!.map(_ => AST.Nat2Nat) | + "nat2data".!.map(_ => AST.Nat2Data) | + "nat".!.map(_ => AST.Nat) | + "fragment".!.map(_ => AST.Fragment) | + "matrixLayout".!.map(_ => AST.MatrixLayout) + ) +} diff --git a/meta/src/main/scala/meta/parser/rise/Type.scala b/meta/src/main/scala/meta/parser/rise/Type.scala new file mode 100644 index 000000000..6cd066db8 --- /dev/null +++ b/meta/src/main/scala/meta/parser/rise/Type.scala @@ -0,0 +1,143 @@ +package meta.parser.rise + +import fastparse.ScalaWhitespace._ +import fastparse._ +import meta.parser.shared.Identifier +import meta.parser._ + +object Type { + sealed trait AST + object AST { + case class Identifier(name: String) extends AST + case class FunType(inT: AST, outT: AST) extends AST + case class DepFunType(id: Identifier, kind: Kind.AST, t: AST) extends AST + case class ImplicitDepFunType(id: Identifier, kind: Kind.AST, t: AST) extends AST + + case class ScalarType(t: String) extends AST + case object NatType extends AST + case class OpaqueType(name: String) extends AST + case class VectorType(size: Nat.AST, elemType: AST) extends AST + case class IndexType(size: Nat.AST) extends AST + case class PairType(lhs: AST, rhs: AST) extends AST + case class DepPairType(id: Identifier, kind: Kind.AST, t: AST) extends AST + case class NatToDataApply(f: AST, n: Nat.AST) extends AST + case class NatToDataLambda(id: Identifier, t: AST) extends AST + case class ArrayType(size: Nat.AST, elemType: AST) extends AST + case class DepArrayType(size: Nat.AST, fdt: AST) extends AST + case class FragmentType(n: Nat.AST, m: Nat.AST, k: Nat.AST, elemType: AST, + fKind: Fragment.AST, mLayoutKind: MatrixLayout.AST) extends AST + case class ManagedBufferType(t: AST) extends AST + } + + object Fragment { + sealed trait AST + object AST { + case class Identifier(name: String) extends AST + object ACC extends AST + object A extends AST + object B extends AST + } + } + + object MatrixLayout { + sealed trait AST + object AST { + case class Identifier(name: String) extends AST + object ROW_MAJOR extends AST + object COL_MAJOR extends AST + object NONE extends AST + } + } + + def TypeSignature[_: P]: P[AST] = { + def DepFunType: P[AST.DepFunType] = + P("(" ~ IdentifierKindPair ~ ")" ~ "->" ~/ TypeSignature).map(AST.DepFunType.tupled) + + def ImplicitDepFunType: P[AST.ImplicitDepFunType] = + P("{" ~ IdentifierKindPair ~ "}" ~ "->" ~/ TypeSignature). + map(AST.ImplicitDepFunType.tupled) + + def FunType: P[AST.FunType] = + P(NoCut(LeftTypeSignature) ~ "->" ~/ TypeSignature).map(AST.FunType.tupled) + + // Types that can appear at the left of an function arrow + def LeftTypeSignature: P[AST] = P(DataType.DataType | ("(" ~ TypeSignature ~ ")")) + + P(DepFunType | ImplicitDepFunType | FunType | LeftTypeSignature) + } + + def TypeIdentifier[_: P]: P[AST.Identifier] = P(Identifier).map(AST.Identifier) + + def IdentifierKindPair[_: P]: P[(AST.Identifier, Kind.AST)] = + P(Identifier.map(AST.Identifier) ~ ":" ~ Kind.Kind) + + object DataType { + def ScalarType[_: P]: P[AST.ScalarType] = + P("bool".! | "int".! | + "i8".! | "i16".! | "i32".! | "i64".! | + "u8".! | "u16".! | "u32".! | "u64".! | + "f16".! | "f32".! | "f64".!).map(AST.ScalarType) + + def NatType[_: P]: P[AST.NatType.type] = P("natType").map(_ => AST.NatType) + + def OpaqueType[_: P]: P[AST.OpaqueType] = P( "\"" ~~ Identifier ~~ "\"" ).map(AST.OpaqueType) + + def IndexType[_: P]: P[AST.IndexType] = P("idx[" ~ Nat.Nat ~ "]").map(AST.IndexType) + + def VectorType[_: P]: P[AST.VectorType] = + P("vec[" ~ DataType ~ "," ~ Nat.Nat ~ "]").map(t => AST.VectorType(t._2, t._1)) + + def FragmentType[_: P]: P[AST.FragmentType] = { + def FragmentKind: P[Fragment.AST] = + P(("fragment." ~~ ( + "ACC".!.map(_ => Fragment.AST.ACC) | + "A".!.map(_ => Fragment.AST.A) | + "B".!.map(_ => Fragment.AST.B)) + ) | TypeIdentifier.map(i => Fragment.AST.Identifier(i.name))) + + def MatrixLayoutKind: P[MatrixLayout.AST] = + P(("matrixLayout." ~~ ( + "ROW_MAJOR".!.map(_ => MatrixLayout.AST.ROW_MAJOR) | + "COL_MAJOR".!.map(_ => MatrixLayout.AST.COL_MAJOR) | + "NONE".!.map(_ => MatrixLayout.AST.NONE)) + ) | TypeIdentifier.map(i => MatrixLayout.AST.Identifier(i.name))) + + P("fragment[" ~ Nat.Nat ~ "," ~ Nat.Nat ~ "," ~ Nat.Nat ~ "," ~ DataType ~ "," ~ FragmentKind ~ + "," ~ MatrixLayoutKind ~ "]").map(AST.FragmentType.tupled) + } + + def ManagedBufferType[_: P]: P[AST.ManagedBufferType] = + P("managed[" ~ DataType ~ "]").map(AST.ManagedBufferType) + + def DepArrayType[_: P]: P[AST.DepArrayType] = + P(Nat.Nat ~ ".." ~/ NatToData).map(AST.DepArrayType.tupled) + + def ArrayType[_: P]: P[AST.ArrayType] = + P(Nat.Nat ~ "." ~~ !"." ~/ DataType).map(AST.ArrayType.tupled) + + def DepPairType[_: P]: P[AST.DepPairType] = + P("(" ~ IdentifierKindPair ~ "**" ~/ DataType ~ ")").map(AST.DepPairType.tupled) + + def NatToDataApply[_: P]: P[AST.NatToDataApply] = + P(NatToData ~ "(" ~ Nat.Nat ~ ")").map(AST.NatToDataApply.tupled) + + def PairType[_: P]: P[AST.PairType] = + P("(" ~ NoCut(DataType) ~ "," ~/ DataType ~ ")").map(AST.PairType.tupled) + + def DataType[_: P]: P[AST] = + P(ScalarType | NatType | OpaqueType | IndexType | VectorType | FragmentType | + ManagedBufferType | DepArrayType | ArrayType | DepPairType | NatToDataApply | + PairType | TypeIdentifier | ("(" ~ DataType ~ ")")) + + def TypeName[_: P]: P[Unit] = + P(ScalarType | NatType | "idx" | "vec" | "fragment" | "matrixLayout") + } + + def NatToData[_: P]: P[AST] = { + def NatToDataLambda: P[AST.NatToDataLambda] = + P("(" ~ IdentifierKindPair.filter(_._2 == Kind.AST.Nat).map(_._1) ~ + "|->" ~/ DataType.DataType ~ ")").map(AST.NatToDataLambda.tupled) + + P(TypeIdentifier | NatToDataLambda) + } +} diff --git a/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala b/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala new file mode 100644 index 000000000..afb868305 --- /dev/null +++ b/meta/src/main/scala/meta/parser/rise/isWellKindedType.scala @@ -0,0 +1,176 @@ +package meta.parser.rise + +import meta.parser._ + +object isWellKindedType { + + def apply(typeAST: Type.AST): Boolean = { + kindOf(typeAST, Map.empty).isDefined + } + + sealed trait DataTypeOrFunctionKind + case class DataTypeKind(kind: Kind.AST) extends DataTypeOrFunctionKind + case object FunctionKind extends DataTypeOrFunctionKind + + def kindOf(typeAST: Type.AST, + env: Map[String, Kind.AST]): Option[DataTypeOrFunctionKind] = { + import Type._ + typeAST match { + case AST.Identifier(name) => + env.get(name).map(DataTypeKind) + case AST.FunType(inT, outT) => + for { + _ <- kindOf(inT, env) + _ <- kindOf(outT, env) + } yield FunctionKind + case AST.DepFunType(id, kind, t) => + if (env.isDefinedAt(id.name)) { + None // we forbid shadowing + } else { + for { + _ <- kindOf(t, env.updated(id.name, kind)) + } yield FunctionKind + } + case AST.ImplicitDepFunType(id, kind, t) => + if (env.isDefinedAt(id.name)) { + None // we forbid shadowing + } else { + for { + _ <- kindOf(t, env.updated(id.name, kind)) + } yield FunctionKind + } + case AST.VectorType(size, elemType) => + for { + k1 <- kindOf(size, env) + k2 <- kindOf(elemType, env) + if k1 == Kind.AST.Nat && k2 == DataTypeKind(Kind.AST.Data) + } yield DataTypeKind(Kind.AST.Data) + case AST.IndexType(size) => + for { + k <- kindOf(size, env) + if k == Kind.AST.Nat + } yield DataTypeKind(Kind.AST.Data) + case AST.PairType(lhs, rhs) => + for { + k1 <- kindOf(lhs, env) + k2 <- kindOf(rhs, env) + if k1 == DataTypeKind(Kind.AST.Data) && k2 == DataTypeKind(Kind.AST.Data) + } yield DataTypeKind(Kind.AST.Data) + case AST.DepPairType(id, kind, t) => + if (env.isDefinedAt(id.name)) { + None // we forbid shadowing + } else { + kindOf(t, env.updated(id.name, kind)) + } + case AST.NatToDataApply(f, n) => + for { + k1 <- kindOf(f, env) + k2 <- kindOf(n, env) + if k1 == DataTypeKind(Kind.AST.Nat2Data) && k2 == Kind.AST.Nat + } yield DataTypeKind(Kind.AST.Data) + case AST.NatToDataLambda(id, t) => + if (env.isDefinedAt(id.name)) { + None // we forbid shadowing + } else { + for { + k <- kindOf(t, env.updated(id.name, Kind.AST.Nat)) + if k == DataTypeKind(Kind.AST.Data) + } yield DataTypeKind(Kind.AST.Nat2Data) + } + case AST.ArrayType(size, elemType) => + for { + k1 <- kindOf(size, env) + k2 <- kindOf(elemType, env) + if k1 == Kind.AST.Nat && k2 == DataTypeKind(Kind.AST.Data) + } yield DataTypeKind(Kind.AST.Data) + case AST.DepArrayType(size, fdt) => + for { + k1 <- kindOf(size, env) + k2 <- kindOf(fdt, env) + if k1 == Kind.AST.Nat && k2 == DataTypeKind(Kind.AST.Nat2Data) + } yield DataTypeKind(Kind.AST.Data) + case AST.FragmentType(n, m, k, elemType, fKind, mLayout) => + for { + k1 <- kindOf(n, env) + k2 <- kindOf(m, env) + k3 <- kindOf(k, env) + k4 <- kindOf(elemType, env) + k5 <- kindOf(fKind, env) + k6 <- kindOf(mLayout, env) + if k1 == Kind.AST.Nat && k2 == Kind.AST.Nat && k3 == Kind.AST.Nat && k4 == DataTypeKind(Kind.AST.Data) && + k5 == Kind.AST.Fragment && k6 == Kind.AST.MatrixLayout + } yield DataTypeKind(Kind.AST.Data) + case AST.ManagedBufferType(dt) => + for { + k1 <- kindOf(dt, env) + if k1 == DataTypeKind(Kind.AST.Data) + } yield DataTypeKind(Kind.AST.Data) + case _: AST.ScalarType | AST.NatType | _: AST.OpaqueType => + Some(DataTypeKind(Kind.AST.Data)) + } + } + + def kindOf(natAST: Nat.AST, + env: Map[String, Kind.AST] + ): Option[Kind.AST] = { + natAST match { + case Nat.AST.Identifier(id) => + env.get(id) + case Nat.AST.Number(_) => + Some(Kind.AST.Nat) + case Nat.AST.BinaryOp(lhs, _, rhs) => + for { + k1 <- kindOf(lhs, env) + k2 <- kindOf(rhs, env) + if k1 == Kind.AST.Nat && k2 == Kind.AST.Nat + } yield Kind.AST.Nat + + case Nat.AST.TernaryOp(_, thenN, elseN) => + for { + k1 <- kindOf(thenN, env) + k2 <- kindOf(elseN, env) + if k1 == Kind.AST.Nat && k2 == Kind.AST.Nat + } yield Kind.AST.Nat + + case Nat.AST.Nat2NatApply(f, n) => + for { + k1 <- kindOf(f, env) + k2 <- kindOf(n, env) + if k1 == DataTypeKind(Kind.AST.Nat2Nat) && k2 == Kind.AST.Nat + } yield Kind.AST.Nat + + case Nat.AST.Sum(id, from, upTo, body) => + val nEnv = env.updated(id.name, Kind.AST.Nat) + for { + k1 <- kindOf(from, nEnv) + k2 <- kindOf(upTo, nEnv) + k3 <- kindOf(body, nEnv) + if k1 == Kind.AST.Nat && k2 == Kind.AST.Nat && k3 == Kind.AST.Nat + } yield Kind.AST.Nat + } + } + + def kindOf(fragmentAST: Type.Fragment.AST, + env: Map[String, Kind.AST] + ): Option[Kind.AST] = { + import Type._ + fragmentAST match { + case Fragment.AST.Identifier(id) => + env.get(id) + case Fragment.AST.ACC | Fragment.AST.A | Fragment.AST.B => Some(Kind.AST.Fragment) + } + } + + def kindOf(matrixLayout: Type.MatrixLayout.AST, + env: Map[String, Kind.AST] + ): Option[Kind.AST] = { + import Type._ + matrixLayout match { + case MatrixLayout.AST.Identifier(id) => + env.get(id) + case MatrixLayout.AST.ROW_MAJOR | + MatrixLayout.AST.COL_MAJOR | + MatrixLayout.AST.NONE => Some(Kind.AST.MatrixLayout) + } + } +} diff --git a/meta/src/main/scala/meta/parser/shared/package.scala b/meta/src/main/scala/meta/parser/shared/package.scala new file mode 100644 index 000000000..0b59a89e1 --- /dev/null +++ b/meta/src/main/scala/meta/parser/shared/package.scala @@ -0,0 +1,21 @@ +package meta.parser + +import fastparse.ScalaWhitespace._ +import fastparse._ + +package object shared { + def Identifier[_: P]: P[String] = { + def Keywords: P[Unit] = + P(( "def" | + (rise.Kind.Kind: P[Unit]) | rise.Type.DataType.TypeName | + (DPIA.Kind.Kind: P[Unit]) + ) ~~ CharPred(_.isWhitespace)) + + val LowerChar = scalaparse.syntax.Identifiers.NamedFunction(CharPredicates.isLower) + val IdCharacter = scalaparse.syntax.Identifiers.NamedFunction(c => + CharPredicates.isLetter(c) || CharPredicates.isDigit(c)) + + P((!Keywords ~ CharPred(LowerChar).! ~~ CharsWhile(IdCharacter).!.?). + map(t => t._1 ++ t._2.getOrElse(""))) + } +} diff --git a/src/main/scala/rise/core/primitives/primitives.rise b/src/main/scala/rise/core/primitives/primitives.rise index 2737c9d3c..5e7895adc 100644 --- a/src/main/scala/rise/core/primitives/primitives.rise +++ b/src/main/scala/rise/core/primitives/primitives.rise @@ -37,7 +37,7 @@ def generate: {n: nat} -> {t: data} -> (idx[n] -> t) -> n.t def idx: {n: nat} -> {t: data} -> idx[n] -> n.t -> t def take: (n: nat) -> {m: nat} -> {t: data} -> (n+m).t -> n.t - def drop: (n: nat) -> {m: nat} -> {t: data} -> (n+m).t -> m.t +def drop: (n: nat) -> {m: nat} -> {t: data} -> (n+m).t -> m.t def concat: {n: nat} -> {m: nat} -> {t: data} -> n.t -> m.t -> (n+m).t def split: (n: nat) -> {m: nat} -> {t: data} -> (m*n).t -> m.n.t @@ -56,8 +56,8 @@ def scatter: {n: nat} -> {m: nat} -> {t: data} -> n.idx[m] -> n.t -> m.t def reorder: {t: data} -> (n: nat) -> (idxF: nat2nat) -> (idxFinv: nat2nat) -> n.t -> n.t def padCst: {n: nat} -> (l: nat) -> (r: nat) -> {t: data} -> t -> n.t -> (l+n+r).t -def padClamp: {n: nat} -> (l: nat) -> (r: nat) -> {t: data} -> n.t -> (l+n+r).t -def padEmpty: {n: nat} -> (r: nat) -> {t: data} -> n.t -> (n+r).t +def padClamp: {n: nat} -> (l: nat) -> (r: nat) -> {t: data} -> n.t -> (l+n+r).t +def padEmpty: {n: nat} -> (r: nat) -> {t: data} -> n.t -> (n+r).t def zip: {n: nat} -> {s: data} -> {t: data} -> n.s -> n.t -> n.(s, t) def unzip: {n: nat} -> {s: data} -> {t: data} -> n.(s, t) -> (n.s, n.t) diff --git a/src/main/scala/rise/core/traverse.scala b/src/main/scala/rise/core/traverse.scala index bbc038368..724fe1f4f 100644 --- a/src/main/scala/rise/core/traverse.scala +++ b/src/main/scala/rise/core/traverse.scala @@ -63,6 +63,10 @@ object traverse { case VectorType(n, e) => for {n1 <- natDispatch(Reference)(n); e1 <- `type`(e)} yield VectorType(n1, e1) + case ManagedBufferType(dt) => + for {dt1 <- datatype(dt)} + yield ManagedBufferType(dt1) + case o: OpaqueType => return_(o: DataType) case FragmentType(rows, columns, d3, dt, fragKind, layout) => for {rows1 <- nat(rows); columns1 <- nat(columns); d31 <- nat(d3); dt1 <- datatype(dt); fragKind1 <- fragmentKind(fragKind); layout1 <- matrixLayout(layout)} diff --git a/src/main/scala/rise/core/types/Type.scala b/src/main/scala/rise/core/types/Type.scala index 679a7c1c0..1ada7f7ad 100644 --- a/src/main/scala/rise/core/types/Type.scala +++ b/src/main/scala/rise/core/types/Type.scala @@ -68,6 +68,10 @@ object f64 extends ScalarType { override def toString: String = "f64" } object NatType extends DataType { override def toString: String = "nat" } +final case class OpaqueType(name: String) extends DataType { + override def toString: String = name +} + // TODO: enforce ScalarType sealed case class VectorType(size: Nat, elemType: DataType) extends DataType { override def toString: String = s"<$size>$elemType" @@ -91,11 +95,11 @@ sealed trait MatrixLayout object MatrixLayout { object Row_Major extends MatrixLayout { override def toString = "Row_Major" } object Col_Major extends MatrixLayout { override def toString = "Col_Major" } + object None extends MatrixLayout } -final case class MatrixLayoutIdentifier( - name: String, - override val isExplicit: Boolean = false +final case class MatrixLayoutIdentifier(name: String, + override val isExplicit: Boolean = false ) extends MatrixLayout with Kind.Identifier with Kind.Explicitness { @@ -110,7 +114,7 @@ sealed trait FragmentKind object FragmentKind { object AMatrix extends FragmentKind { override def toString = "AMatrix"} object BMatrix extends FragmentKind { override def toString = "BMatrix"} - object Acuumulator extends FragmentKind { override def toString = "Acuumulator"} + object Accumulator extends FragmentKind { override def toString = "Accumulator"} } final case class FragmentKindIdentifier(name: String, @@ -126,7 +130,7 @@ final case class FragmentKindIdentifier(name: String, object FragmentType { def apply(rows: Nat, columns:Nat, d3: Nat, dataType: DataType): FragmentType = { - FragmentType(rows, columns, d3, dataType, FragmentKind.Acuumulator, MatrixLayout.Row_Major) + FragmentType(rows, columns, d3, dataType, FragmentKind.Accumulator, MatrixLayout.None) } } @@ -137,13 +141,17 @@ final case class FragmentType(rows: Nat, fragmentKind: FragmentKind, layout: MatrixLayout) extends DataType { override def toString: String = - if (fragmentKind == FragmentKind.Acuumulator) + if (fragmentKind == FragmentKind.Accumulator) s"Fragment[$rows,$columns,$d3,$dataType,$fragmentKind]" else s"Fragment[$rows,$columns,$d3,$dataType,$fragmentKind,$layout]" } +final case class ManagedBufferType(dt: DataType) extends DataType { + override def toString: String = s"managed[$dt]" + +} final case class DepPairType[K <: Kind: KindName]( x: K#I, diff --git a/src/main/scala/rise/eqsat/TypeNode.scala b/src/main/scala/rise/eqsat/TypeNode.scala index 0e8f90db5..6119a5267 100644 --- a/src/main/scala/rise/eqsat/TypeNode.scala +++ b/src/main/scala/rise/eqsat/TypeNode.scala @@ -151,7 +151,7 @@ object DataType { case rct.PairType(dt1, dt2) => PairType(fromNamed(dt1, bound), fromNamed(dt2, bound)) case rct.ArrayType(s, et) => ArrayType(Nat.fromNamed(s, bound), fromNamed(et, bound)) case _: rct.DepArrayType | _: rct.DepPairType[_] | - _: rct.NatToDataApply | _: rct.FragmentType => + _: rct.NatToDataApply | _: rct.FragmentType | _: rct.ManagedBufferType | _: rct.OpaqueType => throw new Exception(s"did not expect $dt") }) } diff --git a/src/main/scala/shine/C/Compilation/CodeGenerator.scala b/src/main/scala/shine/C/Compilation/CodeGenerator.scala index 94b037e15..e5577d20b 100644 --- a/src/main/scala/shine/C/Compilation/CodeGenerator.scala +++ b/src/main/scala/shine/C/Compilation/CodeGenerator.scala @@ -112,12 +112,18 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, CCodeGen.codeGenNewDoubleBuffer(ArrayType(n, dt), in, out, ps, p, env) case f@For(unroll) => - val (i, p) = f.unwrapBody - CCodeGen.codeGenFor(f.n, i, p, unroll, env) + f.loopBody match { + case Lambda(i, p) => + CCodeGen.codeGenFor(f.n, i, p, unroll, env) + case _ => throw new Exception("This should not happen") + } case f@ForNat(unroll) => - val (i, p) = f.unwrapBody - CCodeGen.codeGenForNat(f.n, i, p, unroll, env) + f.loopBody match { + case shine.DPIA.Phrases.DepLambda(i, p) => + CCodeGen.codeGenForNat(f.n, i, p, unroll, env) + case _ => throw new Exception("This should not happen") + } case Proj1(pair) => Lifting.liftPair(pair)._1 |> cmd(env) case Proj2(pair) => Lifting.liftPair(pair)._2 |> cmd(env) @@ -438,7 +444,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, case _ => error(s"Expected path to be not empty") } - case Pad(n, l, r, _, pad, array) => path match { + case PadCst(n, l, r, _, pad, array) => path match { case (i: CIntExpr) :: ps => pad |> exp(env, ps, padExpr => genPad(n, l, r, padExpr, padExpr, i, ps, array, env, cont)) @@ -490,7 +496,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, case _ => error(s"Expected path to be not empty") } - case MakeArray(_, elems) => path match { + case MakeArray(elems) => path match { case (i: CIntExpr) :: ps => try { elems(i.eval) |> exp(env, ps, cont) } catch { @@ -503,9 +509,9 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, case DepIdx(_, _, i, e) => e |> exp(env, CIntExpr(i) :: path, cont) - case ForeignFunctionCall(f, inTs, outT, args) => - CCodeGen.codeGenForeignFunctionCall(f, inTs, outT, args, env, fe => - generateAccess(outT, fe, path, env, cont) + case ffc@ForeignFunctionCall(f, inTs, args) => + CCodeGen.codeGenForeignFunctionCall(f, inTs, ffc.outT, args, env, fe => + generateAccess(ffc.outT, fe, path, env, cont) ) case Proj1(pair) => SimplifyNats.simplifyIndexAndNatExp(Lifting.liftPair(pair)._1) |> exp(env, path, cont) @@ -543,7 +549,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, case shine.DPIA.Types.f32 => C.AST.Type.float case shine.DPIA.Types.f64 => C.AST.Type.double case _: shine.DPIA.Types.IndexType => C.AST.Type.int - case _: shine.DPIA.Types.VectorType | _: FragmentType | _: pipeline.type => + case _: shine.DPIA.Types.VectorType | _: FragmentType => throw new Exception(s"$b types in C are not supported") } case a: shine.DPIA.Types.ArrayType => C.AST.ArrayType(typ(a.elemType), Some(a.size)) @@ -856,7 +862,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, }) } - def codeGenForeignFunctionCall(funDecl: ForeignFunction.Declaration, + def codeGenForeignFunctionCall(funDecl: rise.core.ForeignFunction.Decl, inTs: collection.Seq[DataType], outT: DataType, args: collection.Seq[Phrase[ExpType]], diff --git a/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala b/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala index 9a7ac7c4b..32994bf80 100644 --- a/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala @@ -90,7 +90,7 @@ object AcceptorTranslation { case AsVector(n, m, dt, access, array) => acc(array)(AsVectorAcc(n, m, dt, A)) - case AsVectorAligned(n, m, w, dt, array) => + case AsVectorAligned(n, m, dt, access, array) => acc(array)(AsVectorAcc(n, m, dt, A)) case DepIdx(n, ft, index, array) => @@ -225,7 +225,7 @@ object AcceptorTranslation { y, x, A))))) case Scatter(n, m, dt, indices, input) => - con(indices)(fun(expT(m`.`idx(n), read))(y => + con(indices)(fun(expT(n`.`idx(m), read))(y => acc(input)(ScatterAcc(n, m, dt, y, A)))) case slide@Slide(n, sz, sp, dt, input) => @@ -283,12 +283,12 @@ object AcceptorTranslation { λ(expT({l * n}`.`dt, read))(x => acc(f(l)(x))(o)))), x))) - case ocl.KernelCall(name, localSize, globalSize, inTs, outT, args) => + case ocl.KernelCall(name, localSize, globalSize, _, args) => def rec(ts: Seq[Phrase[ExpType]], - es: Seq[Phrase[ExpType]]): Phrase[CommType] = { + es: Seq[Phrase[ExpType]]): Phrase[CommType] = { ts match { case Nil => - oclImp.KernelCallCmd(name, localSize, globalSize, A, es) + oclImp.KernelCallCmd(name, localSize, globalSize, es)(A.t.dataType, A) case Seq(arg, tail@_*) => con(arg)(λ(expT(arg.t.dataType, read))(e => rec(tail, es :+ e))) } @@ -303,7 +303,7 @@ object AcceptorTranslation { λ(expT(dt1, read))(x => λ(accT(dt2))(o => acc(f(x))(o))), x, A))) - case ocl.OpenCLFunctionCall(name, inTs, outT, args) => + case fc@ocl.OpenCLFunctionCall(name, inTs, args) => def rec(ts: Seq[(Phrase[ExpType], DataType)], exps: Seq[Phrase[ExpType]], inTs: Seq[DataType]): Phrase[CommType] = { @@ -311,7 +311,7 @@ object AcceptorTranslation { // with only one argument left to process return the assignment of the OpenCLFunction call case Seq( (arg, inT) ) => con(arg)(λ(expT(inT, read))(e => - A :=|outT| ocl.OpenCLFunctionCall(name, inTs :+ inT, outT, exps :+ e) )) + A :=|fc.outT| ocl.OpenCLFunctionCall(name, inTs :+ inT, exps :+ e)(fc.outT) )) // with a `tail` of arguments left, recurse case Seq( (arg, inT), tail@_* ) => con(arg)(λ(expT(inT, read))(e => rec(tail, exps :+ e, inTs :+ inT) )) @@ -321,17 +321,17 @@ object AcceptorTranslation { rec(args zip inTs, Seq(), Seq()) // CUDA - case cuda.AsFragment(rows, columns, d3, dataType, fragmentKind, matrix, layout) => + case cuda.AsFragment(rows, columns, layers, dataType, fragmentKind, layout, matrix) => con(matrix)(λ(ExpType(ArrayType(rows, ArrayType(columns, dataType)), read))(matrix => - cudaImp.WmmaLoad(rows, columns, d3, dataType, fragmentKind, layout, matrix, A))) + cudaImp.WmmaLoad(rows, columns, layers, dataType, fragmentKind, layout, matrix, A))) - case cuda.AsMatrix(rows, columns, d3, dataType, fragment) => + case cuda.AsMatrix(rows, columns, layers, dataType, fragment) => con(fragment)(λ(ExpType(fragment.t.dataType, read))(fragment => - cudaImp.WmmaStore(rows, columns, d3, dataType, fragment, A))) + cudaImp.WmmaStore(rows, columns, layers, dataType, fragment, A))) - case cuda.GenerateFragment(rows, columns, d3, dataType, fill, fragmentKind, layout) => + case cuda.GenerateFragment(rows, columns, layers, dataType, frag, layout, fill) => con(fill)(λ(ExpType(dataType, read))(fill => - cudaImp.WmmaFill(rows, columns, d3, dataType, fill, fragmentKind, layout, A))) + cudaImp.WmmaFill(rows, columns, layers, dataType, frag, layout, fill, A))) case map@cuda.Map(level, dim) => val (n, dt1, dt2, f, array) = map.unwrap @@ -340,17 +340,17 @@ object AcceptorTranslation { λ(expT(dt1, read))(x => λ(accT(dt2))(o => acc(f(x))(o))), x, A))) - case cuda.MapFragmentElements(fragType, fragment, fun) => - con(fragment)(λ(expT(fragType, read))(input => - shine.cuda.primitives.imperative.ForFragmentElements(fragType, input, A, - λ(expT(fragType.dataType, read))(x => - λ(accT(fragType.dataType))(o => + case cuda.MapFragment(rows, columns, layers, dt, frag, layout, fun, input) => + con(input)(λ(expT(FragmentType(rows, columns, layers, dt, frag, layout), read))(input => + shine.cuda.primitives.imperative.ForFragment(rows, columns, layers, dt, frag, layout, input, A, + λ(expT(dt, read))(x => + λ(accT(dt))(o => acc(fun(x))(o)))))) case cuda.TensorMatMultAdd(m, n, k, layoutA, layoutB, dataType, dataTypeAcc, aMatrix, bMatrix, cMatrix) => con(aMatrix)(λ(ExpType(FragmentType(m, n, k, dataType, FragmentKind.AMatrix, layoutA), read))(aMatrix => con(bMatrix)(λ(ExpType(FragmentType(m, n, k, dataType, FragmentKind.BMatrix, layoutB), read))(bMatrix => - con(cMatrix)(λ(ExpType(FragmentType(m, n, k, dataTypeAcc), read))(cMatrix => + con(cMatrix)(λ(ExpType(FragmentType(m, n, k, dataTypeAcc, FragmentKind.Accumulator, MatrixLayout.None), read))(cMatrix => cudaImp.WmmaMMA(m, n, k, layoutA, layoutB, dataType, dataTypeAcc, aMatrix, bMatrix, cMatrix, A))))))) } } diff --git a/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala b/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala index b9d5338d3..2617c22a7 100644 --- a/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala @@ -119,7 +119,7 @@ object ContinuationTranslation { con(array)(λ(expT((n + m)`.` dt, read))(x => C(Drop(n, m, dt, x)))) - case ForeignFunctionCall(funDecl, inTs, outT, args) => + case ffc@ForeignFunctionCall(funDecl, inTs, args) => def rec(ts: Seq[(Phrase[ExpType], DataType)], exps: Seq[Phrase[ExpType]], inTs: Seq[DataType]): Phrase[CommType] = { @@ -127,7 +127,7 @@ object ContinuationTranslation { // with only one argument left to process return the assignment of the function call case Seq( (arg, inT) ) => con(arg)(λ(expT(inT, read))(e => - outT match { + ffc.outT match { // TODO: this is an ugly fix to avoid calling the function multiple times // for pair assignment, see: // https://github.com/rise-lang/shine/issues/58 @@ -140,10 +140,10 @@ object ContinuationTranslation { case _ => `new`.apply } - backendNew(outT, tmp => - Assign(outT, tmp.wr, ForeignFunctionCall(funDecl, inTs :+ inT, outT, exps :+ e)) `;` + backendNew(ffc.outT, tmp => + Assign(ffc.outT, tmp.wr, ForeignFunctionCall(funDecl, inTs :+ inT, exps :+ e)(ffc.outT)) `;` C(tmp.rd)) - case _ => C( ForeignFunctionCall(funDecl, inTs :+ inT, outT, exps :+ e) ) + case _ => C( ForeignFunctionCall(funDecl, inTs :+ inT, exps :+ e)(ffc.outT) ) })) // with a `tail` of arguments left, rec case Seq( (arg, inT), tail@_* ) => @@ -191,13 +191,13 @@ object ContinuationTranslation { con(value)(fun(value.t)(x => con(f(x))(C))) - case MakeArray(dt, elements) => + case ma@MakeArray(elements) => def rec(func: Vector[Phrase[ExpType]], imp: Vector[Phrase[ExpType]]): Phrase[CommType] = { func match { - case xf +: func => con(xf)(fun(expT(dt, read))(xi => + case xf +: func => con(xf)(fun(expT(ma.dt, read))(xi => rec(func, imp :+ xi) )) - case _ => C(MakeArray(dt, imp)) + case _ => C(MakeArray(imp)(ma.n, ma.dt)) } } @@ -257,10 +257,10 @@ object ContinuationTranslation { con(e)(λ(e.t)(x => C(NatAsIndex(n, x)))) - case Pad(n, l, r, dt, padExp, array) => + case PadCst(n, l, r, dt, padExp, array) => con(array)(λ(expT(n`.`dt, read))(x => con(padExp)(λ(expT(dt, read))(p => - C(Pad(n, l, r, dt, p, x)))))) + C(PadCst(n, l, r, dt, p, x)))))) case PadClamp(n, l, r, dt, array) => con(array)(λ(expT(n`.`dt, read))(x => @@ -291,7 +291,8 @@ object ContinuationTranslation { acc(scanSeq)(tmp.wr) `;` C(tmp.rd) )) case slide@Slide(n, sz, sp, dt, input) => - con(input)(λ(expT(slide.inputSize`.`dt, read))(x => + val inputSize = sp*n+sz + con(input)(λ(expT(inputSize`.`dt, read))(x => C(Slide(n, sz, sp, dt, x)) )) case Snd(dt1, dt2, pair) => @@ -361,7 +362,7 @@ object ContinuationTranslation { `new`(map.n `.` map.dt2, λ(varT(map.n `.` map.dt2))(tmp => acc(map)(tmp.wr) `;` C(tmp.rd))) - case ocl.OpenCLFunctionCall(name, inTs, outT, args) => + case fc@ocl.OpenCLFunctionCall(name, inTs, args) => def rec(ts: Seq[(Phrase[ExpType], DataType)], es: Seq[Phrase[ExpType]], inTs: Seq[DataType]): Phrase[CommType] = { @@ -369,7 +370,7 @@ object ContinuationTranslation { // with only one argument left to process continue with the OpenCLFunction call case Seq( (arg, inT) ) => con(arg)(λ(expT(inT, read))(e => - C(ocl.OpenCLFunctionCall(name, inTs :+ inT, outT, es :+ e)) )) + C(ocl.OpenCLFunctionCall(name, inTs :+ inT, es :+ e)(fc.outT)) )) // with a `tail` of arguments left, rec case Seq( (arg, inT), tail@_* ) => con(arg)(λ(expT(inT, read))(e => rec(tail, es :+ e, inTs :+ inT) )) @@ -379,12 +380,10 @@ object ContinuationTranslation { rec(args zip inTs, Seq(), Seq()) case reduceSeq@ocl.ReduceSeq(unroll) => - val (n, initAddrSpace, dt1, dt2, f, init, array) = - ( reduceSeq.n, reduceSeq.initAddrSpace, reduceSeq.dt1, reduceSeq.dt2, - reduceSeq.f, reduceSeq.init, reduceSeq.array ) + val (n, a, dt1, dt2, f, init, array) = reduceSeq.unwrap con(array)(λ(expT(n`.`dt1, read))(X => - oclI.ReduceSeqI(n, initAddrSpace, dt1, dt2, + oclI.ReduceSeqI(n, a, dt1, dt2, λ(expT(dt2, read))(x => λ(expT(dt1, read))(y => λ(accT(dt2))(o => acc( f(x)(y) )( o )))), @@ -400,7 +399,7 @@ object ContinuationTranslation { case cuda.GlobalToShared(dt, inputGlobal) => val adj = AdjustArraySizesForAllocations(inputGlobal, dt, AddressSpace.Local) - shine.OpenCL.DSL.`new` (AddressSpace.Private) (pipeline, pipeline => + shine.OpenCL.DSL.`new` (AddressSpace.Private) (OpaqueType("pipeline"), pipeline => shine.OpenCL.DSL.`new` (AddressSpace.Local) (adj.dt, tmp => acc(inputGlobal)(cudaIm.GlobalToSharedAcc(dt, pipeline.rd, tmp.wr)) `;` cudaIm.SyncPipeline(pipeline.rd) `;` @@ -414,19 +413,18 @@ object ContinuationTranslation { `new`(n `.` dt2, λ(varT(n `.` dt2))(tmp => acc(map)(tmp.wr) `;` C(tmp.rd))) - case mapFragmentElements@cuda.MapFragmentElements(fragType, fragment, fun) => - val dt = fragType.dataType - + case m@cuda.MapFragment(rows, columns, layers, dt, frag, layout, fun, input) => + val fragType = FragmentType(rows, columns, layers, dt, frag, layout) shine.OpenCL.primitives.imperative.New(AddressSpace.Private, fragType, λ(VarType(fragType))(fragmentAcc => - (if (fragment.t.accessType.toString == write.toString) - acc(fragment)(fragmentAcc.wr) `;` - cudaIm.ForFragmentElements(fragType, fragmentAcc.rd, fragmentAcc.wr, + (if (input.t.accessType.toString == write.toString) + acc(input)(fragmentAcc.wr) `;` + cudaIm.ForFragment(rows, columns, layers, dt, frag, layout, fragmentAcc.rd, fragmentAcc.wr, λ(expT(dt, read))(x => λ(accT(dt))(o => acc(fun(x))(o)))) else - acc(mapFragmentElements)(fragmentAcc.wr)) `;` + acc(m)(fragmentAcc.wr)) `;` C(fragmentAcc.rd))) } } diff --git a/src/main/scala/shine/DPIA/Compilation/Passes/UnrollLoops.scala b/src/main/scala/shine/DPIA/Compilation/Passes/UnrollLoops.scala index b618f5644..e002382ab 100644 --- a/src/main/scala/shine/DPIA/Compilation/Passes/UnrollLoops.scala +++ b/src/main/scala/shine/DPIA/Compilation/Passes/UnrollLoops.scala @@ -16,27 +16,36 @@ object UnrollLoops { override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = p match { case f@For(true) => - val (ident, body) = f.unwrapBody - Continue(unrollLoop(f.n, init = 0, step = 1, i => - Phrase.substitute(NatAsIndex(f.n, Natural(i)), - `for` = ident, in = body)), this) + f.loopBody match { + case shine.DPIA.Phrases.Lambda(x, body) => + Continue(unrollLoop(f.n, init = 0, step = 1, i => + Phrase.substitute(NatAsIndex(f.n, Natural(i)), + `for` = x, in = body)), this) + case _ => throw new Exception("This should not happen") + } case f@ForNat(true) => - val (ident, body) = f.unwrapBody - Continue(unrollLoop(f.n, init = 0, step = 1, i => - PhraseType.substitute(i, `for` = ident, in = body)), this) - case pf@ParFor(_, _, true) => - val (ident, identOut, body) = pf.unwrapBody - pf.out.t.dataType match { - case ArrayType(_, elemType) => - Continue(unrollLoop(pf.n, pf.init, pf.step, i => - Phrase.substitute( - IdxAcc(pf.n, elemType, - NatAsIndex(pf.n, Natural(i)), pf.out), - `for` = identOut, - Phrase.substitute(NatAsIndex(pf.n, Natural(i)), - `for` = ident, in = body))), this) - case _ => - throw new Exception("OpenCLParFor acceptor has to be of ArrayType.") + f.loopBody match { + case shine.DPIA.Phrases.DepLambda(x, body) => + Continue(unrollLoop(f.n, init = 0, step = 1, i => + PhraseType.substitute(i, `for` = x, in = body)), this) + case _ => throw new Exception("This should not happen") + } + case pf@ParFor(_, _, true, _) => + pf.body match { + case shine.DPIA.Phrases.Lambda(ident, shine.DPIA.Phrases.Lambda(identOut, body)) => + pf.out.t.dataType match { + case ArrayType(_, elemType) => + Continue(unrollLoop(pf.n, pf.init, pf.step, i => + Phrase.substitute( + IdxAcc(pf.n, elemType, + NatAsIndex(pf.n, Natural(i)), pf.out), + `for` = identOut, + Phrase.substitute(NatAsIndex(pf.n, Natural(i)), + `for` = ident, in = body))), this) + case _ => + throw new Exception("OpenCLParFor acceptor has to be of ArrayType.") + } + case _ => throw new Exception("This should not happen") } case _ => Continue(p, this) @@ -73,7 +82,7 @@ object UnrollLoops { val numIter = ceilDiv(stopMax - startMin, incr) val tmp = (0 until numIter).foldLeft[Phrase[CommType]]( - Comment(s"unrolling loop of $numIter"))({ + shine.DPIA.DSL.comment(s"unrolling loop of $numIter"))({ case (prev, i) => val index = init + Cst(i * incr) assert(isSmaller(index, n).contains(true)) //TODO add if-guards otherwise. diff --git a/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala b/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala index 0014f37b3..2d63673f2 100644 --- a/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala +++ b/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala @@ -7,6 +7,10 @@ import shine.DPIA.Types.DataType._ import shine.DPIA._ import shine.DPIA.primitives.functional.{Fst, Snd} +object ImperativePrimitives { + def skip: Skip = Skip() +} + object `new` { def apply(dt: DataType, f: Phrase[VarType ->: CommType]): New = @@ -79,7 +83,7 @@ object streamNext { } object comment { - def apply(comment: String): Comment = Comment(comment) + def apply(comment: String): Comment = Comment(comment)() } object fst { @@ -109,5 +113,3 @@ object pairAcc2 { def apply(fstT: DataType, sndT: DataType, record: Phrase[AccType]): PairAcc2 = PairAcc2(fstT, sndT, record) } - -object skip extends Skip diff --git a/src/main/scala/shine/DPIA/DSL/package.scala b/src/main/scala/shine/DPIA/DSL/package.scala index fe0c214f5..40e1dc323 100644 --- a/src/main/scala/shine/DPIA/DSL/package.scala +++ b/src/main/scala/shine/DPIA/DSL/package.scala @@ -11,15 +11,15 @@ import scala.language.implicitConversions package object DSL { implicit class BinOps(lhs: Phrase[ExpType]) { - def +(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.ADD, lhs, rhs) - def -(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.SUB, lhs, rhs) - def *(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.MUL, lhs, rhs) - def /(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.DIV, lhs, rhs) - def %(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.MOD, lhs, rhs) - def >(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.GT, lhs, rhs) - def <(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.LT, lhs, rhs) - def =:=(rhs: Phrase[ExpType]) = BinOp(Operators.Binary.EQ, lhs, rhs) - def unary_- = UnaryOp(Operators.Unary.NEG, lhs) + def +(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.ADD, lhs, rhs) + def -(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.SUB, lhs, rhs) + def *(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.MUL, lhs, rhs) + def /(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.DIV, lhs, rhs) + def %(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.MOD, lhs, rhs) + def >(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.GT, lhs, rhs) + def <(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.LT, lhs, rhs) + def =:=(rhs: Phrase[ExpType]): BinOp = BinOp(Operators.Binary.EQ, lhs, rhs) + def unary_- : UnaryOp = UnaryOp(Operators.Unary.NEG, lhs) } implicit class ExpPhraseExtensions(e: Phrase[ExpType]) { diff --git a/src/main/scala/shine/DPIA/Phrases/PrettyPhrasePrinter.scala b/src/main/scala/shine/DPIA/Phrases/PrettyPhrasePrinter.scala index 3c9f53126..1de17598d 100644 --- a/src/main/scala/shine/DPIA/Phrases/PrettyPhrasePrinter.scala +++ b/src/main/scala/shine/DPIA/Phrases/PrettyPhrasePrinter.scala @@ -35,6 +35,18 @@ object PrettyPhrasePrinter { case PhrasePair(fst, snd) => s"(${apply(fst)}, ${apply(snd)})" + case shine.DPIA.primitives.imperative.Comment(comment) => s"\n//$comment\n" + + case shine.OpenCL.primitives.imperative.Barrier(local, global) => + s"""barrier( ${if(local) "CLK_LOCAL_MEM_FENCE" else ""} ${if(global && local) "|" else ""} + ${if(global) "CLK_GLOBAL_MEM_FENCE" else ""})""" + + case shine.cuda.primitives.imperative.SyncThreads() => "__syncthreads()" + + case shine.cuda.primitives.imperative.SyncWarp() => "__syncwarp()" + + case shine.cuda.primitives.imperative.SyncPipeline(pipe) => s"$pipe.commit_and_wait()" + case c: Primitive[_] => c.prettyPrint } } diff --git a/src/main/scala/shine/DPIA/Phrases/VisitAndRebuild.scala b/src/main/scala/shine/DPIA/Phrases/VisitAndRebuild.scala index 76e005bb9..e07cc7937 100644 --- a/src/main/scala/shine/DPIA/Phrases/VisitAndRebuild.scala +++ b/src/main/scala/shine/DPIA/Phrases/VisitAndRebuild.scala @@ -186,6 +186,11 @@ object VisitAndRebuild { case r: PairType => PairType(visitDataTypeAndRebuild(r.fst, v), visitDataTypeAndRebuild(r.snd, v)) - case d => d + case ManagedBufferType(dt) => + ManagedBufferType(visitDataTypeAndRebuild(dt, v)) + case d => d match { + case _: ComposedType | _: BasicType | _: OpaqueType | + _: NatToDataApply | _: DataTypeIdentifier => d + } } } diff --git a/src/main/scala/shine/DPIA/Types/DataType.scala b/src/main/scala/shine/DPIA/Types/DataType.scala index 311ea4897..2b18660d8 100644 --- a/src/main/scala/shine/DPIA/Types/DataType.scala +++ b/src/main/scala/shine/DPIA/Types/DataType.scala @@ -17,18 +17,19 @@ sealed trait MatrixLayout object MatrixLayout { object Row_Major extends MatrixLayout { override def toString = "Row_Major" } object Col_Major extends MatrixLayout { override def toString = "Col_Major" } + object None extends MatrixLayout } final case class MatrixLayoutIdentifier(name: String) extends MatrixLayout with Kind.Identifier { - var layout: Option[MatrixLayout] = None + var layout: MatrixLayout = MatrixLayout.None override def toString: String = name def setLayout(matrixLayout: MatrixLayout): Unit = { - if (layout.isEmpty) - layout = Some(matrixLayout) - else if (layout.get != matrixLayout) - throw new Exception(s"could not unify ${layout.get} and $matrixLayout") + if (layout == MatrixLayout.None) + layout = matrixLayout + else if (layout != matrixLayout) + throw new Exception(s"could not unify $layout and $matrixLayout") } } @@ -40,43 +41,34 @@ object FragmentKind { object Accumulator extends FragmentKind { override def toString = "Accumulator"} } -//This can be used to create fragments of kind `Accumulator` which does not need a layout -object FragmentType { - def apply(rows: Nat, columns:Nat, d3: Nat, dataType: DataType): FragmentType = - FragmentType(rows, columns, d3, dataType, FragmentKind.Accumulator, null) -} - /** * Represents a CUDA-fragment which represents a tile of a matrix which is stored in registers of a warp.
* Fragments of kind `Accumulator` does not have a layout. So the `layout` of fragments of kind `Accumulator` * can be ignored. * @param rows number of rows * @param columns number of columns - * @param d3 third dimension which is used in the MMA operation + * @param layers third dimension which is used in the MMA operation * @param dataType dataType of the elements * @param fragmentKind kind of the fragment {@link FragmentKind} * @param layout layout of the fragment {@link MatrixLayout} */ final case class FragmentType(rows: Nat, columns: Nat, - d3: Nat, + layers: Nat, dataType: DataType, fragmentKind: FragmentKind, layout: MatrixLayout) extends BasicType { override def toString: String = - if (fragmentKind == FragmentKind.Accumulator) - s"Fragment[$rows,$columns,$d3,$dataType,$fragmentKind]" - else - s"Fragment[$rows,$columns,$d3,$dataType,$fragmentKind,$layout]" + s"Fragment[$rows,$columns,$layers,$dataType,$fragmentKind,$layout]" override def equals(o: Any): Boolean = { o match { case f: FragmentType => f.fragmentKind match { case FragmentKind.Accumulator => - f.rows.equals(rows) && f.columns.equals(columns) && f.d3.equals(d3) && f.dataType.equals(dataType) + f.rows.equals(rows) && f.columns.equals(columns) && f.layers.equals(layers) && f.dataType.equals(dataType) case _ => - f.rows.equals(rows) && f.columns.equals(columns) && f.d3.equals(d3) && f.dataType.equals(dataType) && + f.rows.equals(rows) && f.columns.equals(columns) && f.layers.equals(layers) && f.dataType.equals(dataType) && f.fragmentKind.equals(fragmentKind) && f.layout.equals(layout) } case _ => false @@ -84,8 +76,6 @@ final case class FragmentType(rows: Nat, } } -object pipeline extends BasicType { override def toString = "pipeline" } - object bool extends ScalarType { override def toString: String = "bool" } object int extends ScalarType { override def toString: String = "int" } @@ -153,7 +143,7 @@ final case class PairType(fst: DataType, snd: DataType) extends ComposedType { override def toString: String = s"($fst x $snd)" } -sealed case class VectorType(size: Nat, elemType: ScalarType) +sealed case class VectorType(size: Nat, elemType: DataType) extends BasicType { override def toString: String = s"<$size>$elemType" @@ -161,7 +151,7 @@ sealed case class VectorType(size: Nat, elemType: ScalarType) object vec { @inline - def apply(size: Nat, elemType: ScalarType): VectorType = + def apply(size: Nat, elemType: DataType): VectorType = VectorType(size, elemType) } @@ -211,7 +201,7 @@ object DataType { case f: FragmentType => FragmentType(ArithExpr.substitute(f.rows, Map((`for`, ae))), ArithExpr.substitute(f.columns, Map((`for`, ae))), - ArithExpr.substitute(f.d3, Map((`for`, ae))), + ArithExpr.substitute(f.layers, Map((`for`, ae))), substitute(ae, `for`, f.dataType), f.fragmentKind, f.layout) case a: DepArrayType => val subMap = Map((`for`, ae)) @@ -267,7 +257,7 @@ object DataType { case ArrayType(size, _) => size case DepArrayType(size, _) => size case _: DataTypeIdentifier | _: NatToDataApply | _: OpaqueType | - _: FragmentType | _: pipeline.type => + _: FragmentType => throw new Exception("This should not happen") } diff --git a/src/main/scala/shine/DPIA/Types/package.scala b/src/main/scala/shine/DPIA/Types/package.scala index ab1603fa9..5d1a7b630 100644 --- a/src/main/scala/shine/DPIA/Types/package.scala +++ b/src/main/scala/shine/DPIA/Types/package.scala @@ -1,13 +1,20 @@ package shine.DPIA +import arithexpr.arithmetic.RangeAdd import shine.DPIA.Phrases.Phrase import shine.DPIA.Types.TypeCheck._ package object Types { implicit class ReverseInferenceHelper(pt: PhraseType) { - def ::[T <: PhraseType](p: Phrase[T]): Unit = p checkTypeEqOrSubtype pt - def `:`[T <: PhraseType](p: Phrase[T]): Unit = p checkTypeEqOrSubtype pt + def ::[T <: PhraseType](p: Phrase[T]): Unit = + if (!(p checkTypeEqOrSubtype pt)) { + throw new Exception(s"Type error: found ${p.t}, expected $pt") + } + def `:`[T <: PhraseType](p: Phrase[T]): Unit = + if (!(p checkTypeEqOrSubtype pt)) { + throw new Exception(s"Type error: found ${p.t}, expected $pt") + } } type NatDependentFunctionType[T <: PhraseType] = DepFunType[NatKind, T] @@ -47,4 +54,22 @@ package object Types { ): DepFunType[AccessKind, T] = DepFunType[AccessKind, T](at, t) } + + object n2dtFun { + def apply(f: NatIdentifier => DataType): NatToDataLambda = { + val x = NatIdentifier(freshName("n")) + NatToDataLambda(x, f(x)) + } + + def apply(r: arithexpr.arithmetic.Range) + (f: NatIdentifier => DataType): NatToDataLambda = { + val x = NatIdentifier(freshName("n"), r) + NatToDataLambda(x, f(x)) + } + + def apply(upperBound: Nat) + (f: NatIdentifier => DataType): NatToDataLambda = { + apply(RangeAdd(0, upperBound, 1))(f) + } + } } diff --git a/src/main/scala/shine/DPIA/fromRise.scala b/src/main/scala/shine/DPIA/fromRise.scala index 285c24129..a9c487b83 100644 --- a/src/main/scala/shine/DPIA/fromRise.scala +++ b/src/main/scala/shine/DPIA/fromRise.scala @@ -340,12 +340,12 @@ object fromRise { case core.slide() => fromType { case nFunT(sz, nFunT(sp, expT(ArrayType(insz, t), `read`) ->: - expT(ArrayType(n, ArrayType(_, _)), `read`))) + expT(ArrayType(np1, ArrayType(_, _)), `read`))) => depFun[NatKind](sz)( depFun[NatKind](sp)( fun[ExpType](expT(insz`.`t, read), e => - Slide(n, sz, sp, t, e)))) + Slide(np1-1, sz, sp, t, e)))) } case core.circularBuffer() => fromType { @@ -482,7 +482,7 @@ object fromRise { depFun[NatKind](q)( fun[ExpType](expT(t, read), cst => fun[ExpType](expT(n`.`t, read), e => - Pad(n, l, q, t, cst, e))))) + PadCst(n, l, q, t, cst, e))))) } case core.padEmpty() => fromType { @@ -694,7 +694,7 @@ object fromRise { fun[ExpType](ExpType(inTs(i), read), a => buildFFCall(args :+ a)) } else { - ForeignFunctionCall(decl, inTs, outT, args) + ForeignFunctionCall(decl, inTs, args)(outT) } } buildFFCall(Vector()) @@ -713,7 +713,7 @@ object fromRise { ): Phrase[_ <: PhraseType] = t match { case FunType(in: ExpType, out) => fun[ExpType](in, e => buildArrayPrimitive(out, elements :+ e)) - case ExpType(ArrayType(_, et), _) => MakeArray(et, elements) + case ExpType(ArrayType(_, et), _) => MakeArray(elements)(elements.size, et) case _ => error(s"did not expect $t") } buildArrayPrimitive(t, Vector()) @@ -763,7 +763,7 @@ object fromRise { => depFun[NatKind](n)( fun[ExpType](expT(mn`.`t, read), e => - AsVectorAligned(n, m, a, t, e))) + AsVectorAligned(n, m, t, a, e))) } case core.asScalar() => fromType { @@ -771,7 +771,7 @@ object fromRise { expT(ArrayType(_, _), _) => fun[ExpType](expT(m`.`vec(n, t), a), e => - AsScalar(m, n, t, a, e)) + AsScalar(n, m, t, a, e)) } case core.vectorFromScalar() => fromType { @@ -813,7 +813,7 @@ object fromRise { depFun[NatKind](ls1)(depFun[NatKind](ls2)(depFun[NatKind](ls3)( depFun[NatKind](gs1)(depFun[NatKind](gs2)(depFun[NatKind](gs3)( fun[ExpType](expT(t, write), e => - ocl.Run(LocalSize(ls1, ls2, ls3), GlobalSize(gs1, gs2, gs3), t, e)))))))) + ocl.Run(LocalSize(ls1, ls2, ls3), GlobalSize(gs1, gs2, gs3))(t, e)))))))) } case core.dmatch() => fromType { @@ -842,20 +842,20 @@ object fromRise { case expT(ArrayType(rows, ArrayType(columns, dt)), `read`) ->: expT(FragmentType(_, _, d3, _, fragType, layout), _) => fun[ExpType](expT(ArrayType(rows, ArrayType(columns, dt)), read), a => - cuda.AsFragment(rows, columns, d3, dt, fragType, a, layout)) + cuda.AsFragment(rows, columns, d3, dt, fragType, layout, a)) } case rcuda.asMatrix() => fromType { - case expT(FragmentType(rows, columns, d3, dt, FragmentKind.Accumulator, _), `read`) ->: + case expT(FragmentType(rows, columns, layers, dt, FragmentKind.Accumulator, _), `read`) ->: expT(ArrayType(_, ArrayType(_, _)), `write`) => - fun[ExpType](expT(FragmentType(rows, columns, d3, dt), read), dFrag => - cuda.AsMatrix(rows, columns, d3, dt, dFrag)) + fun[ExpType](expT(FragmentType(rows, columns, layers, dt, FragmentKind.Accumulator, MatrixLayout.None), read), dFrag => + cuda.AsMatrix(rows, columns, layers, dt, dFrag)) } case rcuda.generateFragment() => fromType { - case expT(dt, `read`) ->: expT(FragmentType(rows, columns, d3, _, fragType, layout), read) => + case expT(dt, `read`) ->: expT(FragmentType(rows, columns, layers, _, frag, layout), read) => fun[ExpType](expT(dt, read), fill => - cuda.GenerateFragment(rows, columns, d3, dt, fill, fragType, layout)) + cuda.GenerateFragment(rows, columns, layers, dt, frag, layout, fill)) } case rcuda.tensorMMA() => fromType { @@ -865,7 +865,7 @@ object fromRise { expT(FragmentType(_, _, _, _, FragmentKind.Accumulator, _), `write`) => fun[ExpType](expT(FragmentType(m, k, n, dt, FragmentKind.AMatrix, layoutA), read), a => fun[ExpType](expT(FragmentType(k, n, m, dt, FragmentKind.BMatrix, layoutB), read), b => - fun[ExpType](expT(FragmentType(m, n, k, dtResult), read), c => + fun[ExpType](expT(FragmentType(m, n, k, dtResult, FragmentKind.Accumulator, MatrixLayout.None), read), c => cuda.TensorMatMultAdd(m, n, k, layoutA, layoutB, dt, dtResult, a, b, c)))) } @@ -873,8 +873,9 @@ object fromRise { case (expT(dt: DataType, `read`) ->: expT(_, `write`)) ->: expT(fragType : FragmentType, `read`) ->: expT(_, _) => fun[ExpType ->: ExpType](ExpType(dt, read) ->: ExpType(dt, write), f => - fun[ExpType](ExpType(fragType, read), fragment => - cuda.MapFragmentElements(fragType.asInstanceOf[FragmentType], fragment, f))) + fun[ExpType](ExpType(fragType, read), input => + cuda.MapFragment(fragType.rows, fragType.columns, fragType.layers, fragType.dataType, + fragType.fragmentKind, fragType.layout, f, input))) } case rcuda.mapGlobal(dim) => fromType { @@ -973,10 +974,12 @@ object fromRise { FragmentType(f.rows, f.d3, f.columns, dataType(f.dataType), FragmentKind.AMatrix, layout(f.layout)) case rt.FragmentKind.BMatrix => FragmentType(f.d3, f.columns, f.rows, dataType(f.dataType), FragmentKind.BMatrix, layout(f.layout)) - case rt.FragmentKind.Acuumulator => + case rt.FragmentKind.Accumulator => FragmentType(f.rows, f.columns, f.d3, dataType(f.dataType), FragmentKind.Accumulator, layout(f.layout)) case _ => throw new Exception("this should not happen") } + case rt.OpaqueType(name) => OpaqueType(name) + case rt.ManagedBufferType(dt) => ManagedBufferType(dataType(dt)) } private val layouts: mutable.HashMap[String, MatrixLayoutIdentifier] = mutable.HashMap.empty @@ -984,6 +987,7 @@ object fromRise { def layout(layout: rt.MatrixLayout): MatrixLayout = layout match { case rt.MatrixLayout.Row_Major => MatrixLayout.Row_Major case rt.MatrixLayout.Col_Major => MatrixLayout.Col_Major + case rt.MatrixLayout.None => MatrixLayout.None case rt.MatrixLayoutIdentifier(name, _) => layouts.getOrElseUpdate(name, MatrixLayoutIdentifier(name)) case _ => throw new Exception("this should not happen") } diff --git a/src/main/scala/shine/DPIA/primitives/functional/AsScalar.scala b/src/main/scala/shine/DPIA/primitives/functional/AsScalar.scala index 61c3e521a..9d62a3b9c 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/AsScalar.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/AsScalar.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class AsScalar(n: Nat, - m: Nat, - dt: ScalarType, - access: AccessType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT(n `.` vec(m, dt), access) - override val t: ExpType = expT((n * m) `.` dt, access) +final case class AsScalar(val n: Nat, val m: Nat, val dt: DataType, val a: AccessType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(m, VectorType(n, dt)), a) + } + override val t: ExpType = expT(ArrayType(m * n, dt), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsScalar = new AsScalar(v.nat(n), v.nat(m), v.data(dt), v.access(a), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/AsVector.scala b/src/main/scala/shine/DPIA/primitives/functional/AsVector.scala index 8ccb6e9cf..dbe82e1ac 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/AsVector.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/AsVector.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class AsVector(n: Nat, - m: Nat, - dt: ScalarType, - access: AccessType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT((m * n) `.` dt, access) - override val t: ExpType = expT(m `.` vec(n, dt), access) +final case class AsVector(val n: Nat, val m: Nat, val dt: DataType, val a: AccessType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(m * n, dt), a) + } + override val t: ExpType = expT(ArrayType(m, VectorType(n, dt)), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsVector = new AsVector(v.nat(n), v.nat(m), v.data(dt), v.access(a), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/AsVectorAligned.scala b/src/main/scala/shine/DPIA/primitives/functional/AsVectorAligned.scala index e8af77637..f84602328 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/AsVectorAligned.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/AsVectorAligned.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class AsVectorAligned(n: Nat, - m: Nat, - w: AccessType, - dt: ScalarType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT((m * n)`.`dt, w) - override val t: ExpType = expT(m`.`vec(n, dt), w) +final case class AsVectorAligned(val n: Nat, val m: Nat, val dt: DataType, val a: AccessType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(m * n, dt), a) + } + override val t: ExpType = expT(ArrayType(m, VectorType(n, dt)), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsVectorAligned = new AsVectorAligned(v.nat(n), v.nat(m), v.data(dt), v.access(a), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Cast.scala b/src/main/scala/shine/DPIA/primitives/functional/Cast.scala index 0d0e40c40..52fc46cb2 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Cast.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Cast.scala @@ -1,15 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Cast(dt1: BasicType, - dt2: BasicType, - e: Phrase[ExpType] - )extends ExpPrimitive { - e :: expT(dt1, read) +final case class Cast(val dt1: DataType, val dt2: DataType, val e: Phrase[ExpType]) extends ExpPrimitive { + { + e :: expT(dt1, read) + } override val t: ExpType = expT(dt2, read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Cast = new Cast(v.data(dt1), v.data(dt2), VisitAndRebuild(e, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/CircularBuffer.scala b/src/main/scala/shine/DPIA/primitives/functional/CircularBuffer.scala index 3798fbfca..2154283c7 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/CircularBuffer.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/CircularBuffer.scala @@ -1,21 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class CircularBuffer(n: Nat, - alloc: Nat, - sz: Nat, - dt1: DataType, - dt2: DataType, - load: Phrase[ExpType ->: ExpType], - input: Phrase[ExpType] - ) extends ExpPrimitive { - load :: expT(dt1, read) ->: expT(dt2, write) - input :: expT((n - 1 + sz)`.`dt1, read) - override val t: ExpType = expT(n`.`(sz`.`dt2), read) +final case class CircularBuffer(val n: Nat, val alloc: Nat, val sz: Nat, val dt1: DataType, val dt2: DataType, val load: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive { + { + load :: FunType(expT(dt1, read), expT(dt2, write)) + input :: expT(ArrayType(n - 1 + sz, dt1), read) + } + override val t: ExpType = expT(ArrayType(n, ArrayType(sz, dt2)), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): CircularBuffer = new CircularBuffer(v.nat(n), v.nat(alloc), v.nat(sz), v.data(dt1), v.data(dt2), VisitAndRebuild(load, v), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Cycle.scala b/src/main/scala/shine/DPIA/primitives/functional/Cycle.scala index 8a1e4525f..995185f97 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Cycle.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Cycle.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -// cycles on the m elements of an array (modulo indexing) to produce an array of n elements -@expPrimitive -final case class Cycle(n: Nat, - m: Nat, - dt: DataType, - input: Phrase[ExpType] - ) extends ExpPrimitive { - input :: expT(m`.`dt, read) - override val t: ExpType = expT(n`.`dt, read) +final case class Cycle(val n: Nat, val m: Nat, val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { + { + input :: expT(ArrayType(m, dt), read) + } + override val t: ExpType = expT(ArrayType(n, dt), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Cycle = new Cycle(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/DMatch.scala b/src/main/scala/shine/DPIA/primitives/functional/DMatch.scala index 647b8857a..074a6d781 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/DMatch.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/DMatch.scala @@ -3,9 +3,7 @@ package shine.DPIA.primitives.functional import shine.DPIA.Phrases._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive -@expPrimitive final case class DMatch(x: NatIdentifier, elemT: DataType, outT: DataType, @@ -14,4 +12,7 @@ final case class DMatch(x: NatIdentifier, input: Phrase[ExpType] ) extends ExpPrimitive { override val t: ExpType = expT(outT, a) + + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[ExpType] = + DMatch(v.nat(x), v.data(elemT), v.data(outT), v.access(a), VisitAndRebuild(f, v), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepIdx.scala b/src/main/scala/shine/DPIA/primitives/functional/DepIdx.scala index 4e7aee8ff..be23be676 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/DepIdx.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/DepIdx.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class DepIdx(n: Nat, - ft: NatToData, - index: Nat, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT(n`.d`ft, read) - override val t: ExpType = expT(ft(index), read) +final case class DepIdx(val n: Nat, val ft: NatToData, val index: Nat, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(DepArrayType(n, ft), read) + } + override val t: ExpType = expT(NatToDataApply(ft, index), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepIdx = new DepIdx(v.nat(n), v.natToData(ft), v.nat(index), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepJoin.scala b/src/main/scala/shine/DPIA/primitives/functional/DepJoin.scala index 8e63454d3..a853edac8 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/DepJoin.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/DepJoin.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - -import arithexpr.arithmetic.BigSum +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class DepJoin(n: Nat, - lenF: NatToNat, - dt: DataType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT(n`.d`{ i => lenF(i) `.` dt }, read) - override val t: ExpType = expT(BigSum(from = 0, upTo = n - 1, i => lenF(i))`.`dt, read) +final case class DepJoin(val n: Nat, val lenF: NatToNat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(DepArrayType(n, n2dtFun { (i: NatIdentifier) => ArrayType(lenF(i), dt) }), read) + } + override val t: ExpType = expT(ArrayType(BigSum(from = 0, upTo = n - 1, (i: Nat) => lenF(i)), dt), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepJoin = new DepJoin(v.nat(n), v.natToNat(lenF), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala index a9d41121e..2dede1eb0 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/DepMapSeq.scala @@ -1,24 +1,21 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -//noinspection TypeAnnotation -@expPrimitive -final case class DepMapSeq(unroll: Boolean) - (val n: Nat, - val ft1: NatToData, - val ft2: NatToData, - val f: Phrase[`(nat)->:`[ExpType ->: ExpType]], - val array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: f.t.x ->: expT(ft1(f.t.x), read) ->: expT(ft2(f.t.x), write) - array :: expT(n `.d` ft1, read) - override val t: ExpType = expT(n`.d`ft2, write) - - def unwrap: (Nat, NatToData, NatToData, Phrase[`(nat)->:`[ExpType ->: ExpType]], Phrase[ExpType]) = - (n, ft1, ft2, f, array) +final case class DepMapSeq(unroll: Boolean)(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: ({ + val k = f.t.x + DepFunType[NatKind, PhraseType](k, FunType(expT(NatToDataApply(ft1, k), read), expT(NatToDataApply(ft2, k), write))) + }) + array :: expT(DepArrayType(n, ft1), read) + } + override val t: ExpType = expT(DepArrayType(n, ft2), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepMapSeq = new DepMapSeq(unroll)(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) + def unwrap: (Nat, NatToData, NatToData, Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], Phrase[ExpType]) = (n, ft1, ft2, f, array) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepTile.scala b/src/main/scala/shine/DPIA/primitives/functional/DepTile.scala index 8d379e32e..be95a8fab 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/DepTile.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/DepTile.scala @@ -8,9 +8,7 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive -@expPrimitive final case class DepTile(n: Nat, tileSize: Nat, haloSize: Nat, dt1: DataType, dt2: DataType, processTiles: Phrase[ExpType ->: ExpType], @@ -28,4 +26,8 @@ final case class DepTile(n: Nat, tileSize: Nat, haloSize: Nat, expT(allTiles `.d` (i => (depSize(i) `.` dt2)), write)) array :: expT((n + haloSize)`.`dt1, read) override val t: ExpType = expT(n`.`dt2, write) + + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[ExpType] = + DepTile(v.nat(n), v.nat(tileSize), v.nat(haloSize), v.data(dt1), v.data(dt2), + VisitAndRebuild(processTiles, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/DepZip.scala b/src/main/scala/shine/DPIA/primitives/functional/DepZip.scala index b5b84244f..8864e226e 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/DepZip.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/DepZip.scala @@ -1,19 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class DepZip(n: Nat, - ft1: NatToData, - ft2: NatToData, - e1: Phrase[ExpType], - e2: Phrase[ExpType] - ) extends ExpPrimitive { - e1 :: expT(n`.d`ft1, read) - e2 :: expT(n`.d`ft2, read) - override val t: ExpType = expT(n`.d`{ i => PairType(ft1(i), ft2(i)) }, read) +final case class DepZip(val n: Nat, val ft1: NatToData, val ft2: NatToData, val e1: Phrase[ExpType], val e2: Phrase[ExpType]) extends ExpPrimitive { + { + e1 :: expT(DepArrayType(n, ft1), read) + e2 :: expT(DepArrayType(n, ft2), read) + } + override val t: ExpType = expT(DepArrayType(n, n2dtFun { (i: NatIdentifier) => PairType(NatToDataApply(ft1, i), NatToDataApply(ft2, i)) }), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepZip = new DepZip(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(e1, v), VisitAndRebuild(e2, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Drop.scala b/src/main/scala/shine/DPIA/primitives/functional/Drop.scala index 81f3961cc..c90942c95 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Drop.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Drop.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -// this drops n many elements from an array of n + m elements -@expPrimitive -final case class Drop(n: Nat, - m: Nat, - dt: DataType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT((n + m)`.`dt, read) - override val t: ExpType = expT(m`.`dt, read) +final case class Drop(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(n + m, dt), read) + } + override val t: ExpType = expT(ArrayType(m, dt), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Drop = new Drop(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala b/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala index a2dea107a..714cd3b23 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/ForeignFunctionCall.scala @@ -1,28 +1,14 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - -import rise.{core => lc} +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -object ForeignFunction { - val Declaration: lc.ForeignFunction.Decl.type = lc.ForeignFunction.Decl - val Definition: lc.ForeignFunction.Def.type = lc.ForeignFunction.Def - type Declaration = lc.ForeignFunction.Decl - type Definition = lc.ForeignFunction.Def -} - -@expPrimitive -final case class ForeignFunctionCall(funDecl: ForeignFunction.Declaration, - inTs: Seq[DataType], - outT: DataType, - args: Seq[Phrase[ExpType]] - ) extends ExpPrimitive { - (inTs zip args).foreach { - case (inT, arg) => arg :: expT(inT, read) - } +final case class ForeignFunctionCall(funDecl: rise.core.ForeignFunction.Decl, inTs: Seq[DataType], args: Seq[Phrase[ExpType]])(val outT: DataType) extends ExpPrimitive { + {} override val t: ExpType = expT(outT, read) - - override def prettyPrint: String = s"${funDecl.name}(${args.map(PrettyPhrasePrinter(_)).mkString(",")})" + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForeignFunctionCall = new ForeignFunctionCall(funDecl, inTs.map(v.data), args.map(VisitAndRebuild(_, v)))(v.data(outT)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Fst.scala b/src/main/scala/shine/DPIA/primitives/functional/Fst.scala index 41c6d331f..8230c972f 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Fst.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Fst.scala @@ -1,15 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Fst(dt1: DataType, - dt2: DataType, - pair: Phrase[ExpType] - ) extends ExpPrimitive { - pair :: expT(dt1 x dt2, read) +final case class Fst(val dt1: DataType, val dt2: DataType, val pair: Phrase[ExpType]) extends ExpPrimitive { + { + pair :: expT(PairType(dt1, dt2), read) + } override val t: ExpType = expT(dt1, read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Fst = new Fst(v.data(dt1), v.data(dt2), VisitAndRebuild(pair, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Gather.scala b/src/main/scala/shine/DPIA/primitives/functional/Gather.scala index 966a34524..343491501 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Gather.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Gather.scala @@ -1,19 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Gather(n: Nat, - m: Nat, - dt: DataType, - indices: Phrase[ExpType], - input: Phrase[ExpType] - ) extends ExpPrimitive { - indices :: expT(m`.`idx(n), read) - input :: expT(n`.`dt, read) - override val t: ExpType = expT(m`.`dt, read) +final case class Gather(val n: Nat, val m: Nat, val dt: DataType, val indices: Phrase[ExpType], val input: Phrase[ExpType]) extends ExpPrimitive { + { + indices :: expT(ArrayType(m, IndexType(n)), read) + input :: expT(ArrayType(n, dt), read) + } + override val t: ExpType = expT(ArrayType(m, dt), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Gather = new Gather(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(indices, v), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Generate.scala b/src/main/scala/shine/DPIA/primitives/functional/Generate.scala index b94552cfc..7cb88033c 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Generate.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Generate.scala @@ -1,16 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Generate(n: Nat, - dt: DataType, - f: Phrase[ExpType ->: ExpType] - ) extends ExpPrimitive { - f :: expT(idx(n), read) ->: expT(dt, read) - override val t: ExpType = expT(n`.`dt, read) +final case class Generate(val n: Nat, val dt: DataType, val f: Phrase[FunType[ExpType, ExpType]]) extends ExpPrimitive { + { + f :: FunType(expT(IndexType(n), read), expT(dt, read)) + } + override val t: ExpType = expT(ArrayType(n, dt), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Generate = new Generate(v.nat(n), v.data(dt), VisitAndRebuild(f, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Idx.scala b/src/main/scala/shine/DPIA/primitives/functional/Idx.scala index acf1d6067..8ccfab1ea 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Idx.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Idx.scala @@ -1,18 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Idx(n: Nat, - dt: DataType, - index: Phrase[ExpType], - array: Phrase[ExpType] - ) extends ExpPrimitive { - index :: expT(idx(n), read) - array :: expT(n`.`dt, read) +final case class Idx(val n: Nat, val dt: DataType, val index: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive { + { + index :: expT(IndexType(n), read) + array :: expT(ArrayType(n, dt), read) + } override val t: ExpType = expT(dt, read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Idx = new Idx(v.nat(n), v.data(dt), VisitAndRebuild(index, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/IdxVec.scala b/src/main/scala/shine/DPIA/primitives/functional/IdxVec.scala index f1211e82b..d2e4c6229 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/IdxVec.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/IdxVec.scala @@ -1,20 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -import scala.language.reflectiveCalls - -@expPrimitive -final case class IdxVec(n: Nat, - st: ScalarType, - index: Phrase[ExpType], - vector: Phrase[ExpType] - ) extends ExpPrimitive { - index :: expT(idx(n), read) - vector :: expT(vec(n, st), read) - override val t: ExpType = expT(st, read) +final case class IdxVec(val n: Nat, val dt: DataType, val index: Phrase[ExpType], val vector: Phrase[ExpType]) extends ExpPrimitive { + { + index :: expT(IndexType(n), read) + vector :: expT(VectorType(n, dt), read) + } + override val t: ExpType = expT(dt, read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxVec = new IdxVec(v.nat(n), v.data(dt), VisitAndRebuild(index, v), VisitAndRebuild(vector, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/IndexAsNat.scala b/src/main/scala/shine/DPIA/primitives/functional/IndexAsNat.scala index 64466878d..94528a350 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/IndexAsNat.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/IndexAsNat.scala @@ -1,15 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class IndexAsNat(n: Nat, - e: Phrase[ExpType] - ) extends ExpPrimitive { - e :: expT(idx(n), read) +final case class IndexAsNat(val n: Nat, val e: Phrase[ExpType]) extends ExpPrimitive { + { + e :: expT(IndexType(n), read) + } override val t: ExpType = expT(NatType, read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): IndexAsNat = new IndexAsNat(v.nat(n), VisitAndRebuild(e, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala b/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala index 612e25af0..6b20d3e4a 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Iterate.scala @@ -1,23 +1,20 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Iterate(n: Nat, - m: Nat, - k: Nat, - dt: DataType, - f: Phrase[`(nat)->:`[ExpType ->: ExpType]], - array: Phrase[ExpType] - ) extends ExpPrimitive { +final case class Iterate(val n: Nat, val m: Nat, val k: Nat, val dt: DataType, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { { - val l = f.t.x - f :: l ->: expT((l * n)`.`dt, read) ->: expT(l`.`dt, write) - array :: expT((m * n.pow(k))`.`dt, read) + f :: ({ + val l = f.t.x + DepFunType[NatKind, PhraseType](l, FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write))) + }) + array :: expT(ArrayType(m * n.pow(k), dt), read) } - override val t: ExpType = expT(m`.`dt, write) + override val t: ExpType = expT(ArrayType(m, dt), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Iterate = new Iterate(v.nat(n), v.nat(m), v.nat(k), v.data(dt), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/IterateStream.scala b/src/main/scala/shine/DPIA/primitives/functional/IterateStream.scala index 481ab7866..531bca06f 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/IterateStream.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/IterateStream.scala @@ -1,19 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class IterateStream(n: Nat, - dt1: DataType, - dt2: DataType, - f: Phrase[ExpType ->: ExpType], - array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt1, read) ->: expT(dt2, write) - array :: expT(n`.`dt1, read) - override val t: ExpType = expT(n`.`dt2, write) +final case class IterateStream(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt1, read), expT(dt2, write)) + array :: expT(ArrayType(n, dt1), read) + } + override val t: ExpType = expT(ArrayType(n, dt2), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): IterateStream = new IterateStream(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Join.scala b/src/main/scala/shine/DPIA/primitives/functional/Join.scala index 593772603..c42697619 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Join.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Join.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Join(n: Nat, - m: Nat, - w: AccessType, - dt: DataType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT(n`.`(m`.`dt), w) - override val t: ExpType = expT((n * m)`.`dt, w) +final case class Join(val n: Nat, val m: Nat, val a: AccessType, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(n, ArrayType(m, dt)), a) + } + override val t: ExpType = expT(ArrayType(n * m, dt), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Join = new Join(v.nat(n), v.nat(m), v.access(a), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Let.scala b/src/main/scala/shine/DPIA/primitives/functional/Let.scala index ae66302cb..cf627e092 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Let.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Let.scala @@ -1,18 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Let(dt1: DataType, - dt2: DataType, - access: AccessType, - value: Phrase[ExpType], - f: Phrase[ExpType ->: ExpType] - ) extends ExpPrimitive { - value :: expT(dt1, read) - f :: expT(dt1, read) ->: expT(dt2, access) - override val t: ExpType = expT(dt2, access) +final case class Let(val dt1: DataType, val dt2: DataType, val a: AccessType, val value: Phrase[ExpType], val f: Phrase[FunType[ExpType, ExpType]]) extends ExpPrimitive { + { + value :: expT(dt1, read) + f :: FunType(expT(dt1, read), expT(dt2, a)) + } + override val t: ExpType = expT(dt2, a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Let = new Let(v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(value, v), VisitAndRebuild(f, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala b/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala index 0a2e9fe1a..528f9a149 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MakeArray.scala @@ -1,14 +1,15 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class MakeArray(dt: DataType, - elements: Vector[Phrase[ExpType]] - ) extends ExpPrimitive { - override val t: ExpType = expT((elements.length: Nat)`.`dt, read) +final case class MakeArray(elements: Vector[Phrase[ExpType]])(val n: Nat, val dt: DataType) extends ExpPrimitive { + {} + override val t: ExpType = expT(ArrayType(n, dt), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MakeArray = new MakeArray(elements.map(VisitAndRebuild(_, v)))(v.nat(n), v.data(dt)) + def unwrap: (Nat, DataType) = (n, dt) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/MakeDepPair.scala b/src/main/scala/shine/DPIA/primitives/functional/MakeDepPair.scala index f3fa0dfb1..eae14229e 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MakeDepPair.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MakeDepPair.scala @@ -3,13 +3,14 @@ package shine.DPIA.primitives.functional import shine.DPIA.Phrases._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive -@expPrimitive final case class MakeDepPair(a: AccessType, fst: NatIdentifier, sndT: DataType, snd: Phrase[ExpType] ) extends ExpPrimitive { override val t: ExpType = expT(DepPairType(fst, sndT), a) + + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[ExpType] = + MakeDepPair(v.access(a), v.nat(fst), v.data(sndT), VisitAndRebuild(snd, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/MakePair.scala b/src/main/scala/shine/DPIA/primitives/functional/MakePair.scala index 278197c8a..38b54d25d 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MakePair.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MakePair.scala @@ -1,18 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class MakePair(dt1: DataType, - dt2: DataType, - access: AccessType, - fst: Phrase[ExpType], - snd: Phrase[ExpType] - ) extends ExpPrimitive { - fst :: expT(dt1, access) - snd :: expT(dt2, access) - override val t: ExpType = expT(dt1 x dt2, access) +final case class MakePair(val dt1: DataType, val dt2: DataType, val a: AccessType, val fst: Phrase[ExpType], val snd: Phrase[ExpType]) extends ExpPrimitive { + { + fst :: expT(dt1, a) + snd :: expT(dt2, a) + } + override val t: ExpType = expT(PairType(dt1, dt2), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MakePair = new MakePair(v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(fst, v), VisitAndRebuild(snd, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Map.scala b/src/main/scala/shine/DPIA/primitives/functional/Map.scala index b2c1b6868..64896fd85 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Map.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Map.scala @@ -1,20 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Map(n: Nat, - dt1: DataType, - dt2: DataType, - access: AccessType, - f: Phrase[ExpType ->: ExpType], - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT(n`.`dt1, access) - f :: expT(dt1, access) ->: expT(dt2, access) - override val t: ExpType = expT(n`.`dt2, access) +final case class Map(val n: Nat, val dt1: DataType, val dt2: DataType, val a: AccessType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt1, a), expT(dt2, a)) + array :: expT(ArrayType(n, dt1), a) + } + override val t: ExpType = expT(ArrayType(n, dt2), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Map = new Map(v.nat(n), v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapFst.scala b/src/main/scala/shine/DPIA/primitives/functional/MapFst.scala index b809e174d..5005b6e07 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MapFst.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MapFst.scala @@ -1,19 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class MapFst(w: AccessType, - dt1: DataType, - dt2: DataType, - dt3: DataType, - f: Phrase[ExpType ->: ExpType], - record: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt1, w) ->: expT(dt3, w) - record :: expT(dt1 x dt2, w) - override val t: ExpType = expT(dt3 x dt2, w) +final case class MapFst(val a: AccessType, val dt1: DataType, val dt2: DataType, val dt3: DataType, val f: Phrase[FunType[ExpType, ExpType]], val pair: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt1, a), expT(dt3, a)) + pair :: expT(PairType(dt1, dt2), a) + } + override val t: ExpType = expT(PairType(dt3, dt2), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapFst = new MapFst(v.access(a), v.data(dt1), v.data(dt2), v.data(dt3), VisitAndRebuild(f, v), VisitAndRebuild(pair, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/MapSeq.scala index ace4952e6..e18bf1a2b 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MapSeq.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MapSeq.scala @@ -1,23 +1,18 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class MapSeq(unroll: Boolean) - (val n: Nat, - val dt1: DataType, - val dt2: DataType, - val f: Phrase[ExpType ->: ExpType], - val array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt1, read) ->: expT(dt2, write) - array :: expT(n`.`dt1, read) - override val t: ExpType = expT(n`.`dt2, write) - - def unwrap: (Nat, DataType, DataType, Phrase[ExpType ->: ExpType], Phrase[ExpType]) = - (n, dt1, dt2, f, array) +final case class MapSeq(unroll: Boolean)(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt1, read), expT(dt2, write)) + array :: expT(ArrayType(n, dt1), read) + } + override val t: ExpType = expT(ArrayType(n, dt2), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapSeq = new MapSeq(unroll)(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) + def unwrap: (Nat, DataType, DataType, Phrase[FunType[ExpType, ExpType]], Phrase[ExpType]) = (n, dt1, dt2, f, array) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapSnd.scala b/src/main/scala/shine/DPIA/primitives/functional/MapSnd.scala index 46e11c666..410b3dc86 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MapSnd.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MapSnd.scala @@ -1,19 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class MapSnd(w: AccessType, - dt1: DataType, - dt2: DataType, - dt3: DataType, - f: Phrase[ExpType ->: ExpType], - record: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt2, w) ->: expT(dt3, w) - record :: expT(dt1 x dt2, w) - override val t: ExpType = expT(dt1 x dt3, w) +final case class MapSnd(val a: AccessType, val dt1: DataType, val dt2: DataType, val dt3: DataType, val f: Phrase[FunType[ExpType, ExpType]], val pair: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt2, a), expT(dt3, a)) + pair :: expT(PairType(dt1, dt2), a) + } + override val t: ExpType = expT(PairType(dt1, dt3), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapSnd = new MapSnd(v.access(a), v.data(dt1), v.data(dt2), v.data(dt3), VisitAndRebuild(f, v), VisitAndRebuild(pair, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapStream.scala b/src/main/scala/shine/DPIA/primitives/functional/MapStream.scala index 755fd7d79..c79d554c6 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MapStream.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MapStream.scala @@ -1,19 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class MapStream(n: Nat, - dt1: DataType, - dt2: DataType, - f: Phrase[ExpType ->: ExpType], - array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt1, read) ->: expT(dt2, write) - array :: expT(n`.`dt1, read) - override val t: ExpType = expT(n`.`dt2, write) +final case class MapStream(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt1, read), expT(dt2, write)) + array :: expT(ArrayType(n, dt1), read) + } + override val t: ExpType = expT(ArrayType(n, dt2), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapStream = new MapStream(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/MapVec.scala b/src/main/scala/shine/DPIA/primitives/functional/MapVec.scala index d79c5c142..8583c637d 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/MapVec.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/MapVec.scala @@ -1,18 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class MapVec(n: Nat, - dt1: ScalarType, - dt2: ScalarType, - f: Phrase[ExpType ->: ExpType], - array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt1, read) ->: expT(dt2, write) - array :: expT(vec(n, dt1), read) - override val t: ExpType = expT(vec(n, dt2), write) +final case class MapVec(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt1, read), expT(dt2, write)) + array :: expT(VectorType(n, dt1), read) + } + override val t: ExpType = expT(VectorType(n, dt2), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapVec = new MapVec(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/NatAsIndex.scala b/src/main/scala/shine/DPIA/primitives/functional/NatAsIndex.scala index 2c9f03482..a7a69904d 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/NatAsIndex.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/NatAsIndex.scala @@ -1,15 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class NatAsIndex(n: Nat, - e: Phrase[ExpType] - ) extends ExpPrimitive { - e :: expT(NatType, read) - override val t: ExpType = expT(idx(n), read) +final case class NatAsIndex(val n: Nat, val e: Phrase[ExpType]) extends ExpPrimitive { + { + e :: expT(NatType, read) + } + override val t: ExpType = expT(IndexType(n), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): NatAsIndex = new NatAsIndex(v.nat(n), VisitAndRebuild(e, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Pad.scala b/src/main/scala/shine/DPIA/primitives/functional/Pad.scala deleted file mode 100644 index e5d9f69b4..000000000 --- a/src/main/scala/shine/DPIA/primitives/functional/Pad.scala +++ /dev/null @@ -1,20 +0,0 @@ -package shine.DPIA.primitives.functional - -import shine.DPIA.Phrases._ -import shine.DPIA.Types.DataType._ -import shine.DPIA.Types._ -import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Pad(n: Nat, - l: Nat, - r: Nat, - dt: DataType, - padExp: Phrase[ExpType], - array: Phrase[ExpType] - ) extends ExpPrimitive { - padExp :: expT(dt, read) - array :: expT(n `.` dt, read) - override val t: ExpType = expT((l + n + r)`.`dt, read) -} diff --git a/src/main/scala/shine/DPIA/primitives/functional/PadClamp.scala b/src/main/scala/shine/DPIA/primitives/functional/PadClamp.scala index f211bdfd6..d8f44a1b4 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/PadClamp.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/PadClamp.scala @@ -1,19 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -// TODO: invalid for empty array -@expPrimitive -final case class PadClamp(n: Nat, - l: Nat, - r: Nat, - dt: DataType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT(n `.` dt, read) - override val t: ExpType = expT((l + n + r)`.`dt, read) +final case class PadClamp(val n: Nat, val l: Nat, val r: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(n, dt), read) + } + override val t: ExpType = expT(ArrayType(l + n + r, dt), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): PadClamp = new PadClamp(v.nat(n), v.nat(l), v.nat(r), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/PadCst.scala b/src/main/scala/shine/DPIA/primitives/functional/PadCst.scala new file mode 100644 index 000000000..18cf95c23 --- /dev/null +++ b/src/main/scala/shine/DPIA/primitives/functional/PadCst.scala @@ -0,0 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package shine.DPIA.primitives.functional +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ +import shine.DPIA.Types._ +import shine.DPIA._ +final case class PadCst(val n: Nat, val l: Nat, val r: Nat, val dt: DataType, val padExp: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive { + { + padExp :: expT(dt, read) + array :: expT(ArrayType(n, dt), read) + } + override val t: ExpType = expT(ArrayType(l + n + r, dt), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): PadCst = new PadCst(v.nat(n), v.nat(l), v.nat(r), v.data(dt), VisitAndRebuild(padExp, v), VisitAndRebuild(array, v)) +} diff --git a/src/main/scala/shine/DPIA/primitives/functional/PadEmpty.scala b/src/main/scala/shine/DPIA/primitives/functional/PadEmpty.scala index 3b8baf59e..ed4d9d75e 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/PadEmpty.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/PadEmpty.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class PadEmpty(n: Nat, - r: Nat, - dt: DataType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT(n `.` dt, write) - override val t: ExpType = expT((n + r)`.`dt, write) +final case class PadEmpty(val n: Nat, val r: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(n, dt), write) + } + override val t: ExpType = expT(ArrayType(n + r, dt), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): PadEmpty = new PadEmpty(v.nat(n), v.nat(r), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Partition.scala b/src/main/scala/shine/DPIA/primitives/functional/Partition.scala index 8a55e17e3..74bdf5622 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Partition.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Partition.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Partition(n: Nat, - m: Nat, - lenF: NatToNat, - dt: DataType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT(n`.`dt, read) - override val t: ExpType = expT(m`.d`{ i => lenF(i)`.`dt }, read) +final case class Partition(val n: Nat, val m: Nat, val lenF: NatToNat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(n, dt), read) + } + override val t: ExpType = expT(DepArrayType(m, n2dtFun { (i: NatIdentifier) => ArrayType(lenF(i), dt) }), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Partition = new Partition(v.nat(n), v.nat(m), v.natToNat(lenF), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/PrintType.scala b/src/main/scala/shine/DPIA/primitives/functional/PrintType.scala index 8f70bfe60..0fd80c754 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/PrintType.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/PrintType.scala @@ -3,9 +3,7 @@ package shine.DPIA.primitives.functional import shine.DPIA.Phrases._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive -@expPrimitive final case class PrintType(msg: String, dt: DataType, access: AccessType, @@ -15,4 +13,7 @@ final case class PrintType(msg: String, input :: expT(dt, access) override val t: ExpType = expT(dt, access) + + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[ExpType] = + PrintType(msg, v.data(dt), v.access(access), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/ReduceSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/ReduceSeq.scala index cee2f2a89..bb07ef410 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/ReduceSeq.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/ReduceSeq.scala @@ -1,25 +1,19 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class ReduceSeq(unroll: Boolean) - (val n: Nat, - val dt1: DataType, - val dt2: DataType, - val f: Phrase[ExpType ->: ExpType ->: ExpType], - val init: Phrase[ExpType], - val array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt2, read) ->: expT(dt1, read) ->: expT(dt2, write) - init :: expT(dt2, write) - array :: expT(n`.`dt1, read) +final case class ReduceSeq(unroll: Boolean)(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], val init: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt2, read), FunType(expT(dt1, read), expT(dt2, write))) + init :: expT(dt2, write) + array :: expT(ArrayType(n, dt1), read) + } override val t: ExpType = expT(dt2, read) - - def unwrap: (Nat, DataType, DataType, Phrase[ExpType ->: ExpType ->: ExpType], Phrase[ExpType], Phrase[ExpType]) = - (n, dt1, dt2, f, init, array) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ReduceSeq = new ReduceSeq(unroll)(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(init, v), VisitAndRebuild(array, v)) + def unwrap: (Nat, DataType, DataType, Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], Phrase[ExpType], Phrase[ExpType]) = (n, dt1, dt2, f, init, array) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Reorder.scala b/src/main/scala/shine/DPIA/primitives/functional/Reorder.scala index 2dce81633..9c9457bf1 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Reorder.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Reorder.scala @@ -1,19 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Reorder(n: Nat, - dt: DataType, - access: AccessType, - idxF: NatToNat, - idxFinv: NatToNat, - input: Phrase[ExpType] - ) extends ExpPrimitive { - input :: expT(n`.`dt, access) - override val t: ExpType = expT(n`.`dt, access) +final case class Reorder(val n: Nat, val dt: DataType, val a: AccessType, val idxF: NatToNat, val idxFiv: NatToNat, val input: Phrase[ExpType]) extends ExpPrimitive { + { + input :: expT(ArrayType(n, dt), a) + } + override val t: ExpType = expT(ArrayType(n, dt), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Reorder = new Reorder(v.nat(n), v.data(dt), v.access(a), v.natToNat(idxF), v.natToNat(idxFiv), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/RotateValues.scala b/src/main/scala/shine/DPIA/primitives/functional/RotateValues.scala index a8d2f0b07..04cbf6714 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/RotateValues.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/RotateValues.scala @@ -1,19 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class RotateValues(n: Nat, - sz: Nat, - dt: DataType, - write: Phrase[ExpType ->: ExpType], - input: Phrase[ExpType] - ) extends ExpPrimitive { - write :: expT(dt, read) ->: expT(dt, shine.DPIA.Types.write) - input :: expT((n - 1 + sz)`.`dt, read) - override val t: ExpType = expT(n`.`(sz`.`dt), read) +final case class RotateValues(val n: Nat, val sz: Nat, val dt: DataType, val wrt: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive { + { + wrt :: FunType(expT(dt, read), expT(dt, write)) + input :: expT(ArrayType(n - 1 + sz, dt), read) + } + override val t: ExpType = expT(ArrayType(n, ArrayType(sz, dt)), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): RotateValues = new RotateValues(v.nat(n), v.nat(sz), v.data(dt), VisitAndRebuild(wrt, v), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/ScanSeq.scala b/src/main/scala/shine/DPIA/primitives/functional/ScanSeq.scala index 48acab9ca..6be6db44b 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/ScanSeq.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/ScanSeq.scala @@ -1,22 +1,18 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -//noinspection TypeAnnotation -@expPrimitive -final case class ScanSeq(n: Nat, - dt1: DataType, - dt2: DataType, - f: Phrase[ExpType ->: ExpType ->: ExpType], - init: Phrase[ExpType], - array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt1, read) ->: expT(dt2, read) ->: expT(dt2, write) - init :: expT(dt2, write) - array :: expT(n`.`dt1, read) - override val t: ExpType = expT(n`.`dt2, read) +final case class ScanSeq(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], val init: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt1, read), FunType(expT(dt2, read), expT(dt2, write))) + init :: expT(dt2, write) + array :: expT(ArrayType(n, dt1), read) + } + override val t: ExpType = expT(ArrayType(n, dt2), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ScanSeq = new ScanSeq(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(init, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Scatter.scala b/src/main/scala/shine/DPIA/primitives/functional/Scatter.scala index 9f77dd186..85c2bbbf7 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Scatter.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Scatter.scala @@ -1,17 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Scatter(n: Nat, m: Nat, dt: DataType, - indices: Phrase[ExpType], - input: Phrase[ExpType] - ) extends ExpPrimitive { - indices :: expT(n`.`idx(m), read) - input :: expT(n`.`dt, write) - override val t: ExpType = expT(m`.`dt, write) +final case class Scatter(val n: Nat, val m: Nat, val dt: DataType, val indices: Phrase[ExpType], val input: Phrase[ExpType]) extends ExpPrimitive { + { + indices :: expT(ArrayType(n, IndexType(m)), read) + input :: expT(ArrayType(n, dt), write) + } + override val t: ExpType = expT(ArrayType(m, dt), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Scatter = new Scatter(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(indices, v), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Slide.scala b/src/main/scala/shine/DPIA/primitives/functional/Slide.scala index 4a71ba49a..92fdf0bbf 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Slide.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Slide.scala @@ -1,23 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - -import arithexpr.arithmetic.SimplifiedExpr +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -import scala.language.reflectiveCalls - -@expPrimitive -final case class Slide(n: Nat, - sz: Nat, - sp: Nat, - dt: DataType, - input: Phrase[ExpType] - ) extends ExpPrimitive { - val inputSize: Nat with SimplifiedExpr = sp * n + sz - sp - - input :: expT(inputSize`.`dt, read) - override val t: ExpType = expT(n`.`(sz`.`dt), read) +final case class Slide(val n: Nat, val sz: Nat, val sp: Nat, val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { + { + input :: expT(ArrayType(sp * n + sz, dt), read) + } + override val t: ExpType = expT(ArrayType(1 + n, ArrayType(sz, dt)), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Slide = new Slide(v.nat(n), v.nat(sz), v.nat(sp), v.data(dt), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Snd.scala b/src/main/scala/shine/DPIA/primitives/functional/Snd.scala index 214151764..77a2e9488 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Snd.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Snd.scala @@ -1,15 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Snd(dt1: DataType, - dt2: DataType, - pair: Phrase[ExpType] - ) extends ExpPrimitive { - pair :: expT(dt1 x dt2, read) +final case class Snd(val dt1: DataType, val dt2: DataType, val pair: Phrase[ExpType]) extends ExpPrimitive { + { + pair :: expT(PairType(dt1, dt2), read) + } override val t: ExpType = expT(dt2, read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Snd = new Snd(v.data(dt1), v.data(dt2), VisitAndRebuild(pair, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Split.scala b/src/main/scala/shine/DPIA/primitives/functional/Split.scala index 3b5ead8fd..426e6af3a 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Split.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Split.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Split(n: Nat, - m: Nat, - w: AccessType, - dt: DataType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT((m * n)`.`dt, w) - override val t: ExpType = expT(m`.`(n`.`dt), w) +final case class Split(val n: Nat, val m: Nat, val a: AccessType, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(m * n, dt), a) + } + override val t: ExpType = expT(ArrayType(m, ArrayType(n, dt)), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Split = new Split(v.nat(n), v.nat(m), v.access(a), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Take.scala b/src/main/scala/shine/DPIA/primitives/functional/Take.scala index 8a9c4ba7f..ffd625320 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Take.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Take.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -// this takes n many elements from an array of n + m elements -@expPrimitive -final case class Take(n: Nat, - m: Nat, - dt: DataType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT((n + m)`.`dt, read) - override val t: ExpType = expT(n`.`dt, read) +final case class Take(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(n + m, dt), read) + } + override val t: ExpType = expT(ArrayType(n, dt), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Take = new Take(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/ToMem.scala b/src/main/scala/shine/DPIA/primitives/functional/ToMem.scala index fd675f837..afd8d1bdf 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/ToMem.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/ToMem.scala @@ -1,14 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class ToMem(dt: DataType, - input: Phrase[ExpType] - ) extends ExpPrimitive { - input :: expT(dt, write) +final case class ToMem(val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { + { + input :: expT(dt, write) + } override val t: ExpType = expT(dt, read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ToMem = new ToMem(v.data(dt), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Transpose.scala b/src/main/scala/shine/DPIA/primitives/functional/Transpose.scala index b72b27e8b..ecdf22dc1 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Transpose.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Transpose.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Transpose(n: Nat, - m: Nat, - dt: DataType, - access: AccessType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT(n`.`(m`.`dt), access) - override val t: ExpType = expT(m`.`(n`.`dt), access) +final case class Transpose(val n: Nat, val m: Nat, val dt: DataType, val a: AccessType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(n, ArrayType(m, dt)), a) + } + override val t: ExpType = expT(ArrayType(m, ArrayType(n, dt)), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Transpose = new Transpose(v.nat(n), v.nat(m), v.data(dt), v.access(a), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/TransposeDepArray.scala b/src/main/scala/shine/DPIA/primitives/functional/TransposeDepArray.scala index 1891bd369..276a3f223 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/TransposeDepArray.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/TransposeDepArray.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class TransposeDepArray(n: Nat, - m: Nat, - f: NatToData, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT(n`.`(m`.d`f), read) - override val t: ExpType = expT(m`.d`{ k => n`.`f(k) }, read) +final case class TransposeDepArray(val n: Nat, val m: Nat, val ft: NatToData, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(n, DepArrayType(m, ft)), read) + } + override val t: ExpType = expT(DepArrayType(n, n2dtFun { (i: NatIdentifier) => ArrayType(n, NatToDataApply(ft, i)) }), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): TransposeDepArray = new TransposeDepArray(v.nat(n), v.nat(m), v.natToData(ft), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Unzip.scala b/src/main/scala/shine/DPIA/primitives/functional/Unzip.scala index 2254ef12c..12d39f4bf 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Unzip.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Unzip.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Unzip(n: Nat, - dt1: DataType, - dt2: DataType, - access: AccessType, - e: Phrase[ExpType] - ) extends ExpPrimitive { - e :: expT(n`.`(dt1 x dt2), access) - override val t: ExpType = expT((n`.`dt1) x (n`.`dt2), access) +final case class Unzip(val n: Nat, val dt1: DataType, val dt2: DataType, val a: AccessType, val e: Phrase[ExpType]) extends ExpPrimitive { + { + e :: expT(ArrayType(n, PairType(dt1, dt2)), a) + } + override val t: ExpType = expT(PairType(ArrayType(n, dt1), ArrayType(n, dt2)), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Unzip = new Unzip(v.nat(n), v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(e, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/VectorFromScalar.scala b/src/main/scala/shine/DPIA/primitives/functional/VectorFromScalar.scala index 4f480c8f4..91b183b13 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/VectorFromScalar.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/VectorFromScalar.scala @@ -1,15 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class VectorFromScalar(n: Nat, - dt: ScalarType, - arg: Phrase[ExpType] - ) extends ExpPrimitive { - arg :: expT(dt, read) - override val t: ExpType = expT(vec(n, dt), read) +final case class VectorFromScalar(val n: Nat, val dt: DataType, val arg: Phrase[ExpType]) extends ExpPrimitive { + { + arg :: expT(dt, read) + } + override val t: ExpType = expT(VectorType(n, dt), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): VectorFromScalar = new VectorFromScalar(v.nat(n), v.data(dt), VisitAndRebuild(arg, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/Zip.scala b/src/main/scala/shine/DPIA/primitives/functional/Zip.scala index d90cc6ece..3f0c328df 100644 --- a/src/main/scala/shine/DPIA/primitives/functional/Zip.scala +++ b/src/main/scala/shine/DPIA/primitives/functional/Zip.scala @@ -1,20 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Zip(n: Nat, - dt1: DataType, - dt2: DataType, - access: AccessType, - e1: Phrase[ExpType], - e2: Phrase[ExpType] - ) extends ExpPrimitive { - e1 :: expT(n`.`dt1, access) - e2 :: expT(n`.`dt2, access) - override val t: ExpType = expT(n`.`(dt1 x dt2), access) +final case class Zip(val n: Nat, val dt1: DataType, val dt2: DataType, val a: AccessType, val e1: Phrase[ExpType], val e2: Phrase[ExpType]) extends ExpPrimitive { + { + e1 :: expT(ArrayType(n, dt1), a) + e2 :: expT(ArrayType(n, dt2), a) + } + override val t: ExpType = expT(ArrayType(n, PairType(dt1, dt2)), a) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Zip = new Zip(v.nat(n), v.data(dt1), v.data(dt2), v.access(a), VisitAndRebuild(e1, v), VisitAndRebuild(e2, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/functional/primitives.dpia b/src/main/scala/shine/DPIA/primitives/functional/primitives.dpia new file mode 100644 index 000000000..37e8ede1c --- /dev/null +++ b/src/main/scala/shine/DPIA/primitives/functional/primitives.dpia @@ -0,0 +1,136 @@ +def asScalar(n: nat, m: nat, dt: data, a: access, array: exp[m.vec[dt, n], a]): exp[(m*n).dt, a] +def asVector(n: nat, m: nat, dt: data, a: access, array: exp[(m*n).dt, a]): exp[m.vec[dt, n], a] +def asVectorAligned(n: nat, m: nat, dt: data, a: access, array: exp[(m*n).dt, a]): exp[m.vec[dt, n], a] + +def cast(dt1: data, dt2: data, e: exp[dt1, read]): exp[dt2, read] + +def circularBuffer(n: nat, alloc: nat, sz: nat, dt1: data, dt2: data, + load: exp[dt1, read] -> exp[dt2, write], + input: exp[(n-1+sz).dt1, read]): exp[n.sz.dt2, read] + +def cycle(n: nat, m: nat, dt: data, input: exp[m.dt, read]): exp[n.dt, read] + +def depIdx(n: nat, ft: nat2data, index: nat, array: exp[n..ft, read]): exp[ft(index), read] + +def depJoin(n: nat, lenF: nat2nat, dt: data, + array: exp[n..(i: nat |-> lenF(i).dt), read]): exp[(sum_(i=0)^(n-1) lenF(i)).dt, read] + +def depMapSeq{unroll: Boolean} + (n: nat, ft1: nat2data, ft2: nat2data, + f: (k: nat) -> exp[ft1(k), read] -> exp[ft2(k), write], + array: exp[n..ft1, read]): exp[n..ft2, write] + +// def depTile(...) + +def depZip(n: nat, ft1: nat2data, ft2: nat2data, + e1: exp[n..ft1, read], e2: exp[n..ft2, read]): exp[n..(i: nat |-> (ft1(i), ft2(i)) ), read] + +// def dmatch(...) + +def drop(n: nat, m: nat, dt: data, array: exp[n+m.dt, read]): exp[m.dt, read] + +def foreignFunctionCall{funDecl: rise.core.ForeignFunction.Decl, + inTs: Seq[DataType], + args: Seq[Phrase[ExpType]]}(outT: data): exp[outT, read] + +def fst(dt1: data, dt2: data, pair: exp[(dt1, dt2), read]): exp[dt1, read] + +def gather(n: nat, m: nat, dt: data, indices: exp[m.idx[n], read], input: exp[n.dt, read]): exp[m.dt, read] + +def generate(n: nat, dt: data, f: exp[idx[n], read] -> exp[dt, read]): exp[n.dt, read] + +def idx(n: nat, dt: data, index: exp[idx[n], read], array: exp[n.dt, read]): exp[dt, read] + +def idxVec(n: nat, dt: data, index: exp[idx[n], read], vector: exp[vec[dt, n], read]): exp[dt, read] + +def indexAsNat(n: nat, e: exp[idx[n], read]): exp[natType, read] + +def iterate(n: nat, m: nat, k: nat, dt: data, + f: (l: nat) -> exp[(l*n).dt, read] -> exp[l.dt, write], + array: exp[(m*(n^k)).dt, read]): exp[m.dt, write] + +def iterateStream(n: nat, dt1: data, dt2: data, + f: exp[dt1, read] -> exp[dt2, write], + array: exp[n.dt1, read]): exp[n.dt2, write] + +def join(n: nat, m: nat, a: access, dt: data, array: exp[n.m.dt, a]): exp[(n*m).dt, a] + +def let(dt1: data, dt2: data, a: access, value: exp[dt1, read], + f: exp[dt1, read] -> exp[dt2, a]): exp[dt2, a] + +def makeArray{elements: Vector[Phrase[ExpType]]}(n: nat, dt: data): exp[n.dt, read] + +// def makeDepPair(...) + +def makePair(dt1: data, dt2: data, a: access, fst: exp[dt1, a], snd: exp[dt2, a]): exp[(dt1, dt2), a] + +def map(n: nat, dt1: data, dt2: data, a: access, + f: exp[dt1, a] -> exp[dt2, a], array: exp[n.dt1, a]): exp[n.dt2, a] + +def mapFst(a: access, dt1: data, dt2: data, dt3: data, + f: exp[dt1, a] -> exp[dt3, a], + pair: exp[(dt1, dt2), a]): exp[(dt3, dt2), a] + +def mapSeq{unroll: Boolean} + (n: nat, dt1: data, dt2: data, + f: exp[dt1, read] -> exp[dt2, write], array: exp[n.dt1, read]): exp[n.dt2, write] + +def mapSnd(a: access, dt1: data, dt2: data, dt3: data, + f: exp[dt2, a] -> exp[dt3, a], + pair: exp[(dt1, dt2), a]): exp[(dt1, dt3), a] + +def mapStream(n: nat, dt1: data, dt2: data, + f: exp[dt1, read] -> exp[dt2, write], array: exp[n.dt1, read]): exp[n.dt2, write] + +def mapVec(n: nat, dt1: data, dt2: data, + f: exp[dt1, read] -> exp[dt2, write], array: exp[vec[dt1, n], read]): exp[vec[dt2, n], write] + +def natAsIndex(n: nat, e: exp[natType, read]): exp[idx[n], read] + +def padClamp(n: nat, l: nat, r: nat, dt: data, array: exp[n.dt, read]): exp[(l+n+r).dt, read] + +def padCst(n: nat, l: nat, r: nat, dt: data, padExp: exp[dt, read], array: exp[n.dt, read]): exp[(l+n+r).dt, read] + +def padEmpty(n: nat, r: nat, dt: data, array: exp[n.dt, write]): exp[(n+r).dt, write] + +def partition(n: nat, m: nat, lenF: nat2nat, dt: data, array: exp[n.dt, read]): exp[m..(i: nat |-> lenF(i).dt), read] + +// def printType{msg: String}(dt: data, a: access, input: exp[dt, a]): exp[dt, a] + +def reduceSeq{unroll: Boolean} + (n: nat, dt1: data, dt2: data, + f: exp[dt2, read] -> exp[dt1, read] -> exp[dt2, write], + init: exp[dt2, write], + array: exp[n.dt1, read]): exp[dt2, read] + +def reorder(n: nat, dt: data, a: access, idxF: nat2nat, idxFiv: nat2nat, input: exp[n.dt, a]): exp[n.dt, a] + +def rotateValues(n: nat, sz: nat, dt: data, + wrt: exp[dt, read] -> exp[dt, write], + input: exp[(n-1+sz).dt, read]): exp[n.sz.dt, read] + +def scanSeq(n: nat, dt1: data, dt2: data, + f: exp[dt1, read] -> exp[dt2, read] -> exp[dt2, write], + init: exp[dt2, write], array: exp[n.dt1, read]): exp[n.dt2, read] + +def scatter(n: nat, m: nat, dt: data, indices: exp[n.idx[m], read], input: exp[n.dt, write]): exp[m.dt, write] + +def slide(n: nat, sz: nat, sp: nat, dt: data, input: exp[(sp*n+sz).dt, read]): exp[(1+n).sz.dt, read] + +def snd(dt1: data, dt2: data, pair: exp[(dt1, dt2), read]): exp[dt2, read] + +def split(n: nat, m: nat, a: access, dt: data, array: exp[(m*n).dt, a]): exp[m.n.dt, a] + +def take(n: nat, m: nat, dt: data, array: exp[(n+m).dt, read]): exp[n.dt, read] + +def toMem(dt: data, input: exp[dt, write]): exp[dt, read] + +def transpose(n: nat, m: nat, dt: data, a: access, array: exp[n.m.dt, a]): exp[m.n.dt, a] + +def transposeDepArray(n: nat, m: nat, ft: nat2data, array: exp[n.m..ft, read]): exp[n..(i: nat |-> n.ft(i)), read] + +def unzip(n: nat, dt1: data, dt2: data, a: access, e: exp[n.(dt1, dt2), a]): exp[(n.dt1, n.dt2), a] + +def vectorFromScalar(n: nat, dt: data, arg: exp[dt, read]): exp[vec[dt, n], read] + +def zip(n: nat, dt1: data, dt2: data, a: access, e1: exp[n.dt1, a], e2: exp[n.dt2, a]): exp[n.(dt1, dt2), a] diff --git a/src/main/scala/shine/DPIA/primitives/imperative/AsScalarAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/AsScalarAcc.scala index 98f0960b6..7c064f934 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/AsScalarAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/AsScalarAcc.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class AsScalarAcc(n: Nat, - m: Nat, - dt: ScalarType, - array: Phrase[AccType] - )extends AccPrimitive { - array :: accT((m * n)`.`dt) - override val t: AccType = accT(n`.`vec(m, dt)) +final case class AsScalarAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(ArrayType(m * n, dt)) + } + override val t: AccType = accT(ArrayType(m, VectorType(n, dt))) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsScalarAcc = new AsScalarAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/AsVectorAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/AsVectorAcc.scala index a8307b313..d68e07768 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/AsVectorAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/AsVectorAcc.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class AsVectorAcc(n: Nat, - m: Nat, - dt: ScalarType, - array: Phrase[AccType] - ) extends AccPrimitive { - array :: accT(n`.`vec(m, dt)) - override val t: AccType = accT((n * m)`.`dt) +final case class AsVectorAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(ArrayType(m, VectorType(n, dt))) + } + override val t: AccType = accT(ArrayType(n * m, dt)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsVectorAcc = new AsVectorAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/Assign.scala b/src/main/scala/shine/DPIA/primitives/imperative/Assign.scala index b2ee75675..0451dcffa 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/Assign.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/Assign.scala @@ -1,18 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -case class Assign(dt: DataType, - lhs: Phrase[AccType], - rhs: Phrase[ExpType] - ) extends CommandPrimitive { - lhs :: accT(dt) - rhs :: expT(dt, read) - - override def prettyPrint: String = s"(${PrettyPhrasePrinter(lhs)} := ${PrettyPhrasePrinter(rhs)})" +final case class Assign(val dt: DataType, val lhs: Phrase[AccType], val rhs: Phrase[ExpType]) extends CommandPrimitive { + { + lhs :: accT(dt) + rhs :: expT(dt, read) + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Assign = new Assign(v.data(dt), VisitAndRebuild(lhs, v), VisitAndRebuild(rhs, v)) } - diff --git a/src/main/scala/shine/DPIA/primitives/imperative/Comment.scala b/src/main/scala/shine/DPIA/primitives/imperative/Comment.scala index 3cba9d039..c1416c259 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/Comment.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/Comment.scala @@ -1,9 +1,14 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - -import shine.DPIA.Phrases.CommandPrimitive -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class Comment(comment : String) extends CommandPrimitive { - override def prettyPrint: String = s"\n//$comment\n" +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ +import shine.DPIA.Types._ +import shine.DPIA._ +final case class Comment(comment: String)() extends CommandPrimitive { + {} + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Comment = new Comment(comment)() } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/CycleAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/CycleAcc.scala index ab55902fa..06be5aec8 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/CycleAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/CycleAcc.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class CycleAcc(n: Nat, - m: Nat, - dt: DataType, - input: Phrase[AccType] - ) extends AccPrimitive { - input :: accT(m`.`dt) - override val t: AccType = accT(n`.`dt) +final case class CycleAcc(val n: Nat, val m: Nat, val dt: DataType, val input: Phrase[AccType]) extends AccPrimitive { + { + input :: accT(ArrayType(m, dt)) + } + override val t: AccType = accT(ArrayType(n, dt)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): CycleAcc = new CycleAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/DMatchI.scala b/src/main/scala/shine/DPIA/primitives/imperative/DMatchI.scala index bf43207d2..374f16fae 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/DMatchI.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/DMatchI.scala @@ -3,13 +3,13 @@ package shine.DPIA.primitives.imperative import shine.DPIA.Phrases._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.comPrimitive -@comPrimitive final case class DMatchI(x: NatIdentifier, elemT: DataType, outT: DataType, f: Phrase[`(nat)->:`[ExpType ->: CommType]], input: Phrase[ExpType] ) extends CommandPrimitive { + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[CommType] = + DMatchI(v.nat(x), v.data(elemT), v.data(outT), VisitAndRebuild(f, v), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/DepIdxAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/DepIdxAcc.scala index c80913c2d..a887fef02 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/DepIdxAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/DepIdxAcc.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class DepIdxAcc(n: Nat, - ft:NatToData, - index: Nat, - array: Phrase[AccType] - ) extends AccPrimitive { - array :: accT(n`.d`ft) - override val t: AccType = accT(ft(index)) +final case class DepIdxAcc(val n: Nat, val ft: NatToData, val index: Nat, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(DepArrayType(n, ft)) + } + override val t: AccType = accT(NatToDataApply(ft, index)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepIdxAcc = new DepIdxAcc(v.nat(n), v.natToData(ft), v.nat(index), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/DepJoinAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/DepJoinAcc.scala index 0491484d4..011ead569 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/DepJoinAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/DepJoinAcc.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - -import arithexpr.arithmetic.BigSum +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class DepJoinAcc(n: Nat, - lenF:NatToNat, - dt: DataType, - array: Phrase[AccType] - ) extends AccPrimitive { - array :: accT(BigSum(from=0, upTo = n-1, i => lenF(i))`.`dt) - override val t: AccType = accT(n`.d`{ i => ArrayType(lenF(i), dt) }) +final case class DepJoinAcc(val n: Nat, val lenF: NatToNat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(ArrayType(BigSum(from = 0, upTo = n - 1, (i: Nat) => lenF(i)), dt)) + } + override val t: AccType = accT(DepArrayType(n, n2dtFun { (i: NatIdentifier) => ArrayType(lenF(i), dt) })) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepJoinAcc = new DepJoinAcc(v.nat(n), v.natToNat(lenF), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/DropAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/DropAcc.scala index c50e0555c..a09617ffc 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/DropAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/DropAcc.scala @@ -1,18 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ -import shine.DPIA.Types._ import shine.DPIA.Types.DataType._ +import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -// this drops n many elements from an array of m elements -@accPrimitive -final case class DropAcc(n: Nat, - m: Nat, - dt: DataType, - array: Phrase[AccType] - ) extends AccPrimitive { - array :: accT({n + m}`.`dt) - override val t: AccType = accT({m - n}`.`dt) +final case class DropAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(ArrayType(n + m, dt)) + } + override val t: AccType = accT(ArrayType(m - n, dt)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): DropAcc = new DropAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/For.scala b/src/main/scala/shine/DPIA/primitives/imperative/For.scala index a1aeb6675..3d5d4ef1c 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/For.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/For.scala @@ -1,20 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class For(unroll: Boolean) - (val n: Nat, - val loopBody: Phrase[ExpType ->: CommType] - ) extends CommandPrimitive { - loopBody :: expT(idx(n), read) ->: comm - - lazy val unwrapBody: (Identifier[ExpType], Phrase[CommType]) = loopBody match { - case Lambda(i, body) => (i, body) - case _ => throw new Exception("This should not happen") +final case class For(unroll: Boolean)(val n: Nat, val loopBody: Phrase[FunType[ExpType, CommType]]) extends CommandPrimitive { + { + loopBody :: FunType(expT(IndexType(n), read), comm) } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): For = new For(unroll)(v.nat(n), VisitAndRebuild(loopBody, v)) + def unwrap: (Nat, Phrase[FunType[ExpType, CommType]]) = (n, loopBody) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala b/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala index 5bace44b0..6bc08c265 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ForNat.scala @@ -1,19 +1,20 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class ForNat(unroll: Boolean) - (val n: Nat, - val loopBody: Phrase[`(nat)->:`[CommType]] - ) extends CommandPrimitive { - loopBody :: loopBody.t.x ->: comm - - lazy val unwrapBody: (NatIdentifier, Phrase[CommType]) = loopBody match { - case DepLambda(i, body) => (i, body) - case _ => throw new Exception("This should not happen") +final case class ForNat(unroll: Boolean)(val n: Nat, val loopBody: Phrase[DepFunType[NatKind, CommType]]) extends CommandPrimitive { + { + loopBody :: ({ + val i = loopBody.t.x + DepFunType[NatKind, PhraseType](i, comm) + }) } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForNat = new ForNat(unroll)(v.nat(n), VisitAndRebuild(loopBody, v)) + def unwrap: (Nat, Phrase[DepFunType[NatKind, CommType]]) = (n, loopBody) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ForVec.scala b/src/main/scala/shine/DPIA/primitives/imperative/ForVec.scala index a4e9d8cf7..2d4967f8a 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ForVec.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ForVec.scala @@ -1,17 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class ForVec(n: Nat, - dt: ScalarType, - out: Phrase[AccType], - loopBody: Phrase[ExpType ->: AccType ->: CommType] - ) extends CommandPrimitive { - out :: accT(vec(n, dt)) - loopBody :: expT(idx(n), read) ->: accT(dt) ->: comm +final case class ForVec(val n: Nat, val dt: DataType, val out: Phrase[AccType], val loopBody: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive { + { + out :: accT(VectorType(n, dt)) + loopBody :: FunType(expT(IndexType(n), read), FunType(accT(dt), comm)) + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForVec = new ForVec(v.nat(n), v.data(dt), VisitAndRebuild(out, v), VisitAndRebuild(loopBody, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/GenerateCont.scala b/src/main/scala/shine/DPIA/primitives/imperative/GenerateCont.scala index 6bc8952ac..029fc9810 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/GenerateCont.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/GenerateCont.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -// note: would not be necessary if generate was defined as indices + map -@expPrimitive -final case class GenerateCont(n: Nat, - dt: DataType, - f: Phrase[ExpType ->: ((ExpType ->: CommType) ->: CommType)] - ) extends ExpPrimitive { - f :: expT(idx(n), read) ->: (expT(dt, read) ->: comm) ->: comm - override val t: ExpType = expT(n`.`dt, read) +final case class GenerateCont(val n: Nat, val dt: DataType, val f: Phrase[FunType[ExpType, FunType[FunType[ExpType, CommType], CommType]]]) extends ExpPrimitive { + { + f :: FunType(expT(IndexType(n), read), FunType(FunType(expT(dt, read), comm), comm)) + } + override val t: ExpType = expT(ArrayType(n, dt), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): GenerateCont = new GenerateCont(v.nat(n), v.data(dt), VisitAndRebuild(f, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/IdxAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/IdxAcc.scala index 3cdfb04a0..09a1942a5 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/IdxAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/IdxAcc.scala @@ -1,18 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class IdxAcc(n: Nat, - dt: DataType, - index: Phrase[ExpType], - array: Phrase[AccType] - )extends AccPrimitive { - index :: expT(idx(n), read) - array :: accT(n`.`dt) +final case class IdxAcc(val n: Nat, val dt: DataType, val index: Phrase[ExpType], val array: Phrase[AccType]) extends AccPrimitive { + { + index :: expT(IndexType(n), read) + array :: accT(ArrayType(n, dt)) + } override val t: AccType = accT(dt) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxAcc = new IdxAcc(v.nat(n), v.data(dt), VisitAndRebuild(index, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/IdxVecAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/IdxVecAcc.scala index eb051e3b1..0566caa11 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/IdxVecAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/IdxVecAcc.scala @@ -1,18 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class IdxVecAcc(n: Nat, - st: ScalarType, - index: Phrase[ExpType], - vector: Phrase[AccType] - ) extends AccPrimitive { - index :: expT(idx(n), read) - vector :: accT(vec(n, st)) - override val t: AccType = accT(st) +final case class IdxVecAcc(val n: Nat, val dt: DataType, val index: Phrase[ExpType], val vector: Phrase[AccType]) extends AccPrimitive { + { + index :: expT(IndexType(n), read) + vector :: accT(VectorType(n, dt)) + } + override val t: AccType = accT(dt) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxVecAcc = new IdxVecAcc(v.nat(n), v.data(dt), VisitAndRebuild(index, v), VisitAndRebuild(vector, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/JoinAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/JoinAcc.scala index 3ba5223b0..fe0923988 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/JoinAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/JoinAcc.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class JoinAcc(n: Nat, - m: Nat, - dt: DataType, - array: Phrase[AccType] - ) extends AccPrimitive { - array :: accT({n * m}`.`dt) - override val t: AccType = accT(n`.`(m`.`dt)) +final case class JoinAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(ArrayType(n * m, dt)) + } + override val t: AccType = accT(ArrayType(n, ArrayType(m, dt))) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): JoinAcc = new JoinAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MapAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/MapAcc.scala index 9ca3a5aaa..b4cd17cec 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/MapAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/MapAcc.scala @@ -1,19 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class MapAcc(n: Nat, - dt1: DataType, - dt2: DataType, - f: Phrase[AccType ->: AccType], - array: Phrase[AccType] - ) extends AccPrimitive { - f :: accT(dt1) ->: accT(dt2) - array :: accT(n`.`dt1) - override val t: AccType = accT(n`.`dt2) +final case class MapAcc(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[AccType, AccType]], val array: Phrase[AccType]) extends AccPrimitive { + { + f :: FunType(accT(dt1), accT(dt2)) + array :: accT(ArrayType(n, dt1)) + } + override val t: AccType = accT(ArrayType(n, dt2)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapAcc = new MapAcc(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MapFstAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/MapFstAcc.scala index b7ff02336..4e90c0d42 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/MapFstAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/MapFstAcc.scala @@ -1,17 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class MapFstAcc(dt1: DataType, - dt2: DataType, - dt3: DataType, - f: Phrase[AccType ->: AccType], - record: Phrase[AccType]) extends AccPrimitive { - f :: accT(dt3) ->: accT(dt1) - record :: accT(dt3 x dt2) - override val t: AccType = accT(dt1 x dt2) +final case class MapFstAcc(val dt1: DataType, val dt2: DataType, val dt3: DataType, val f: Phrase[FunType[AccType, AccType]], val record: Phrase[AccType]) extends AccPrimitive { + { + f :: FunType(accT(dt3), accT(dt1)) + record :: accT(PairType(dt3, dt2)) + } + override val t: AccType = accT(PairType(dt1, dt2)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapFstAcc = new MapFstAcc(v.data(dt1), v.data(dt2), v.data(dt3), VisitAndRebuild(f, v), VisitAndRebuild(record, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MapRead.scala b/src/main/scala/shine/DPIA/primitives/imperative/MapRead.scala index 5a829a2fd..67fe5ad67 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/MapRead.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/MapRead.scala @@ -1,19 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class MapRead(n: Nat, - dt1: DataType, - dt2: DataType, - f: Phrase[ExpType ->: (ExpType ->: CommType) ->: CommType], - input: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt1, read) ->: (expT(dt2, read) ->: comm) ->: comm - input :: expT(n`.`dt1, read) - override val t: ExpType = expT(n`.`dt2, read) +final case class MapRead(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[FunType[ExpType, CommType], CommType]]], val input: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt1, read), FunType(FunType(expT(dt2, read), comm), comm)) + input :: expT(ArrayType(n, dt1), read) + } + override val t: ExpType = expT(ArrayType(n, dt2), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapRead = new MapRead(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MapSndAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/MapSndAcc.scala index 3d19844ce..3c163d1cf 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/MapSndAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/MapSndAcc.scala @@ -1,17 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class MapSndAcc(dt1: DataType, - dt2: DataType, - dt3: DataType, - f: Phrase[AccType ->: AccType], - record: Phrase[AccType]) extends AccPrimitive { - f :: accT(dt3) ->: accT(dt2) - record :: accT(dt1 x dt3) - override val t: AccType = accT(dt1 x dt2) +final case class MapSndAcc(val dt1: DataType, val dt2: DataType, val dt3: DataType, val f: Phrase[FunType[AccType, AccType]], val record: Phrase[AccType]) extends AccPrimitive { + { + f :: FunType(accT(dt3), accT(dt2)) + record :: accT(PairType(dt1, dt3)) + } + override val t: AccType = accT(PairType(dt1, dt2)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapSndAcc = new MapSndAcc(v.data(dt1), v.data(dt2), v.data(dt3), VisitAndRebuild(f, v), VisitAndRebuild(record, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MkDPairFstI.scala b/src/main/scala/shine/DPIA/primitives/imperative/MkDPairFstI.scala index af1f88de1..c85cbce16 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/MkDPairFstI.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/MkDPairFstI.scala @@ -3,9 +3,9 @@ package shine.DPIA.primitives.imperative import shine.DPIA.Phrases._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.comPrimitive -@comPrimitive final case class MkDPairFstI(fst: Nat, A: Phrase[AccType]) extends CommandPrimitive { + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[CommType] = + MkDPairFstI(v.nat(fst), VisitAndRebuild(A, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/MkDPairSndAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/MkDPairSndAcc.scala index b6c4c865f..8df8914c7 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/MkDPairSndAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/MkDPairSndAcc.scala @@ -3,11 +3,12 @@ package shine.DPIA.primitives.imperative import shine.DPIA.NatIdentifier import shine.DPIA.Phrases._ import shine.DPIA.Types.{AccType, DataType} -import shine.macros.Primitive.accPrimitive -@accPrimitive final case class MkDPairSndAcc(fst: NatIdentifier, sndT: DataType, A: Phrase[AccType]) extends AccPrimitive { override val t = AccType(sndT) + + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Phrase[AccType] = + MkDPairSndAcc(v.nat(fst), v.data(sndT), VisitAndRebuild(A, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/New.scala b/src/main/scala/shine/DPIA/primitives/imperative/New.scala index 265549e24..b376a80a9 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/New.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/New.scala @@ -1,13 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class New(dt: DataType, - f: Phrase[VarType ->: CommType] - ) extends CommandPrimitive { - f :: varT(dt) ->: comm +final case class New(val dt: DataType, val f: Phrase[FunType[PhrasePairType[ExpType, AccType], CommType]]) extends CommandPrimitive { + { + f :: FunType(PhrasePairType(expT(dt, read), accT(dt)), comm) + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): New = new New(v.data(dt), VisitAndRebuild(f, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/NewDoubleBuffer.scala b/src/main/scala/shine/DPIA/primitives/imperative/NewDoubleBuffer.scala index 6c4ceec3b..7001ed75c 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/NewDoubleBuffer.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/NewDoubleBuffer.scala @@ -1,21 +1,18 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class NewDoubleBuffer(dt1: DataType, - dt2: DataType, - dt3: DataType, - n: Nat, - in: Phrase[ExpType], - out: Phrase[AccType], - f: Phrase[(ExpType x AccType x CommType x CommType) ->: CommType] - )extends CommandPrimitive { - in :: expT(dt1, read) - out :: accT(dt2) - f :: (((varT(n`.`dt3) x comm) x comm) ->: comm) +final case class NewDoubleBuffer(val dt1: DataType, val dt2: DataType, val dt3: DataType, val n: Nat, val in: Phrase[ExpType], val out: Phrase[AccType], val f: Phrase[FunType[PhrasePairType[PhrasePairType[PhrasePairType[ExpType, AccType], CommType], CommType], CommType]]) extends CommandPrimitive { + { + in :: expT(dt1, read) + out :: accT(dt2) + f :: FunType(PhrasePairType(PhrasePairType(PhrasePairType(expT(ArrayType(n, dt3), read), accT(ArrayType(n, dt3))), comm), comm), comm) + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): NewDoubleBuffer = new NewDoubleBuffer(v.data(dt1), v.data(dt2), v.data(dt3), v.nat(n), VisitAndRebuild(in, v), VisitAndRebuild(out, v), VisitAndRebuild(f, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc.scala index 09d9b11fc..b6f03c5c7 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc.scala @@ -1,17 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class PairAcc(dt1: DataType, - dt2: DataType, - fst: Phrase[AccType], - snd: Phrase[AccType] - ) extends AccPrimitive { - fst :: accT(dt1) - snd :: accT(dt2) - override val t: AccType = accT(dt1 x dt2) +final case class PairAcc(val dt1: DataType, val dt2: DataType, val fst: Phrase[AccType], val snd: Phrase[AccType]) extends AccPrimitive { + { + fst :: accT(dt1) + snd :: accT(dt2) + } + override val t: AccType = accT(PairType(dt1, dt2)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): PairAcc = new PairAcc(v.data(dt1), v.data(dt2), VisitAndRebuild(fst, v), VisitAndRebuild(snd, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc1.scala b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc1.scala index af0e85003..924410406 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc1.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc1.scala @@ -1,15 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class PairAcc1(dt1: DataType, - dt2: DataType, - pair: Phrase[AccType] - ) extends AccPrimitive { - pair :: accT(dt1 x dt2) +final case class PairAcc1(val dt1: DataType, val dt2: DataType, val pair: Phrase[AccType]) extends AccPrimitive { + { + pair :: accT(PairType(dt1, dt2)) + } override val t: AccType = accT(dt1) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): PairAcc1 = new PairAcc1(v.data(dt1), v.data(dt2), VisitAndRebuild(pair, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc2.scala b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc2.scala index dd8370269..f49ccd845 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/PairAcc2.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/PairAcc2.scala @@ -1,15 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class PairAcc2(dt1: DataType, - dt2: DataType, - pair: Phrase[AccType] - ) extends AccPrimitive { - pair :: accT(dt1 x dt2) +final case class PairAcc2(val dt1: DataType, val dt2: DataType, val pair: Phrase[AccType]) extends AccPrimitive { + { + pair :: accT(PairType(dt1, dt2)) + } override val t: AccType = accT(dt2) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): PairAcc2 = new PairAcc2(v.data(dt1), v.data(dt2), VisitAndRebuild(pair, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ReorderAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/ReorderAcc.scala index a30bd8efc..7820bfcf0 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ReorderAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ReorderAcc.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class ReorderAcc(n: Nat, - dt: DataType, - idxF: NatToNat, - array: Phrase[AccType] - ) extends AccPrimitive { - array :: accT(n`.`dt) - override val t: AccType = accT(n`.`dt) +final case class ReorderAcc(val n: Nat, val dt: DataType, val idxF: NatToNat, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(ArrayType(n, dt)) + } + override val t: AccType = accT(ArrayType(n, dt)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ReorderAcc = new ReorderAcc(v.nat(n), v.data(dt), v.natToNat(idxF), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ScatterAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/ScatterAcc.scala index c86146cc3..483b94e19 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ScatterAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ScatterAcc.scala @@ -1,16 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class ScatterAcc(n: Nat, m: Nat, dt: DataType, - indices: Phrase[ExpType], - array: Phrase[AccType]) extends AccPrimitive { - indices :: expT(n`.`idx(m), read) - array :: accT(n`.`dt) - override val t: AccType = accT(m`.`dt) +final case class ScatterAcc(val n: Nat, val m: Nat, val dt: DataType, val indices: Phrase[ExpType], val array: Phrase[AccType]) extends AccPrimitive { + { + indices :: expT(ArrayType(n, IndexType(m)), read) + array :: accT(ArrayType(m, dt)) + } + override val t: AccType = accT(ArrayType(n, dt)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ScatterAcc = new ScatterAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(indices, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/Seq.scala b/src/main/scala/shine/DPIA/primitives/imperative/Seq.scala index 579940f89..15500208b 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/Seq.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/Seq.scala @@ -1,13 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class Seq(c1: Phrase[CommType], - c2: Phrase[CommType]) - extends CommandPrimitive { - c1 :: comm - c2 :: comm +import shine.DPIA._ +final case class Seq(val c1: Phrase[CommType], val c2: Phrase[CommType]) extends CommandPrimitive { + { + c1 :: comm + c2 :: comm + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Seq = new Seq(VisitAndRebuild(c1, v), VisitAndRebuild(c2, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/Skip.scala b/src/main/scala/shine/DPIA/primitives/imperative/Skip.scala index 569d140b4..1d8ac8dec 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/Skip.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/Skip.scala @@ -1,14 +1,14 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ - -// not final (and not using the macro) because of DSL.typed.skip -case class Skip() extends CommandPrimitive { - +import shine.DPIA._ +final case class Skip() extends CommandPrimitive { + {} override val t: CommType = comm - - override def visitAndRebuild(f: VisitAndRebuild.Visitor): Phrase[CommType] = this - - override def prettyPrint: String = "skip" + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Skip = new Skip() } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/SplitAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/SplitAcc.scala index 07b2b0b01..4ad7b97c6 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/SplitAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/SplitAcc.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class SplitAcc(n: Nat, - m: Nat, - dt: DataType, - array: Phrase[AccType] - ) extends AccPrimitive { - array :: accT(m`.`(n`.`dt)) - override val t: AccType = accT({n * m}`.`dt) +final case class SplitAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(ArrayType(m, ArrayType(n, dt))) + } + override val t: AccType = accT(ArrayType(n * m, dt)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): SplitAcc = new SplitAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/TakeAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/TakeAcc.scala index ccea96f58..923cb6ded 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/TakeAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/TakeAcc.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class TakeAcc(n: Nat, - m: Nat, - dt: DataType, - array: Phrase[AccType] - ) extends AccPrimitive { - array :: accT({n + m}`.`dt) - override val t: AccType = accT(n`.`dt) +final case class TakeAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(ArrayType(n + m, dt)) + } + override val t: AccType = accT(ArrayType(n, dt)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): TakeAcc = new TakeAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/TransposeAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/TransposeAcc.scala index 31430d627..3944e3d8a 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/TransposeAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/TransposeAcc.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class TransposeAcc(n: Nat, - m: Nat, - dt: DataType, - array: Phrase[AccType] - ) extends AccPrimitive { - array :: accT(m`.`(n`.`dt)) - override val t: AccType = accT(n`.`(m`.`dt)) +final case class TransposeAcc(val n: Nat, val m: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(ArrayType(m, ArrayType(n, dt))) + } + override val t: AccType = accT(ArrayType(n, ArrayType(m, dt))) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): TransposeAcc = new TransposeAcc(v.nat(n), v.nat(m), v.data(dt), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/UnzipAcc.scala b/src/main/scala/shine/DPIA/primitives/imperative/UnzipAcc.scala index 8e5dc8fb2..9e0069f74 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/UnzipAcc.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/UnzipAcc.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class UnzipAcc(n: Nat, - dt1: DataType, - dt2: DataType, - a: Phrase[AccType] - ) extends AccPrimitive { - a :: accT((n`.`dt1) x (n`.`dt2)) - override val t: AccType = accT(n`.`(dt1 x dt2)) +final case class UnzipAcc(val n: Nat, val dt1: DataType, val dt2: DataType, val a: Phrase[AccType]) extends AccPrimitive { + { + a :: accT(PairType(ArrayType(n, dt1), ArrayType(n, dt2))) + } + override val t: AccType = accT(ArrayType(n, PairType(dt1, dt2))) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): UnzipAcc = new UnzipAcc(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(a, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc1.scala b/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc1.scala index 3e4821950..28343629d 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc1.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc1.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class ZipAcc1(n: Nat, - dt1: DataType, - dt2: DataType, - array: Phrase[AccType] - ) extends AccPrimitive { - array :: accT(n`.`(dt1 x dt2)) - override val t: AccType = accT(n`.`dt1) +final case class ZipAcc1(val n: Nat, val dt1: DataType, val dt2: DataType, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(ArrayType(n, PairType(dt1, dt2))) + } + override val t: AccType = accT(ArrayType(n, dt1)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ZipAcc1 = new ZipAcc1(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc2.scala b/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc2.scala index b18d7e34e..f84abc66e 100644 --- a/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc2.scala +++ b/src/main/scala/shine/DPIA/primitives/imperative/ZipAcc2.scala @@ -1,17 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.DPIA.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class ZipAcc2(n: Nat, - dt1: DataType, - dt2: DataType, - array: Phrase[AccType] - ) extends AccPrimitive { - array :: accT(n`.`(dt1 x dt2)) - override val t: AccType = accT(n`.`dt2) +final case class ZipAcc2(val n: Nat, val dt1: DataType, val dt2: DataType, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(ArrayType(n, PairType(dt1, dt2))) + } + override val t: AccType = accT(ArrayType(n, dt2)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ZipAcc2 = new ZipAcc2(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/DPIA/primitives/imperative/primitives.dpia b/src/main/scala/shine/DPIA/primitives/imperative/primitives.dpia new file mode 100644 index 000000000..9069428fb --- /dev/null +++ b/src/main/scala/shine/DPIA/primitives/imperative/primitives.dpia @@ -0,0 +1,69 @@ +def asScalarAcc(n: nat, m: nat, dt: data, array: acc[(m*n).dt]): acc[m.vec[dt, n]] + +def assign(dt: data, lhs: acc[dt], rhs: exp[dt, read]): comm + +def asVectorAcc(n: nat, m: nat, dt: data, array: acc[m.vec[dt, n]]): acc[(n*m).dt] + +def comment{comment: String}(): comm + +def cycleAcc(n: nat, m: nat, dt: data, input: acc[m.dt]): acc[n.dt] + +def depIdxAcc(n: nat, ft: nat2data, index: nat, array: acc[n..ft]): acc[ft(index)] + +def depJoinAcc(n: nat, lenF: nat2nat, dt: data, array: acc[(sum_(i=0)^(n-1) lenF(i)).dt]): acc[n..(i: nat |-> lenF(i).dt)] + +// def dMatchI(...) + +// this drops n many elements from an array of m elements +def dropAcc(n: nat, m: nat, dt: data, array: acc[(n+m).dt]): acc[(m-n).dt] + +def for{unroll: Boolean}(n: nat, loopBody: exp[idx[n], read] -> comm): comm +def forNat{unroll: Boolean}(n: nat, loopBody: (i: nat) -> comm): comm +def forVec(n: nat, dt: data, out: acc[vec[dt, n]], loopBody: exp[idx[n], read] -> acc[dt] -> comm): comm + +// note: would not be necessary if generate was defined as indices + map +def generateCont(n: nat, dt: data, f: exp[idx[n], read] -> (exp[dt, read] -> comm) -> comm): exp[n.dt, read] + +def idxAcc(n: nat, dt: data, index: exp[idx[n], read], array: acc[n.dt]): acc[dt] +def idxVecAcc(n: nat, dt: data, index: exp[idx[n], read], vector: acc[vec[dt, n]]): acc[dt] + +def joinAcc(n: nat, m: nat, dt: data, array: acc[(n*m).dt]): acc[n.m.dt] + +def mapAcc(n: nat, dt1: data, dt2: data, f: acc[dt1] -> acc[dt2], array: acc[n.dt1]): acc[n.dt2] + +def mapFstAcc(dt1: data, dt2: data, dt3: data, f: acc[dt3] -> acc[dt1], record: acc[(dt3, dt2)]): acc[(dt1, dt2)] +def mapRead(n: nat, dt1: data, dt2: data, + f: exp[dt1, read] -> (exp[dt2, read] -> comm) -> comm, input: exp[n.dt1, read]): exp[n.dt2, read] +def mapSndAcc(dt1: data, dt2: data, dt3: data, f: acc[dt3] -> acc[dt2], record: acc[(dt1, dt3)]): acc[(dt1, dt2)] + +// def mkDPairFstI(...) +// def mkDPairSndAcc(...) + +def new(dt: data, f: var[dt] -> comm): comm + +def newDoubleBuffer(dt1: data, dt2: data, dt3: data, n: nat, + in: exp[dt1, read], out: acc[dt2], + f: ((var[n.dt3], comm), comm) -> comm): comm + +def pairAcc(dt1: data, dt2: data, fst: acc[dt1], snd: acc[dt2]): acc[(dt1, dt2)] +def pairAcc1(dt1: data, dt2: data, pair: acc[(dt1, dt2)]): acc[dt1] +def pairAcc2(dt1: data, dt2: data, pair: acc[(dt1, dt2)]): acc[dt2] + +def reorderAcc(n: nat, dt: data, idxF: nat2nat, array: acc[n.dt]): acc[n.dt] + +def scatterAcc(n: nat, m: nat, dt: data, indices: exp[n.idx[m], read], array: acc[m.dt]): acc[n.dt] + +def seq(c1: comm, c2: comm): comm + +def skip(): comm + +def splitAcc(n: nat, m: nat, dt: data, array: acc[m.n.dt]): acc[(n*m).dt] + +def takeAcc(n: nat, m: nat, dt: data, array: acc[(n+m).dt]): acc[n.dt] + +def transposeAcc(n: nat, m: nat, dt: data, array: acc[m.n.dt]): acc[n.m.dt] + +def unzipAcc(n: nat, dt1: data, dt2: data, a: acc[(n.dt1, n.dt2)]): acc[n.(dt1, dt2)] + +def zipAcc1(n: nat, dt1: data, dt2: data, array: acc[n.(dt1, dt2)]): acc[n.dt1] +def zipAcc2(n: nat, dt1: data, dt2: data, array: acc[n.(dt1, dt2)]): acc[n.dt2] diff --git a/src/main/scala/shine/DPIA/primitives/intermediate/MapVecI.scala b/src/main/scala/shine/DPIA/primitives/intermediate/MapVecI.scala index 323afad08..25db6c81d 100644 --- a/src/main/scala/shine/DPIA/primitives/intermediate/MapVecI.scala +++ b/src/main/scala/shine/DPIA/primitives/intermediate/MapVecI.scala @@ -7,7 +7,7 @@ import shine.DPIA._ import shine.OpenMP.DSL.parForVec object MapVecI { - def apply(n: Nat, st1: ScalarType, st2: ScalarType, + def apply(n: Nat, st1: DataType, st2: DataType, f: Phrase[ExpType ->: AccType ->: CommType], in: Phrase[ExpType], out: Phrase[AccType]): Phrase[CommType] = diff --git a/src/main/scala/shine/OpenCL/AdjustArraySizesForAllocations.scala b/src/main/scala/shine/OpenCL/AdjustArraySizesForAllocations.scala index 55efcd53d..423d1b0f2 100644 --- a/src/main/scala/shine/OpenCL/AdjustArraySizesForAllocations.scala +++ b/src/main/scala/shine/OpenCL/AdjustArraySizesForAllocations.scala @@ -8,7 +8,7 @@ import shine.DPIA._ import shine.DPIA.primitives.functional._ import shine.OpenCL.{primitives => ocl} import shine.cuda.{primitives => cuda} -import shine.cuda.primitives.functional.{AsMatrix, GenerateFragment, MapFragmentElements, TensorMatMultAdd, AsFragment} +import shine.cuda.primitives.functional.{AsMatrix, GenerateFragment, MapFragment, TensorMatMultAdd, AsFragment} import shine.cuda.warpDim object AdjustArraySizesForAllocations { @@ -85,7 +85,7 @@ object AdjustArraySizesForAllocations { case _: Identifier[_] | _: Literal | _: Natural | _: VectorFromScalar | _: Cast | _: ForeignFunctionCall | _: BinOp | _: UnaryOp | _: GenerateFragment | - _: AsMatrix | _: AsFragment | _: MapFragmentElements | + _: AsMatrix | _: AsFragment | _: MapFragment | _: TensorMatMultAdd => parallInfo //TODO visit value first? @@ -113,7 +113,7 @@ object AdjustArraySizesForAllocations { } val stride = determineStride(parallLevel, dim, addrSpace) - val outerDimension = ocl.imperative.IdxDistributeAcc(adjSize, oldSize, stride, parallLevel, adjElemT, A) + val outerDimension = ocl.imperative.IdxDistributeAcc(parallLevel)(adjSize, oldSize, stride, adjElemT, A) val arr = identifier(freshName("x"), accT(adjElemT)) val mapFunBody = adjustedAcceptor(parallInfo.tail, adjElemT, oldElemT, addrSpace)(arr) @@ -149,7 +149,7 @@ object AdjustArraySizesForAllocations { } val stride = determineStride(parallLevel, dim, addrSpace) - val outerDimension = ocl.imperative.IdxDistribute(adjSize, oldSize, stride, parallLevel, adjElemT, E) + val outerDimension = ocl.imperative.IdxDistribute(parallLevel)(adjSize, oldSize, stride, adjElemT, E) val arr = identifier(freshName("arr"), expT(adjElemT, read)) val mapFunBody = adjustedExpr(parallInfo.tail, adjElemT, oldElemT, addrSpace)(arr) diff --git a/src/main/scala/shine/OpenCL/Compilation/HostCodeGenerator.scala b/src/main/scala/shine/OpenCL/Compilation/HostCodeGenerator.scala index 23f5d5767..a2bb26544 100644 --- a/src/main/scala/shine/OpenCL/Compilation/HostCodeGenerator.scala +++ b/src/main/scala/shine/OpenCL/Compilation/HostCodeGenerator.scala @@ -25,12 +25,13 @@ case class HostCodeGenerator(override val decls: C.Compilation.CodeGenerator.Dec override def name: String = "OpenCL Host" override def cmd(env: Environment): Phrase[CommType] => Stmt = { - case KernelCallCmd(name, LocalSize(ls), GlobalSize(gs), output, args) => - kernelCallCmd(name, ls, gs, output, args, env) - case NewManagedBuffer(dt, access, Lambda(v, p)) => + case k@KernelCallCmd(name, LocalSize(ls), GlobalSize(gs), args) => + kernelCallCmd(name, ls, gs, k.output, args, env) + case n@NewManagedBuffer(access) => + val (dt, Lambda(v, p)) = n.unwrap newManagedBuffer(dt, access, v, p, env) - case HostExecution(params, body) => - hostExecution(params, body, env) + case h@HostExecution(params) => + hostExecution(params, h.body, env) case phrase => phrase |> super.cmd(env) } @@ -202,7 +203,7 @@ case class HostCodeGenerator(override val decls: C.Compilation.CodeGenerator.Dec C.AST.BinaryExpr(C.AST.ArithmeticExpr(a.size), BinaryOperator.*, bufferSize(a.elemType)) case a: DepArrayType => ??? // TODO case _: DepPairType | _: NatToDataApply | _: DataTypeIdentifier | _: OpaqueType | - _: shine.DPIA.Types.FragmentType | _: pipeline.type => + _: shine.DPIA.Types.FragmentType => throw new Exception(s"did not expect ${dt}") } diff --git a/src/main/scala/shine/OpenCL/Compilation/HostManagedBuffers.scala b/src/main/scala/shine/OpenCL/Compilation/HostManagedBuffers.scala index c9da200e2..51a2f8e8f 100644 --- a/src/main/scala/shine/OpenCL/Compilation/HostManagedBuffers.scala +++ b/src/main/scala/shine/OpenCL/Compilation/HostManagedBuffers.scala @@ -72,7 +72,7 @@ object HostManagedBuffers { i -> access }).toMap env.foreach { case (i, a) => recordManagedAccess(managed, i, a) } - (HostExecution(env, p2), ReadsAndWrites.empty) + (HostExecution(env)(p2), ReadsAndWrites.empty) } } @@ -97,10 +97,10 @@ object HostManagedBuffers { collectWrites(lhs, metadata.host_writes) collectReads(rhs, allocs, metadata.host_reads) Stop(p) - case ocl.KernelCallCmd(_, _, _, out, in) => + case k@ocl.KernelCallCmd(_, _, _, in) => in.foreach(collectReads(_, allocs, metadata.device_reads)) - collectWrites(out, metadata.device_writes) - ((out, DEVICE_WRITE) +: in.map(_ -> DEVICE_READ)).foreach { + collectWrites(k.output, metadata.device_writes) + ((k.output, DEVICE_WRITE) +: in.map(_ -> DEVICE_READ)).foreach { case (i: Identifier[_], a) => recordManagedAccess(managed, i, a) case (Proj1(i: Identifier[_]), a) => recordManagedAccess(managed, i, a) case (Proj2(i: Identifier[_]), a) => recordManagedAccess(managed, i, a) @@ -152,11 +152,14 @@ object HostManagedBuffers { case dpia.New(dt, Lambda(x, body)) if managed.contains(x) => val access = managed(x)._1 val x2 = managed(x)._2.asInstanceOf[Identifier[VarType]] - Continue(ocl.NewManagedBuffer(dt, access, Lambda(x2, body)), this) + Continue(ocl.NewManagedBuffer(access)(dt, Lambda(x2, body)), this) case _: dpia.New | _: Lambda[_, _] | _: dpia.Seq | _: Proj2[_, _] | _: Proj1[_, _] | Natural(_) => Continue(p, this) - case _: ocl.KernelCallCmd => Continue(p, this) + case k@ocl.KernelCallCmd(name, ls, gs, args) => + val newOutput = VisitAndRebuild(k.output, this) + Stop(ocl.KernelCallCmd(name, ls, gs, args.map(VisitAndRebuild(_, this)))( + newOutput.t.dataType, newOutput)) case _: HostExecution => Stop(p) case unexpected => throw new Exception(s"did not expect $unexpected") } @@ -180,7 +183,7 @@ object HostManagedBuffers { case JoinAcc(_, _, _, a) => collectWrites(a, writes) case SplitAcc(_, _, _, a) => collectWrites(a, writes) case AsScalarAcc(_, _, _, a) => collectWrites(a, writes) - case ocl.IdxDistributeAcc(_, _, _, _, _, a) => collectWrites(a, writes) + case idx:ocl.IdxDistributeAcc => collectWrites(idx.array, writes) case PairAcc1(_, _, a) => collectWrites(a, writes) case PairAcc2(_, _, a) => collectWrites(a, writes) case TakeAcc(_, _, _, a) => collectWrites(a, writes) @@ -212,7 +215,7 @@ object HostManagedBuffers { collectReads(e1, allocs, reads); collectReads(e2, allocs, reads) case Slide(_, _, _, _, e) => collectReads(e, allocs, reads) case Map(_, _, _, _, _, e) => collectReads(e, allocs, reads) - case ocl.IdxDistribute(_, _, _, _, _, e) => collectReads(e, allocs, reads) + case idx: ocl.IdxDistribute => collectReads(idx.array, allocs, reads) case dpia.MapRead(_, _, _, _, e) => collectReads(e, allocs, reads) case dpia.GenerateCont(_, _, _) => giveUp() case AsScalar(_, _, _, _, e) => collectReads(e, allocs, reads) @@ -226,12 +229,12 @@ object HostManagedBuffers { case Split(_, _, _, _, e) => collectReads(e, allocs, reads) case Zip(_, _, _, _, e1, e2) => collectReads(e1, allocs, reads); collectReads(e2, allocs, reads) - case Pad(_, _, _, _, e1, e2) => + case PadCst(_, _, _, _, e1, e2) => collectReads(e1, allocs, reads); collectReads(e2, allocs, reads) case PadClamp(_, _, _, _, e) => collectReads(e, allocs, reads) case Cast(_, _, e) => collectReads(e, allocs, reads) - case ForeignFunctionCall(_, _, _, es) => + case ForeignFunctionCall(_, _, es) => es.foreach { collectReads(_, allocs, reads) } @@ -242,7 +245,7 @@ object HostManagedBuffers { case MakePair(_, _, _, e1, e2) => collectReads(e1, allocs, reads); collectReads(e2, allocs, reads) case Reorder(_, _, _, _, _, e) => collectReads(e, allocs, reads) - case MakeArray(_, es) => + case MakeArray(es) => es.foreach { collectReads(_, allocs, reads) } diff --git a/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala b/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala index c2dffa049..af5ede7dd 100644 --- a/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala +++ b/src/main/scala/shine/OpenCL/Compilation/KernelCodeGenerator.scala @@ -38,12 +38,18 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, override def cmd(env: Environment): Phrase[CommType] => Stmt = { case f: ocl.ParFor => - val (i, o, p) = f.unwrapBody - OpenCLCodeGen.codeGenOpenCLParFor(f, f.n, f.dt, f.out, i, o, p, env) + f.body match { + case Lambda(i, Lambda(o, p)) => + OpenCLCodeGen.codeGenOpenCLParFor(f, f.n, f.dt, f.out, i, o, p, env) + case _ => throw new Exception("This should not happen") + } case f: ocl.ParForNat => - val (i, o, p) = f.unwrapBody - OpenCLCodeGen.codeGenOpenCLParForNat(f, f.n, f.out, i, o, p, env) + f.body match { + case DepLambda(i: NatIdentifier, Lambda(o, p)) => + OpenCLCodeGen.codeGenOpenCLParForNat(f, f.n, f.out, i, o, p, env) + case _ => throw new Exception("This should not happen") + } case phrase@Assign(dt, a, e) => dt match { case VectorType(_, _) => @@ -84,20 +90,20 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, override def acc(env: Environment, path: Path, cont: Expr => Stmt): Phrase[AccType] => Stmt = { - case AsVectorAcc(n, _, _, a) => path match { - case (i: CIntExpr) :: ps => a |> acc(env, CIntExpr(i / n) :: ps, cont) + case AsVectorAcc(_, m, _, a) => path match { + case (i: CIntExpr) :: ps => a |> acc(env, CIntExpr(i / m) :: ps, cont) case _ => error(s"Expected path to be not empty") } - case AsScalarAcc(_, m, dt, a) => path match { + case AsScalarAcc(n, _, dt, a) => path match { case (i: CIntExpr) :: (j: CIntExpr) :: ps => - a |> acc(env, CIntExpr((i * m) + j) :: ps, cont) + a |> acc(env, CIntExpr((i * n) + j) :: ps, cont) case (i: CIntExpr) :: Nil => // TODO: check alignment and use pointer with correct address space - a |> acc(env, CIntExpr(i * m) :: Nil, array => { + a |> acc(env, CIntExpr(i * n) :: Nil, array => { // the continuation has to add the value ... val ptr = C.AST.UnaryExpr(C.AST.UnaryOperator.&, array) - cont(C.AST.FunCall(C.AST.DeclRef(s"vstore$m"), + cont(C.AST.FunCall(C.AST.DeclRef(s"vstore$n"), immutable.Seq(C.AST.Literal("0"), ptr))) }) case _ => error(s"Expected path to be not empty") @@ -105,10 +111,10 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, case IdxVecAcc(_, _, i, a) => CCodeGen.codeGenIdxAcc(i, a, env, path, cont) - case ocl.IdxDistributeAcc(_, _, stride, _, _, a) => path match { + case idx: ocl.IdxDistributeAcc => path match { // TODO: ensure that i % stride == init ? case (i: CIntExpr) :: ps => - a |> acc(env, CIntExpr(i / stride) :: ps, cont) + idx.array |> acc(env, CIntExpr(i / idx.stride) :: ps, cont) case _ => error(s"Expected a C-Integer-Expression on the path.") } @@ -161,9 +167,9 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, }) case _ => error(s"unexpected $path") } - case AsScalar(_, m, _, _, e) => path match { + case AsScalar(n, _, _, _, e) => path match { case (i: CIntExpr) :: ps => - e |> exp(env, CIntExpr(i / m) :: CIntExpr(i % m) :: ps, cont) + e |> exp(env, CIntExpr(i / n) :: CIntExpr(i % n) :: ps, cont) case _ => error(s"Expected path to be not empty") } // TODO: this has to be refactored @@ -185,13 +191,13 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, case IdxVec(_, _, i, e) => CCodeGen.codeGenIdx(i, e, env, path, cont) - case OpenCLFunctionCall(name, _, _, args) => + case OpenCLFunctionCall(name, _, args) => CCodeGen.codeGenForeignCall(name, args, env, Nil, cont) - case ocl.IdxDistribute(_, _, stride, _, _, e) => path match { + case idx: ocl.IdxDistribute => path match { // TODO: ensure that i % stride == init ? case (i: CIntExpr) :: ps - => e |> exp(env, CIntExpr(i / stride) :: ps, cont) + => idx.array |> exp(env, CIntExpr(i / idx.stride) :: ps, cont) case _ => error(s"Expected a C-Integer-Expression on the path.") } @@ -307,7 +313,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, p: Phrase[CommType], env: Environment): Stmt = { assert(!f.unroll) - val cI = C.AST.DeclRef(f.name) + val cI = C.AST.DeclRef(freshName(f.prefix)) val range = RangeAdd(f.init, n, f.step) val updatedGen = updatedRanges(cI.name, range) @@ -363,7 +369,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, p: Phrase[CommType], env: Environment): Stmt = { assert(!f.unroll) - val cI = C.AST.DeclRef(f.name) + val cI = C.AST.DeclRef(freshName(f.prefix)) val range = RangeAdd(f.init, n, f.step) val updatedGen = updatedRanges(cI.name, range) @@ -436,7 +442,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, AddressSpace.Private, Some(expr))) } - def codeGenVectorLiteral(n: Int, dt: ScalarType, + def codeGenVectorLiteral(n: Int, dt: DataType, f: Int => Phrase[ExpType], env: Environment, cont: Expr => Stmt, diff --git a/src/main/scala/shine/OpenCL/Compilation/Passes/FlagPrivateArrayLoops.scala b/src/main/scala/shine/OpenCL/Compilation/Passes/FlagPrivateArrayLoops.scala index 808339789..ca31433ac 100644 --- a/src/main/scala/shine/OpenCL/Compilation/Passes/FlagPrivateArrayLoops.scala +++ b/src/main/scala/shine/OpenCL/Compilation/Passes/FlagPrivateArrayLoops.scala @@ -4,7 +4,7 @@ import shine.DPIA.Phrases._ import shine.DPIA.Types.{CommType, PhraseType} import shine.DPIA.primitives.functional.{Idx, NatAsIndex} import shine.DPIA.primitives.imperative.{For, ForNat, IdxAcc} -import shine.DPIA.{ArrayData, Nat} +import shine.DPIA.{ArrayData, Nat, NatIdentifier} import shine.OpenCL.AddressSpace import shine.OpenCL.primitives.imperative.{New, ParFor, ParForNat} @@ -44,9 +44,12 @@ object FlagPrivateArrayLoops { eliminateVars ++= indexingIdents Stop(p) case pf: ParFor if collectIdents(pf.out).exists(privMemIdents(_)) => - val (i, o, _) = pf.unwrapBody - eliminateVars += i.name - Continue(p, this.copy(privMemIdents = privMemIdents + o)) + pf.body match { + case Lambda(i, Lambda(o, _)) => + eliminateVars += i.name + Continue(p, this.copy(privMemIdents = privMemIdents + o)) + case _ => throw new Exception("This should not happen") + } case _ => Continue(p, this) } @@ -60,24 +63,30 @@ object FlagPrivateArrayLoops { eliminateVars: mutable.Set[String]): Phrase[CommType] = { VisitAndRebuild(p, new VisitAndRebuild.Visitor { override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = p match { - case f@For(_) if (eliminateVars(f.unwrapBody._1.name)) => - val (i, _) = f.unwrapBody + case f@For(_) if (eliminateVars(f.loopBody.asInstanceOf[Lambda[_, _]].param.name)) => + val i = f.loopBody.asInstanceOf[Lambda[_, _]].param eliminateVars -= i.name Continue(For(unroll = true)(f.n, f.loopBody), this) - case f@ForNat(_) if (eliminateVars(f.unwrapBody._1.name)) => - val (i, _) = f.unwrapBody + case f@ForNat(_) if (eliminateVars(f.loopBody.asInstanceOf[DepLambda[_, _]].x.name)) => + val i = f.loopBody.asInstanceOf[DepLambda[_, _]].x eliminateVars -= i.name Continue(ForNat(unroll = true)(f.n, f.loopBody), this) - case pf@ParFor(level, dim, _) if (eliminateVars(pf.unwrapBody._1.name)) => - val (i, _, _) = pf.unwrapBody - eliminateVars -= i.name - Continue(ParFor(level, dim, unroll = true) - (pf.n, pf.dt, pf.out, pf.loopBody, pf.init, pf.step), this) - case pf@ParForNat(level, dim, _) if (eliminateVars(pf.unwrapBody._1.name)) => - val (i, _, _) = pf.unwrapBody - eliminateVars -= i.name - Continue(ParForNat(level, dim, unroll = true) - (pf.n, pf.ft, pf.out, pf.loopBody, pf.init, pf.step), this) + case pf@ParFor(level, dim, _, name) if (eliminateVars(pf.body.asInstanceOf[Lambda[_, _]].param.name)) => + pf.body match { + case Lambda(i, _) => + eliminateVars -= i.name + Continue(ParFor(level, dim, unroll = true, name)( + pf.init, pf.n, pf.step, pf.dt, pf.out, pf.body), this) + case _ => throw new Exception("This should not happen") + } + case pf@ParForNat(level, dim, _, name) if (eliminateVars(pf.body.asInstanceOf[DepLambda[_, _]].x.name)) => + pf.body match { + case DepLambda(i: NatIdentifier, _) => + eliminateVars -= i.name + Continue(ParForNat(level, dim, unroll = true, name)( + pf.init, pf.n, pf.step, pf.ft, pf.out, pf.body), this) + case _ => throw new Exception("This should not happen") + } case _ => Continue(p, this) } diff --git a/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala b/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala index 4ef760f11..e772d1d50 100644 --- a/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala +++ b/src/main/scala/shine/OpenCL/Compilation/Passes/InsertMemoryBarriers.scala @@ -1,6 +1,6 @@ package shine.OpenCL.Compilation.Passes -import shine.DPIA.->: +import shine.DPIA.{->:, NatIdentifier} import shine.DPIA.Phrases._ import shine.DPIA.Types._ import shine.DPIA.primitives.functional @@ -61,35 +61,53 @@ object InsertMemoryBarriers { override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = { p match { case f@For(unroll) => - val (x, body) = f.unwrapBody - Stop(For(unroll)(f.n, Lambda(x, visitLoopBody(body, allocs, metadata)))) + f.loopBody match { + case Lambda(x, body) => + Stop(For(unroll)(f.n, Lambda(x, visitLoopBody(body, allocs, metadata)))) + case _ => throw new Exception("This should not happen") + } case f@ForNat(unroll) => - val (x, body) = f.unwrapBody - Stop(ForNat(unroll)(f.n, - DepLambda[NatKind, CommType](x, visitLoopBody(body, allocs, metadata)))) - case pf@ocl.ParFor(Local, dim, unroll) => - val (x, o, body) = pf.unwrapBody - val outer_wg_writes = mutable.Map[Identifier[_ <: PhraseType], AddressSpace]() - collectWrites(pf.out, allocs, outer_wg_writes) - Stop(ocl.ParFor(Local, dim, unroll)(pf.n, pf.dt, pf.out, - Lambda(x, Lambda(o, - visitLoopBody(body, allocs, metadata, outer_wg_writes))), pf.init, pf.step)) - case pf@ocl.ParFor(level, dim, unroll) => - val (x, o, body) = pf.unwrapBody - Stop(ocl.ParFor(level, dim, unroll)(pf.n, pf.dt, pf.out, - Lambda(x, Lambda(o, visitLoopBody(body, allocs, metadata))), pf.init, pf.step)) - case pf@ocl.ParForNat(Local, dim, unroll) => - val (x, o, body) = pf.unwrapBody - val outer_wg_writes = mutable.Map[Identifier[_ <: PhraseType], AddressSpace]() - collectWrites(pf.out, allocs, outer_wg_writes) - Stop(ocl.ParForNat(Local, dim, unroll)(pf.n, pf.ft, pf.out, - DepLambda[NatKind, AccType ->: CommType](x, Lambda(o, - visitLoopBody(body, allocs, metadata, outer_wg_writes))), pf.init, pf.step)) - case pf@ocl.ParForNat(level, dim, unroll) => - val (x, o, body) = pf.unwrapBody - Stop(ocl.ParForNat(level, dim, unroll)(pf.n, pf.ft, pf.out, - DepLambda[NatKind, AccType ->: CommType](x, Lambda(o, - visitLoopBody(body, allocs, metadata))), pf.init, pf.step)) + f.loopBody match { + case DepLambda(x, body) => + Stop(ForNat(unroll)(f.n, + DepLambda[NatKind, CommType](x, visitLoopBody(body, allocs, metadata)))) + case _ => throw new Exception("This should not happen") + } + case pf@ocl.ParFor(Local, dim, unroll, name) => + pf.body match { + case Lambda(x, Lambda(o, body)) => + val outer_wg_writes = mutable.Map[Identifier[_ <: PhraseType], AddressSpace]() + collectWrites(pf.out, allocs, outer_wg_writes) + Stop(ocl.ParFor(Local, dim, unroll, name)(pf.init, pf.n, pf.step, pf.dt, pf.out, + Lambda(x, Lambda(o, + visitLoopBody(body, allocs, metadata, outer_wg_writes))))) + case _ => throw new Exception("This should not happen") + } + case pf@ocl.ParFor(level, dim, unroll, name) => + pf.body match { + case Lambda(x, Lambda(o, body)) => + Stop(ocl.ParFor(level, dim, unroll, name)(pf.init, pf.n, pf.step, pf.dt, pf.out, + Lambda(x, Lambda(o, visitLoopBody(body, allocs, metadata))))) + case _ => throw new Exception("This should not happen") + } + case pf@ocl.ParForNat(Local, dim, unroll, name) => + pf.body match { + case DepLambda(i: NatIdentifier, Lambda(o, p)) => + val outer_wg_writes = mutable.Map[Identifier[_ <: PhraseType], AddressSpace]() + collectWrites(pf.out, allocs, outer_wg_writes) + Stop(ocl.ParForNat(Local, dim, unroll, name)(pf.init, pf.n, pf.step, pf.ft, pf.out, + DepLambda[NatKind, AccType ->: CommType](i, Lambda(o, + visitLoopBody(p, allocs, metadata, outer_wg_writes))))) + case _ => throw new Exception("This should not happen") + } + case pf@ocl.ParForNat(level, dim, unroll, name) => + pf.body match { + case DepLambda(i: NatIdentifier, Lambda(o, p)) => + Stop(ocl.ParForNat(level, dim, unroll, name)(pf.init, pf.n, pf.step, pf.ft, pf.out, + DepLambda[NatKind, AccType ->: CommType](i, Lambda(o, + visitLoopBody(p, allocs, metadata))))) + case _ => throw new Exception("This should not happen") + } case ocl.New(addr, _, Lambda(x, _)) if addr != AddressSpace.Private => Continue(p, Visitor(allocs + (x -> addr), metadata)) case ocl.NewDoubleBuffer(addr, dt1, dt2, dt3, n, in, out, Lambda(x, body)) @@ -141,7 +159,7 @@ object InsertMemoryBarriers { case JoinAcc(_, _, _, a) => collectWrites(a, allocs, writes) case SplitAcc(_, _, _, a) => collectWrites(a, allocs, writes) case AsScalarAcc(_, _, _, a) => collectWrites(a, allocs, writes) - case ocl.IdxDistributeAcc(_, _, _, _, _, a) => collectWrites(a, allocs, writes) + case idx: ocl.IdxDistributeAcc => collectWrites(idx.array, allocs, writes) case PairAcc1(_, _, a) => collectWrites(a, allocs, writes) case PairAcc2(_, _, a) => collectWrites(a, allocs, writes) case PairAcc(_, _, a, b) => @@ -175,7 +193,7 @@ object InsertMemoryBarriers { collectReads(e1, allocs, reads); collectReads(e2, allocs, reads) case Slide(_, _, _, _, e) => collectReads(e, allocs, reads) case functional.Map(_, _, _, _, _, e) => collectReads(e, allocs, reads) - case ocl.IdxDistribute(_, _, _, _, _, e) => collectReads(e, allocs, reads) + case idx: ocl.IdxDistribute => collectReads(idx.array, allocs, reads) case MapRead(_, _, _, _, e) => collectReads(e, allocs, reads) case GenerateCont(_, _, _) => giveUp() case AsScalar(_, _, _, _, e) => collectReads(e, allocs, reads) @@ -189,12 +207,12 @@ object InsertMemoryBarriers { case Split(_, _, _, _, e) => collectReads(e, allocs, reads) case Zip(_, _, _, _, e1, e2) => collectReads(e1, allocs, reads); collectReads(e2, allocs, reads) - case Pad(_, _, _, _, e1, e2) => + case PadCst(_, _, _, _, e1, e2) => collectReads(e1, allocs, reads); collectReads(e2, allocs, reads) case PadClamp(_, _, _, _, e) => collectReads(e, allocs, reads) case Cast(_, _, e) => collectReads(e, allocs, reads) - case ForeignFunctionCall(_, _, _, es) => + case ForeignFunctionCall(_, _, es) => es.foreach { collectReads(_, allocs, reads) } @@ -205,7 +223,7 @@ object InsertMemoryBarriers { case MakePair(_, _, _, e1, e2) => collectReads(e1, allocs, reads); collectReads(e2, allocs, reads) case Reorder(_, _, _, _, _, e) => collectReads(e, allocs, reads) - case MakeArray(_, es) => + case MakeArray(es) => es.foreach { collectReads(_, allocs, reads) } diff --git a/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala b/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala index 25c57d303..d6622ff82 100644 --- a/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala +++ b/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala @@ -19,16 +19,15 @@ object SeparateHostAndKernelCode { var kernelDefinitions = mutable.ArrayBuffer[KernelDef]() val hostDefinition = VisitAndRebuild(p, new VisitAndRebuild.Visitor { override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = p match { - case Run(localSize, globalSize, _, value) => + case r@Run(localSize, globalSize) => val name = s"k$kernelNum" kernelNum += 1 - val (closedDef, args) = closeDefinition(value) + val (closedDef, args) = closeDefinition(r.input) val kernelDef = KernelDef(name, closedDef, localSize, globalSize) kernelDefinitions += kernelDef Stop(KernelCall(name, localSize, globalSize, kernelDef.paramTypes.map(_.dataType), - kernelDef.returnType.dataType, - args).asInstanceOf[Phrase[T]]) + args)(kernelDef.returnType.dataType).asInstanceOf[Phrase[T]]) // on the fly beta-reduction case Apply(fun, arg) => diff --git a/src/main/scala/shine/OpenCL/DSL/package.scala b/src/main/scala/shine/OpenCL/DSL/package.scala index 2196f4af3..4ef029e70 100644 --- a/src/main/scala/shine/OpenCL/DSL/package.scala +++ b/src/main/scala/shine/OpenCL/DSL/package.scala @@ -9,32 +9,51 @@ import shine.OpenCL.primitives.imperative._ package object DSL { + def parFor(level: ParallelismLevel, + dim: Int, + unroll: Boolean + ): (Nat, DataType, Phrase[AccType], Phrase[FunType[ExpType, FunType[AccType, CommType]]]) => ParFor = + level match { + case Global => ParFor(level, dim, unroll, "gl_id_")( + get_global_id(dim), _, get_global_size(dim), _, _, _) + case Local => ParFor(level, dim, unroll, "l_id_")( + get_local_id(dim), _, get_local_size(dim), _, _, _) + case WorkGroup => ParFor(level, dim, unroll, "wg_id_")( + get_group_id(dim), _, get_num_groups(dim), _, _, _) + case Sequential | Warp | Lane => throw new Exception("This should not happen") + } + + def parForNat(level: ParallelismLevel, + dim: Int, + unroll: Boolean + ): (Nat, NatToData, Phrase[AccType], Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) => ParForNat = + level match { + case Global => ParForNat(level, dim, unroll, "gl_id_")( + get_global_id(dim), _, get_global_size(dim), _, _, _) + case Local => ParForNat(level, dim, unroll, "l_id_")( + get_local_id(dim), _, get_local_size(dim), _, _, _) + case WorkGroup => ParForNat(level, dim, unroll, "wg_id_")( + get_group_id(dim), _, get_num_groups(dim), _, _, _) + case Sequential | Warp | Lane => throw new Exception("This should not happen") + } + private def parForBodyFunction(n:Nat, ft:NatToData, f:NatIdentifier => Phrase[AccType] => Phrase[CommType] ): DepLambda[NatKind, AccType ->: CommType] = { nFun(idx => λ(accT(ft(idx)))(o => f(idx)(o)), RangeAdd(0, n, 1)) } - object parForNatGlobal { - def apply(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType], - f:NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat = { - ParForNat(Global, dim, unroll = false)(n, ft, out, parForBodyFunction(n, ft, f)) - } - } + def parForNatGlobal(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType], + f: NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat = + parForNat(Global, dim, unroll = false)(n, ft, out, parForBodyFunction(n, ft, f)) - object parForNatWorkGroup { - def apply(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType], - f:NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat = { - ParForNat(WorkGroup, dim, unroll = false)(n, ft, out, parForBodyFunction(n, ft, f)) - } - } + def parForNatWorkGroup(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType], + f:NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat = + parForNat(WorkGroup, dim, unroll = false)(n, ft, out, parForBodyFunction(n, ft, f)) - object parForNatLocal { - def apply(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType], - f:NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat = { - ParForNat(Local, dim, unroll = false)(n, ft, out, parForBodyFunction(n, ft, f)) - } - } + def parForNatLocal(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType], + f:NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat = + parForNat(Local, dim, unroll = false)(n, ft, out, parForBodyFunction(n, ft, f)) object `new` { def apply(addrSpace: shine.DPIA.Types.AddressSpace) @@ -60,6 +79,6 @@ package object DSL { } object barrier { - def apply(local: Boolean = true, global: Boolean = true) = Barrier(local, global) + def apply(local: Boolean = true, global: Boolean = true) = Barrier(local, global)() } } diff --git a/src/main/scala/shine/OpenCL/KernelExecutor.scala b/src/main/scala/shine/OpenCL/KernelExecutor.scala index a00dfeb6f..9f97232b0 100644 --- a/src/main/scala/shine/OpenCL/KernelExecutor.scala +++ b/src/main/scala/shine/OpenCL/KernelExecutor.scala @@ -378,8 +378,7 @@ object KernelExecutor { case DepArrayType(_, NatToDataLambda(_, elemType)) => getOutputType(elemType) case DepArrayType(_, _) | _: NatToDataApply => throw new Exception("This should not happen") - case _: DepPairType | _: ManagedBufferType | _: OpaqueType | - _: FragmentType | _: pipeline.type => + case _: DepPairType | _: ManagedBufferType | _: OpaqueType | _: FragmentType => throw new Exception(s"${dt} not supported as output type") } @@ -408,7 +407,7 @@ object KernelExecutor { throw new Exception("This should not happen") } case _: DepPairType | _: NatToDataApply | _: DataTypeIdentifier | - _: ManagedBufferType | _: OpaqueType | _: FragmentType | _: pipeline.type => + _: ManagedBufferType | _: OpaqueType | _: FragmentType => throw new Exception(s"the byte size of ${dt} should not be requested") } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/CircularBuffer.scala b/src/main/scala/shine/OpenCL/primitives/functional/CircularBuffer.scala index 9cc0e4305..f4b206cbf 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/CircularBuffer.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/CircularBuffer.scala @@ -1,22 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class CircularBuffer(a: AddressSpace, - n: Nat, - alloc: Nat, - sz: Nat, - dt1: DataType, - dt2: DataType, - load: Phrase[ExpType ->: ExpType], - input: Phrase[ExpType] - ) extends ExpPrimitive { - load :: expT(dt1, read) ->: expT(dt2, write) - input :: expT((n - 1 + sz)`.`dt1, read) - override val t: ExpType = expT(n`.`(sz`.`dt2), read) -} \ No newline at end of file +final case class CircularBuffer(val a: AddressSpace, val n: Nat, val alloc: Nat, val sz: Nat, val dt1: DataType, val dt2: DataType, val load: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive { + { + load :: FunType(expT(dt1, read), expT(dt2, write)) + input :: expT(ArrayType(n - 1 + sz, dt1), read) + } + override val t: ExpType = expT(ArrayType(n, ArrayType(sz, dt2)), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): CircularBuffer = new CircularBuffer(v.addressSpace(a), v.nat(n), v.nat(alloc), v.nat(sz), v.data(dt1), v.data(dt2), VisitAndRebuild(load, v), VisitAndRebuild(input, v)) +} diff --git a/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala b/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala index 3ffe8829a..adfe41170 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/DepMap.scala @@ -1,25 +1,21 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.OpenCL.ParallelismLevel -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class DepMap(level: ParallelismLevel, - dim: Int) - (val n: Nat, - val ft1:NatToData, - val ft2:NatToData, - val f: Phrase[`(nat)->:`[ExpType ->: ExpType]], - val array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: f.t.x ->: expT(ft1(f.t.x), read) ->: expT(ft2(f.t.x), write) - array :: expT(n `.d` ft1, read) - override val t: ExpType = expT(n`.d`ft2, write) - - def unwrap: (Nat, NatToData, NatToData, Phrase[`(nat)->:`[ExpType ->: ExpType]], Phrase[ExpType]) = - (n, ft1, ft2, f, array) +final case class DepMap(level: shine.OpenCL.ParallelismLevel, dim: Int)(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: ({ + val m = f.t.x + DepFunType[NatKind, PhraseType](m, FunType(expT(NatToDataApply(ft1, m), read), expT(NatToDataApply(ft2, m), write))) + }) + array :: expT(DepArrayType(n, ft1), read) + } + override val t: ExpType = expT(DepArrayType(n, ft2), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepMap = new DepMap(level, dim)(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) + def unwrap: (Nat, NatToData, NatToData, Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], Phrase[ExpType]) = (n, ft1, ft2, f, array) } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala b/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala index 2016c5556..0387931b5 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/Iterate.scala @@ -1,24 +1,20 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Iterate(a: AddressSpace, - n: Nat, - m: Nat, - k: Nat, - dt: DataType, - f: Phrase[`(nat)->:`[ExpType ->: ExpType]], - array: Phrase[ExpType] - ) extends ExpPrimitive { +final case class Iterate(val a: AddressSpace, val n: Nat, val m: Nat, val k: Nat, val dt: DataType, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { { - val l = f.t.x - f :: l ->: expT({l * n}`.`dt, read) ->: expT(l`.`dt, write) - array :: expT({m * n.pow(k)}`.`dt, read) + f :: ({ + val l = f.t.x + DepFunType[NatKind, PhraseType](l, FunType(expT(ArrayType(l * n, dt), read), expT(ArrayType(l, dt), write))) + }) + array :: expT(ArrayType(m * n.pow(k), dt), read) } - override val t: ExpType = expT(m`.`dt, write) + override val t: ExpType = expT(ArrayType(m, dt), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Iterate = new Iterate(v.addressSpace(a), v.nat(n), v.nat(m), v.nat(k), v.data(dt), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala b/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala index edff14958..d2f150c9a 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/KernelCall.scala @@ -1,20 +1,14 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.OpenCL.{GlobalSize, LocalSize} -import shine.macros.Primitive.expPrimitive - -@expPrimitive -case class KernelCall(name: String, - localSize: LocalSize, - globalSize: GlobalSize, - inTs: Seq[DataType], - outT: DataType, - args: Seq[Phrase[ExpType]]) extends ExpPrimitive { - (inTs zip args).foreach{ - case (inT, arg) => arg :: expT(inT, read) - } +final case class KernelCall(name: String, localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize, inTs: Seq[DataType], args: Seq[Phrase[ExpType]])(val outT: DataType) extends ExpPrimitive { + {} override val t: ExpType = expT(outT, write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): KernelCall = new KernelCall(name, localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v), inTs.map(v.data), args.map(VisitAndRebuild(_, v)))(v.data(outT)) } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/Map.scala b/src/main/scala/shine/OpenCL/primitives/functional/Map.scala index 45137ca96..a99170788 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/Map.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/Map.scala @@ -1,25 +1,18 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.OpenCL.ParallelismLevel -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Map(level: ParallelismLevel, - dim: Int) - (val n: Nat, - val dt1: DataType, - val dt2: DataType, - val f: Phrase[ExpType ->: ExpType], - val array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt1, read) ->: expT(dt2, write) - array :: expT(n `.` dt1, read) - override val t: ExpType = expT(n `.` dt2, write) - - def unwrap: (Nat, DataType, DataType, Phrase[ExpType ->: ExpType], Phrase[ExpType]) = - (n, dt1, dt2, f, array) +final case class Map(level: shine.OpenCL.ParallelismLevel, dim: Int)(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt1, read), expT(dt2, write)) + array :: expT(ArrayType(n, dt1), read) + } + override val t: ExpType = expT(ArrayType(n, dt2), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Map = new Map(level, dim)(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) + def unwrap: (Nat, DataType, DataType, Phrase[FunType[ExpType, ExpType]], Phrase[ExpType]) = (n, dt1, dt2, f, array) } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala b/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala index 2c47bfe8a..a3f2ec8a0 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/OpenCLFunctionCall.scala @@ -1,20 +1,14 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -import scala.language.reflectiveCalls - -@expPrimitive -final case class OpenCLFunctionCall(name: String, - inTs: Seq[DataType], - outT: DataType, - args: Seq[Phrase[ExpType]] - ) extends ExpPrimitive { - (inTs zip args).foreach{ - case (inT, arg) => arg :: expT(inT, read) - } - override val t: ExpType = expT(outT, read) +final case class OpenCLFunctionCall(name: String, inTs: Seq[DataType], args: Seq[Phrase[ExpType]])(val outT: DataType) extends ExpPrimitive { + {} + override val t: ExpType = expT(outT, write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): OpenCLFunctionCall = new OpenCLFunctionCall(name, inTs.map(v.data), args.map(VisitAndRebuild(_, v)))(v.data(outT)) } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/ReduceSeq.scala b/src/main/scala/shine/OpenCL/primitives/functional/ReduceSeq.scala index 78e416668..f00544913 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/ReduceSeq.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/ReduceSeq.scala @@ -1,23 +1,19 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class ReduceSeq(unroll: Boolean) - (val n: Nat, - val initAddrSpace: AddressSpace, - val dt1: DataType, - val dt2: DataType, - val f: Phrase[ExpType ->: ExpType ->: ExpType], - val init: Phrase[ExpType], - val array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt2, read) ->: expT(dt1, read) ->: expT(dt2, write) - init :: expT(dt2, write) - array :: expT(n`.`dt1, read) +final case class ReduceSeq(unroll: Boolean)(val n: Nat, val a: AddressSpace, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], val init: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt2, read), FunType(expT(dt1, read), expT(dt2, write))) + init :: expT(dt2, write) + array :: expT(ArrayType(n, dt1), read) + } override val t: ExpType = expT(dt2, read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ReduceSeq = new ReduceSeq(unroll)(v.nat(n), v.addressSpace(a), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(init, v), VisitAndRebuild(array, v)) + def unwrap: (Nat, AddressSpace, DataType, DataType, Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], Phrase[ExpType], Phrase[ExpType]) = (n, a, dt1, dt2, f, init, array) } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/RotateValues.scala b/src/main/scala/shine/OpenCL/primitives/functional/RotateValues.scala index bf8f2fe27..093afa594 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/RotateValues.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/RotateValues.scala @@ -1,20 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class RotateValues(a: AddressSpace, - n: Nat, - sz: Nat, - dt: DataType, - write: Phrase[ExpType ->: ExpType], - input: Phrase[ExpType] - ) extends ExpPrimitive { - write :: expT(dt, read) ->: expT(dt, shine.DPIA.Types.write) - input :: expT((n - 1 + sz)`.`dt, read) - override val t: ExpType = expT(n`.`(sz`.`dt), read) -} \ No newline at end of file +final case class RotateValues(val a: AddressSpace, val n: Nat, val sz: Nat, val dt: DataType, val wrt: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive { + { + wrt :: FunType(expT(dt, read), expT(dt, write)) + input :: expT(ArrayType(n - 1 + sz, dt), read) + } + override val t: ExpType = expT(ArrayType(n, ArrayType(sz, dt)), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): RotateValues = new RotateValues(v.addressSpace(a), v.nat(n), v.nat(sz), v.data(dt), VisitAndRebuild(wrt, v), VisitAndRebuild(input, v)) +} diff --git a/src/main/scala/shine/OpenCL/primitives/functional/Run.scala b/src/main/scala/shine/OpenCL/primitives/functional/Run.scala index 25e276c2d..6dc810e2c 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/Run.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/Run.scala @@ -1,17 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.functional - -import shine.DPIA._ +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.OpenCL.{GlobalSize, LocalSize} -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Run(localSize: LocalSize, - globalSize: GlobalSize, - dt: DataType, - input: Phrase[ExpType] - ) extends ExpPrimitive { - input :: expT(dt, write) +import shine.DPIA._ +final case class Run(localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize)(val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { + { + input :: expT(dt, write) + } override val t: ExpType = expT(dt, write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Run = new Run(localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v))(v.data(dt), VisitAndRebuild(input, v)) + def unwrap: (DataType, Phrase[ExpType]) = (dt, input) } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/ToMem.scala b/src/main/scala/shine/OpenCL/primitives/functional/ToMem.scala index e7cd552bc..26cd17bb2 100644 --- a/src/main/scala/shine/OpenCL/primitives/functional/ToMem.scala +++ b/src/main/scala/shine/OpenCL/primitives/functional/ToMem.scala @@ -1,15 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class ToMem(addrSpace: AddressSpace, - dt: DataType, - input: Phrase[ExpType] - ) extends ExpPrimitive { - input :: expT(dt, write) +final case class ToMem(val a: AddressSpace, val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { + { + input :: expT(dt, write) + } override val t: ExpType = expT(dt, read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ToMem = new ToMem(v.addressSpace(a), v.data(dt), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/OpenCL/primitives/functional/primitives.dpia b/src/main/scala/shine/OpenCL/primitives/functional/primitives.dpia new file mode 100644 index 000000000..7f24dbfd8 --- /dev/null +++ b/src/main/scala/shine/OpenCL/primitives/functional/primitives.dpia @@ -0,0 +1,45 @@ +def circularBuffer(a: address, n: nat, alloc: nat, sz: nat, dt1: data, dt2: data, + load: exp[dt1, read] -> exp[dt2, write], + input: exp[(n-1+sz).dt1, read]): exp[n.sz.dt2, read] + +def depMap{level: shine.OpenCL.ParallelismLevel, dim: Int} + (n: nat, ft1: nat2data, ft2: nat2data, + f: (m: nat) -> exp[ft1(m), read] -> exp[ft2(m), write], + array: exp[n..ft1, read]): exp[n..ft2, write] + +def iterate(a: address, n: nat, m: nat, k: nat, dt: data, + f: (l: nat) -> exp[(l*n).dt, read] -> exp[l.dt, write], + array: exp[(m*(n^k)).dt, read]): exp[m.dt, write] + +def kernelCall{name: String, + localSize: shine.OpenCL.LocalSize, + globalSize: shine.OpenCL.GlobalSize, + inTs: Seq[DataType], + args: Seq[Phrase[ExpType]]} + (outT: data): exp[outT, write] + +def map{level: shine.OpenCL.ParallelismLevel, dim: Int} + (n: nat, dt1: data, dt2: data, + f: exp[dt1, read] -> exp[dt2, write], + array: exp[n.dt1, read]): exp[n.dt2, write] + +def openCLFunctionCall{name: String, + inTs: Seq[DataType], + args: Seq[Phrase[ExpType]]} + (outT: data): exp[outT, write] + +def reduceSeq{unroll: Boolean} + (n: nat, a: address, dt1: data, dt2: data, + f: exp[dt2, read] -> exp[dt1, read] -> exp[dt2, write], + init: exp[dt2, write], + array: exp[n.dt1, read]): exp[dt2, read] + +def rotateValues(a: address, n: nat, sz: nat, dt: data, + wrt: exp[dt, read] -> exp[dt, write], + input: exp[(n-1+sz).dt, read]): exp[n.sz.dt, read] + +def run{localSize: shine.OpenCL.LocalSize, + globalSize: shine.OpenCL.GlobalSize} + (dt: data, input: exp[dt, write]): exp[dt, write] + +def toMem(a: address, dt: data, input: exp[dt, write]): exp[dt, read] diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/Barrier.scala b/src/main/scala/shine/OpenCL/primitives/imperative/Barrier.scala index 9f0520381..cf222f8b7 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/Barrier.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/Barrier.scala @@ -1,11 +1,14 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.imperative - -import shine.DPIA.Phrases.CommandPrimitive -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class Barrier(local: Boolean, global: Boolean) extends CommandPrimitive { - override def prettyPrint: String = - s"""barrier( ${if(local) "CLK_LOCAL_MEM_FENCE" else ""} ${if(global && local) "|" else ""} - ${if(global) "CLK_GLOBAL_MEM_FENCE" else ""})""" +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ +import shine.DPIA.Types._ +import shine.DPIA._ +final case class Barrier(local: Boolean, global: Boolean)() extends CommandPrimitive { + {} + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Barrier = new Barrier(local, global)() } diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/HostExecution.scala b/src/main/scala/shine/OpenCL/primitives/imperative/HostExecution.scala index a3054141c..54f977635 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/HostExecution.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/HostExecution.scala @@ -1,20 +1,19 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.OpenCL.AccessFlags -import shine.macros.Primitive.comPrimitive - -// In a host program which contains host managed buffers, -// A host execution is a section where only plain arrays are used -@comPrimitive -final case class HostExecution(params: Map[Identifier[_ <: PhraseType], AccessFlags], - body: Phrase[CommType]) extends CommandPrimitive { - body :: comm - - override def visitAndRebuild(f: VisitAndRebuild.Visitor): Phrase[CommType] = - HostExecution( - params.map({ case (k, v) => - VisitAndRebuild(k, f).asInstanceOf[Identifier[_ <: PhraseType]] -> v }), - VisitAndRebuild(body, f)) -} \ No newline at end of file +import shine.DPIA._ +final case class HostExecution(params: Map[Identifier[_ <: PhraseType], shine.OpenCL.AccessFlags])(val body: Phrase[CommType]) extends CommandPrimitive { + { + body :: comm + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): HostExecution = new HostExecution(params.map({ + case (key, value) => + VisitAndRebuild(key, v).asInstanceOf[Identifier[_ <: PhraseType]] -> value + }))(VisitAndRebuild(body, v)) +} diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistribute.scala b/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistribute.scala index 09c65e30f..c046d1597 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistribute.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistribute.scala @@ -1,20 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.OpenCL.ParallelismLevel -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class IdxDistribute(m: Nat, - n: Nat, - stride: Nat, - parallelismLevel: ParallelismLevel, - dt: DataType, - array: Phrase[ExpType] - ) extends ExpPrimitive { - array :: expT(m`.`dt, read) - override val t: ExpType = expT(n`.`dt, read) +final case class IdxDistribute(parallelismLevel: shine.OpenCL.ParallelismLevel)(val m: Nat, val n: Nat, val stride: Nat, val dt: DataType, val array: Phrase[ExpType]) extends ExpPrimitive { + { + array :: expT(ArrayType(m, dt), read) + } + override val t: ExpType = expT(ArrayType(n, dt), read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxDistribute = new IdxDistribute(parallelismLevel)(v.nat(m), v.nat(n), v.nat(stride), v.data(dt), VisitAndRebuild(array, v)) + def unwrap: (Nat, Nat, Nat, DataType, Phrase[ExpType]) = (m, n, stride, dt, array) } diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistributeAcc.scala b/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistributeAcc.scala index 0e7de8392..0de4a9e08 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistributeAcc.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/IdxDistributeAcc.scala @@ -1,20 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.DPIA.{Nat, _} -import shine.OpenCL.ParallelismLevel -import shine.macros.Primitive.accPrimitive - -@accPrimitive -final case class IdxDistributeAcc(m: Nat, - n: Nat, - stride: Nat, - parallelismLevel: ParallelismLevel, - dt: DataType, - array: Phrase[AccType] - ) extends AccPrimitive { - array :: accT(m`.`dt) - override val t: AccType = accT(n`.`dt) +import shine.DPIA._ +final case class IdxDistributeAcc(parallelismLevel: shine.OpenCL.ParallelismLevel)(val m: Nat, val n: Nat, val stride: Nat, val dt: DataType, val array: Phrase[AccType]) extends AccPrimitive { + { + array :: accT(ArrayType(m, dt)) + } + override val t: AccType = accT(ArrayType(n, dt)) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): IdxDistributeAcc = new IdxDistributeAcc(parallelismLevel)(v.nat(m), v.nat(n), v.nat(stride), v.data(dt), VisitAndRebuild(array, v)) + def unwrap: (Nat, Nat, Nat, DataType, Phrase[AccType]) = (m, n, stride, dt, array) } diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala b/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala index 4864c7b20..2e420b356 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/KernelCallCmd.scala @@ -1,14 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.OpenCL.{GlobalSize, LocalSize} -import shine.macros.Primitive.comPrimitive - -@comPrimitive -case class KernelCallCmd(name: String, - localSize: LocalSize, - globalSize: GlobalSize, - output: Phrase[AccType], - args: Seq[Phrase[ExpType]]) extends CommandPrimitive - +import shine.DPIA._ +final case class KernelCallCmd(name: String, localSize: shine.OpenCL.LocalSize, globalSize: shine.OpenCL.GlobalSize, args: Seq[Phrase[ExpType]])(val dt: DataType, val output: Phrase[AccType]) extends CommandPrimitive { + { + output :: accT(dt) + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): KernelCallCmd = new KernelCallCmd(name, localSize.visitAndRebuild(v), globalSize.visitAndRebuild(v), args.map(VisitAndRebuild(_, v)))(v.data(dt), VisitAndRebuild(output, v)) + def unwrap: (DataType, Phrase[AccType]) = (dt, output) +} diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/New.scala b/src/main/scala/shine/OpenCL/primitives/imperative/New.scala index 4af7b56f6..4b455cb27 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/New.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/New.scala @@ -1,14 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class New(a: AddressSpace, - dt: DataType, - f: Phrase[VarType ->: CommType] - ) extends CommandPrimitive { - f :: varT(dt) ->: comm +final case class New(val a: AddressSpace, val dt: DataType, val f: Phrase[FunType[PhrasePairType[ExpType, AccType], CommType]]) extends CommandPrimitive { + { + f :: FunType(PhrasePairType(expT(dt, read), accT(dt)), comm) + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): New = new New(v.addressSpace(a), v.data(dt), VisitAndRebuild(f, v)) } diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/NewDoubleBuffer.scala b/src/main/scala/shine/OpenCL/primitives/imperative/NewDoubleBuffer.scala index b96f98d74..9e50f0b6a 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/NewDoubleBuffer.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/NewDoubleBuffer.scala @@ -1,22 +1,18 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.imperative - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class NewDoubleBuffer(a: AddressSpace, - dt1: DataType, - dt2: DataType, - dt3: DataType, - n: Nat, - in: Phrase[ExpType], - out: Phrase[AccType], - f: Phrase[(ExpType x AccType x CommType x CommType) ->: CommType] - ) extends CommandPrimitive { - in :: expT(dt1, read) - out :: accT(dt2) - f :: (((varT(n`.`dt3) x comm) x comm) ->: comm) +final case class NewDoubleBuffer(val a: AddressSpace, val dt1: DataType, val dt2: DataType, val dt3: DataType, val n: Nat, val in: Phrase[ExpType], val out: Phrase[AccType], val f: Phrase[FunType[PhrasePairType[PhrasePairType[PhrasePairType[ExpType, AccType], CommType], CommType], CommType]]) extends CommandPrimitive { + { + in :: expT(dt1, read) + out :: accT(dt2) + f :: FunType(PhrasePairType(PhrasePairType(PhrasePairType(expT(ArrayType(n, dt3), read), accT(ArrayType(n, dt3))), comm), comm), comm) + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): NewDoubleBuffer = new NewDoubleBuffer(v.addressSpace(a), v.data(dt1), v.data(dt2), v.data(dt3), v.nat(n), VisitAndRebuild(in, v), VisitAndRebuild(out, v), VisitAndRebuild(f, v)) } diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/NewManagedBuffer.scala b/src/main/scala/shine/OpenCL/primitives/imperative/NewManagedBuffer.scala index bfdc31ede..5c439ae17 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/NewManagedBuffer.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/NewManagedBuffer.scala @@ -1,14 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.imperative - -import shine.DPIA._ -import shine.DPIA.Types._ +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ -import shine.OpenCL.AccessFlags -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class NewManagedBuffer(dt: DataType, - access: AccessFlags, - k: Phrase[VarType ->: CommType]) extends CommandPrimitive { - k :: varT(ManagedBufferType(dt)) ->: comm +import shine.DPIA.Types.DataType._ +import shine.DPIA.Types._ +import shine.DPIA._ +final case class NewManagedBuffer(access: shine.OpenCL.AccessFlags)(val dt: DataType, val k: Phrase[FunType[PhrasePairType[ExpType, AccType], CommType]]) extends CommandPrimitive { + { + k :: FunType(PhrasePairType(expT(ManagedBufferType(dt), read), accT(ManagedBufferType(dt))), comm) + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): NewManagedBuffer = new NewManagedBuffer(access)(v.data(dt), VisitAndRebuild(k, v)) + def unwrap: (DataType, Phrase[FunType[PhrasePairType[ExpType, AccType], CommType]]) = (dt, k) } diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/ParFor.scala b/src/main/scala/shine/OpenCL/primitives/imperative/ParFor.scala index 52b1759a5..47efa67bb 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/ParFor.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/ParFor.scala @@ -1,53 +1,18 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.imperative - -import shine.DPIA.Phrases.{CommandPrimitive, _} +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.OpenCL._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class ParFor(level: ParallelismLevel, - dim: Int, - unroll: Boolean) - (val n: Nat, - val dt: DataType, - val out: Phrase[AccType], - val loopBody: Phrase[ExpType ->: AccType ->: CommType], - val init: Nat = ParFor.initInit(level, dim), - val step: Nat = ParFor.initStep(level, dim) - ) extends CommandPrimitive { - val name: String = ParFor.initName(level) - - out :: accT(n`.`dt) - loopBody :: expT(idx(n), read) ->: accT(dt) ->: comm - - lazy val unwrapBody: (Identifier[ExpType], Identifier[AccType], Phrase[CommType]) = loopBody match { - case Lambda(i, Lambda(o, body)) => (i, o, body) - case _ => throw new Exception("This should not happen") - } -} - -object ParFor { - def initInit(level: ParallelismLevel, dim: Int): Nat = level match { - case Global => get_global_id(dim) - case Local => get_local_id (dim) - case WorkGroup => get_group_id (dim) - case Sequential | Warp | Lane => throw new Exception("This should not happen") - } - - def initStep(level: ParallelismLevel, dim: Int): Nat = level match { - case Global => get_global_size(dim) - case Local => get_local_size (dim) - case WorkGroup => get_num_groups (dim) - case Sequential | Warp | Lane => throw new Exception("This should not happen") - } - - def initName(level: ParallelismLevel): String = level match { - case Global => freshName("gl_id_") - case Local => freshName( "l_id_") - case WorkGroup => freshName("wg_id_") - case Sequential | Warp | Lane => throw new Exception("This should not happen") +final case class ParFor(level: shine.OpenCL.ParallelismLevel, dim: Int, unroll: Boolean, prefix: String)(val init: Nat, val n: Nat, val step: Nat, val dt: DataType, val out: Phrase[AccType], val body: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive { + { + out :: accT(ArrayType(n, dt)) + body :: FunType(expT(IndexType(n), read), FunType(accT(dt), comm)) } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParFor = new ParFor(level, dim, unroll, prefix)(v.nat(init), v.nat(n), v.nat(step), v.data(dt), VisitAndRebuild(out, v), VisitAndRebuild(body, v)) + def unwrap: (Nat, Nat, Nat, DataType, Phrase[AccType], Phrase[FunType[ExpType, FunType[AccType, CommType]]]) = (init, n, step, dt, out, body) } diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala b/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala index a7e94c82b..bd45c3fd0 100644 --- a/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala +++ b/src/main/scala/shine/OpenCL/primitives/imperative/ParForNat.scala @@ -1,52 +1,21 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenCL.primitives.imperative - -import shine.DPIA.Phrases.{CommandPrimitive, _} +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.OpenCL._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class ParForNat(level: ParallelismLevel, - dim: Int, - unroll: Boolean) - (val n: Nat, - val ft: NatToData, - val out: Phrase[AccType], - val loopBody: Phrase[`(nat)->:`[AccType ->: CommType]], - val init: Nat = ParForNat.initInit(level, dim), - val step: Nat = ParForNat.initStep(level, dim), - val name: String = ParForNat.initName(level) - ) extends CommandPrimitive { - out :: accT(n`.d`ft) - loopBody :: loopBody.t.x ->: accT(ft(loopBody.t.x)) ->: comm - - lazy val unwrapBody: (NatIdentifier, Identifier[AccType], Phrase[CommType]) = loopBody match { - case DepLambda(i: NatIdentifier, Lambda(o, body)) => (i, o, body) - case _ => throw new Exception("This should not happen") - } -} - -object ParForNat { - def initInit(level: ParallelismLevel, dim: Int): Nat = level match { - case Global => get_global_id(dim) - case Local => get_local_id(dim) - case WorkGroup => get_group_id(dim) - case Sequential | Warp | Lane => throw new Exception("This should not happen") - } - - def initStep(level: ParallelismLevel, dim: Int): Nat = level match { - case Global => get_global_size(dim) - case Local => get_local_size(dim) - case WorkGroup => get_num_groups(dim) - case Sequential | Warp | Lane => throw new Exception("This should not happen") - } - - def initName(level: ParallelismLevel): String = level match { - case Global => freshName("gl_id_") - case Local => freshName("l_id_") - case WorkGroup => freshName("wg_id_") - case Sequential | Warp | Lane => throw new Exception("This should not happen") +final case class ParForNat(level: shine.OpenCL.ParallelismLevel, dim: Int, unroll: Boolean, prefix: String)(val init: Nat, val n: Nat, val step: Nat, val ft: NatToData, val out: Phrase[AccType], val body: Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) extends CommandPrimitive { + { + out :: accT(DepArrayType(n, ft)) + body :: ({ + val i = body.t.x + DepFunType[NatKind, PhraseType](i, FunType(accT(NatToDataApply(ft, i)), comm)) + }) } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParForNat = new ParForNat(level, dim, unroll, prefix)(v.nat(init), v.nat(n), v.nat(step), v.natToData(ft), VisitAndRebuild(out, v), VisitAndRebuild(body, v)) + def unwrap: (Nat, Nat, Nat, NatToData, Phrase[AccType], Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) = (init, n, step, ft, out, body) } diff --git a/src/main/scala/shine/OpenCL/primitives/imperative/primitives.dpia b/src/main/scala/shine/OpenCL/primitives/imperative/primitives.dpia new file mode 100644 index 000000000..5b78dbaa0 --- /dev/null +++ b/src/main/scala/shine/OpenCL/primitives/imperative/primitives.dpia @@ -0,0 +1,45 @@ +def barrier{local: Boolean, global: Boolean}(): comm + +// In a host program which contains host managed buffers, +// A host execution is a section where only plain arrays are used +def hostExecution{params: Map[Identifier[_ <: PhraseType], shine.OpenCL.AccessFlags]} + (body: comm): comm + +def idxDistribute{parallelismLevel: shine.OpenCL.ParallelismLevel} + (m: nat, n: nat, stride: nat, dt: data, array: exp[m.dt, read]): exp[n.dt, read] + +def idxDistributeAcc{parallelismLevel: shine.OpenCL.ParallelismLevel} + (m: nat, n: nat, stride: nat, dt: data, array: acc[m.dt]): acc[n.dt] + +def kernelCallCmd{name: String, + localSize: shine.OpenCL.LocalSize, + globalSize: shine.OpenCL.GlobalSize, + args: Seq[Phrase[ExpType]]} + (dt: data, output: acc[dt]): comm + +def new(a: address, dt: data, f: var[dt] -> comm): comm + +def newDoubleBuffer(a: address, dt1: data, dt2: data, dt3: data, n: nat, + in: exp[dt1, read], out: acc[dt2], + f: ((var[n.dt3], comm), comm) -> comm): comm + +def newManagedBuffer{access: shine.OpenCL.AccessFlags} + (dt: data, k: var[managed[dt]] -> comm): comm + +def parFor{level: shine.OpenCL.ParallelismLevel, + dim: Int, + unroll: Boolean, + prefix: String} + (init: nat, n: nat, step: nat, + dt: data, + out: acc[n.dt], + body: exp[idx[n], read] -> acc[dt] -> comm): comm + +def parForNat{level: shine.OpenCL.ParallelismLevel, + dim: Int, + unroll: Boolean, + prefix: String} + (init: nat, n: nat, step: nat, + ft: nat2data, + out: acc[n..ft], + body: (i: nat) -> acc[ft(i)] -> comm): comm diff --git a/src/main/scala/shine/OpenCL/primitives/intermediate/MapI.scala b/src/main/scala/shine/OpenCL/primitives/intermediate/MapI.scala index 2489221dc..0618afeec 100644 --- a/src/main/scala/shine/OpenCL/primitives/intermediate/MapI.scala +++ b/src/main/scala/shine/OpenCL/primitives/intermediate/MapI.scala @@ -6,7 +6,6 @@ import shine.DPIA.Types.DataType.idx import shine.DPIA.Types._ import shine.DPIA._ import shine.OpenCL._ -import shine.OpenCL.primitives.imperative.ParFor final case class MapI(level: ParallelismLevel, dim: Int) { def apply(n: Nat, dt1: DataType, dt2: DataType, @@ -14,7 +13,7 @@ final case class MapI(level: ParallelismLevel, dim: Int) { in: Phrase[ExpType], out: Phrase[AccType]): Phrase[CommType] = { comment(s"map${level.toString}") `;` - ParFor(level, dim, unroll = false)(n, dt2, out, + shine.OpenCL.DSL.parFor(level, dim, unroll = false)(n, dt2, out, λ(expT(idx(n), read))(i => λ(accT(dt2))(a => f(in `@` i)(a)))) } } diff --git a/src/main/scala/shine/OpenMP/CodeGenerator.scala b/src/main/scala/shine/OpenMP/CodeGenerator.scala index ff9271f18..e567531c9 100644 --- a/src/main/scala/shine/OpenMP/CodeGenerator.scala +++ b/src/main/scala/shine/OpenMP/CodeGenerator.scala @@ -46,19 +46,19 @@ class CodeGenerator(override val decls: CCodeGenerator.Declarations, override def acc(env: Environment, path: Path, cont: Expr => Stmt): Phrase[AccType] => Stmt = { - case AsVectorAcc(n, _, _, a) => path match { - case (i: CIntExpr) :: ps => a |> acc(env, CIntExpr(i / n) :: ps, cont) + case AsVectorAcc(_, m, _, a) => path match { + case (i: CIntExpr) :: ps => a |> acc(env, CIntExpr(i / m) :: ps, cont) case _ => error(s"Expected path to be not empty") } - case AsScalarAcc(_, m, dt, a) => path match { + case AsScalarAcc(n, _, dt, a) => path match { case (i: CIntExpr) :: (j: CIntExpr) :: ps => - a |> acc(env, CIntExpr((i * m) + j) :: ps, cont) + a |> acc(env, CIntExpr((i * n) + j) :: ps, cont) case (i: CIntExpr) :: Nil => - a |> acc(env, CIntExpr(i * m) :: Nil, { + a |> acc(env, CIntExpr(i * n) :: Nil, { case ArraySubscript(v, idx) => // emit something like: ((struct float4 *)v)[idx] - val ptrType = C.AST.PointerType(typ(VectorType(m, dt))) + val ptrType = C.AST.PointerType(typ(VectorType(n, dt))) cont(C.AST.ArraySubscript(C.AST.Cast(ptrType, v), idx)) }) case _ => error(s"Expected path to be not empty") @@ -93,9 +93,9 @@ class CodeGenerator(override val decls: CCodeGenerator.Declarations, } case _ => phrase |> super.exp(env, path, cont) } - case ForeignFunctionCall(f, inTs, outT, args) => - OpenMPCodeGen.codeGenForeignFunctionCall(f, inTs, outT, args, env, path, cont) - case AsVectorAligned(n, _, _, dt, e) => path match { + case ffc@ForeignFunctionCall(f, inTs, args) => + OpenMPCodeGen.codeGenForeignFunctionCall(f, inTs, ffc.outT, args, env, path, cont) + case AsVectorAligned(n, _, dt, _, e) => path match { case (i: CIntExpr) :: (j: CIntExpr) :: ps => e |> exp(env, CIntExpr((i * n) + j) :: ps, cont) @@ -277,7 +277,7 @@ class CodeGenerator(override val decls: CCodeGenerator.Declarations, } } - def codeGenForeignFunctionCall(funDecl: ForeignFunction.Declaration, + def codeGenForeignFunctionCall(funDecl: rise.core.ForeignFunction.Decl, inTs: collection.Seq[DataType], outT: DataType, args: collection.Seq[Phrase[ExpType]], diff --git a/src/main/scala/shine/OpenMP/DSL/ImperativePrimitives.scala b/src/main/scala/shine/OpenMP/DSL/ImperativePrimitives.scala index 9b21e9972..9fb9b594e 100644 --- a/src/main/scala/shine/OpenMP/DSL/ImperativePrimitives.scala +++ b/src/main/scala/shine/OpenMP/DSL/ImperativePrimitives.scala @@ -19,7 +19,7 @@ object parFor { object `parForVec` { def apply(n: Nat, - st: ScalarType, + st: DataType, out: Phrase[AccType], f: Phrase[ExpType] => Phrase[AccType] => Phrase[CommType]): ForVec = ForVec(n, st, out, λ(expT(idx(n), read))(i => λ(accT(st))(o => f(i)(o) ))) diff --git a/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala b/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala index e6f2afea6..9035e541e 100644 --- a/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala +++ b/src/main/scala/shine/OpenMP/primitives/functional/DepMapPar.scala @@ -1,19 +1,20 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenMP.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class DepMapPar(n: Nat, - ft1: NatToData, - ft2: NatToData, - f: Phrase[`(nat)->:`[ExpType ->: ExpType]], - array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: f.t.x ->: expT(ft1(f.t.x), read) ->: expT(ft2(f.t.x), write) - array :: expT(n `.d` ft1, read) - override val t: ExpType = expT(n`.d`ft2, write) +final case class DepMapPar(val n: Nat, val ft1: NatToData, val ft2: NatToData, val f: Phrase[DepFunType[NatKind, FunType[ExpType, ExpType]]], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: ({ + val m = f.t.x + DepFunType[NatKind, PhraseType](m, FunType(expT(NatToDataApply(ft1, m), read), expT(NatToDataApply(ft2, m), write))) + }) + array :: expT(DepArrayType(n, ft1), read) + } + override val t: ExpType = expT(DepArrayType(n, ft2), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): DepMapPar = new DepMapPar(v.nat(n), v.natToData(ft1), v.natToData(ft2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/OpenMP/primitives/functional/MapPar.scala b/src/main/scala/shine/OpenMP/primitives/functional/MapPar.scala index c84edc469..1b622ffd4 100644 --- a/src/main/scala/shine/OpenMP/primitives/functional/MapPar.scala +++ b/src/main/scala/shine/OpenMP/primitives/functional/MapPar.scala @@ -1,19 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenMP.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class MapPar(n: Nat, - dt1: DataType, - dt2: DataType, - f: Phrase[ExpType ->: ExpType], - array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt1, read) ->: expT(dt2, write) - array :: expT(n`.`dt1, read) - override val t: ExpType = expT(n`.`dt2, write) +final case class MapPar(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt1, read), expT(dt2, write)) + array :: expT(ArrayType(n, dt1), read) + } + override val t: ExpType = expT(ArrayType(n, dt2), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapPar = new MapPar(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/OpenMP/primitives/functional/ReducePar.scala b/src/main/scala/shine/OpenMP/primitives/functional/ReducePar.scala index 37f148228..9cc4bf236 100644 --- a/src/main/scala/shine/OpenMP/primitives/functional/ReducePar.scala +++ b/src/main/scala/shine/OpenMP/primitives/functional/ReducePar.scala @@ -1,20 +1,18 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenMP.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class ReducePar(n: Nat, - dt1: DataType, dt2: DataType, - f: Phrase[ExpType ->: ExpType ->: ExpType], - init: Phrase[ExpType], - array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt2, read) ->: expT(dt1, read) ->: expT(dt2, write) - init :: expT(dt2, write) - array :: expT(n`.`dt1, read) +final case class ReducePar(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, FunType[ExpType, ExpType]]], val init: Phrase[ExpType], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt2, read), FunType(expT(dt1, read), expT(dt2, write))) + init :: expT(dt2, write) + array :: expT(ArrayType(n, dt1), read) + } override val t: ExpType = expT(dt2, read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ReducePar = new ReducePar(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(init, v), VisitAndRebuild(array, v)) } diff --git a/src/main/scala/shine/OpenMP/primitives/functional/primitives.dpia b/src/main/scala/shine/OpenMP/primitives/functional/primitives.dpia new file mode 100644 index 000000000..ffd7dde02 --- /dev/null +++ b/src/main/scala/shine/OpenMP/primitives/functional/primitives.dpia @@ -0,0 +1,12 @@ +def depMapPar(n: nat, ft1: nat2data, ft2: nat2data, + f: (m: nat) -> exp[ft1(m), read] -> exp[ft2(m), write], + array: exp[n..ft1, read]): exp[n..ft2, write] + +def mapPar(n: nat, dt1: data, dt2: data, + f: exp[dt1, read] -> exp[dt2, write], + array: exp[n.dt1, read]): exp[n.dt2, write] + +def reducePar(n: nat, dt1: data, dt2: data, + f: exp[dt2, read] -> exp[dt1, read] -> exp[dt2, write], + init: exp[dt2, write], + array: exp[n.dt1, read]): exp[dt2, read] \ No newline at end of file diff --git a/src/main/scala/shine/OpenMP/primitives/imperative/ParFor.scala b/src/main/scala/shine/OpenMP/primitives/imperative/ParFor.scala index 3c1a0006b..d467a4c29 100644 --- a/src/main/scala/shine/OpenMP/primitives/imperative/ParFor.scala +++ b/src/main/scala/shine/OpenMP/primitives/imperative/ParFor.scala @@ -1,22 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenMP.primitives.imperative - -import shine.DPIA.Phrases.{CommandPrimitive, Phrase, _} +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class ParFor(n: Nat, - dt: DataType, - out: Phrase[AccType], - body: Phrase[ExpType ->: AccType ->: CommType] - ) extends CommandPrimitive { - out :: accT(n`.`dt) - body :: expT(idx(n), read) ->: accT(dt) ->: comm - - lazy val unwrapBody: (Identifier[ExpType], Identifier[AccType], Phrase[CommType]) = body match { - case Lambda(i, Lambda(o, body)) => (i, o, body) - case _ => throw new Exception("This should not happen") +final case class ParFor(val n: Nat, val dt: DataType, val out: Phrase[AccType], val body: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive { + { + out :: accT(ArrayType(n, dt)) + body :: FunType(expT(IndexType(n), read), FunType(accT(dt), comm)) } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParFor = new ParFor(v.nat(n), v.data(dt), VisitAndRebuild(out, v), VisitAndRebuild(body, v)) } diff --git a/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala b/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala index fc6eed716..92b4f87b3 100644 --- a/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala +++ b/src/main/scala/shine/OpenMP/primitives/imperative/ParForNat.scala @@ -1,22 +1,20 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.OpenMP.primitives.imperative - -import shine.DPIA.Phrases.{CommandPrimitive, Phrase, _} +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class ParForNat(n: Nat, - ft: NatToData, - out: Phrase[AccType], - body: Phrase[`(nat)->:`[AccType ->: CommType]] - ) extends CommandPrimitive { - out :: accT(n`.d`ft) - body :: body.t.x ->: accT(ft(body.t.x)) ->: comm - - lazy val unwrapBody: (NatIdentifier, Identifier[AccType], Phrase[CommType]) = body match { - case DepLambda(n, Lambda(o, body)) => (n, o, body) - case _ => throw new Exception("This should not happen") +final case class ParForNat(val n: Nat, val ft: NatToData, val out: Phrase[AccType], val body: Phrase[DepFunType[NatKind, FunType[AccType, CommType]]]) extends CommandPrimitive { + { + out :: accT(DepArrayType(n, ft)) + body :: ({ + val i = body.t.x + DepFunType[NatKind, PhraseType](i, FunType(accT(NatToDataApply(ft, i)), comm)) + }) } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParForNat = new ParForNat(v.nat(n), v.natToData(ft), VisitAndRebuild(out, v), VisitAndRebuild(body, v)) } diff --git a/src/main/scala/shine/OpenMP/primitives/imperative/primitives.dpia b/src/main/scala/shine/OpenMP/primitives/imperative/primitives.dpia new file mode 100644 index 000000000..4ceca55c9 --- /dev/null +++ b/src/main/scala/shine/OpenMP/primitives/imperative/primitives.dpia @@ -0,0 +1,7 @@ +def parFor(n: nat, dt: data, + out: acc[n.dt], + body: exp[idx[n], read] -> acc[dt] -> comm): comm + +def parForNat(n: nat, ft: nat2data, + out: acc[n..ft], + body: (i: nat) -> acc[ft(i)] -> comm): comm \ No newline at end of file diff --git a/src/main/scala/shine/cuda/AST/Types.scala b/src/main/scala/shine/cuda/AST/Types.scala index eb683847b..34c41d417 100644 --- a/src/main/scala/shine/cuda/AST/Types.scala +++ b/src/main/scala/shine/cuda/AST/Types.scala @@ -10,8 +10,8 @@ object Wmma { case MatrixLayout.Row_Major => "nvcuda::wmma::row_major" case MatrixLayout.Col_Major => "nvcuda::wmma::col_major" case i: MatrixLayoutIdentifier => - if (i.layout.isDefined) - toString(i.layout.get) + if (i.layout != MatrixLayout.None) + toString(i.layout) else throw new Exception(s"layout $i not infered!") case _ => throw new Exception("this should not happen") diff --git a/src/main/scala/shine/cuda/Compilation/KernelCodeGenerator.scala b/src/main/scala/shine/cuda/Compilation/KernelCodeGenerator.scala index 28b37898d..9f9b76eff 100644 --- a/src/main/scala/shine/cuda/Compilation/KernelCodeGenerator.scala +++ b/src/main/scala/shine/cuda/Compilation/KernelCodeGenerator.scala @@ -48,8 +48,11 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, override def cmd(env: Environment): Phrase[CommType] => Stmt = { case f: ParFor => - val (i, o, p) = f.unwrapBody - codeGenCudaParFor(f, f.n, f.dt, f.out, i, o, p, env) + f.body match { + case Lambda(i, Lambda(o, p)) => + codeGenCudaParFor(f, f.n, f.dt, f.out, i, o, p, env) + case _ => throw new Exception("This should not happen") + } case WmmaLoad(m, n, k, _, fragType, layoutIdentifier, matrix, fragmentAcc) => matrix |> exp(env, List(CIntExpr(0), CIntExpr(0)), matrixTile => { @@ -92,7 +95,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, C.AST.ArithmeticExpr(ldm), C.AST.DeclRef(toString(layout)))))})) - case WmmaFill(_, _, _, _, fill, _, _, fragmentAcc) => + case WmmaFill(_, _, _, _, _, _, fill, fragmentAcc) => fill |> exp(env, Nil, fill => fragmentAcc |> acc(env, Nil, fragment => C.AST.ExprStmt(C.AST.FunCall( @@ -114,7 +117,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, bMatrix, cMatrix))))))) - case ForFragmentElements(fragType, inFragment, outFragmemt, Lambda(in, Lambda(out, p))) => + case ForFragment(_, _, _, dt, _, _, inFragment, outFragmemt, Lambda(in, Lambda(out, p))) => inFragment |> exp(env, Nil, fragmentIn => outFragmemt |> acc(env, Nil, fragmentOut => { val n = C.AST.StructMemberAccess(fragmentOut, C.AST.DeclRef("num_elements")) @@ -142,8 +145,8 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, C.AST.ForLoop(C.AST.DeclStmt(init), cond, increment, C.AST.Block( immutable.Seq( - C.AST.DeclStmt(C.AST.VarDecl(xIInPointer.name, PointerType(typ(fragType.dataType)), init = Some(xIInDecl))), - C.AST.DeclStmt(C.AST.VarDecl(xIOutPointer.name, PointerType(typ(fragType.dataType)), init = Some(xIOutDecl))), + C.AST.DeclStmt(C.AST.VarDecl(xIInPointer.name, PointerType(typ(dt)), init = Some(xIInDecl))), + C.AST.DeclStmt(C.AST.VarDecl(xIOutPointer.name, PointerType(typ(dt)), init = Some(xIOutDecl))), p |> updatedGen.cmd(env updatedIdentEnv (in -> xIIn) updatedIdentEnv (out -> xIOut))))))})) case Assign(_, lhsAcc, rhs) => @@ -230,18 +233,18 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, case _ => error(s"Expected path to be not empty") } - case AsScalarAcc(_, m, dt, a) => path match { + case AsScalarAcc(n, _, dt, a) => path match { case (i : CIntExpr) :: (j : CIntExpr) :: ps => - a |> acc(env, CIntExpr((i * m) + j) :: ps, cont) + a |> acc(env, CIntExpr((i * n) + j) :: ps, cont) case (i : CIntExpr) :: Nil => - a |> acc(env,CIntExpr(i * m) :: Nil, array => { + a |> acc(env,CIntExpr(i * n) :: Nil, array => { cont( //acces first (vector-)element pointed by the pointer C.AST.ArraySubscript( //cast pointer to array to pointer of vectorType C.AST.Cast( - C.AST.PointerType(getVectorType(dt, m)), + C.AST.PointerType(getVectorType(dt, n)), C.AST.UnaryExpr(C.AST.UnaryOperator.&, array)), C.AST.ArithmeticExpr(0)))}) @@ -254,7 +257,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, override def exp(env: Environment, path: Path, cont: Expr => Stmt): Phrase[ExpType] => Stmt = { - case phrase@AsVectorAligned(n, _, _, dt, e) => path match { + case phrase@AsVectorAligned(n, _, dt, _, e) => path match { case (i : CIntExpr) :: (j : CIntExpr) :: ps => e |> exp(env, CIntExpr((i * n) + j) :: ps, cont) @@ -288,14 +291,14 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, C.AST.FragmentType(m, n, k, typ(dataType), fragmentKind, layout) case shine.DPIA.Types.f16 => cuda.AST.Type.half - case shine.DPIA.Types.pipeline => + case shine.DPIA.Types.OpaqueType("pipeline") => cuda.AST.Type.pipeline case _ => super.typ(dt) } - private def getVectorType(dt: ScalarType, n: Nat): Type = { + private def getVectorType(dt: DataType, n: Nat): Type = { if (n.eval > 0 && n.eval <= 4) dt match { case shine.DPIA.Types.u8 => BasicType(s"uchar$n") @@ -306,13 +309,13 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, case shine.DPIA.Types.u32 => BasicType(s"uint$n") case shine.DPIA.Types.f32 => BasicType(s"float$n") case shine.DPIA.Types.f64 => BasicType(s"double$n") - case _ => ??? + case _ => throw new Exception(s"Can't create vector type from: ($dt, $n)") } else dt match { case shine.DPIA.Types.f16 if (n.eval > 0 && n.eval <= 8) => BasicType(s"float${n/2}") case shine.DPIA.Types.f16 if (n.eval > 0 && n.eval <= 16) => BasicType(s"double${n/4}") - case _ => ??? + case _ => throw new Exception(s"Can't create vector type from: ($dt, $n)") } } @@ -331,7 +334,7 @@ class KernelCodeGenerator(override val decls: CCodeGenerator.Declarations, p: Phrase[CommType], env: Environment): Stmt = { assert(!f.unroll) - val cI = C.AST.DeclRef(f.name) + val cI = C.AST.DeclRef(freshName(f.prefix)) val range = RangeAdd(f.init, n, f.step) val updatedGen = updatedRanges(cI.name, range) diff --git a/src/main/scala/shine/cuda/Compilation/TranslationContext.scala b/src/main/scala/shine/cuda/Compilation/TranslationContext.scala index 92ea837aa..41f8b7c5a 100644 --- a/src/main/scala/shine/cuda/Compilation/TranslationContext.scala +++ b/src/main/scala/shine/cuda/Compilation/TranslationContext.scala @@ -5,15 +5,15 @@ import shine.DPIA.Phrases.Phrase import shine.DPIA.Types.{AccType, CommType, DataType, ExpType, FragmentType, read} import shine.DPIA.primitives.imperative.Assign import shine.DPIA.{accT, expT} -import shine.cuda.primitives.imperative.ForFragmentElements +import shine.cuda.primitives.imperative.ForFragment class TranslationContext() extends shine.OpenCL.Compilation.TranslationContext { override def assign(dt: DataType, lhs: Phrase[AccType], rhs: Phrase[ExpType]): Phrase[CommType] = { dt match { - case f: FragmentType => - ForFragmentElements(f, rhs, lhs, + case FragmentType(rows, columns, layers, dt, frag, layout) => + ForFragment(rows, columns, layers, dt, frag, layout, rhs, lhs, λ(expT(dt, read))(x => λ(accT(dt))(o => Assign(dt, o, x)))) diff --git a/src/main/scala/shine/cuda/DSL/package.scala b/src/main/scala/shine/cuda/DSL/package.scala new file mode 100644 index 000000000..5223be1e1 --- /dev/null +++ b/src/main/scala/shine/cuda/DSL/package.scala @@ -0,0 +1,27 @@ +package shine.cuda + +import shine.DPIA.Nat +import shine.DPIA.Phrases._ +import shine.DPIA.Types._ +import shine.OpenCL._ +import shine.cuda.primitives.imperative.ParFor + +package object DSL { + def parFor(level: ParallelismLevel, + dim: Int, + unroll: Boolean + ): (Nat, DataType, Phrase[AccType], Phrase[FunType[ExpType, FunType[AccType, CommType]]]) => ParFor = + level match { + case Global => ParFor(level, dim, unroll, "gl_id_")( + globalId(dim), _, globalDim(dim), _, _, _) + case Local => ParFor(level, dim, unroll, "tid_")( + threadId(dim), _, blockDim(dim), _, _, _) + case WorkGroup => ParFor(level, dim, unroll, "block_id_")( + blockId(dim), _, gridDim(dim), _, _, _) + case Warp => ParFor(level, dim, unroll, "warp_id_")( + warpId(dim), _, warpDim(dim), _, _, _) + case Lane => ParFor(level, dim, unroll, "lane_id_")( + laneId(dim), _, warpSize, _, _, _) + case Sequential => throw new Exception("This should not happen") + } +} diff --git a/src/main/scala/shine/cuda/package.scala b/src/main/scala/shine/cuda/package.scala index 306b1d282..39cc51db1 100644 --- a/src/main/scala/shine/cuda/package.scala +++ b/src/main/scala/shine/cuda/package.scala @@ -1,7 +1,9 @@ package shine -import arithexpr.arithmetic.{ArithExpr, ContinuousRange, PosInf, Range, SimplifiedExpr} -import shine.OpenCL.BuiltInFunctionCall +import arithexpr.arithmetic.{ArithExpr, SimplifiedExpr} +import shine.DPIA.Nat +import shine.DPIA.Phrases.Phrase +import shine.DPIA.Types.{DataType, ExpType, FragmentKind, MatrixLayoutIdentifier} package object cuda { @@ -39,4 +41,11 @@ package object cuda { def apply(param: Int): ArithExpr with SimplifiedExpr = blockDim(param) * gridDim(param) } + + object AsFragment{ + def apply(rows: Nat, columns: Nat, layers: Nat, dataType: DataType, + frag: FragmentKind, matrix: Phrase[ExpType]): shine.cuda.primitives.functional.AsFragment = + shine.cuda.primitives.functional.AsFragment(rows, columns, layers, dataType, + frag, MatrixLayoutIdentifier("ml"), matrix) + } } diff --git a/src/main/scala/shine/cuda/primitives/functional/AsFragment.scala b/src/main/scala/shine/cuda/primitives/functional/AsFragment.scala index 20000579c..31d151cf2 100644 --- a/src/main/scala/shine/cuda/primitives/functional/AsFragment.scala +++ b/src/main/scala/shine/cuda/primitives/functional/AsFragment.scala @@ -1,36 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.functional - -import shine.DPIA.Nat +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.macros.Primitive.expPrimitive - -object AsFragment{ - def apply(rows: Nat, columns: Nat, d3: Nat, dataType: DataType, fragmentType: FragmentKind, matrix: Phrase[ExpType]): - AsFragment = AsFragment(rows, columns, d3, dataType, fragmentType, matrix, MatrixLayoutIdentifier("ml")) -} - -/** - * Returns a fragment from a matrix tile which must resides in shared or global memory ({@link WmmaLoad}).
- * This primitive needs to be executed by a full warp! - * @param rows number of rows of the fragment ({@link FragmentType#rows}) - * @param columns number of columns of the fragment ({@link FragmentType#columns}) - * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3}) - * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype}) - * @param fragmentKind kind of the fragment ({@link FragmentType#fragmentKind}) - * @param layout layout of the fragment ({@link FragmentType#layout}). - * The layout will be infered in Codegeneration. Hence a `MatrixLayoutIdentifier` can be - * used as layout. - */ -@expPrimitive -case class AsFragment(rows: Nat, - columns: Nat, - d3: Nat, - dataType: DataType, - fragmentKind: FragmentKind, - matrix: Phrase[ExpType], - layout: MatrixLayout) extends ExpPrimitive{ - - matrix :: ExpType(ArrayType(rows, ArrayType(columns, dataType)), write) - override val t: ExpType = ExpType(FragmentType(rows, columns, d3, dataType, fragmentKind, layout), write) +import shine.DPIA._ +final case class AsFragment(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val input: Phrase[ExpType]) extends ExpPrimitive { + { + input :: expT(ArrayType(rows, ArrayType(columns, dt)), read) + } + override val t: ExpType = expT(FragmentType(rows, columns, layers, dt, frag, layout), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsFragment = new AsFragment(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/cuda/primitives/functional/AsMatrix.scala b/src/main/scala/shine/cuda/primitives/functional/AsMatrix.scala index fbf2c3c96..c3ea35b06 100644 --- a/src/main/scala/shine/cuda/primitives/functional/AsMatrix.scala +++ b/src/main/scala/shine/cuda/primitives/functional/AsMatrix.scala @@ -1,27 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.functional - -import shine.DPIA.Nat +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.macros.Primitive.expPrimitive - -/** - * Returns a matrix tile with the elements from the `Accumulator`-fragment. The matrix tile must resides - * in shared or global memory ({@link WmmaStore}).
- * This primitive needs to be executed by a full warp! - * @param rows number of rows of the fragment ({@link FragmentType#rows}) - * @param columns number of columns of the fragment ({@link FragmentType#columns}) - * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3}) - * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype}) - * @param fragment fragment from which the elements should be stored - */ -@expPrimitive -case class AsMatrix(rows: Nat, - columns: Nat, - d3: Nat, - dataType: DataType, - fragment: Phrase[ExpType]) extends ExpPrimitive { - - fragment :: ExpType(FragmentType(rows, columns, d3, dataType), read) - override val t: ExpType = ExpType(ArrayType(rows, ArrayType(columns, dataType)), write) +import shine.DPIA._ +final case class AsMatrix(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { + { + input :: expT(FragmentType(rows, columns, layers, dt, FragmentKind.Accumulator, MatrixLayout.None), read) + } + override val t: ExpType = expT(ArrayType(rows, ArrayType(columns, dt)), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): AsMatrix = new AsMatrix(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), VisitAndRebuild(input, v)) } diff --git a/src/main/scala/shine/cuda/primitives/functional/GenerateFragment.scala b/src/main/scala/shine/cuda/primitives/functional/GenerateFragment.scala index 7bb57180f..fa0144660 100644 --- a/src/main/scala/shine/cuda/primitives/functional/GenerateFragment.scala +++ b/src/main/scala/shine/cuda/primitives/functional/GenerateFragment.scala @@ -1,31 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.functional - -import shine.DPIA.Nat +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.macros.Primitive.expPrimitive - -/** - * Returns a fragment in which all elements have a specific value ({@link WmmaFill}).
- * This primitive needs to be executed by a full warp! - * @param rows number of rows of the fragment ({@link FragmentType#rows}) - * @param columns number of columns of the fragment ({@link FragmentType#columns}) - * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3}) - * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype}) - * @param fill new value of all elements in the fragment (type of fill: `dataType`) - * @param fragmentKind kind of the fragment ({@link FragmentType#fragmentKind}) - * @param layout layout of the fragment ({@link FragmentType#layout}) - */ -@expPrimitive -case class GenerateFragment(rows: Nat, - columns: Nat, - d3: Nat, - dataType: DataType, - fill: Phrase[ExpType], - fragmentKind: FragmentKind, - layout: MatrixLayout) extends ExpPrimitive { - - fill :: ExpType(dataType, read) - - override val t: ExpType = ExpType(FragmentType(rows, columns, d3, dataType), write) +import shine.DPIA._ +final case class GenerateFragment(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val fill: Phrase[ExpType]) extends ExpPrimitive { + { + fill :: expT(dt, read) + } + override val t: ExpType = expT(FragmentType(rows, columns, layers, dt, frag, layout), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): GenerateFragment = new GenerateFragment(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(fill, v)) } diff --git a/src/main/scala/shine/cuda/primitives/functional/GlobalToShared.scala b/src/main/scala/shine/cuda/primitives/functional/GlobalToShared.scala index caecc088b..cc4f06188 100644 --- a/src/main/scala/shine/cuda/primitives/functional/GlobalToShared.scala +++ b/src/main/scala/shine/cuda/primitives/functional/GlobalToShared.scala @@ -1,21 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.functional - -import shine.DPIA.DSL.{`new` => _} +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.expPrimitive - -/** - * Returns a copy in shared memory of data in global memory ({@link GlobalToSharedAcc}). - * @param dt datatype of data which should be copied - * @param inputGlobal data in global memory which should be copied to shared memory - */ -@expPrimitive -final case class GlobalToShared(dt: DataType, - inputGlobal: Phrase[ExpType]) extends ExpPrimitive { - - inputGlobal :: expT(dt, write) +final case class GlobalToShared(val dt: DataType, val input: Phrase[ExpType]) extends ExpPrimitive { + { + input :: expT(dt, write) + } override val t: ExpType = expT(dt, read) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): GlobalToShared = new GlobalToShared(v.data(dt), VisitAndRebuild(input, v)) } - diff --git a/src/main/scala/shine/cuda/primitives/functional/Map.scala b/src/main/scala/shine/cuda/primitives/functional/Map.scala index 8805e60a7..7695393cd 100644 --- a/src/main/scala/shine/cuda/primitives/functional/Map.scala +++ b/src/main/scala/shine/cuda/primitives/functional/Map.scala @@ -1,25 +1,18 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.functional - +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.OpenCL.ParallelismLevel -import shine.macros.Primitive.expPrimitive - -@expPrimitive -final case class Map(level: ParallelismLevel, - dim: Int) - (val n: Nat, - val dt1: DataType, - val dt2: DataType, - val f: Phrase[ExpType ->: ExpType], - val array: Phrase[ExpType] - ) extends ExpPrimitive { - f :: expT(dt1, read) ->: expT(dt2, write) - array :: expT(n `.` dt1, read) - override val t: ExpType = expT(n `.` dt2, write) - - def unwrap: (Nat, DataType, DataType, Phrase[ExpType ->: ExpType], Phrase[ExpType]) = - (n, dt1, dt2, f, array) +final case class Map(level: shine.OpenCL.ParallelismLevel, dim: Int)(val n: Nat, val dt1: DataType, val dt2: DataType, val f: Phrase[FunType[ExpType, ExpType]], val array: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt1, read), expT(dt2, write)) + array :: expT(ArrayType(n, dt1), read) + } + override val t: ExpType = expT(ArrayType(n, dt2), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): Map = new Map(level, dim)(v.nat(n), v.data(dt1), v.data(dt2), VisitAndRebuild(f, v), VisitAndRebuild(array, v)) + def unwrap: (Nat, DataType, DataType, Phrase[FunType[ExpType, ExpType]], Phrase[ExpType]) = (n, dt1, dt2, f, array) } diff --git a/src/main/scala/shine/cuda/primitives/functional/MapFragment.scala b/src/main/scala/shine/cuda/primitives/functional/MapFragment.scala new file mode 100644 index 000000000..9757cf271 --- /dev/null +++ b/src/main/scala/shine/cuda/primitives/functional/MapFragment.scala @@ -0,0 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package shine.cuda.primitives.functional +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ +import shine.DPIA.Types._ +import shine.DPIA._ +final case class MapFragment(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val f: Phrase[FunType[ExpType, ExpType]], val input: Phrase[ExpType]) extends ExpPrimitive { + { + f :: FunType(expT(dt, read), expT(dt, write)) + input :: expT(FragmentType(rows, columns, layers, dt, frag, layout), read) + } + override val t: ExpType = expT(FragmentType(rows, columns, layers, dt, frag, layout), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): MapFragment = new MapFragment(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(f, v), VisitAndRebuild(input, v)) +} diff --git a/src/main/scala/shine/cuda/primitives/functional/MapFragmentElements.scala b/src/main/scala/shine/cuda/primitives/functional/MapFragmentElements.scala deleted file mode 100644 index 0ce195bec..000000000 --- a/src/main/scala/shine/cuda/primitives/functional/MapFragmentElements.scala +++ /dev/null @@ -1,26 +0,0 @@ -package shine.cuda.primitives.functional - -import shine.DPIA.->: -import shine.DPIA.Phrases._ -import shine.DPIA.Types._ -import shine.macros.Primitive.expPrimitive - -/** - * Returns a fragment with the values of an applied function to elements of a fragment.
- * This primitive needs to be executed by a full warp! - * @param fragType type of the fragment - * @param fragment fragment of type `fragType` on whose elements the function should be applied - * @param fun function which takes an element of type `fragType.dataType` and - * returns an element of type `fragType.dataType` - */ -@expPrimitive -case class MapFragmentElements(fragType: FragmentType, - fragment: Phrase[ExpType], - fun: Phrase[ExpType ->: ExpType], - ) extends ExpPrimitive { - - fragment :: ExpType(fragType, read) - fun :: ExpType(fragType.dataType, read) ->: ExpType(fragType.dataType, write) - - override val t: ExpType = ExpType(fragType, write) -} diff --git a/src/main/scala/shine/cuda/primitives/functional/TensorMatMultAdd.scala b/src/main/scala/shine/cuda/primitives/functional/TensorMatMultAdd.scala index 88a61459e..c35cf09a7 100644 --- a/src/main/scala/shine/cuda/primitives/functional/TensorMatMultAdd.scala +++ b/src/main/scala/shine/cuda/primitives/functional/TensorMatMultAdd.scala @@ -1,41 +1,18 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.functional - -import shine.DPIA.Nat +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.macros.Primitive.expPrimitive - -/** - * Executes an MMA instruction using (multiple) Tensor Cores ({@link WmmaMMA}).
- * Returns a `Accumulator`-fragment as result of: aMatrix * bMatrix + cMatrix - * (inplace operations using the same variable as `Acceptor` in `acceptorTranslation` and - * as `cMatrix` are possible).
- * This primitive needs to be executed by a full warp! - * @param m number of rows of the `aMatrix` - * @param n number of columns of the `bMatrix` and the `cMatrix` - * @param k number of columns of the `aMatrix` and number of rows of the `bMatrix` - * @param layoutA layout of the `aMatrix` - * @param layoutB layout of the `bMatrix` - * @param dataType datatype of elements of `aMatrix` and `bMatrix` ({@link FragmentType#datatype}) - * @param dataTypeAcc datatype of elements of `cMatrix` and the resultMatrix ({@link FragmentType#datatype}) - * @param aMatrix first factor of type fragment - * @param bMatrix second factor of type fragment - * @param cMatrix accumulator of type fragment which is added to the product of `aMatrix` * `bMatrix` - */ -@expPrimitive -final case class TensorMatMultAdd(m: Nat, - n: Nat, - k: Nat, - layoutA: MatrixLayout, - layoutB: MatrixLayout, - dataType: DataType, - dataTypeAcc: DataType, - aMatrix: Phrase[ExpType], - bMatrix: Phrase[ExpType], - cMatrix: Phrase[ExpType]) extends ExpPrimitive { - aMatrix :: ExpType(FragmentType(m, k, n, dataType, FragmentKind.AMatrix, layoutA), read) - bMatrix :: ExpType(FragmentType(k, n, m, dataType, FragmentKind.BMatrix, layoutB), read) - cMatrix :: ExpType(FragmentType(m, n, k, dataTypeAcc), read) - - override val t: ExpType = ExpType(FragmentType(m, n, k, dataTypeAcc), write) +import shine.DPIA._ +final case class TensorMatMultAdd(val m: Nat, val n: Nat, val k: Nat, val layoutA: MatrixLayout, val layoutB: MatrixLayout, val dt1: DataType, val dt2: DataType, val aMatrix: Phrase[ExpType], val bMatrix: Phrase[ExpType], val cMatrix: Phrase[ExpType]) extends ExpPrimitive { + { + aMatrix :: expT(FragmentType(m, k, n, dt1, FragmentKind.AMatrix, layoutA), read) + bMatrix :: expT(FragmentType(k, n, m, dt1, FragmentKind.BMatrix, layoutB), read) + cMatrix :: expT(FragmentType(m, n, k, dt2, FragmentKind.Accumulator, MatrixLayout.None), read) + } + override val t: ExpType = expT(FragmentType(m, n, k, dt2, FragmentKind.Accumulator, MatrixLayout.None), write) + override def visitAndRebuild(v: VisitAndRebuild.Visitor): TensorMatMultAdd = new TensorMatMultAdd(v.nat(m), v.nat(n), v.nat(k), layoutA, layoutB, v.data(dt1), v.data(dt2), VisitAndRebuild(aMatrix, v), VisitAndRebuild(bMatrix, v), VisitAndRebuild(cMatrix, v)) } diff --git a/src/main/scala/shine/cuda/primitives/functional/primitives.dpia b/src/main/scala/shine/cuda/primitives/functional/primitives.dpia new file mode 100644 index 000000000..c89d64a7c --- /dev/null +++ b/src/main/scala/shine/cuda/primitives/functional/primitives.dpia @@ -0,0 +1,96 @@ +/** + * Returns a fragment from a matrix tile which must resides in shared or global memory ({@link WmmaLoad}).
+ * This primitive needs to be executed by a full warp! + * @param rows number of rows of the fragment ({@link FragmentType#rows}) + * @param columns number of columns of the fragment ({@link FragmentType#columns}) + * @param layers third dimension which is used in the MMA operation ({@link FragmentType#d3}) + * @param dt dataType of the elements in the fragment ({@link FragmentType#datatype}) + * @param frag kind of the fragment ({@link FragmentType#fragmentKind}) + * @param layout layout of the fragment ({@link FragmentType#layout}). + * The layout will be infered in Codegeneration. Hence a `MatrixLayoutIdentifier` can be + * used as layout. + */ +def asFragment(rows: nat, columns: nat, layers: nat, dt: data, + frag: fragment, layout: matrixLayout, + input: exp[rows.columns.dt, read]): exp[fragment[rows, columns, layers, dt, frag, layout], write] + +/** + * Returns a matrix tile with the elements from the `Accumulator`-fragment. The matrix tile must resides + * in shared or global memory ({@link WmmaStore}). + * This primitive needs to be executed by a full warp! + * @param rows number of rows of the fragment ({@link FragmentType#rows}) + * @param columns number of columns of the fragment ({@link FragmentType#columns}) + * @param layers third dimension which is used in the MMA operation ({@link FragmentType#d3}) + * @param dt dataType of the elements in the fragment ({@link FragmentType#datatype}) + * @param input fragment from which the elements should be stored + */ +def asMatrix(rows: nat, columns: nat, layers: nat, dt: data, + input: exp[fragment[rows, columns, layers, dt, fragment.ACC, matrixLayout.NONE], read] + ): exp[rows.columns.dt, write] + +/** + * Returns a fragment in which all elements have a specific value ({@link WmmaFill}).
+ * This primitive needs to be executed by a full warp! + * @param rows number of rows of the fragment ({@link FragmentType#rows}) + * @param columns number of columns of the fragment ({@link FragmentType#columns}) + * @param layers third dimension which is used in the MMA operation ({@link FragmentType#d3}) + * @param dt dataType of the elements in the fragment ({@link FragmentType#datatype}) + * @param frag kind of the fragment ({@link FragmentType#fragmentKind}) + * @param layout layout of the fragment ({@link FragmentType#layout}) + * @param fill new value of all elements in the fragment (type of fill: `dataType`) + */ +def generateFragment(rows: nat, columns: nat, layers: nat, dt: data, + frag: fragment, layout: matrixLayout, + fill: exp[dt, read] + ): exp[fragment[rows, columns, layers, dt, frag, layout], write] + +/** + * Returns a copy in shared memory of data in global memory ({@link GlobalToSharedAcc}). + * @param dt datatype of data which should be copied + * @param input data in global memory which should be copied to shared memory + */ +def globalToShared(dt: data, input: exp[dt, write]): exp[dt, read] + +def map{level: shine.OpenCL.ParallelismLevel, dim: Int} + (n: nat, dt1: data, dt2: data, + f: exp[dt1, read] -> exp[dt2, write], + array: exp[n.dt1, read]): exp[n.dt2, write] + +/** + * Returns a fragment with the values of an applied function to elements of a fragment.
+ * This primitive needs to be executed by a full warp! + * @param fragType type of the fragment + * @param fragment fragment of type `fragType` on whose elements the function should be applied + * @param fun function which takes an element of type `fragType.dataType` and + * returns an element of type `fragType.dataType` + */ +def mapFragment(rows: nat, columns: nat, layers: nat, dt: data, + frag: fragment, layout: matrixLayout, + f: exp[dt, read] -> exp[dt, write], + input: exp[fragment[rows, columns, layers, dt, frag, layout], read] + ): exp[fragment[rows, columns, layers, dt, frag, layout], write] + +/** + * Executes an MMA instruction using (multiple) Tensor Cores ({@link WmmaMMA}).
+ * Returns a `Accumulator`-fragment as result of: aMatrix * bMatrix + cMatrix + * (inplace operations using the same variable as `Acceptor` in `acceptorTranslation` and + * as `cMatrix` are possible).
+ * This primitive needs to be executed by a full warp! + * @param m number of rows of the `aMatrix` + * @param n number of columns of the `bMatrix` and the `cMatrix` + * @param k number of columns of the `aMatrix` and number of rows of the `bMatrix` + * @param layoutA layout of the `aMatrix` + * @param layoutB layout of the `bMatrix` + * @param dataType datatype of elements of `aMatrix` and `bMatrix` ({@link FragmentType#datatype}) + * @param dataTypeAcc datatype of elements of `cMatrix` and the resultMatrix ({@link FragmentType#datatype}) + * @param aMatrix first factor of type fragment + * @param bMatrix second factor of type fragment + * @param cMatrix accumulator of type fragment which is added to the product of `aMatrix` * `bMatrix` + */ +def tensorMatMultAdd(m: nat, n: nat, k: nat, + layoutA: matrixLayout, layoutB: matrixLayout, + dt1: data, dt2: data, + aMatrix: exp[fragment[m, k, n, dt1, fragment.A, layoutA], read], + bMatrix: exp[fragment[k, n, m, dt1, fragment.B, layoutB], read], + cMatrix: exp[fragment[m, n, k, dt2, fragment.ACC, matrixLayout.NONE], read] + ): exp[fragment[m, n, k, dt2, fragment.ACC, matrixLayout.NONE], write] diff --git a/src/main/scala/shine/cuda/primitives/imperative/ForFragment.scala b/src/main/scala/shine/cuda/primitives/imperative/ForFragment.scala new file mode 100644 index 000000000..283476433 --- /dev/null +++ b/src/main/scala/shine/cuda/primitives/imperative/ForFragment.scala @@ -0,0 +1,18 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +package shine.cuda.primitives.imperative +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ +import shine.DPIA.Types._ +import shine.DPIA._ +final case class ForFragment(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val in: Phrase[ExpType], val out: Phrase[AccType], val fun: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive { + { + in :: expT(FragmentType(rows, columns, layers, dt, frag, layout), read) + out :: accT(FragmentType(rows, columns, layers, dt, frag, layout)) + fun :: FunType(expT(dt, read), FunType(accT(dt), comm)) + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ForFragment = new ForFragment(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(in, v), VisitAndRebuild(out, v), VisitAndRebuild(fun, v)) +} diff --git a/src/main/scala/shine/cuda/primitives/imperative/ForFragmentElements.scala b/src/main/scala/shine/cuda/primitives/imperative/ForFragmentElements.scala deleted file mode 100644 index 5a6a1cfbf..000000000 --- a/src/main/scala/shine/cuda/primitives/imperative/ForFragmentElements.scala +++ /dev/null @@ -1,28 +0,0 @@ -package shine.cuda.primitives.imperative - -import shine.DPIA.->: -import shine.DPIA.Phrases._ -import shine.DPIA.Types._ -import shine.macros.Primitive.comPrimitive - -/** - * Applies a function on every element of a fragment.
- * This primitive needs to be executed by a full warp! - * @param fragType type of the fragment - * @param in fragment of type `fragType` whose elements should be iterated - * @param out fragment-Acceptor of type `fragType` which is used to store the result - * @param fun function which takes an element of type `fragType.dataType` from `in` and - * an element-Acceptor of type `fragType.dataType` from `out` and returns a command - */ -@comPrimitive -final case class ForFragmentElements(fragType: FragmentType, - in: Phrase[ExpType], - out: Phrase[AccType], - fun: Phrase[ExpType ->: AccType ->: CommType], - ) extends CommandPrimitive { - in :: ExpType(fragType, read) - out :: AccType(fragType) - fun :: FunType(ExpType(fragType.dataType, read), FunType(AccType(fragType.dataType), comm)) - - override def prettyPrint: String = s"ForFragmentElements(${PrettyPhrasePrinter(in)}, ${PrettyPhrasePrinter(out)}, ${PrettyPhrasePrinter(fun)})" -} diff --git a/src/main/scala/shine/cuda/primitives/imperative/GlobalToSharedAcc.scala b/src/main/scala/shine/cuda/primitives/imperative/GlobalToSharedAcc.scala index a0ba8cd68..dba030d80 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/GlobalToSharedAcc.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/GlobalToSharedAcc.scala @@ -1,30 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.imperative - -import shine.DPIA.DSL.{`new` => _} +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.macros.Primitive.accPrimitive - -/** - * Copy a element from global memory to shared memory without using registers - * (faster than normal copy operations using the =-operator).
- * This requires CUDA 11 and compute capability >= 8. For devices with ompute capability - * smaller than 8 this will be compiled by the CUDA-Compiler to the same as normal copy - * operations using the =-operator). - * @param dt datatype of element which should be copied - * @param pipe pipeline which should be used to execute this copy instruction - * @param outputShared output-Acceptor in shared memory of type `dt` - */ -@accPrimitive -final case class GlobalToSharedAcc(dt: DataType, - pipe: Phrase[ExpType], - outputShared: Phrase[AccType] - ) extends AccPrimitive { - pipe :: expT(pipeline, read) - outputShared :: accT(dt) +final case class GlobalToSharedAcc(val dt: DataType, val pipe: Phrase[ExpType], val outputShared: Phrase[AccType]) extends AccPrimitive { + { + pipe :: expT(OpaqueType("pipeline"), read) + outputShared :: accT(dt) + } override val t: AccType = accT(dt) - - override def prettyPrint: String = - s"(GlobalToSharedAcc $pipe, ${PrettyPhrasePrinter(outputShared)})" + override def visitAndRebuild(v: VisitAndRebuild.Visitor): GlobalToSharedAcc = new GlobalToSharedAcc(v.data(dt), VisitAndRebuild(pipe, v), VisitAndRebuild(outputShared, v)) } diff --git a/src/main/scala/shine/cuda/primitives/imperative/ParFor.scala b/src/main/scala/shine/cuda/primitives/imperative/ParFor.scala index 8d05be355..8a471f68d 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/ParFor.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/ParFor.scala @@ -1,60 +1,18 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.imperative - -import shine.DPIA.Phrases.{CommandPrimitive, _} +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ import shine.DPIA._ -import shine.OpenCL._ -import shine.cuda.{blockDim, blockId, globalDim, globalId, gridDim, laneId, threadId, warpDim, warpId, warpSize} -import shine.macros.Primitive.comPrimitive - -@comPrimitive -final case class ParFor(level: ParallelismLevel, - dim: Int, - unroll: Boolean) - (val n: Nat, - val dt: DataType, - val out: Phrase[AccType], - val loopBody: Phrase[ExpType ->: AccType ->: CommType], - val init: Nat = ParFor.initInit(level, dim), - val step: Nat = ParFor.initStep(level, dim) - ) extends CommandPrimitive { - val name: String = ParFor.initName(level) - - out :: accT(n`.`dt) - loopBody :: expT(idx(n), read) ->: accT(dt) ->: comm - - lazy val unwrapBody: (Identifier[ExpType], Identifier[AccType], Phrase[CommType]) = loopBody match { - case Lambda(i, Lambda(o, body)) => (i, o, body) - case _ => throw new Exception("This should not happen") - } -} - -object ParFor { - def initInit(level: ParallelismLevel, dim: Int): Nat = level match { - case Global => globalId(dim) - case Local => threadId(dim) - case WorkGroup => blockId(dim) - case Warp => warpId(dim) - case Lane => laneId(dim) - case Sequential => throw new Exception("This should not happen") - } - - def initStep(level: ParallelismLevel, dim: Int): Nat = level match { - case Global => globalDim(dim) - case Local => blockDim(dim) - case WorkGroup => gridDim(dim) - case Warp => warpDim(dim) - case Lane => warpSize - case Sequential => throw new Exception("This should not happen") - } - - def initName(level: ParallelismLevel): String = level match { - case Global => freshName("gl_id_") - case Local => freshName("tid_") - case WorkGroup => freshName("block_id_") - case Warp => freshName("warp_id_") - case Lane => freshName(s"lane_id_") - case Sequential => throw new Exception("This should not happen") +final case class ParFor(level: shine.OpenCL.ParallelismLevel, dim: Int, unroll: Boolean, prefix: String)(val init: Nat, val n: Nat, val step: Nat, val dt: DataType, val out: Phrase[AccType], val body: Phrase[FunType[ExpType, FunType[AccType, CommType]]]) extends CommandPrimitive { + { + out :: accT(ArrayType(n, dt)) + body :: FunType(expT(IndexType(n), read), FunType(accT(dt), comm)) } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): ParFor = new ParFor(level, dim, unroll, prefix)(v.nat(init), v.nat(n), v.nat(step), v.data(dt), VisitAndRebuild(out, v), VisitAndRebuild(body, v)) + def unwrap: (Nat, Nat, Nat, DataType, Phrase[AccType], Phrase[FunType[ExpType, FunType[AccType, CommType]]]) = (init, n, step, dt, out, body) } diff --git a/src/main/scala/shine/cuda/primitives/imperative/SyncPipeline.scala b/src/main/scala/shine/cuda/primitives/imperative/SyncPipeline.scala index d22b59bf2..c1f06a010 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/SyncPipeline.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/SyncPipeline.scala @@ -1,16 +1,16 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.imperative - -import shine.DPIA.Phrases.{CommandPrimitive, Phrase} -import shine.DPIA.Types.{ExpType, pipeline, read} -import shine.DPIA.expT -import shine.macros.Primitive.comPrimitive - -/** - * Execute and wait for all asynchronous memory transactions (used by the {@link GlobalToSharedAcc}) - */ -@comPrimitive -final case class SyncPipeline(pipe: Phrase[ExpType]) extends CommandPrimitive { - pipe :: expT(pipeline, read) - - override def prettyPrint: String = s"$pipe.commit_and_wait()" +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ +import shine.DPIA.Types._ +import shine.DPIA._ +final case class SyncPipeline(val pipe: Phrase[ExpType]) extends CommandPrimitive { + { + pipe :: expT(OpaqueType("pipeline"), read) + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): SyncPipeline = new SyncPipeline(VisitAndRebuild(pipe, v)) } diff --git a/src/main/scala/shine/cuda/primitives/imperative/SyncThreads.scala b/src/main/scala/shine/cuda/primitives/imperative/SyncThreads.scala index 523217d10..d7d7acafc 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/SyncThreads.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/SyncThreads.scala @@ -1,12 +1,14 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.imperative - -import shine.DPIA.Phrases.CommandPrimitive -import shine.macros.Primitive.comPrimitive - -/** - * Synchronize all thread in a single thread block. - */ -@comPrimitive +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ +import shine.DPIA.Types._ +import shine.DPIA._ final case class SyncThreads() extends CommandPrimitive { - override def prettyPrint: String = "__syncthreads()" + {} + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): SyncThreads = new SyncThreads() } diff --git a/src/main/scala/shine/cuda/primitives/imperative/SyncWarp.scala b/src/main/scala/shine/cuda/primitives/imperative/SyncWarp.scala index b058c1a8b..31718237a 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/SyncWarp.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/SyncWarp.scala @@ -1,12 +1,14 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.imperative - -import shine.DPIA.Phrases.CommandPrimitive -import shine.macros.Primitive.comPrimitive - -/** - * Synchronize all elements in a single warp. - */ -@comPrimitive +import arithexpr.arithmetic._ +import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ +import shine.DPIA.Types._ +import shine.DPIA._ final case class SyncWarp() extends CommandPrimitive { - override def prettyPrint: String = "__syncwarp()" + {} + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): SyncWarp = new SyncWarp() } diff --git a/src/main/scala/shine/cuda/primitives/imperative/WmmaFill.scala b/src/main/scala/shine/cuda/primitives/imperative/WmmaFill.scala index cfbea2664..cf250936a 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/WmmaFill.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/WmmaFill.scala @@ -1,35 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.imperative - -import shine.DPIA.Nat +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.macros.Primitive.comPrimitive - -/** - * Fills a fragment with a specific value.
- * This primitive needs to be executed by a full warp! - * @param rows number of rows of the fragment ({@link FragmentType#rows}) - * @param columns number of columns of the fragment ({@link FragmentType#columns}) - * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3}) - * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype}) - * @param fill new value of all elements in the fragment (type of fill: `dataType`) - * @param fragmentKind kind of the fragment ({@link FragmentType#fragmentKind}) - * @param layout layout of the fragment ({@link FragmentType#layout}) - * @param fragment fragment-Acceptor whose elements should be changed to `fill` - */ -@comPrimitive -case class WmmaFill(rows: Nat, - columns: Nat, - d3: Nat, - dataType: DataType, - fill: Phrase[ExpType], - fragmentKind: FragmentKind, - layout: MatrixLayout, - fragment: Phrase[AccType] - ) extends CommandPrimitive { - fill :: ExpType(dataType, read) - fragment :: AccType(FragmentType(rows, columns, d3, dataType, fragmentKind, layout)) - - override def prettyPrint: String = - s"WmmaFill(${PrettyPhrasePrinter(fill)}, ${PrettyPhrasePrinter(fragment)})" +import shine.DPIA._ +final case class WmmaFill(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val fill: Phrase[ExpType], val target: Phrase[AccType]) extends CommandPrimitive { + { + fill :: expT(dt, read) + target :: accT(FragmentType(rows, columns, layers, dt, frag, layout)) + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): WmmaFill = new WmmaFill(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(fill, v), VisitAndRebuild(target, v)) } diff --git a/src/main/scala/shine/cuda/primitives/imperative/WmmaLoad.scala b/src/main/scala/shine/cuda/primitives/imperative/WmmaLoad.scala index 00cb5479a..61b33bcba 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/WmmaLoad.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/WmmaLoad.scala @@ -1,36 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.imperative - -import shine.DPIA.Nat +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.macros.Primitive.comPrimitive - -/** - * Loads a tile of a matrix into a fragment.
- * This primitive needs to be executed by a full warp! - * @param rows number of rows of the fragment ({@link FragmentType#rows}) - * @param columns number of columns of the fragment ({@link FragmentType#columns}) - * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3}) - * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype}) - * @param fragmentKind kind of the fragment ({@link FragmentType#fragmentKind}) - * @param layout layout of the fragment ({@link FragmentType#layout}) - * @param matrixTile matrix tile which should be loaded into the fragment - * @param fragment fragment-Acceptor into which the `matrixTile` should be loaded - */ -@comPrimitive -final case class WmmaLoad(rows: Nat, - columns: Nat, - d3: Nat, - dataType: DataType, - fragmentKind: FragmentKind, - layout: MatrixLayout, - matrixTile: Phrase[ExpType], - fragment: Phrase[AccType] - ) extends CommandPrimitive { - fragment :: ExpType(FragmentType(rows, columns, d3, dataType, fragmentKind, layout), write) - matrixTile :: ExpType(ArrayType(rows, ArrayType(columns, dataType)), read) - - override def prettyPrint: String = { - s"wmmaLoad(${PrettyPhrasePrinter(matrixTile)}, ${PrettyPhrasePrinter(fragment)})" +import shine.DPIA._ +final case class WmmaLoad(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val frag: FragmentKind, val layout: MatrixLayout, val matrixTile: Phrase[ExpType], val target: Phrase[AccType]) extends CommandPrimitive { + { + matrixTile :: expT(ArrayType(rows, ArrayType(columns, dt)), read) + target :: accT(FragmentType(rows, columns, layers, dt, frag, layout)) } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): WmmaLoad = new WmmaLoad(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), frag, layout, VisitAndRebuild(matrixTile, v), VisitAndRebuild(target, v)) } diff --git a/src/main/scala/shine/cuda/primitives/imperative/WmmaMMA.scala b/src/main/scala/shine/cuda/primitives/imperative/WmmaMMA.scala index 4f4bb8c9d..793a562f5 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/WmmaMMA.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/WmmaMMA.scala @@ -1,46 +1,19 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.imperative - -import shine.DPIA.Nat +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.macros.Primitive.comPrimitive - -/** - * Executes an MMA instruction using (multiple) Tensor Cores.
- * Calculates: aMatrix * bMatrix + cMatrix
- * This primitive needs to be executed by a full warp! - * @param m number of rows of the `aMatrix` - * @param n number of columns of the `bMatrix` and the `cMatrix` - * @param k number of columns of the `aMatrix` and number of rows of the `bMatrix` - * @param layoutA layout of the `aMatrix` - * @param layoutB layout of the `bMatrix` - * @param dataType datatype of elements of `aMatrix` and `bMatrix` ({@link FragmentType#datatype}) - * @param dataTypeAcc datatype of elements of `cMatrix` and the resultMatrix ({@link FragmentType#datatype}) - * @param aMatrix first factor of type fragment - * @param bMatrix second factor of type fragment - * @param cMatrix accumulator of type fragment which is added to the product of `aMatrix` * `bMatrix` - * @param resultMatrix fragment-Accumulator in which the result is stored - * (inplace operations using the `cMatrix` as resultMatrix is possible) - */ -@comPrimitive -case class WmmaMMA(m: Nat, - n: Nat, - k: Nat, - layoutA : MatrixLayout, - layoutB : MatrixLayout, - dataType: DataType, - dataTypeAcc: DataType, - aMatrix: Phrase[ExpType], - bMatrix: Phrase[ExpType], - cMatrix: Phrase[ExpType], - resultMatrix: Phrase[AccType] - ) extends CommandPrimitive { - aMatrix :: ExpType(FragmentType(m, k, n, dataType, FragmentKind.AMatrix, layoutA), read) - bMatrix :: ExpType(FragmentType(k, n, m, dataType, FragmentKind.BMatrix, layoutB), read) - cMatrix :: ExpType(FragmentType(m, n, k, dataTypeAcc), read) - resultMatrix :: AccType(FragmentType(m, n, k, dataTypeAcc)) - - override def prettyPrint: String = - s"WmmaMMA(${PrettyPhrasePrinter(aMatrix)}, ${PrettyPhrasePrinter(bMatrix)}," + - s"${PrettyPhrasePrinter(cMatrix)}, ${PrettyPhrasePrinter(resultMatrix)})" +import shine.DPIA._ +final case class WmmaMMA(val m: Nat, val n: Nat, val k: Nat, val layoutA: MatrixLayout, val layoutB: MatrixLayout, val dt1: DataType, val dt2: DataType, val aMatrix: Phrase[ExpType], val bMatrix: Phrase[ExpType], val cMatrix: Phrase[ExpType], val resultMatrix: Phrase[AccType]) extends CommandPrimitive { + { + aMatrix :: expT(FragmentType(m, k, n, dt1, FragmentKind.AMatrix, layoutA), read) + bMatrix :: expT(FragmentType(k, n, m, dt1, FragmentKind.BMatrix, layoutB), read) + cMatrix :: expT(FragmentType(m, n, k, dt2, FragmentKind.Accumulator, MatrixLayout.None), read) + resultMatrix :: accT(FragmentType(m, n, k, dt2, FragmentKind.Accumulator, MatrixLayout.None)) + } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): WmmaMMA = new WmmaMMA(v.nat(m), v.nat(n), v.nat(k), layoutA, layoutB, v.data(dt1), v.data(dt2), VisitAndRebuild(aMatrix, v), VisitAndRebuild(bMatrix, v), VisitAndRebuild(cMatrix, v), VisitAndRebuild(resultMatrix, v)) } diff --git a/src/main/scala/shine/cuda/primitives/imperative/WmmaStore.scala b/src/main/scala/shine/cuda/primitives/imperative/WmmaStore.scala index a463fc53b..fb5e7ffd6 100644 --- a/src/main/scala/shine/cuda/primitives/imperative/WmmaStore.scala +++ b/src/main/scala/shine/cuda/primitives/imperative/WmmaStore.scala @@ -1,34 +1,17 @@ +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // +// This file is automatically generated and should not be changed manually // +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // package shine.cuda.primitives.imperative - -import shine.DPIA.Nat +import arithexpr.arithmetic._ import shine.DPIA.Phrases._ +import shine.DPIA.Types.DataType._ import shine.DPIA.Types._ -import shine.macros.Primitive.comPrimitive - -/** - * Stores the elements from a fragment with fragmentKind `Accumulator` into a - * matrix tile which resides in shared or global memory.
- * This primitive needs to be executed by a full warp! - * @param rows number of rows of the fragment ({@link FragmentType#rows}) - * @param columns number of columns of the fragment ({@link FragmentType#columns}) - * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3}) - * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype}) - * @param fragment fragment from which the elements should be stored - * @param matrixTile matrixTile-Acceptor in which the elements should be stored - */ -@comPrimitive -final case class WmmaStore(rows: Nat, - columns: Nat, - d3: Nat, - dataType: DataType, - fragment: Phrase[ExpType], - matrixTile: Phrase[AccType] - ) extends CommandPrimitive { - fragment :: ExpType(FragmentType(rows, columns, d3, dataType), read) - matrixTile :: AccType(ArrayType(rows, ArrayType(columns, dataType))) - - override def prettyPrint: String = { - s"wmmaStore(${PrettyPhrasePrinter(fragment)}, ${PrettyPhrasePrinter(matrixTile)})" +import shine.DPIA._ +final case class WmmaStore(val rows: Nat, val columns: Nat, val layers: Nat, val dt: DataType, val value: Phrase[ExpType], val matrixTile: Phrase[AccType]) extends CommandPrimitive { + { + value :: expT(FragmentType(rows, columns, layers, dt, FragmentKind.Accumulator, MatrixLayout.None), read) + matrixTile :: accT(ArrayType(rows, ArrayType(columns, dt))) } + override val t: CommType = comm + override def visitAndRebuild(v: VisitAndRebuild.Visitor): WmmaStore = new WmmaStore(v.nat(rows), v.nat(columns), v.nat(layers), v.data(dt), VisitAndRebuild(value, v), VisitAndRebuild(matrixTile, v)) } - diff --git a/src/main/scala/shine/cuda/primitives/imperative/primitives.dpia b/src/main/scala/shine/cuda/primitives/imperative/primitives.dpia new file mode 100644 index 000000000..4922f06ef --- /dev/null +++ b/src/main/scala/shine/cuda/primitives/imperative/primitives.dpia @@ -0,0 +1,123 @@ +/** + * Applies a function on every element of a fragment.
+ * This primitive needs to be executed by a full warp! + * @param fragType type of the fragment + * @param in fragment of type `fragType` whose elements should be iterated + * @param out fragment-Acceptor of type `fragType` which is used to store the result + * @param fun function which takes an element of type `fragType.dataType` from `in` and + * an element-Acceptor of type `fragType.dataType` from `out` and returns a command + */ +def forFragment(rows: nat, columns: nat, layers: nat, dt: data, + frag: fragment, layout: matrixLayout, + in: exp[fragment[rows, columns, layers, dt, frag, layout], read], + out: acc[fragment[rows, columns, layers, dt, frag, layout]], + fun: exp[dt, read] -> acc[dt] -> comm): comm + +/** + * Copy a element from global memory to shared memory without using registers + * (faster than normal copy operations using the =-operator).
+ * This requires CUDA 11 and compute capability >= 8. For devices with ompute capability + * smaller than 8 this will be compiled by the CUDA-Compiler to the same as normal copy + * operations using the =-operator). + * @param dt datatype of element which should be copied + * @param pipe pipeline which should be used to execute this copy instruction + * @param outputShared output-Acceptor in shared memory of type `dt` + */ +def globalToSharedAcc(dt: data, pipe: exp["pipeline", read], outputShared: acc[dt]): acc[dt] + +def parFor{level: shine.OpenCL.ParallelismLevel, + dim: Int, + unroll: Boolean, + prefix: String} + (init: nat, n: nat, step: nat, + dt: data, + out: acc[n.dt], + body: exp[idx[n], read] -> acc[dt] -> comm): comm + +/** + * Execute and wait for all asynchronous memory transactions (used by the {@link GlobalToSharedAcc}) + */ +def syncPipeline(pipe: exp["pipeline", read]): comm + +/** + * Synchronize all thread in a single thread block. + */ +def syncThreads(): comm + +/** + * Synchronize all elements in a single warp. + */ +def syncWarp(): comm + +/** + * Fills a fragment with a specific value.
+ * This primitive needs to be executed by a full warp! + * @param rows number of rows of the fragment ({@link FragmentType#rows}) + * @param columns number of columns of the fragment ({@link FragmentType#columns}) + * @param layers third dimension which is used in the MMA operation ({@link FragmentType#d3}) + * @param dt dataType of the elements in the fragment ({@link FragmentType#datatype}) + * @param fill new value of all elements in the fragment (type of fill: `dataType`) + * @param frag kind of the fragment ({@link FragmentType#fragmentKind}) + * @param layout layout of the fragment ({@link FragmentType#layout}) + * @param target fragment-Acceptor whose elements should be changed to `fill` + */ +def wmmaFill(rows: nat, columns: nat, layers: nat, dt: data, frag: fragment, layout: matrixLayout, + fill: exp[dt, read], + target: acc[fragment[rows, columns, layers, dt, frag, layout]]): comm + +/** + * Loads a tile of a matrix into a fragment.
+ * This primitive needs to be executed by a full warp! + * @param rows number of rows of the fragment ({@link FragmentType#rows}) + * @param columns number of columns of the fragment ({@link FragmentType#columns}) + * @param d3 third dimension which is used in the MMA operation ({@link FragmentType#d3}) + * @param dataType dataType of the elements in the fragment ({@link FragmentType#datatype}) + * @param fragmentKind kind of the fragment ({@link FragmentType#fragmentKind}) + * @param layout layout of the fragment ({@link FragmentType#layout}) + * @param matrixTile matrix tile which should be loaded into the fragment + * @param fragment fragment-Acceptor into which the `matrixTile` should be loaded + */ +def wmmaLoad(rows: nat, columns: nat, layers: nat, dt: data, frag: fragment, layout: matrixLayout, + matrixTile: exp[rows.columns.dt, read], + target: acc[fragment[rows, columns, layers, dt, frag, layout]]): comm + +/** + * Executes an MMA instruction using (multiple) Tensor Cores.
+ * Calculates: aMatrix * bMatrix + cMatrix
+ * This primitive needs to be executed by a full warp! + * @param m number of rows of the `aMatrix` + * @param n number of columns of the `bMatrix` and the `cMatrix` + * @param k number of columns of the `aMatrix` and number of rows of the `bMatrix` + * @param layoutA layout of the `aMatrix` + * @param layoutB layout of the `bMatrix` + * @param dt1 datatype of elements of `aMatrix` and `bMatrix` ({@link FragmentType#datatype}) + * @param dt2 datatype of elements of `cMatrix` and the resultMatrix ({@link FragmentType#datatype}) + * @param aMatrix first factor of type fragment + * @param bMatrix second factor of type fragment + * @param cMatrix accumulator of type fragment which is added to the product of `aMatrix` * `bMatrix` + * @param resultMatrix fragment-Accumulator in which the result is stored + * (inplace operations using the `cMatrix` as resultMatrix is possible) + */ +def wmmaMMA(m: nat, n: nat, k: nat, + layoutA: matrixLayout, + layoutB: matrixLayout, + dt1: data, dt2: data, + aMatrix: exp[fragment[m, k, n, dt1, fragment.A, layoutA], read], + bMatrix: exp[fragment[k, n, m, dt1, fragment.B, layoutB], read], + cMatrix: exp[fragment[m, n, k, dt2, fragment.ACC, matrixLayout.NONE], read], + resultMatrix: acc[fragment[m, n, k, dt2, fragment.ACC, matrixLayout.NONE]]): comm + +/** + * Stores the elements from a fragment with fragmentKind `Accumulator` into a + * matrix tile which resides in shared or global memory.
+ * This primitive needs to be executed by a full warp! + * @param rows number of rows of the fragment ({@link FragmentType#rows}) + * @param columns number of columns of the fragment ({@link FragmentType#columns}) + * @param layers third dimension which is used in the MMA operation ({@link FragmentType#d3}) + * @param dt dataType of the elements in the fragment ({@link FragmentType#datatype}) + * @param value fragment from which the elements should be stored + * @param matrixTile matrixTile-Acceptor in which the elements should be stored + */ +def wmmaStore(rows: nat, columns: nat, layers: nat, dt: data, + value: exp[fragment[rows, columns, layers, dt, fragment.ACC, matrixLayout.NONE], read], + matrixTile: acc[rows.columns.dt]): comm diff --git a/src/main/scala/shine/cuda/primitives/intermediate/MapI.scala b/src/main/scala/shine/cuda/primitives/intermediate/MapI.scala index c2a1f9401..32cdbea62 100644 --- a/src/main/scala/shine/cuda/primitives/intermediate/MapI.scala +++ b/src/main/scala/shine/cuda/primitives/intermediate/MapI.scala @@ -14,7 +14,7 @@ final case class MapI(level: ParallelismLevel, dim: Int) { in: Phrase[ExpType], out: Phrase[AccType]): Phrase[CommType] = { val imperativ = comment(s"map${level.toString}") `;` - ParFor(level, dim, unroll = false)(n, dt2, out, + shine.cuda.DSL.parFor(level, dim, unroll = false)(n, dt2, out, λ(expT(idx(n), read))(i => λ(accT(dt2))(a => f(in `@` i)(a)))) //TODO use other InsertMemoryBarrieres-mechanism level match { diff --git a/src/test/scala/shine/cuda/MMTest.scala b/src/test/scala/shine/cuda/MMTest.scala index 40ef9e0bc..78bb48f74 100644 --- a/src/test/scala/shine/cuda/MMTest.scala +++ b/src/test/scala/shine/cuda/MMTest.scala @@ -37,22 +37,20 @@ class MMTest extends test_util.TestWithCUDA { AsMatrix(mTile, nTile, kTile, f32, //do matrix multiplication - ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, nTile, kTile, f16), - TensorMatMultAdd(mTile, nTile, kTile, Row_Major, Row_Major, f16, f32, + ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None), + TensorMatMultAdd(mTile, nTile, kTile, MatrixLayoutIdentifier("ml"), MatrixLayoutIdentifier("ml"), f16, f32, //load aMatrix into a fragment - ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, kTile, nTile, f16, FragmentKind.AMatrix, Row_Major), - AsFragment(mTile, kTile, nTile, f16, FragmentKind.AMatrix, - matrixATile)), + ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, kTile, nTile, f16, FragmentKind.AMatrix, MatrixLayoutIdentifier("ml")), + shine.cuda.AsFragment(mTile, kTile, nTile, f16, FragmentKind.AMatrix, matrixATile)), //load bMatrix into a fragment - ToMem(shine.cuda.AddressSpace.Private, FragmentType(kTile, nTile, mTile, f16, FragmentKind.BMatrix, Row_Major), - AsFragment(kTile, nTile, mTile, f16, FragmentKind.BMatrix, - matrixBTile)), + ToMem(shine.cuda.AddressSpace.Private, FragmentType(kTile, nTile, mTile, f16, FragmentKind.BMatrix, MatrixLayoutIdentifier("ml")), + shine.cuda.AsFragment(kTile, nTile, mTile, f16, FragmentKind.BMatrix, matrixBTile)), //add fragment with zeros - ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, nTile, kTile, f16), - GenerateFragment(mTile, nTile, kTile, f32, Literal(FloatData(0.0f)), FragmentKind.Accumulator, Row_Major)))))) + ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None), + GenerateFragment(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None, Literal(FloatData(0.0f)))))))) ) val kernel = gen.cuda.kernel("matrixMult").fromPhrase(simpleMatMulTile) @@ -93,7 +91,7 @@ class MMTest extends test_util.TestWithCUDA { val matrixATile = Identifier(freshName("MatrixATile"), ExpType(ArrayType(mTile, ArrayType(k, f16)), read)) val matrixBTile = Identifier(freshName("MatrixBTile"), ExpType(ArrayType(k, ArrayType(nTile, f16)), read)) - val matrixCFrag = Identifier(freshName("MatrixCFrag"), ExpType(FragmentType(mTile, nTile, kTile, f32), read)) + val matrixCFrag = Identifier(freshName("MatrixCFrag"), ExpType(FragmentType(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None), read)) val matrixABTiles = Identifier(freshName("MatrixABTiles"), ExpType(PairType( ArrayType(mTile, ArrayType(kTile, f16)), ArrayType(kTile, ArrayType(nTile, f16))), read)) @@ -108,21 +106,21 @@ class MMTest extends test_util.TestWithCUDA { PairType( ArrayType(mTile, ArrayType(kTile, f16)), ArrayType(kTile, ArrayType(nTile, f16))), - FragmentType(mTile, nTile, kTile, f32), + FragmentType(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None), Lambda[ExpType, FunType[ExpType, ExpType]](matrixCFrag, Lambda[ExpType, ExpType](matrixABTiles, - TensorMatMultAdd(mTile, nTile, kTile, Row_Major, Row_Major, f16, f32, - ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, kTile, nTile, f16, FragmentKind.AMatrix, Row_Major), - AsFragment(mTile, kTile, nTile, f16, FragmentKind.AMatrix, + TensorMatMultAdd(mTile, nTile, kTile, MatrixLayoutIdentifier("ml"), MatrixLayoutIdentifier("ml"), f16, f32, + ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, kTile, nTile, f16, FragmentKind.AMatrix, MatrixLayoutIdentifier("ml")), + shine.cuda.AsFragment(mTile, kTile, nTile, f16, FragmentKind.AMatrix, Transpose(kTile, mTile, f16, read, Fst( ArrayType(kTile, ArrayType(mTile, f16)), ArrayType(kTile, ArrayType(nTile, f16)), matrixABTiles)))), - ToMem(shine.cuda.AddressSpace.Private, FragmentType(kTile, nTile, mTile, f16, FragmentKind.BMatrix, Row_Major), - AsFragment(kTile, nTile, mTile, f16, FragmentKind.BMatrix, + ToMem(shine.cuda.AddressSpace.Private, FragmentType(kTile, nTile, mTile, f16, FragmentKind.BMatrix, MatrixLayoutIdentifier("ml")), + shine.cuda.AsFragment(kTile, nTile, mTile, f16, FragmentKind.BMatrix, Snd( ArrayType(mTile, ArrayType(kTile, f16)), ArrayType(kTile, ArrayType(nTile, f16)), @@ -130,7 +128,7 @@ class MMTest extends test_util.TestWithCUDA { matrixCFrag))), - GenerateFragment(mTile, nTile, kTile, f32, Literal(FloatData(0.0f)), FragmentKind.Accumulator, Row_Major), + GenerateFragment(mTile, nTile, kTile, f32, FragmentKind.Accumulator, Row_Major, Literal(FloatData(0.0f))), Zip(k /^ kTile, ArrayType(mTile, ArrayType(kTile, f16)), @@ -190,7 +188,7 @@ class MMTest extends test_util.TestWithCUDA { val matrixARow = Identifier(freshName("MatrixARow"), ExpType(ArrayType(mTile, ArrayType(k, f16)), read)) val matrixBColumnT = Identifier(freshName("MatrixBColumn"), ExpType(ArrayType(nTile, ArrayType(k, f16)), read)) - val matrixCFrag = Identifier(freshName("MatrixCFrag"), ExpType(FragmentType(mTile, nTile, kTile, f32), read)) + val matrixCFrag = Identifier(freshName("MatrixCFrag"), ExpType(FragmentType(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None), read)) val matrixABTiles = Identifier(freshName("MatrixABTiles"), ExpType(PairType( ArrayType(kTile, ArrayType(mTile, f16)), @@ -205,8 +203,8 @@ class MMTest extends test_util.TestWithCUDA { Lambda[ExpType, FunType[ExpType, ExpType]](matrixA, //And matrixB Lambda[ExpType, ExpType](matrixB, - Transpose(n, m, f32, read, - Join(n /^ nTile, nTile, read, ArrayType(m, f32), + Transpose(n, m, f32, write, + Join(n /^ nTile, nTile, write, ArrayType(m, f32), //Map over nTile-column-block of matrixB Map(Local, 1)(n /^ nTile, //A transposed column of matrixB @@ -237,17 +235,17 @@ class MMTest extends test_util.TestWithCUDA { ArrayType(kTile, ArrayType(nTile, f16))), //Result: tile of cMatrix as fragment - FragmentType(mTile, nTile, kTile, f32), + FragmentType(mTile, nTile, kTile, f32, FragmentKind.Accumulator, MatrixLayout.None), //Multiply matrixATile and matrixBTile Lambda[ExpType, FunType[ExpType, ExpType]](matrixCFrag, Lambda[ExpType, ExpType](matrixABTiles, //matrix multiply and accumulate - TensorMatMultAdd(mTile, nTile, kTile, Row_Major, Row_Major, f16, f32, + TensorMatMultAdd(mTile, nTile, kTile, MatrixLayoutIdentifier("ml"), MatrixLayoutIdentifier("ml"), f16, f32, //matrixATile as fragment - ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, kTile, nTile, f16, FragmentKind.AMatrix, Row_Major), - AsFragment(mTile, kTile, nTile, f16, FragmentKind.AMatrix, + ToMem(shine.cuda.AddressSpace.Private, FragmentType(mTile, kTile, nTile, f16, FragmentKind.AMatrix, MatrixLayoutIdentifier("ml")), + shine.cuda.AsFragment(mTile, kTile, nTile, f16, FragmentKind.AMatrix, Transpose(kTile, mTile, f16, read, Fst( ArrayType(mTile, ArrayType(kTile, f16)), @@ -255,8 +253,8 @@ class MMTest extends test_util.TestWithCUDA { matrixABTiles)))), //matrixBTile as fragment - ToMem(shine.cuda.AddressSpace.Private, FragmentType(kTile, nTile, mTile, f16, FragmentKind.BMatrix, Row_Major), - AsFragment(kTile, nTile, mTile, f16, FragmentKind.BMatrix, + ToMem(shine.cuda.AddressSpace.Private, FragmentType(kTile, nTile, mTile, f16, FragmentKind.BMatrix, MatrixLayoutIdentifier("ml")), + shine.cuda.AsFragment(kTile, nTile, mTile, f16, FragmentKind.BMatrix, Snd( ArrayType(mTile, ArrayType(kTile, f16)), ArrayType(kTile, ArrayType(nTile, f16)), @@ -265,7 +263,7 @@ class MMTest extends test_util.TestWithCUDA { matrixCFrag))), //Neutral Element for Reduce: fragment initialized with zeros - GenerateFragment(mTile, nTile, kTile, f32, Literal(FloatData(0.0f)), FragmentKind.Accumulator, Row_Major), + GenerateFragment(mTile, nTile, kTile, f32, FragmentKind.Accumulator, Row_Major, Literal(FloatData(0.0f))), //Zip transposed, splited row of matrixA and splited column of matrixB Zip(k /^ kTile,