diff --git a/src/main/scala/apps/gemv.scala b/src/main/scala/apps/gemv.scala index cd22cb5da..bdf48916c 100644 --- a/src/main/scala/apps/gemv.scala +++ b/src/main/scala/apps/gemv.scala @@ -84,7 +84,7 @@ object gemv { zip(xs)(t._1) |> split(n) |> toLocalFun(mapLocal( - reduceSeq(fun(a => fun(x => mult(x) + a)))(lf32(0.0f)) + oclReduceSeq(AddressSpace.Private)(fun(a => fun(x => mult(x) + a)))(lf32(0.0f)) )) |> mapLocal(fun(x => (alpha * x) + (t._2 * beta))) )) |> join @@ -99,10 +99,10 @@ object gemv { reorderWithStride(128) |> split(n /^ 128) |> toLocalFun(mapLocal( - reduceSeq(fun(a => fun(x => mult(x) + a)))(lf32(0.0f)) + oclReduceSeq(AddressSpace.Private)(fun(a => fun(x => mult(x) + a)))(lf32(0.0f)) )) |> split(128) |> - toLocalFun(mapLocal(reduceSeq(add)(lf32(0.0f)))) |> + toLocalFun(mapLocal(oclReduceSeq(AddressSpace.Private)(add)(lf32(0.0f)))) |> mapLocal(fun(x => (alpha * x) + (t._2 * beta))) )) |> join )) @@ -117,9 +117,9 @@ object gemv { reorderWithStride(128) |> split(n /^ 128) |> toLocalFun(mapLocal( - reduceSeq(fun(a => fun(x => mult(x) + a)))(lf32(0.0f)) + oclReduceSeq(AddressSpace.Private)(fun(a => fun(x => mult(x) + a)))(lf32(0.0f)) )) |> - toLocalFun(reduceSeq(add)(lf32(0.0f))) |> + toLocalFun(oclReduceSeq(AddressSpace.Private)(add)(lf32(0.0f))) |> fun(x => (alpha * x) + (t._2 * beta)) )) )) diff --git a/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala b/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala index 35c267985..5afc81018 100644 --- a/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala +++ b/src/main/scala/shine/OpenCL/Compilation/SeparateHostAndKernelCode.scala @@ -1,7 +1,7 @@ package shine.OpenCL.Compilation import arithexpr.arithmetic.NamedVar -import rise.core.types.{DataKind, DataType, NatIdentifier, NatKind} +import rise.core.types.{DataKind, DataType, NatIdentifier, NatKind, NatToNat, NatToNatLambda} import rise.core.types.DataType.DataTypeIdentifier import shine.DPIA.Compilation.FunDef import shine.DPIA.Phrases._ @@ -90,8 +90,8 @@ object SeparateHostAndKernelCode { // TODO: collect free nat identifiers? private def freeVariables(p: Phrase[_ <: PhraseType]) : (Set[Identifier[ExpType]], Set[NamedVar]) = { - var idents = scala.collection.mutable.Set[Identifier[ExpType]]() - var natIdents = scala.collection.mutable.Set[NamedVar]() + val idents = scala.collection.mutable.Set[Identifier[ExpType]]() + val natIdents = scala.collection.mutable.Set[NamedVar]() case class Visitor(boundV: Set[Identifier[_]], boundT: Set[DataTypeIdentifier], @@ -110,6 +110,12 @@ object SeparateHostAndKernelCode { case _ => Continue(p, this) } + override def natToNat(ft: NatToNat): NatToNat = ft match { + case NatToNatLambda(x, b) => + NatToNatLambda(x, this.copy(boundN = boundN + x).nat(b)) + case _ => super.natToNat(ft) + } + override def nat[N <: Nat](n: N): N = { natIdents ++= n.varList.collect { case v: NamedVar if !boundN(v) => v diff --git a/src/test/scala/apps/gemvCheck.scala b/src/test/scala/apps/gemvCheck.scala index 933ee86e4..db1b51562 100644 --- a/src/test/scala/apps/gemvCheck.scala +++ b/src/test/scala/apps/gemvCheck.scala @@ -3,10 +3,13 @@ package apps import gemv._ import rise.core.DSL._ import Type._ +import rise.autotune +import rise.core.Expr import rise.core.types._ -import util.gen +import util.{gen} import util.gen.c.function import rise.core.types.DataType._ +import shine.OpenCL.{GlobalSize, LocalSize} class gemvCheck extends test_util.Tests { private val N = 128 @@ -40,6 +43,21 @@ class gemvCheck extends test_util.Tests { ocl.gemvKeplerBest.toExpr } + test("OpenCL gemv versions host-code generation creates syntactically correct host-code"){ + + def run(e: ToBeTyped[Expr], localSize: LocalSize, globalSize: GlobalSize):String = { + val wrapped = autotune.wrapOclRun(localSize, globalSize)(e) + val codeModule = gen.opencl.hosted.fromExpr(wrapped) + shine.OpenCL.Module.translateToString(codeModule) // syntax checker is called here + } + + run(ocl.gemvBlastN, LocalSize(64), GlobalSize(1024)) + run(ocl.gemvBlastT, LocalSize(64), GlobalSize(1024)) + run(ocl.gemvFused, LocalSize(128), GlobalSize(1024)) + run(ocl.gemvFusedAMD, LocalSize(128), GlobalSize(1024)) + run(ocl.gemvKeplerBest, LocalSize(128), GlobalSize(1024)) + } + test("OpenMP gemv versions type inference works") { omp.gemvFused.toExpr }