Skip to content

Commit

Permalink
Merge pull request #218 from rise-lang/host_code_generation_failure
Browse files Browse the repository at this point in the history
Fix SeparateHostAndKernelCode for NatToNatLambda
  • Loading branch information
johanneslenfers authored Oct 27, 2021
2 parents 6eb08b5 + 5071fc7 commit 047cb52
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
10 changes: 5 additions & 5 deletions src/main/scala/apps/gemv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ object gemv {
zip(xs)(t._1) |>
split(n) |>
toLocalFun(mapLocal(
reduceSeq(fun(a => fun(x => mult(x) + a)))(lf32(0.0f))
oclReduceSeq(AddressSpace.Private)(fun(a => fun(x => mult(x) + a)))(lf32(0.0f))
)) |>
mapLocal(fun(x => (alpha * x) + (t._2 * beta)))
)) |> join
Expand All @@ -99,10 +99,10 @@ object gemv {
reorderWithStride(128) |>
split(n /^ 128) |>
toLocalFun(mapLocal(
reduceSeq(fun(a => fun(x => mult(x) + a)))(lf32(0.0f))
oclReduceSeq(AddressSpace.Private)(fun(a => fun(x => mult(x) + a)))(lf32(0.0f))
)) |>
split(128) |>
toLocalFun(mapLocal(reduceSeq(add)(lf32(0.0f)))) |>
toLocalFun(mapLocal(oclReduceSeq(AddressSpace.Private)(add)(lf32(0.0f)))) |>
mapLocal(fun(x => (alpha * x) + (t._2 * beta)))
)) |> join
))
Expand All @@ -117,9 +117,9 @@ object gemv {
reorderWithStride(128) |>
split(n /^ 128) |>
toLocalFun(mapLocal(
reduceSeq(fun(a => fun(x => mult(x) + a)))(lf32(0.0f))
oclReduceSeq(AddressSpace.Private)(fun(a => fun(x => mult(x) + a)))(lf32(0.0f))
)) |>
toLocalFun(reduceSeq(add)(lf32(0.0f))) |>
toLocalFun(oclReduceSeq(AddressSpace.Private)(add)(lf32(0.0f))) |>
fun(x => (alpha * x) + (t._2 * beta))
))
))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package shine.OpenCL.Compilation

import arithexpr.arithmetic.NamedVar
import rise.core.types.{DataKind, DataType, NatIdentifier, NatKind}
import rise.core.types.{DataKind, DataType, NatIdentifier, NatKind, NatToNat, NatToNatLambda}
import rise.core.types.DataType.DataTypeIdentifier
import shine.DPIA.Compilation.FunDef
import shine.DPIA.Phrases._
Expand Down Expand Up @@ -90,8 +90,8 @@ object SeparateHostAndKernelCode {
// TODO: collect free nat identifiers?
private def freeVariables(p: Phrase[_ <: PhraseType])
: (Set[Identifier[ExpType]], Set[NamedVar]) = {
var idents = scala.collection.mutable.Set[Identifier[ExpType]]()
var natIdents = scala.collection.mutable.Set[NamedVar]()
val idents = scala.collection.mutable.Set[Identifier[ExpType]]()
val natIdents = scala.collection.mutable.Set[NamedVar]()

case class Visitor(boundV: Set[Identifier[_]],
boundT: Set[DataTypeIdentifier],
Expand All @@ -110,6 +110,12 @@ object SeparateHostAndKernelCode {
case _ => Continue(p, this)
}

override def natToNat(ft: NatToNat): NatToNat = ft match {
case NatToNatLambda(x, b) =>
NatToNatLambda(x, this.copy(boundN = boundN + x).nat(b))
case _ => super.natToNat(ft)
}

override def nat[N <: Nat](n: N): N = {
natIdents ++= n.varList.collect {
case v: NamedVar if !boundN(v) => v
Expand Down
20 changes: 19 additions & 1 deletion src/test/scala/apps/gemvCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package apps
import gemv._
import rise.core.DSL._
import Type._
import rise.autotune
import rise.core.Expr
import rise.core.types._
import util.gen
import util.{gen}
import util.gen.c.function
import rise.core.types.DataType._
import shine.OpenCL.{GlobalSize, LocalSize}

class gemvCheck extends test_util.Tests {
private val N = 128
Expand Down Expand Up @@ -40,6 +43,21 @@ class gemvCheck extends test_util.Tests {
ocl.gemvKeplerBest.toExpr
}

test("OpenCL gemv versions host-code generation creates syntactically correct host-code"){

def run(e: ToBeTyped[Expr], localSize: LocalSize, globalSize: GlobalSize):String = {
val wrapped = autotune.wrapOclRun(localSize, globalSize)(e)
val codeModule = gen.opencl.hosted.fromExpr(wrapped)
shine.OpenCL.Module.translateToString(codeModule) // syntax checker is called here
}

run(ocl.gemvBlastN, LocalSize(64), GlobalSize(1024))
run(ocl.gemvBlastT, LocalSize(64), GlobalSize(1024))
run(ocl.gemvFused, LocalSize(128), GlobalSize(1024))
run(ocl.gemvFusedAMD, LocalSize(128), GlobalSize(1024))
run(ocl.gemvKeplerBest, LocalSize(128), GlobalSize(1024))
}

test("OpenMP gemv versions type inference works") {
omp.gemvFused.toExpr
}
Expand Down

0 comments on commit 047cb52

Please sign in to comment.