Skip to content

Commit

Permalink
[XLS] Replace Param with StateRead in Proc
Browse files Browse the repository at this point in the history
This switches Procs in XLS IR from overloading Param nodes to represent both state elements & the act of reading them to owning StateElements with StateRead nodes to represent the act of reading them.

PiperOrigin-RevId: 698160989
  • Loading branch information
ericastor authored and copybara-github committed Nov 19, 2024
1 parent d59743e commit bdc3c79
Show file tree
Hide file tree
Showing 124 changed files with 1,935 additions and 1,272 deletions.
65 changes: 64 additions & 1 deletion docs_src/ir_semantics.md
Original file line number Diff line number Diff line change
Expand Up @@ -1263,9 +1263,72 @@ Value | Type

`after_all` can consume an arbitrary number of token operands including zero.

### State-affecting operations

Procs include a concept of local state, represented as a set of elements. Each
activation can read the state values as they would be left by the previous
activation, and can set the state values for subsequent activations to see.

#### **`state_read`**

Reads (and consumes) the value in the given state element. Every state element
must have a corresponding `state_read` operation.

```
result = state_read(state_element=st)
```

**Types**

Value | Type
-------- | ----
`result` | `T`

**Keyword arguments**

<!-- mdformat off(multiline table cells not supported in mkdocs) -->

| Keyword | Type | Required | Default | Description |
| --------------- | ----------------- | -------- | ------- | --------------------------------- |
| `state_element` | `string` | yes | | Name of the state element to read |

<!-- mdformat on -->

#### **`next_value`**

If `predicate` is true or absent, sets the value that the next activation will
see for the given state element. For each state element, at most one
`next_value` node may fire in a given activation; otherwise, undefined behavior
can result. For this reason, frontends & optimizations should be exceptionally
careful when emitting predicated `next_value` nodes; for safety, frontends may
choose instead to emit a single `next_value` node where `value` uses either a
`sel` or a `priority_sel`, in which case optimizations may translate the result
to multiple `next_value` nodes to potentially enable better throughput.

```
result = next_value(param=read, value=v)
```

**Types**

Value | Type
-------- | ----
`result` | `()`

**Keyword arguments**

<!-- mdformat off(multiline table cells not supported in mkdocs) -->

| Keyword | Type | Required | Default | Description |
| ------- | ---- | -------- | ------- | ---------------------------------------------- |
| `param` | `T` | yes | | The `state_read` for the target state element |
| `value` | `T` | yes | | The value to write to the target state element |

<!-- mdformat on -->

### Other side-effecting operations

Aside from channels operations such as `send` and `receive` several other
Aside from channel operations such as `send` and `receive`, several other
operations have side-effects. Care must be taken when adding, removing, or
transforming these operations, e.g., in the optimizer.

Expand Down
2 changes: 2 additions & 0 deletions xls/codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ cc_library(
"//xls/ir:op",
"//xls/ir:register",
"//xls/ir:source_location",
"//xls/ir:state_element",
"//xls/ir:type",
"//xls/ir:value",
"//xls/ir:value_utils",
Expand Down Expand Up @@ -1132,6 +1133,7 @@ cc_test(
"//xls/ir:op",
"//xls/ir:register",
"//xls/ir:source_location",
"//xls/ir:state_element",
"//xls/passes:pass_base",
"//xls/scheduling:pipeline_schedule",
"@com_google_absl//absl/status:status_matchers",
Expand Down
67 changes: 36 additions & 31 deletions xls/codegen/block_conversion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
#include "xls/ir/proc.h"
#include "xls/ir/register.h"
#include "xls/ir/source_location.h"
#include "xls/ir/state_element.h"
#include "xls/ir/topo_sort.h"
#include "xls/ir/type.h"
#include "xls/ir/value.h"
Expand Down Expand Up @@ -2185,11 +2186,11 @@ class CloneNodesIntoBlockHandler {
for (Node* node : sorted_nodes) {
Node* next_node = nullptr;
if (node->Is<Param>()) {
if (is_proc_) {
XLS_ASSIGN_OR_RETURN(next_node, HandleStateParam(node, stage));
} else {
XLS_ASSIGN_OR_RETURN(next_node, HandleFunctionParam(node));
}
XLS_RET_CHECK(!is_proc_);
XLS_ASSIGN_OR_RETURN(next_node, HandleFunctionParam(node));
} else if (node->Is<StateRead>()) {
XLS_RET_CHECK(is_proc_);
XLS_ASSIGN_OR_RETURN(next_node, HandleStateRead(node, stage));
} else if (node->Is<Next>()) {
XLS_RET_CHECK(is_proc_);
XLS_RETURN_IF_ERROR(HandleNextValue(node, stage));
Expand Down Expand Up @@ -2318,14 +2319,16 @@ class CloneNodesIntoBlockHandler {
}

private:
// Don't clone state Param operations. Instead replace with a RegisterRead
// Don't clone state read operations. Instead replace with a RegisterRead
// operation.
absl::StatusOr<Node*> HandleStateParam(Node* node, Stage stage) {
absl::StatusOr<Node*> HandleStateRead(Node* node, Stage stage) {
CHECK_GE(stage, 0);

Proc* proc = function_base_->AsProcOrDie();
Param* param = node->As<Param>();
XLS_ASSIGN_OR_RETURN(int64_t index, proc->GetStateParamIndex(param));
StateRead* state_read = node->As<StateRead>();
StateElement* state_element = state_read->state_element();
XLS_ASSIGN_OR_RETURN(int64_t index,
proc->GetStateElementIndex(state_element));

Register* reg = nullptr;
RegisterRead* reg_read = nullptr;
Expand All @@ -2334,7 +2337,7 @@ class CloneNodesIntoBlockHandler {
// and updated. That register should be created with the
// state parameter's name. See UpdateStateRegisterWithReset().
std::string name =
block()->UniquifyNodeName(absl::StrCat("__", param->name()));
block()->UniquifyNodeName(absl::StrCat("__", state_element->name()));

XLS_ASSIGN_OR_RETURN(reg, block()->AddRegister(name, node->GetType()));

Expand All @@ -2347,8 +2350,8 @@ class CloneNodesIntoBlockHandler {

// The register write will be created later in HandleNextValue.
result_.state_registers[index] =
StateRegister{.name = std::string(param->name()),
.reset_value = proc->GetInitValueElement(index),
StateRegister{.name = std::string(state_element->name()),
.reset_value = state_element->initial_value(),
.read_stage = stage,
.reg = reg,
.reg_write = nullptr,
Expand Down Expand Up @@ -2391,23 +2394,26 @@ class CloneNodesIntoBlockHandler {
absl::Status HandleNextValue(Node* node, Stage stage) {
Proc* proc = function_base_->AsProcOrDie();
Next* next = node->As<Next>();
Param* param = next->param()->As<Param>();
XLS_ASSIGN_OR_RETURN(int64_t index, proc->GetStateParamIndex(param));
StateElement* state_element =
next->state_read()->As<StateRead>()->state_element();
XLS_ASSIGN_OR_RETURN(int64_t index,
proc->GetStateElementIndex(state_element));

CHECK_EQ(proc->GetNextStateElement(index), param);
CHECK_EQ(proc->GetNextStateElement(index), next->state_read());
StateRegister& state_register = *result_.state_registers.at(index);
state_register.next_values.push_back(
{.stage = stage,
.value = next->value() == next->param()
.value = next->value() == next->state_read()
? std::nullopt
: std::make_optional(node_map_.at(next->value())),
.predicate =
next->predicate().has_value()
? std::make_optional(node_map_.at(next->predicate().value()))
: std::nullopt});

bool last_next_value =
absl::c_all_of(proc->next_values(param), [&](Next* next_value) {
bool last_next_value = absl::c_all_of(
proc->next_values(proc->GetStateRead(state_element)),
[&](Next* next_value) {
return next_value == next || node_map_.contains(next_value);
});
if (!last_next_value) {
Expand All @@ -2416,7 +2422,7 @@ class CloneNodesIntoBlockHandler {
return absl::OkStatus();
}

if (param->GetType()->GetFlatBitCount() > 0) {
if (state_element->type()->GetFlatBitCount() > 0) {
// We need a write for the actual value.

// We should only create the RegisterWrite once.
Expand All @@ -2431,17 +2437,17 @@ class CloneNodesIntoBlockHandler {
/*reset=*/std::nullopt, state_register.reg));
result_.output_states[stage].push_back(index);
result_.node_to_stage_map[state_register.reg_write] = stage;
} else if (!param->GetType()->IsToken() &&
param->GetType() != proc->package()->GetTupleType({})) {
} else if (!state_element->type()->IsToken() &&
state_element->type() != proc->package()->GetTupleType({})) {
return absl::UnimplementedError(
absl::StrFormat("Proc has zero-width state element %d, but type is "
"not token or empty tuple, instead got %s.",
index, node->GetType()->ToString()));
}

// If the next state can be determined in a later cycle than the param
// access, we have a non-trivial backedge between initiations (II>1); use a
// "full" bit to track whether the state is currently valid.
// If the next state can be determined in a later cycle than the state read,
// we have a non-trivial backedge between initiations (II>1); use a "full"
// bit to track whether the state is currently valid.
//
// TODO(epastor): Consider an optimization that merges the "full" bits for
// all states with the same read stage & matching write stages/predicates...
Expand Down Expand Up @@ -3510,16 +3516,15 @@ absl::StatusOr<CodegenPassUnit> ProcToCombinationalBlock(
// In a combinational module, the proc cannot have any state to avoid
// combinational loops. That is, the only loop state must be empty tuples.
if (proc->GetStateElementCount() > 1 &&
!std::all_of(proc->StateParams().begin(), proc->StateParams().end(),
[&](Param* p) {
return p->GetType() == proc->package()->GetTupleType({});
})) {
!absl::c_all_of(proc->StateElements(), [&](StateElement* st) {
return st->type() == proc->package()->GetTupleType({});
})) {
return absl::InvalidArgumentError(absl::StrFormat(
"Proc must have no state (or state type is all empty tuples) when "
"lowering to a combinational block. Proc state type is: {%s}",
absl::StrJoin(proc->StateParams(), ", ",
[](std::string* out, Param* p) {
absl::StrAppend(out, p->GetType()->ToString());
absl::StrJoin(proc->StateElements(), ", ",
[](std::string* out, StateElement* st) {
absl::StrAppend(out, st->type()->ToString());
})));
}

Expand Down
1 change: 1 addition & 0 deletions xls/codegen/block_conversion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class TestDelayEstimator : public DelayEstimator {
case Op::kAfterAll:
case Op::kMinDelay:
case Op::kParam:
case Op::kStateRead:
case Op::kNext:
case Op::kInputPort:
case Op::kOutputPort:
Expand Down
1 change: 1 addition & 0 deletions xls/codegen/pipeline_generator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class TestDelayEstimator : public DelayEstimator {
absl::StatusOr<int64_t> GetOperationDelayInPs(Node* node) const override {
switch (node->op()) {
case Op::kParam:
case Op::kStateRead:
case Op::kLiteral:
case Op::kBitSlice:
case Op::kConcat:
Expand Down
18 changes: 11 additions & 7 deletions xls/codegen/register_combining_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "xls/ir/op.h"
#include "xls/ir/register.h"
#include "xls/ir/source_location.h"
#include "xls/ir/state_element.h"
#include "xls/passes/pass_base.h"
#include "xls/scheduling/pipeline_schedule.h"

Expand Down Expand Up @@ -82,22 +83,25 @@ class RegisterCombiningPassTest : public IrTestBase {

auto StateToRegMatcher(BValue st) {
// TODO(allight): Recreates a block-conversion function.
EXPECT_THAT(st.node(), m::Param());
return Reg(absl::StrFormat("__%s", st.GetName()));
EXPECT_THAT(st.node(), m::StateRead());
StateElement* state_element = st.node()->As<StateRead>()->state_element();
return Reg(absl::StrFormat("__%s", state_element->name()));
}
auto StateToRegFullMatcher(BValue st) {
// TODO(allight): Recreates a block-conversion function.
EXPECT_THAT(st.node(), m::Param());
return Reg(absl::StrFormat("__%s_full", st.GetName()));
EXPECT_THAT(st.node(), m::StateRead());
StateElement* state_element = st.node()->As<StateRead>()->state_element();
return Reg(absl::StrFormat("__%s_full", state_element->name()));
}
auto StageValidMatcher(Stage s) {
return Reg(absl::StrFormat("p%d_valid", s));
}
auto NodeToRegMatcher(BValue v, Stage s) {
// TODO(allight): Recreates a block-conversion function.
if (v.node()->Is<Param>()) {
return Reg(
MatchesRegex(absl::StrFormat("p%d___%s__[0-9]+", s, v.GetName())));
if (v.node()->Is<StateRead>()) {
StateElement* state_element = v.node()->As<StateRead>()->state_element();
return Reg(MatchesRegex(
absl::StrFormat("p%d___%s__[0-9]+", s, state_element->name())));
}
return Reg(MatchesRegex(absl::StrFormat("p%d_%s", s, NodeName(v))));
}
Expand Down
1 change: 1 addition & 0 deletions xls/contrib/xlscc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ cc_library(
"//xls/ir:ir_parser",
"//xls/ir:op",
"//xls/ir:source_location",
"//xls/ir:state_element",
"//xls/ir:type",
"//xls/ir:value",
"//xls/ir:value_utils",
Expand Down
Loading

0 comments on commit bdc3c79

Please sign in to comment.