Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Relax Architecture Overview

Yuchen Jin edited this page Nov 5, 2021 · 11 revisions

Relax Architecture Overview

Authors: OctoML Relax Working Group

This doc is meant to serve as a high level design overview of key elements in the relax design. Relax is the codename for us to evolve the design of high-level IR. The main goal of the doc is to provide a concise but clear summary of key motivations and design points without getting into low-level architecture dependent details and advanced topics.

It is intended to be used as a first doc to understand relax from an architecture level. Please also refer to other design docs (TODO: link) for details on specific aspects.

Key Goals

Relax have three key goals motivated by our past lessons in ML acceleration.

G0: Support dynamic shape workloads

Specifically, we need to support dynamic shape workloads that are currently out of reach with today’s intermediate representations and optimize performance for both new and old.

(Partially) dynamic shape models are ubiquitous in today's machine learning workloads. The dynamism could come from the variable input size, or simply missing information in the program.

G1: Support "computational graph" style with advanced semantics

Most of the machine learning engineers are familiar with "computational graph" and its optimizations under the assumption that every operation in the graph has no side effect. While such optimization is useful for a majority of the programs. As we start to work with random numbers, states and weight updates, we also need to be able to represent programs that contain more complex semantics, such as control, inplace updates and side effects.

Additionally, some advanced optimizations could require us to work with mutations, such as inplace updates for scatter gather operations.

We need to find a way to enable most people to write computational graph optimizations, while still being able to represent these advanced semantics.

G2: Unify the abstraction for cross layer optimizations

Right now TVM contains a clear boundary between abstractions. Relay to TIR lowering is done in a single shot translation fashion. However, we start to see a strong need of performing optimizations across the layers. For example, ideally the automation decisions in TensorIR should inform fusion and layout decisions at the high-level. This needs comes up in our applications of TensorCore auto-scheduling as well as NPU related workloads.

Key Design Points

import tvm.script
from tvm.script import tir as T, relax as R

@tvm.script.ir_module
class MyIRModule:
    @T.prim_func
    def tir_exp_func(x: T.handle, y: T.handle): ## <= D2
        n = T.var("n")
        X = T.match_buffer(x, (n,), "float32")
        Y = T.match_buffer(y, (n,), "float32")
        with T.grid(n) as i:
            Y[i] = T.exp(X[i]) 

    @R.func
    def relax_func(x: R.Tensor[(n, k), "f32"], w: R.Tensor[_, "f32"]):
        # n, k above are implicitly defined by the signature
        # so we will be able to refer to n, k in the later part of the program
        with R.dataflow(): ### <= D0
            lv0 = R.match_shape(w, (k, m)) ## <= D1
            lv1: R.Tensor[(n, m), "f32"] = R.dot(x, lv0)
            lv2: R.Tensor[(n * m,), "f32"] = R.flatten(lv1) ## <= D1
            lv3: R.Shape = (n * m,)  ## <= D1 
            gv0: R.Tensor[lv2, "f32"] = R.call_dps(lv2, tir_exp_func, [lv3])   ## <= D2
            R.outputs(gv0)

        R.call_packed("custom_inplace_update", gv0)  ## <= D0, D2
        return gv0 

We can use the above code-snippet to demonstrate the key design points of relax. Note that the script syntax is still evolving and can be subject to change.

D0: Dataflow block as a first class construct

Majority of the relax_func code are encapsulated in a with R.dataflow() construct. All the operations under the dataflow block is side-effect-free and does not contain advanced control flows(such as if-then-else) or nested scopes.

A dataflow block can effectively be viewed as a computational graph embedded in the program. Note that most of the bindings variables(lv0, lv1, lv2, lv3) within the dataflow block is "local", which means they are only visible within the block. These variables can be viewed as "internal nodes" of the computational graph. We can mark a variable as output(gv0), in which case the variable will be visible in later part of the program. These output variables can be viewed as output nodes in the computational graph.

Note that R.call_packed("custom_inplace_update", gv0) is outside of the dataflow block. Everything that is outside of a dataflow block can have side effect. So we cannot perform optimizations such as reordering these bindings according to topological order unless we do more careful analysis We expect most of the optimizations will happen at the dataflow block level. These optimizations can be done by ML engineers who are familiar with the computational graph concept. The ability to isolate and represent effectful components also provides opportunities for more advanced optimizations for the places that need them.

D1: Shape deduction as first class computation

Shape deduction is essential to dynamic model workloads. Under a dynamic shape setting, we usually need to compute the shapes of the intermediate tensors before running the computation. Additionally, we also need to handle cases where the shape itself is data-dependent (e.g. unique). Finally, most dynamic shape workloads still contain a lot of (partially) static shapes, ideally we want to take benefit of these static shape information for optimization.

from tvm.script import relax as R

@R.func
def shape_example(x: R.Tensor[(n, 2, 2), "f32"]):
    with R.dataflow():
        # symbolic and static shape deduction
        lv0: R.Tensor[(n, 4), "f32"] = R.reshape(x, (n, 4)) 
        lv1: R.Tensor[(n * 4,), "f32"] = R.flatten(lv0)
        lv2: R.Shape = (n * 4,)
        # external opaque shape function
        lv3: R.Shape = R.call_packed("myshape_func", lv2)
        lv4: R.Tensor[lv3, "f32"] = R.call_dps(lv3, "custom_func", [lv1]) 
        # data dependent case
        lv5: R.Tensor[_, "f32"] = R.unique(lv4)
        # re-match shape
        lv6: R.Tensor[(m,), "f32"] = R.match_shape(lv5, (m,))
        gv0: R.Tensor[(m,), "f32"] = R.exp(lv6)
        R.outputs(gv0)
    return gv0

The above program covers typical scenarios in shape deduction(marked in comments). Importantly, shape is now part of the computation along with Tensor values. This reflects the fact that computation of shapes can happen in runtime.

While the text format type annotation lv0: R.Tensor[(n, 4), "f32"] shows the shape of each value. This is only a syntactic sugar, from the IR's point of view the shape field (n, 4) is not part of the lv0.checked_type. The type of lv0 is DynTensor(rank=2, dtype="f32"), the shape is a special value field that is attached to each Expr. We made this explicit choice to simplify the type inference so that we do not need to get into the full dependent type land.

There a two key constructs related to symbolic shape computation:

D1a: match_shape

value = match_shape(lhs, pattern)

The match shape construct takes a lhs value and a pattern(of symbolic integer expressions). It have two overloaded semantics:

  • When lhs is a Tensor, it will match lhs.shape to the pattern, populates the corresponding symbolic integer variable if it occurs in the pattern for the first time, and then return a new Tensor that is same as lhs but the shape field is updated to pattern.
  • lhs can also be a Shape that directly matches to the pattern. This is useful when we want to isolate out shape functions that does not corresponds to any Tensor value.

Examples

from tvm.script import relax as R

@R.func
def shape_example(x: R.Tensor[_, "f32"], y: R.Tensor[_, "f32"]):
    with R.dataflow():
        # the match shape defines n, m because it appears for the first time
        lv0: R.Tensor[(n, m)] = R.match_shape(x, (n, m))
        # the second occurance of n, m will translate into an assertion 
        # that y's shape equals (n, m)
        lv1: R.Tensor[(n, m)] = R.match_shape(y, (n, m)) 
        # we can also call match_shape on shape expressions
        lv2: Shape = R.match_shape(R.shape_of(y), (n, m)) 

D1b: shape construction from tuple of symbolic integers

After we obtained the symbolic integers such as n and m. We can recompose them together to form an Expr. Any tuple of symbolic integer expressions can be recognized as a Shape value in relax. As a result (n, m) is a shape value.

Ways to do shape propagation

Importantly, because shape is now part of the value happens during computation. Compile time shape inference can be viewed as compile-time constant folding (or partial evaluation) on operations that happens with regard to shape. There are a few ways for the program to express shape computations:

  • W1: Symbolic shape propagation. A shape can be destructed into symbolic integers (n or m in the above program) and we can then use expression of symbolic integers(n*4) to represent shape calculation. Notably, static shape is a special case of (constant symbolic) integers. The symbolic integer can then recompose to form a shape value(e.g. (n* 4, ) ).
  • W2: Opaque shape function calls. We can also implement opaque shape functions (myshape_func). These opaque shape functions are useful fallbacks to quickly hack up a runtime shape function.
  • W3: For data-dependent shape(unique), we will simply defer to a runtime call f(inputs)->output that takes the input Tensor, allocates and return the output tensor. We can then fetch the shape of lv5 from the Tensor value by match_shape construct.

Implications for pass writing

Many of the optimization passes will need to look into the shape information. Now that many of the shape can be symbolic (n, 4), the most ideal optimization passes will need to generalize a bit to leverage the symbolic information. For example, in the above programs, we know that all the n corresponds to the same value. This kind of constraints is super useful. Additionally, thanks to the symbolic integer in the arith module, we can reuse the mechanism of proves to check equivalence and deduction of symbolic expressions(e.g. prove(n4 == n 4)).

Because symbolic integer(tir.PrimExpr) eagerly constant fold, when the input is static shape, the result of computations should be folded eagerly to constant integer as well, preserving the properties we need for static shape dependent optimizations.

Because we can now represent a mixed symbolic static shape in a tuple (n, 4), we can try to take benefit of the static information for additional optimizations.

D2: Direct interaction with TensorIR and PackedFunc

The final key design decision we made is to allow the high-level IR to be able to directly interact and call into lower-level TensorIR and PackedFunc. The TensorIR functions and many external libraries adopt a destination passing convention(we need to explicitly allocate the output and pass in as an argument to the function). We use dps(destination passing) to denote this convention. dps is very important in low-level ML optimizations as it allows us to globally allocate the intermediate storage in a single shot if possible, and executes the computation without active memory allocation.

Calling a dps function means after the call, result is passed back via the function arguments (e.g., result in the example below) instead of the return value of a function.

// not destination passing
int func(int x) {
  return 1;
}
// destination passing
void func(int x, int *result) {  
  *result = 1;
}

dps style means mutation(of output) in nature. We need a way to bridge the calls into the high-level (pure) dataflow land, so that we can perform computational graph style rewriting on a sequence of tir calls.

D2a: call_dps

call_dps is an intrinsic that bridges the gap.

def call_dps(output_shape: Shape, lowlevel_func: Expr, inputs: Tuple[Expr]) -> Expr:
    """Example code to demonstrate the semantics of call dps"""
    out_tensor = alloc_tensor(output_shape, current_expr.dtype)
    lowlevel_func(*inputs, out_tensor)
    return out_tensor

call dps takes in the output shape, lowlevel_func(can be packed func, tir PrimFunc) and a tuple of inputs. The semantics of call_dps can be demonstrated by the above code. Notably, when we lower call_dps, we do not need to choose separate output tensor allocations. The compiler can choose to create a memory plan of the intermdiate tensors and tie things together for effective reuse.

Notably, the output_shape parameter to call_dps intrinsic can be an opaque shape value, a symbolic integer tuple or a constant shape.

The lowlevel_func can be any function with the signature

fn(input0, input1,... out0, out1...)

The two most common cases include: (1) A TIR function (2) An opaque packed func

Implementation note

CallDPS can be implemented as a special intrinsic(Op) to minimize the impact to the IR changes(instead of a standalone IR node). From the AST point of view, this becomes:

Call(op=Op::Get("relax.call_dps"), shape, level_func, inputs)

This would also allow future iterations of call_dps without changing the IR itself, which might be needed at certain time point:

  • Enable sequence of multiple mutations on the same array(in the case of concat related ops)
  • Enable passing symbolic shape hints to a fused op.

Implications for Integration

D2 enables us to directly embed lower level abstractions into the high level abstractions(R.func). This unlocks a lot of opportunities, including, but not limited to:

  • Incrementally lower different parts of the program using different strategies.
  • Allow automation to take a call_dps to tir, perform optimization and rewrite into multiple call_dps note that informs layout rewriting decisions to the high-level.
  • Bring BYOC flow as natural part of transformation(by transforming part of the graph into call of opaque packed functions).

D2b: Packed function calls

We use R.call_packed to indicate a call to a packed function. From the ast's point of view we do not need to introduce an additional call node, instead we can introduce an ExternFunc construct that represent a packedfunc where we can call into.

Call(op=ExternFunc("my_packed_func"), *args)

R.call_packed only served as a syntax sugar torepresent the above AST node. This allows us to unify all the calls. Notably, it also allows us to mix packed function and call_dps when necessary.

lv4: R.Tensor[lv3, "f32"] = R.call_dps(lv3, "custom_func", [lv1]) 

corresponds to the following AST.

Call(op=Op::Get("relax.call_dps"), shape, ExternFunc("my_packed_func"), [lv1])

CallDPS on external packed functions can be useful when we want to directly integrate low level libraries(such as cudnn) into the high level without invoking memory allocation.

Additional Considerations

This section covers additional design considerations that are not directly covered by the three key design points(D0, D1, D2).

A0: Enable opaque object type

In some cases we might need to go out of the original strongly typed world of Tensor, Shape and introduce a generic object type(that corresponds to tvm.runtime.Object). This is usually due to two common needs:

  • Ability to support runtime objects that are not part of the type system(yet), for example, storage allocator, vm state.
  • Ability to express more flexible programs(like those in TorchScript) that do not have type information.

We still encourage most of the code follow the strongly typed version, and will require explicit type casting to convert an object to Tensor before running operations(they turns into a runtime assertion and cast). Thanks to tvm's object system, we can easily support this feature in runtime.

A1: Pythonic syntax

from tvm.script import relax as R

@R.func
def fn(x: R.Tensor[(n, m), "f32"]):
    y: R.Tensor[(n, m), "f32"] = x + 1  
    return y

The above syntax may not pass pylint check because n and m are undefined variables. One possible way to tradeoff this is to allow string in the signature(same as python's type checking mechanism) to represent shape.

from tvm.script import relax as R

@R.func
def fn(x: R.Tensor["(n, m)", "f32"]):
    n, m = R.shape_vars("n", "m")
    y: R.Tensor[(n, m), "f32"] = x + 1  
    return y

We can also declare n, m in the global scope(however that pollutes the global scope and may not be desirable. Similarly, we might need to use "_" in the place of _ for unknown shapes if we do not declare it and want pylint check to pass.