Skip to content

Latest commit

 

History

History
309 lines (226 loc) · 22 KB

exercise_three.md

File metadata and controls

309 lines (226 loc) · 22 KB

Exercise Three

In this practical we are going to explore transformations in more detail to add OpenMP parallelism and/or vectorisation to our loop in an automated manner.

Learning objectives are:

  • Exploring the role of transformations and how these can manipulate the IR
  • To understand how transformations are developed
  • Gain an understanding of the key ways in which the IR can be traversed and manipulated
  • Awareness of the parallel operation in the scf dialect
  • To further demonstrate reusability benefits of MLIR transformations

Sample solutions to this exercise are provided in sample_solutions in-case you get stuck or just want to compare your efforts with ours.

It is assumed that you have a command line terminal in the training-intro/practical/three directory.

Having problems?
As you go through this exercise if there is anything you are unsure about or are stuck on then please do not hesitate to ask one of the tutorial demonstrators and we will be happy to assist!

The starting point and the plan

We are starting with the same code in practical two as illustrated below, however now we are going to write a transformation that will convert the resulting for operation in the scf dialect into a parallel operation of that same dialect.

@python_compile
def ex_three():
    val=0.0
    add_val=88.2
    for a in range(0, 100000):
      val=val+add_val
    print(val)

ex_three()

The parallel operation represents a parallel for loop, and there are existing MLIR transformations run via mlir-opt that will then parallelise this via OpenMP by lowering into the omp dialect, apply vectorisation by lowering to the vector dialect, or acclerate this via GPUs by lowering to the gpu dialect. This is an illustration of the major reuse benefits of MLIR, where developers need not understand the underlying omp, vector, or gpu dialects, but instead can convert a loop into this higher level parallel operation and all other transformatiosn to exploit these facets are present and can be easily reused.

Illustration of parallel lowering

Driving the transformation

Of course, one way of leveraging the parallel operation would be to edit our tiny_py_to_standard transformation which lowers from tiny py down to the standard dialects, issuing parallel instead of for. However, let's assume that we do not want to edit that and instead wish to apply an optimisation/transformation pass on the resulting IR that comes out of tiny_py_to_standard in order to convert our sequential loop into a parallel one.

If you take a look in tinypy-opt tool (which is in src/tools from the practical directory) you will see at line 17 the register_all_passes function which is registering possible transformations that can be performed on the IR. The second of these, ConvertForToParallel is the transformation that we will be working with in this exercise and have already started off for you.

This transformation can be found in src/for_to_parallel.py and the transformation entry point is defined at the bottom of the file by the class ConvertForToParallel, where the name field defines the name of the transformation as provided to tinypy_opt.

@dataclass
class ConvertForToParallel(ModulePass):
  """
  This is the entry point for the transformation pass which will then apply the rewriter
  """
  name = 'for-to-parallel'

  def apply(self, ctx: MLContext, input_module: ModuleOp):
    walker = PatternRewriteWalker(GreedyRewritePatternApplier([]), apply_recursively=False)
    walker.rewrite_module(input_module)

Here we are creating the PatternRewriteWalker which walks the IR in the block and instruction order, and rewrite it in place if needed. As an argument we provide an instantiation of GreedyRewritePatternApplier which applies a list of patterns in order until one pattern matches. Currently an empty list of patterns are provided (due to the empty list []) and we need to provide our rewrite pattern here.

The rewrite pattern defined above, that we will be working with in a moment, is ApplyForToParallelRewriter, so instantiate this at the first line of the apply method and then pass this as a member of the list argument to GreedyRewritePatternApplier.

Not sure or having problems? Please feel free to ask if there is anything you are unsure about, or you can check the sample solution

What is needed in the IR

We now want to replace the for operation in the scf dialect with a parallel operation, based on what we generated for exercise two, you might assume that this would look something like the following:

%7 = "scf.parallel"(%4, %5, %6, %0) ({
^0(%8 : index, %9 : f32):
  %10 = "arith.addf"(%9, %1) : (f32, f32) -> f32
  "scf.yield"(%10) : (f32) -> ()
}) : (index, index, index, f32) -> f32

However, unfortunately it is not quite this easy! You can see from the IR above that we are updating the left hand side of the addf operation on each loop iteration, effectively undertaking a sum reduction overall. Because of this loop carried dependency, this reduction must be wrapped in a reduce operation which instructs MLIR to implement this as a reduction when it lowers to the omp, vector, or gpu dialects.

Instead, the following is what we are after, where it can be seen that there is now only one argument to the top level block (which is the loop iteration count), and within this block sits the reduction operation from the scf dialect. This operation must contain one block with the left hand and right hand sides of the operation that is being reduced, in this case addf, and the result of this is returned out via the reduce.return operation.

%7 = "scf.parallel"(%4, %5, %6, %0) ({
^0(%8 : index):
  "scf.reduce"(%1) ({
    ^1(%lhs : f32, %rhs : f32):
      %11 = "arith.addf"(%lhs, %rhs) : (f32, f32) -> f32
      "scf.reduce.return"(%11) : (f32) -> ()
    }) : (f32) -> ()
  "scf.yield"() : () -> ()
}) {"operand_segment_sizes" = array<i32: 1, 1, 1, 1>} : (index, index, index, f32) -> f32

If we look at the first line, "%7 =scf.parallel"(%4, %5, %6, %0) in the above snippet, the first three arguments are the loop lower bounds, upper bounds, and step size respectively. All others, here %0 are values provided as arguments and MLIR will use each of these as the left hand side argument (lhs) of the reduce operations. Therefore the number of value provided arguments, in this case 1, must match the number of reductions. This is similar for results of the parallel operation, where the result of the ith reduce operation is mapped to the ith overall result. Lastly, the argument provided in "scf.reduce"(%1) is set as the right hand side (rhs) of the subsequent operation.

Developing the rewrite pass

We now need to develop the rewrite pass to convert the for operation to a parallel operation and extract out the values being updated each iteration and wrap these in a reduce operation. We have started this for you in the ApplyForToParallelRewriter class of the src/for_to_parallel.py file.

class ApplyForToParallelRewriter(RewritePattern):

    @op_type_rewrite_pattern
    def match_and_rewrite(self,
                          for_loop: scf.For, rewriter: PatternRewriter):

        # First we get the body of the for loop and detach it (as will attack to the
        # parallel loop when we create it)
        loop_body=for_loop.body.blocks[0]
        for_loop.body.detach_block(0)
        # Now get the arguments to the yield at the end of the for loop and the arguments
        # to the loop block too
        yielded_args=list(loop_body.ops.last.arguments)
        block_args=list(loop_body.args)

        ops_to_add=[]
        for op in loop_body.ops:          
          # We go through each operation in the loop body and see if it is one that needs
          # a reduction operation applied to it
          if op.name in matched_operations.keys():
            # We need to find if it is the LHS or RHS that is based upon the argument to the block
            # if it is neither then ignore this as it is not going to be updated from one iteration
            # to the next so no need to wrap in a reduction
            if isinstance(op.lhs, BlockArgument):
              block_arg_op=op.lhs
              other_arg=op.rhs
            elif isinstance(op.rhs, BlockArgument):
              block_arg_op=op.rhs
              other_arg=op.lhs
            else:
              continue

            # Now detach op from the body and remove from those arguments yielded
            # and arguments to the top level block
            op.detach()
            yielded_args.remove(op.results[0])
            block_args.remove(block_arg_op)

            # Create a new block for this reduction operation which has the type of
            # operation LHS and RHS present
            block_arg_types=[] # Needs to be completed!
            block = Block(arg_types=block_arg_types)

            # Retrieve the dialect operation to instantiate
            op_instance=matched_operations[op.name]
            assert op_instance is not None

            # Instantiate the dialect operation and create a reduce return operation
            # that will return the result, then add these operations to the block
            new_op=op_instance.get(block.args[0], block.args[1])
            reduce_result=None # Needs to be completed!
            block.add_ops([new_op, reduce_result])

            # Create the reduce operation and add to the top level block
            reduce_op=None # Needs to be completed!
            #ops_to_add.append(reduce_op)

        # Create a new top level block which will have far fewer arguments
        # as none of the reduction arguments are now present here
        new_block=Block(arg_types=[arg.typ for arg in block_args])
        new_block.add_ops(ops_to_add)

        for op in loop_body.ops:
            op.detach()
            new_block.add_op(op)

        # We have a yield at the end of the block which yields non reduction
        # arguments
        new_yield=scf.Yield.get(*yielded_args)
        new_block.erase_op(new_block.ops.last)
        new_block.add_op(new_yield)

        # Create our parallel operation and replace the for loop with this
        parallel_loop=None # Needs to be completed!         

The method match_and_rewrite defined as def match_and_rewrite(self, for_loop: scf.For, rewriter: PatternRewriter) will be called whenever the IR walker encounters a node which is of type scf.For. This is the argument for_loop to the method, which we can then manipulate as required by the transformation

If we look at line 55 of src/for_to_parallel.py, which is block_arg_types=[] # Needs to be completed!, we need to provide the two types of the left and right hand sides as arguments to the block. These are block_arg_op.typ and other_arg.typ respectively, and each should be a member of the list (with a comma separating them).

At line 65, which is reduce_result=None # Needs to be completed! we need to create the reduce.return operation which will return the result of the calculation's operation. We can create this by calling the get method on scf.ReduceReturnOp, with new_op.results[0] as the argument (this provides the SSA result of the new_op operation that we created at the line above.

At line 69, reduce_op=None # Needs to be completed!, we need to create the overall reduce operation. This is done by calling the get method on scf.ReduceOp, and there are two arguments needed here. The first is the operand, other_arg, provided to this (%1 in our IR example of the previous section) and the second is the block, which is the block variable in the code, that will comprise this operation.

Now we have done this we need to create the parallel loop operation itself, which is line 88, parallel_loop=None # Needs to be completed!. Again, we will be calling the get method but this time on scf.ParallelOp. We can directly reuse the loop bounds and step from the for loop, for_loop.lb, for_loop.ub, and for_loop.step as the first three arguments but crucially each of these needs to be wrapped in a list (so it will be [for_loop.lb]) - we will explain why that is the case a little later on. The new_block variable is our block, that is the fourth argument and again must be wrapped in a list, and the fifth argument is the list of SSA argument values provided (in the IR example above this will be %0) and is for_loop.iter_args which is already a list so need not be wrapped in one.

We are almost there, the last step is to instruct xDSL to replace the for loop with the new parallel loop. As the last line of this match_and_rewrite method, just after you created the parallel operation, you should add rewriter.replace_matched_op(parallel_loop).

Not sure or having problems? Please feel free to ask if there is anything you are unsure about, or you can check the sample solution

Looking at the rewrite pass in more detail

A lot of the work being done here in the transformation is in extracting the arguments out of the top level block and removing them from the final yield. Effectively this transformation is manipulating the internal structure from the first IR to the second IR provided in the previous section, so have a look through the code and see if you can understand these it is making these modification, don't hesitate to ask one of the presenters about this if you are unsure about anything. For instance, detach on an operation will remove it from its block (an operation can only be a member of one block), erase_op will remove an operation from a block and add_op adds an operation.

There are some things worth highlighting in the IR and rewrite pass that we have skipped over so far. Firstly, we passed the parallel loop's lower and upper bounds along with the step in a list. This is because a parallel operation can represent a nested loop and the below IR example illustrates a parallel loop operating over a nested loop, with %4 being the lower bounds of the top loop and %5 the lower bounds of the inner loop. The upper bounds and step values follow a similar logic. You can see that the block now has two index argument, representing the current iteration of both the inner and outer loops. MLIR will transform this as it feels most appropriate, for instance with the openmp lowering it will likely apply the collapse clause.

%7 = "scf.parallel"(%4, %5, %6, %7, %8, %9) ({
  ^0(%10 : index, %11 : index):
  ...
  "scf.yield"() : () -> ()
}) {"operand_segment_sizes" = array<i32: 2, 2, 2, 0>} : (index, index, index, index, index, index) -> ()

You can see in the above IR that we have operand_segment_sizes provided as an argument to the operation. This is required for varadic operands, which are operands which can have any size. Here the attribute is informing the operation that it is two lower bound operands, two upper bound operands, and two step operands but no SSA value arguments to be passed in.

Running our transformation pass

Now we have developed our pass, let's run it through tinypy-opt as per the following snippet. Note that here we are undertaking two transformations, first our previous tiny-py-to-standard lowering and then the for-to-parallel which because it comes second operates on the results of the first transformation.

user@login01:~$ tinypy-opt output.mlir -p tiny-py-to-standard,for-to-parallel

The following is the IR outputted from these two transformations, you can see the parallel, reduce, and reduce.return operations that we have added into our transformation in this section. The rest of the IR is the same as that generated in exercise two, and that is a major benefit of using parallel because we can parallelise a loop without requiring extensive IR changes elsewhere.

"builtin.module"() ({
  "func.func"() ({
    %0 = "arith.constant"() {"value" = 0.0 : f32} : () -> f32
    %1 = "arith.constant"() {"value" = 88.2 : f32} : () -> f32
    %2 = "arith.constant"() {"value" = 0 : i32} : () -> i32
    %3 = "arith.constant"() {"value" = 100000 : i32} : () -> i32
    %4 = "arith.index_cast"(%2) : (i32) -> index
    %5 = "arith.index_cast"(%3) : (i32) -> index
    %6 = "arith.constant"() {"value" = 1 : index} : () -> index
    %7 = "scf.parallel"(%4, %5, %6, %0) ({
    ^0(%8 : index):
      "scf.reduce"(%1) ({
      ^1(%9 : f32, %10 : f32):
        %11 = "arith.addf"(%9, %10) : (f32, f32) -> f32
        "scf.reduce.return"(%11) : (f32) -> ()
      }) : (f32) -> ()
      "scf.yield"() : () -> ()
    }) {"operand_segment_sizes" = array<i32: 1, 1, 1, 1>} : (index, index, index, f32) -> f32
    %12 = "llvm.mlir.addressof"() {"global_name" = @str0} : () -> !llvm.ptr<!llvm.array<3 x i8>>
    %13 = "llvm.getelementptr"(%12) {"rawConstantIndices" = array<i32: 0, 0>} : (!llvm.ptr<!llvm.array<3 x i8>>) -> !llvm.ptr<i8>
    %14 = "arith.extf"(%7) : (f32) -> f64
    "func.call"(%13, %14) {"callee" = @printf} : (!llvm.ptr<i8>, f64) -> ()
    "func.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> (), "sym_visibility" = "public"} : () -> ()
  "llvm.mlir.global"() ({
  }) {"global_type" = !llvm.array<3 x i8>, "sym_name" = "str0", "linkage" = #llvm.linkage<"internal">, "addr_space" = 0 : i32, "constant", "value" = "%f\n", "unnamed_addr" = 0 : i64} : () -> ()
  "func.func"() ({
  }) {"sym_name" = "printf", "function_type" = (!llvm.ptr<i8>, f64) -> (), "sym_visibility" = "private"} : () -> ()
}) : () -> ()

Compile and run

We are now ready to feed this into mlir-opt and generate LLVM IR to pass to Clang to build out executable. Similarly to exercise one you should create a file with the .mlir ending, via

user@login01:~$ tinypy-opt output.mlir -p tiny-py-to-standard,for-to-parallel -o ex_three.mlir

Threaded parallelism via OpenMP

Execute the following:

user@login01:~$ mlir-opt --pass-pipeline="builtin.module(loop-invariant-code-motion, convert-scf-to-openmp, convert-scf-to-cf, convert-cf-to-llvm{index-bitwidth=64}, convert-arith-to-llvm{index-bitwidth=64}, convert-openmp-to-llvm, convert-func-to-llvm, reconcile-unrealized-casts)" ex-three.mlir | mlir-translate -mlir-to-llvmir | clang -fopenmp -x ir -o test -

This is similar to the mlir-opt command that we issued in exercice two, but with a few additions. Firstly, convert-scf-to-openmp will run the MLIR transformation to lower our parallel loop to the omp dialect, and secondly convert-openmp-to-llvm will then lower this to the llvm dialect. Furthermore you can see that we have had to pass the -fopenmp flag to clang as we must now link with the OpenMP runtime.

You can either run this on the login node (or local machine), or submit to the batch queue for execution on a compute node.

We can execute the test executable direclty on the login node if we wish by (or if you are following the tutorial on your local machine):

user@login01:~$ export OMP_NUM_THREADS=8
user@login01:~$ ./test

A submission script called sub_ex3.srun is prepared that you can submit to the batch queue and will run over all 128 cores of the node.

user@login01:~$ sbatch sub_ex3.srun

You can check on the status of your job in the queue via squeue -u $USER and once this has completed an output file will appear in your directly that contains the stdio output of the job. You can cat or less this file, which ever you prefer.

In the submission file we have added the time command which reports how long the executable took to run, and indeed if running this locally you can achieve this via time ./test. Experiment with running over different numbers of OpenMP threads, via the OMP_NUM_THREADS environment variable (which you will see is set to 128 in the sub_ex3.srun and can be changed). How does this impact the runtime? You can also change the problem size (e.g. the number of loop iterations) by modifying the value in the origional ex_three.py Python file and then regenerating and recompiling.

Adding vectorisation

We can use the scf-parallel-loop-specialization pass to apply vectorisation to our parallel loop, in order to do this (we do this instead of OpenMP, but the two can be mixed):

user@login01:~$ mlir-opt --pass-pipeline="builtin.module(loop-invariant-code-motion, scf-parallel-loop-specialization, convert-scf-to-cf, convert-cf-to-llvm{index-bitwidth=64}, convert-arith-to-llvm{index-bitwidth=64}, convert-func-to-llvm, reconcile-unrealized-casts)" ex-three.mlir | mlir-translate -mlir-to-llvmir | clang -fopenmp -x ir -o test -

The executable is then run in the same manner as with OpenMP

Running on a GPU

We don't have GPUs in ARHCER2, so their use is beyond the scope of this course, but if you have a GPU machine then you can transform your parallel loop into the gpu dialect via the following

user@login01:~$ mlir-opt --pass-pipeline="builtin.module(scf-parallel-loop-tiling{parallel-loop-tile-sizes=1024,1,1}, canonicalize, func.func(gpu-map-parallel-loops), convert-parallel-loops-to-gpu, lower-affine, gpu-kernel-outlining,func.func(gpu-async-region),canonicalize,convert-arith-to-llvm{index-bitwidth=64},convert-scf-to-cf,convert-cf-to-llvm{index-bitwidth=64},gpu.module(convert-gpu-to-nvvm,reconcile-unrealized-casts,canonicalize,gpu-to-cubin),gpu-to-llvm,canonicalize)" ex-three.mlir | mlir-translate -mlir-to-llvmir | clang -x ir -o test -

Note Your LLVM must have been built with explicit support for GPUs via passing the -DLLVM_TARGETS_TO_BUILD="X86;NVPTX" flag to cmake