Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fat JAR executable that generates C code from Rise programs #240

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ modules.xml
*.pdf
*.gz
*.sc

float-safe-optimizer.jar
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,20 @@ The source code for the compiler is organised into sub-packages of the `shine` p

### Setup and Documentation
Please have a look at: https://rise-lang.org/doc/

### Float Safe Optimizer

This repository contains an optimizer executable that preserves floating-point semantics.
To build a Fat JAR executable:
```sh
sbt float_safe_optimizer/assembly
```
To optimize a Rise program and generate code:
```sh
java -Xss20m -Xms512m -Xmx4G -jar float-safe-optimizer.jar $function_name $rise_source_path $output_path
```
For example:
```sh
java -Xss20m -Xms512m -Xmx4G -jar float-safe-optimizer.jar add3 float-safe-optimizer/examples/add3Seq.rise float-safe-optimizer/examples/add3Seq.c
java -Xss20m -Xms512m -Xmx4G -jar float-safe-optimizer.jar add3 float-safe-optimizer/examples/add3.rise float-safe-optimizer/examples/add3.c
```
22 changes: 21 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,25 @@ clap := {
"echo y" #| (baseDirectory.value + "/lib/clap/buildClap.sh") !
}


lazy val float_safe_optimizer = (project in file("float-safe-optimizer"))
.dependsOn(riseAndShine)
.enablePlugins(AssemblyPlugin)
.settings(
excludeDependencies ++= Seq(
ExclusionRule("org.scala-lang.modules", s"scala-xml_${scalaBinaryVersion.value}"),
ExclusionRule("junit", "junit"),
ExclusionRule("com.novocode", "junit-interface"),
ExclusionRule("org.scalacheck", "scalacheck"),
ExclusionRule("org.scalatest", "scalatest"),
ExclusionRule("com.lihaoyi", s"os-lib_${scalaBinaryVersion.value}"),
ExclusionRule("com.typesafe.play", s"play-json_${scalaBinaryVersion.value}"),
ExclusionRule("org.rise-lang", s"opencl-executor_${scalaBinaryVersion.value}"),
ExclusionRule("org.rise-lang", "CUexecutor"),
ExclusionRule("org.elevate-lang", s"cuda-executor_${scalaBinaryVersion.value}"),
ExclusionRule("org.elevate-lang", s"meta_${scalaBinaryVersion.value}"),
),
name := "float-safe-optimizer",
javaOptions ++= Seq("-Xss20m", "-Xms512m", "-Xmx4G"),
assemblyOutputPath in assembly := file("float-safe-optimizer.jar"),
)

42 changes: 42 additions & 0 deletions float-safe-optimizer/Main.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package float_safe_optimizer

import util.gen
import rise.core.Expr
import rise.core.DSL.ToBeTyped

object Main {
def main(args: Array[String]): Unit = {
val name = args(0)
val exprSourcePath = args(1)
val outputPath = args(2)

val exprSource = util.readFile(exprSourcePath)
val untypedExpr = parseExpr(prefixImports(exprSource))
val typedExpr = untypedExpr.toExpr
val optimizedExpr = Optimize(typedExpr)
println(optimizedExpr)
val code = gen.openmp.function.asStringFromExpr(optimizedExpr)
util.writeToPath(outputPath, code)
}

def prefixImports(source: String): String =
s"""
|import rise.core.DSL._
|import rise.core.DSL.Type._
|import rise.core.DSL.HighLevelConstructs._
|import rise.core.primitives._
|import rise.core.types._
|import rise.core.types.DataType._
|import rise.openmp.DSL._
|import rise.openmp.primitives._
|$source
|""".stripMargin

def parseExpr(source: String): ToBeTyped[Expr] = {
import scala.reflect.runtime.universe
import scala.tools.reflect.ToolBox

val toolbox = universe.runtimeMirror(getClass.getClassLoader).mkToolBox()
toolbox.eval(toolbox.parse(source)).asInstanceOf[ToBeTyped[Expr]]
}
}
98 changes: 98 additions & 0 deletions float-safe-optimizer/Optimize.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package float_safe_optimizer

import rise.eqsat._

object Optimize {
def apply(e: rise.core.Expr): rise.core.Expr = {
val expr = Expr.fromNamed(e)
val (body, annotation, wrapBody) = analyseTopLevel(expr)

val rules = {
import rise.eqsat.rules._
Seq(
// implementation choices:
reduceSeq,
mapSeq,
// satisfying read/write annotations:
mapSeqArray,
// simplifications:
mapFusion,
reduceSeqMapFusion,
removeTransposePair,
fstReduction,
sndReduction,
/* maybe:
omp.mapPar --> need heuristic vs mapSeq
toMemAfterMapSeq / storeToMem
storeToMem
reduceSeqMapFusion
mapSeqUnroll/reduceSeqUnroll --> need heuristic
eliminateMapIdentity

is it worth the cost?:
betaExtract
betaNatExtract
eta

not generic enough, use Elevate passes or custom applier?:
idxReduction_i_n
*/
)
}

LoweringSearch.init().run(BENF, Cost, Seq(body), rules, Some(annotation)) match {
case Some(resBody) =>
val res = wrapBody(resBody)
Expr.toNamed(res)
case None => throw new Exception("could not find valid low-level expression")
}
}

// TODO: this code might be avoidable by making DPIA+codegen rely on top-level type instead of top-level constructs
def analyseTopLevel(e: Expr)
: (Expr, (BeamExtractRW.TypeAnnotation, Map[Int, BeamExtractRW.TypeAnnotation]), Expr => Expr) = {
import rise.eqsat.RWAnnotationDSL._

// returns (body, argCount, wrapBody)
def rec(e: Expr): (Expr, Int, Expr => Expr) = {
e.node match {
case NatLambda(e2) =>
val (b, a, w) = rec(e2)
(b, a, b => Expr(NatLambda(w(b)), e.t))
case DataLambda(e2) =>
val (b, a, w) = rec(e2)
(b, a, b => Expr(DataLambda(w(b)), e.t))
case Lambda(e2) =>
e.t.node match {
case FunType(Type(_: DataTypeNode[_, _]), _) => ()
case _ => throw new Exception("top level higher-order functions are not supported")
}
val (b, a, w) = rec(e2)
(b, a + 1, b => Expr(Lambda(w(b)), e.t))
case _ =>
if (!e.t.node.isInstanceOf[DataTypeNode[_, _]]) {
throw new Exception("expected body with data type")
}
(e, 0, b => b)
}
}

val (b, a, w) = rec(e)
(b, (write, List.tabulate(a)(i => i -> read).toMap), w)
}

object Cost extends CostFunction[Int] {
val ordering = implicitly

override def cost(egraph: EGraph, enode: ENode, t: TypeId, costs: EClassId => Int): Int = {
import rise.core.primitives._

val nodeCost = enode match {
// prefer avoiding mapSeq
case Primitive(mapSeq()) => 5
case _ => 1
}
enode.children().foldLeft(nodeCost) { case (acc, eclass) => acc + costs(eclass) }
}
}
}
1 change: 1 addition & 0 deletions float-safe-optimizer/examples/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.c
3 changes: 3 additions & 0 deletions float-safe-optimizer/examples/add3.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
depFun((n: Nat) => fun((n`.`i32) ->: (n`.`i32))(in =>
map(add(li32(3)))(in)
))
3 changes: 3 additions & 0 deletions float-safe-optimizer/examples/add3Seq.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
depFun((n: Nat) => fun((n`.`i32) ->: (n`.`i32))(in =>
mapSeq(add(li32(3)))(in)
))
3 changes: 3 additions & 0 deletions float-safe-optimizer/examples/add3TypeError.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
depFun((n: Nat) => fun((n`.`i32) ->: (n`.`f32))(in =>
map(add(li32(3)))(in)
))
4 changes: 4 additions & 0 deletions float-safe-optimizer/examples/addv.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
depFun((n: Nat) => depFun((m: Nat) => depFun((o: Nat) =>
fun(((n+o)`.`i32) ->: ((m+o)`.`i32) ->: (o`.`i32))((a, b) =>
zip(take(o)(a))(take(o)(b)) |> map(fun(x => fst(x) + snd(x)))
))))
2 changes: 2 additions & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.4.0-RC1")
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.3.0")
addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.2.17")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.15.0")
addDependencyTreePlugin
13 changes: 8 additions & 5 deletions src/main/scala/rise/eqsat/Analysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost])
case App(f, e) =>
val fInT = egraph(egraph.get(f).t) match {
case FunType(inT, _) => inT
case _ => throw new Exception("this should not happen")
case _ => throw new Exception("app expected fun type")
}
val eT = egraph.get(e).t

Expand All @@ -713,7 +713,7 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost])
}
}
}
case _ => throw new Exception("this should not happen")
case _ => throw new Exception("app expected fun type")
}
}
newBeams
Expand Down Expand Up @@ -748,7 +748,7 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost])
}
annotation match {
case NotDataTypeAnnotation(NatFunType(at)) => (at, env) -> newBeam
case _ => throw new Exception("this should not happen")
case _ => throw new Exception("natApp expected NatFunType")
}
}
case DataApp(f, _) =>
Expand All @@ -762,7 +762,7 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost])
}
annotation match {
case NotDataTypeAnnotation(DataFunType(at)) => (at, env) -> newBeam
case _ => throw new Exception("this should not happen")
case _ => throw new Exception("dataApp expected DataFunType")
}
}
case AddrApp(f, _) =>
Expand All @@ -776,7 +776,7 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost])
}
annotation match {
case NotDataTypeAnnotation(AddrFunType(at)) => (at, env) -> newBeam
case _ => throw new Exception("this should not happen")
case _ => throw new Exception("addrApp expected AddrFunType")
}
}
case NatLambda(e) =>
Expand Down Expand Up @@ -923,6 +923,9 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost])
}
}
Seq(rec(n))
case rp.id() =>
// FIXME: only supports non-functional values
Seq(read ->: read, write ->: write)
case _ => throw new Exception(s"did not expect $p")
}
val beam = Seq((
Expand Down
30 changes: 20 additions & 10 deletions src/main/scala/rise/eqsat/LoweringSearch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,37 @@ object LoweringSearch {

// TODO: enable giving a sketch, maybe merge with GuidedSearch?
class LoweringSearch(var filter: Predicate) {
private def topLevelAnnotation(e: Expr): BeamExtractRW.TypeAnnotation = {
private def topLevelAnnotation(t: Type): BeamExtractRW.TypeAnnotation = {
import RWAnnotationDSL._
e.node match {
case NatLambda(e) => nFunT(topLevelAnnotation(e))
case DataLambda(e) => dtFunT(topLevelAnnotation(e))
case Lambda(e) => read ->: topLevelAnnotation(e)
case _ =>
assert(e.t.node.isInstanceOf[DataTypeNode[_, _]])
t.node match {
case NatFunType(t) => nFunT(topLevelAnnotation(t))
case DataFunType(t) => dtFunT(topLevelAnnotation(t))
case AddrFunType(t) => aFunT(topLevelAnnotation(t))
case FunType(ta, tb) =>
if (!ta.node.isInstanceOf[DataTypeNode[_, _]]) {
throw new Exception("top level higher-order functions are not supported")
}
read ->: topLevelAnnotation(tb)
case _: DataTypeNode[_, _] =>
write
case _ =>
throw new Exception(s"did not expect type $t")
}
}

def run(normalForm: NF,
costFunction: CostFunction[_],
startBeam: Seq[Expr],
loweringRules: Seq[Rewrite]): Option[Expr] = {
loweringRules: Seq[Rewrite],
annotations: Option[(BeamExtractRW.TypeAnnotation, Map[Int, BeamExtractRW.TypeAnnotation])] = None): Option[Expr] = {
println("---- lowering")
val egraph = EGraph.empty()
val normBeam = startBeam.map(normalForm.normalize)

val expectedAnnotation = topLevelAnnotation(normBeam.head)
val expectedAnnotations = annotations match {
case Some(annotations) => annotations
case None => (topLevelAnnotation(normBeam.head.t), Map.empty[Int, BeamExtractRW.TypeAnnotation])
}

val rootId = normBeam.map(egraph.addExpr)
.reduce[EClassId] { case (a, b) => egraph.union(a, b)._1 }
Expand All @@ -43,7 +53,7 @@ class LoweringSearch(var filter: Predicate) {

util.printTime("lowered extraction time", {
val tmp = Analysis.oneShot(BeamExtractRW(1, costFunction), egraph)(egraph.find(rootId))
tmp.get((expectedAnnotation, Map.empty))
tmp.get(expectedAnnotations)
.map { beam => ExprWithHashCons.expr(egraph)(beam.head._2) }
})
}
Expand Down
Loading