Skip to content

Commit

Permalink
Add One-Hot Encoding for Static FSMs (#2037)
Browse files Browse the repository at this point in the history
* cli options

* static control tests

* fsm implementation type

* FSM allocation

* extended instatiation of staticfsm object with one hot encoding

* one hot

* clippy

* clippy

* does this make clippy stop complainng

* one-hot initialize to 00001

* cleaner code

* loose ends

* code refactoring

* hopefully github shows fewer changes

* clippy

* rewrite tests

* documentation, code cleaning

* better documentation

* higher one-hot cutoff

---------

Co-authored-by: Parth Sarkar <[email protected]>
  • Loading branch information
calebmkim and parthsarkar17 authored May 13, 2024
1 parent cf8224e commit 015da80
Show file tree
Hide file tree
Showing 25 changed files with 850 additions and 219 deletions.
347 changes: 262 additions & 85 deletions calyx-opt/src/analysis/static_schedule.rs

Large diffs are not rendered by default.

213 changes: 129 additions & 84 deletions calyx-opt/src/passes/compile_static.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::analysis::{GraphColoring, StaticFSM, StaticSchedule};
use crate::traversal::{Action, Named, VisResult, Visitor};
use crate::traversal::{
Action, ConstructVisitor, Named, ParseVal, PassOpt, VisResult, Visitor,
};
use calyx_ir as ir;
use calyx_ir::{guard, structure, GetAttributes};
use calyx_utils::Error;
use calyx_utils::{CalyxResult, Error};
use ir::{build_assignments, RRC};
use itertools::Itertools;
use std::collections::{HashMap, HashSet};
use std::ops::Not;
use std::rc::Rc;

#[derive(Default)]
/// Compiles Static Islands
pub struct CompileStatic {
/// maps original static group names to the corresponding group that has an FSM that reset early
Expand All @@ -20,6 +21,10 @@ pub struct CompileStatic {
signal_reg_map: HashMap<ir::Id, ir::Id>,
/// maps reset_early_group names to StaticFSM object
fsm_info_map: HashMap<ir::Id, ir::RRC<StaticFSM>>,
// ========= Pass Options ============
/// How many states the static FSM must have before we pick binary encoding over
/// one-hot
one_hot_cutoff: u64,
}

impl Named for CompileStatic {
Expand All @@ -30,39 +35,37 @@ impl Named for CompileStatic {
fn description() -> &'static str {
"compiles static sub-programs into a dynamic group"
}
}

// Given a list of `static_groups`, find the group named `name`.
// If there is no such group, then there is an unreachable! error.
fn find_static_group(
name: &ir::Id,
static_groups: &[ir::RRC<ir::StaticGroup>],
) -> ir::RRC<ir::StaticGroup> {
Rc::clone(
static_groups
.iter()
.find(|static_group| static_group.borrow().name() == name)
.unwrap_or_else(|| {
unreachable!("couldn't find static group {name}")
}),
)
fn opts() -> Vec<PassOpt> {
vec![PassOpt::new(
"one-hot-cutoff",
"The upper limit on the number of states the static FSM must have before we pick binary \
encoding over one-hot. Defaults to 0 (i.e., always choose binary encoding)",
ParseVal::Num(0),
PassOpt::parse_num,
)]
}
}

// Given an input static_group `sgroup`, finds the names of all of the groups
// that it triggers through their go hole.
// E.g., if `sgroup` has assignments that write to `sgroup1[go]` and `sgroup2[go]`
// then return `{sgroup1, sgroup2}`
// Assumes that static groups will only write the go holes of other static
// groups, and never dynamic groups (which seems like a reasonable assumption).
fn get_go_writes(sgroup: &ir::RRC<ir::StaticGroup>) -> HashSet<ir::Id> {
let mut uses = HashSet::new();
for asgn in &sgroup.borrow().assignments {
let dst = asgn.dst.borrow();
if dst.is_hole() && dst.name == "go" {
uses.insert(dst.get_parent_name());
}
impl ConstructVisitor for CompileStatic {
fn from(ctx: &ir::Context) -> CalyxResult<Self> {
let opts = Self::get_opts(ctx);

Ok(CompileStatic {
one_hot_cutoff: opts["one-hot-cutoff"].pos_num().unwrap(),
reset_early_map: HashMap::new(),
wrapper_map: HashMap::new(),
signal_reg_map: HashMap::new(),
fsm_info_map: HashMap::new(),
})
}

fn clear_data(&mut self) {
self.reset_early_map = HashMap::new();
self.wrapper_map = HashMap::new();
self.signal_reg_map = HashMap::new();
self.fsm_info_map = HashMap::new();
}
uses
}

impl CompileStatic {
Expand Down Expand Up @@ -92,7 +95,8 @@ impl CompileStatic {
});

// fsm.out == 0
let first_state = *fsm_object.borrow().query_between(builder, (0, 1));
let first_state =
*fsm_object.borrow_mut().query_between(builder, (0, 1));
structure!( builder;
let signal_on = constant(1, 1);
let signal_off = constant(0, 1);
Expand Down Expand Up @@ -163,7 +167,7 @@ impl CompileStatic {
)
});

let fsm_eq_0 = *fsm_object.borrow().query_between(builder, (0, 1));
let fsm_eq_0 = *fsm_object.borrow_mut().query_between(builder, (0, 1));

let wrapper_group =
builder.add_group(format!("while_wrapper_{}", group_name));
Expand All @@ -189,34 +193,6 @@ impl CompileStatic {
wrapper_group
}

// Get early reset group name from static control (we assume the static control
// is an enable).
fn get_reset_group_name(&self, sc: &mut ir::StaticControl) -> &ir::Id {
// assume that there are only static enables left.
// if there are any other type of static control, then error out.
let ir::StaticControl::Enable(s) = sc else {
unreachable!("Non-Enable Static Control should have been compiled away. Run {} to do this", crate::passes::StaticInliner::name());
};

let sgroup = s.group.borrow_mut();
let sgroup_name = sgroup.name();
// get the "early reset group". It should exist, since we made an
// early_reset group for every static group in the component
let early_reset_name =
self.reset_early_map.get(&sgroup_name).unwrap_or_else(|| {
unreachable!(
"group {} not in self.reset_early_map",
sgroup_name
)
});

early_reset_name
}
}

// These are the functions/methods used to assign FSMs to static islands
// (Currently we use greedy coloring).
impl CompileStatic {
// Given a `coloring` of static group names, along with the actual `static_groups`,
// it builds one StaticSchedule per color.
fn build_schedule_objects(
Expand Down Expand Up @@ -252,6 +228,66 @@ impl CompileStatic {
.collect()
}

// Get early reset group name from static control (we assume the static control
// is an enable).
fn get_reset_group_name(&self, sc: &mut ir::StaticControl) -> &ir::Id {
// assume that there are only static enables left.
// if there are any other type of static control, then error out.
let ir::StaticControl::Enable(s) = sc else {
unreachable!("Non-Enable Static Control should have been compiled away. Run {} to do this", crate::passes::StaticInliner::name());
};

let sgroup = s.group.borrow_mut();
let sgroup_name = sgroup.name();
// get the "early reset group". It should exist, since we made an
// early_reset group for every static group in the component
let early_reset_name =
self.reset_early_map.get(&sgroup_name).unwrap_or_else(|| {
unreachable!(
"group {} not in self.reset_early_map",
sgroup_name
)
});

early_reset_name
}
}

// These are the functions used to allocate FSMs to static islands
impl CompileStatic {
// Given a list of `static_groups`, find the group named `name`.
// If there is no such group, then there is an unreachable! error.
fn find_static_group(
name: &ir::Id,
static_groups: &[ir::RRC<ir::StaticGroup>],
) -> ir::RRC<ir::StaticGroup> {
Rc::clone(
static_groups
.iter()
.find(|static_group| static_group.borrow().name() == name)
.unwrap_or_else(|| {
unreachable!("couldn't find static group {name}")
}),
)
}

// Given an input static_group `sgroup`, finds the names of all of the groups
// that it triggers through their go hole.
// E.g., if `sgroup` has assignments that write to `sgroup1[go]` and `sgroup2[go]`
// then return `{sgroup1, sgroup2}`
// Assumes that static groups will only write the go holes of other static
// groups, and never dynamic groups (which seems like a reasonable assumption).
fn get_go_writes(sgroup: &ir::RRC<ir::StaticGroup>) -> HashSet<ir::Id> {
let mut uses = HashSet::new();
for asgn in &sgroup.borrow().assignments {
let dst = asgn.dst.borrow();
if dst.is_hole() && dst.name == "go" {
uses.insert(dst.get_parent_name());
}
}
uses
}

// Gets all of the triggered static groups within `c`, and adds it to `cur_names`.
// Relies on sgroup_uses_map to take into account groups that are triggered through
// their `go` hole.
Expand Down Expand Up @@ -417,8 +453,10 @@ impl CompileStatic {
group_names: &mut HashSet<ir::Id>,
sgroups: &Vec<ir::RRC<ir::StaticGroup>>,
) {
let group_uses =
get_go_writes(&find_static_group(parent_group, sgroups));
let group_uses = Self::get_go_writes(&Self::find_static_group(
parent_group,
sgroups,
));
for group_use in group_uses {
for ancestor in full_group_ancestry.iter() {
cur_mapping.entry(*ancestor).or_default().insert(group_use);
Expand Down Expand Up @@ -461,6 +499,21 @@ impl CompileStatic {
}
cur_mapping
}

pub fn get_coloring(
sgroups: &Vec<ir::RRC<ir::StaticGroup>>,
control: &ir::Control,
) -> HashMap<ir::Id, ir::Id> {
// `sgroup_uses_map` builds a mapping of static groups -> groups that
// it (even indirectly) triggers the `go` port of.
let sgroup_uses_map = Self::build_sgroup_uses_map(sgroups);
// Build conflict graph and get coloring.
let mut conflict_graph: GraphColoring<ir::Id> =
GraphColoring::from(sgroups.iter().map(|g| g.borrow().name()));
Self::add_par_conflicts(control, &sgroup_uses_map, &mut conflict_graph);
Self::add_go_port_conflicts(&sgroup_uses_map, &mut conflict_graph);
conflict_graph.color_greedy(None, true)
}
}

// These are the functions used to compile for the static *component* interface
Expand Down Expand Up @@ -519,7 +572,7 @@ impl CompileStatic {

// Makes `done` signal for promoted static<n> component.
fn make_done_signal_for_promoted_component(
fsm: &StaticFSM,
fsm: &mut StaticFSM,
builder: &mut ir::Builder,
comp_sig: RRC<ir::Cell>,
) -> Vec<ir::Assignment<ir::Nothing>> {
Expand Down Expand Up @@ -581,14 +634,16 @@ impl CompileStatic {
// The assignments are removed from `sgroup` and placed into
// `builder.component`'s continuous assignments.
fn compile_static_interface(
&self,
sgroup: ir::RRC<ir::StaticGroup>,
builder: &mut ir::Builder,
) {
if sgroup.borrow().get_latency() > 1 {
// Build a StaticSchedule object, realize it and add assignments
// as continuous assignments.
let mut sch = StaticSchedule::from(vec![Rc::clone(&sgroup)]);
let (mut assigns, fsm) = sch.realize_schedule(builder, true);
let (mut assigns, mut fsm) =
sch.realize_schedule(builder, true, self.one_hot_cutoff);
builder
.component
.continuous_assignments
Expand All @@ -598,7 +653,7 @@ impl CompileStatic {
// If necessary, add the logic to produce a done signal.
let done_assigns =
Self::make_done_signal_for_promoted_component(
&fsm, builder, comp_sig,
&mut fsm, builder, comp_sig,
);
builder
.component
Expand Down Expand Up @@ -666,20 +721,7 @@ impl Visitor for CompileStatic {
// The first thing is to assign FSMs -> static islands.
// We sometimes assign the same FSM to different static islands
// to reduce register usage. We do this by getting greedy coloring.

// `sgroup_uses_map` builds a mapping of static groups -> groups that
// it (even indirectly) triggers the `go` port of.
let sgroup_uses_map = Self::build_sgroup_uses_map(&sgroups);
// Build conflict graph and get coloring.
let mut conflict_graph: GraphColoring<ir::Id> =
GraphColoring::from(sgroups.iter().map(|g| g.borrow().name()));
Self::add_par_conflicts(
&comp.control.borrow(),
&sgroup_uses_map,
&mut conflict_graph,
);
Self::add_go_port_conflicts(&sgroup_uses_map, &mut conflict_graph);
let coloring = conflict_graph.color_greedy(None, true);
let coloring = Self::get_coloring(&sgroups, &comp.control.borrow());

let mut builder = ir::Builder::new(comp, sigs);
// Build one StaticSchedule object per color
Expand All @@ -706,13 +748,16 @@ impl Visitor for CompileStatic {
// Compile top level static group differently.
// We know that the top level static island has its own
// unique FSM so we can do `.pop().unwrap()`
Self::compile_static_interface(
self.compile_static_interface(
sch.static_groups.pop().unwrap(),
&mut builder,
)
} else {
let (mut static_group_assigns, fsm) = sch
.realize_schedule(&mut builder, static_component_interface);
let (mut static_group_assigns, fsm) = sch.realize_schedule(
&mut builder,
static_component_interface,
self.one_hot_cutoff,
);
let fsm_ref = ir::rrc(fsm);
for static_group in sch.static_groups.iter() {
// Create the dynamic "early reset group" that will replace the static group.
Expand Down
4 changes: 2 additions & 2 deletions examples/futil/dot-product.expect
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) {
fsm.write_en = early_reset_cond00_go.out | early_reset_static_seq_go.out ? 1'd1;
fsm.clk = clk;
fsm.reset = reset;
fsm.in = fsm.out != 4'd0 & early_reset_cond00_go.out ? adder.out;
fsm.in = !(fsm.out == 4'd0) & early_reset_cond00_go.out ? adder.out;
fsm.in = fsm.out == 4'd0 & early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_seq_go.out ? 4'd0;
fsm.in = fsm.out != 4'd7 & early_reset_static_seq_go.out ? adder0.out;
fsm.in = !(fsm.out == 4'd7) & early_reset_static_seq_go.out ? adder0.out;
adder.left = early_reset_cond00_go.out ? fsm.out;
adder.right = early_reset_cond00_go.out ? 4'd1;
add0.left = fsm.out == 4'd6 & early_reset_static_seq_go.out ? v0.read_data;
Expand Down
2 changes: 1 addition & 1 deletion examples/futil/simple.expect
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) {
fsm.write_en = early_reset_static_seq_go.out ? 1'd1;
fsm.clk = clk;
fsm.reset = reset;
fsm.in = fsm.out != 3'd4 & early_reset_static_seq_go.out ? adder.out;
fsm.in = !(fsm.out == 3'd4) & early_reset_static_seq_go.out ? adder.out;
fsm.in = fsm.out == 3'd4 & early_reset_static_seq_go.out ? 3'd0;
adder.left = early_reset_static_seq_go.out ? fsm.out;
adder.right = early_reset_static_seq_go.out ? 3'd1;
Expand Down
4 changes: 2 additions & 2 deletions examples/futil/vectorized-add.expect
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) {
fsm.write_en = early_reset_cond00_go.out | early_reset_static_seq_go.out ? 1'd1;
fsm.clk = clk;
fsm.reset = reset;
fsm.in = fsm.out != 3'd0 & early_reset_cond00_go.out ? adder.out;
fsm.in = fsm.out != 3'd3 & early_reset_static_seq_go.out ? adder0.out;
fsm.in = !(fsm.out == 3'd0) & early_reset_cond00_go.out ? adder.out;
fsm.in = !(fsm.out == 3'd3) & early_reset_static_seq_go.out ? adder0.out;
fsm.in = fsm.out == 3'd0 & early_reset_cond00_go.out | fsm.out == 3'd3 & early_reset_static_seq_go.out ? 3'd0;
adder.left = early_reset_cond00_go.out ? fsm.out;
adder.right = early_reset_cond00_go.out ? 3'd1;
Expand Down
Loading

0 comments on commit 015da80

Please sign in to comment.