Skip to content

Commit

Permalink
Refactor the top level partitioning pass into using the chunk_node me…
Browse files Browse the repository at this point in the history
…thod
  • Loading branch information
Gabriel Rodriguez-Canal committed Aug 29, 2024
1 parent 40faa69 commit 7643e88
Showing 1 changed file with 11 additions and 73 deletions.
84 changes: 11 additions & 73 deletions xdsl/transforms/experimental/dataflow_graph2.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,79 +158,6 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
# TODO: this will be a parameter of the DSE
CHUNK_FACTOR = 2

@dataclass
class PartitionTopLevelNodes(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
if op.sym_name.data == "top":
top_level_nodes = [builtin.SymbolTable.lookup_symbol(op, body_op.callee.root_reference) for body_op in op.body.ops if isinstance(body_op, func.Call)]

# Chunk the loop node by a fixed chunk factor (this will be a parameter of the DSE in the future)
for top_level_node in top_level_nodes:
assert isinstance(top_level_node, func.FuncOp)

calls = [top_level_op for top_level_op in top_level_node.walk() if isinstance(top_level_op, func.Call)]
called_node = builtin.SymbolTable.lookup_symbol(op, calls[0].callee.root_reference)

# The chunking happens effectively in the for loop of the called node. This requires recalculating the bounds for each
# clone and TODO: the view of the data for each new loop node
for_called_node = [called_node_op for called_node_op in called_node.walk() if isinstance(called_node_op, scf.For)]
if for_called_node:
for_called_node = for_called_node[0]

iters = get_loop_iters(for_called_node)
n_chunks = int(iters / CHUNK_FACTOR + math.ceil(iters % CHUNK_FACTOR/CHUNK_FACTOR))

chunks = []
chunk_calls = []
for i in range(n_chunks):
chunks.append(called_node.clone())
chunk_name = chunks[i].sym_name.data + f"_{str(i)}"
chunks[i].sym_name = builtin.StringAttr(chunk_name)
chunk_calls.append(func.Call(chunk_name, calls[0].arguments, calls[0].res))


lb = for_called_node.lb.owner.value.value.data
ub = for_called_node.ub.owner.value.value.data
it_range = ub - lb

for i in range(n_chunks):
chunk_for_loop = [chunk_op for chunk_op in chunks[i].walk() if isinstance(chunk_op, scf.For)][0]

# TODO: adapt this for the case where the number of iterations doesn't divide evenly by the CHUNK_FACTOR
chunk_lb = chunk_for_loop.lb.owner.value.value.data
chunk_lb += int(i * (it_range / CHUNK_FACTOR))
chunk_ub = int(chunk_lb + (i+1) * (it_range / CHUNK_FACTOR))

chunk_for_loop.lb.owner.value = builtin.IntegerAttr.from_index_int_value(chunk_lb)
chunk_for_loop.ub.owner.value = builtin.IntegerAttr.from_index_int_value(chunk_ub)


# Data partitioning
n_chunks = len(chunk_calls)
for chunk_idx,chunk_call in enumerate(chunk_calls):
for idx,arg in enumerate(chunk_call.arguments):
if isinstance(arg.type, memref.MemRefType):
memref_dims = [dim.data for dim in arg.type.shape.data]
sizes = memref_dims
sizes[0] = int(sizes[0] / n_chunks)

subview = memref.Subview.from_static_parameters(arg, arg.type, [chunk_idx * sizes[0],0,0,0], sizes, [0,0,0,0])
rewriter.insert_op_before(subview, calls[0])
chunk_call.operands[idx] = subview.result

for chunk in chunks:
rewriter.insert_op_before(chunk, called_node)

for chunk_call in chunk_calls:
rewriter.insert_op_before(chunk_call, calls[0])

rewriter.erase_op(calls[0])

# This is assuming all the chunk nodes run in parallel, i.e. there were enough resources to instantiate one module per node
new_top_level_node_latency = top_level_node.attributes['latency'].value.data / CHUNK_FACTOR
top_level_node.attributes['latency'] = builtin.FloatAttr(new_top_level_node_latency, builtin.Float32Type())

base_chunk_name_counter = dict()

def chunk_node(top_level_node: dataflow.Node, rewriter: PatternRewriter, operations_to_remove: set[Operation]):
Expand Down Expand Up @@ -370,6 +297,17 @@ def chunk_node(top_level_node: dataflow.Node, rewriter: PatternRewriter, operati
for chunk in chunks:
chunk_node(chunk, rewriter, operations_to_remove)

@dataclass
class PartitionTopLevelNodes(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
if op.sym_name.data == "top":
top_level_nodes = [builtin.SymbolTable.lookup_symbol(op, body_op.callee.root_reference) for body_op in op.body.ops if isinstance(body_op, func.Call)]

# Chunk the loop node by a fixed chunk factor (this will be a parameter of the DSE in the future)
for top_level_node in top_level_nodes:
chunk_node(top_level_node, rewriter)

@dataclass
class DSE(RewritePattern):
function_latency: dict[str, float]
Expand Down

0 comments on commit 7643e88

Please sign in to comment.