Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Relax Shape Computation Design

Yuchen Jin edited this page Jan 7, 2022 · 3 revisions

Relax Shape Computation Design

Authors(alphabetical): @altanh, @electriclilies, @jroesch, @junrushao1994, @mbs-octoml, @mikepapadim, @tkonolige, @tqchen, @YuchenJin, @ZihengJiang

This document outlines the design of shape deduction and shape computation in Relax. For a broad background of Relax, please refer to Relax Architecture Overview.

Key goals

The following key goals are considered when designing shape computation in relax.

G0: First class support for dynamic symbolic integer shape

from tvm.script import relax as R

@R.function
def shape_example(x: R.Tensor[(n, m, 2), "f32"]):
    with R.dataflow():
        # symbolic and static shape deduction
	lv0: R.Tensor[(n, m * 2), "f32"] = R.reshape(x, (n, m*2)) 
	lv1: R.Tensor[(n * m * 2,), "f32"] = R.flatten(lv0)
	...

Symbolic integer shape allows us to effectively represent the shape computation and their relations in symbolic expressions. The above example gives an example of such a case. The compiler needs to know that the same n get threaded through multiple computational blocks. The ability to deduce, constant fold, and analyze(prove equality and other relations) will enable more optimizations.

We also need to make it easy from the API level, for developers(and users) to be able to directly look up, prove properties, run symbolic computations during AST construction, rewriting, and inspection time.

G1: Safety net for general cases

While fixed rank, dynamic symbolic shape relation covers most of the use cases. Inevitably we also need to be able to cover general cases that may not fall into the category:

  • C0: Dynamic shape relations where output shape is data dependent on the input (e.g. unique operator).
  • C1: Rank of a tensor is not known (can happen in rare cases of loops).
  • C2: dtype of a tensor is not known.
  • C3: Other cases, opaque runtime objects for low-level libraries(e.g. PRNG handle, cuDNN context).

As a result, it is important to have a "safety net" solution so that we cover the general cases.

G2: Compatible path for future improvements

As always, there are ways to continuously improve the ability to do shape support in general. For example, building capabilities to effectively codegen for dynamic rank operators. One thing we need to make sure is that the proposed design has a compatible path for future, advanced improvements. This perspective also is closely related to G1, as we can always fall back to the safety net for now before we add advanced shape deduction mechanisms.

Shape Constraint Representation

A shape (constraint) of a tensor is represented by two fields of the relax.Expr(RelayExpr).

  • checked_type_: Type, stores the generic rank and dtype constraints.
  • shape_: Expr, stores ways to compute shape of the expression at runtime.

DynTensorType

checked_type_ stores the compile time deduced type of an expression. Tensor Expr contains the following two fields:

class DynTensorType: 
    # rank can be unknown
    rank: int
    # dtype can be unknown
    dtype: DataType

In most common cases, the rank and dtype are known for a given Tensor expr. But we also allow types with unknown rank and dtype (due to consideration of G1 and G2).

Shape field

DynTensorType does not contain the shape information. Instead, the shape of a Tensor is stored in an optional shape_ field in an Expr.

For an Expr x, x.shape_ can contain the following values:

  • V0: ShapeExpr, which contains an Array<PrimExpr>, indicates that x has a known symbolic shape that can be deduced from previous symbolic values.
  • V1: Generic relax.Expr, which can call into opaque (shape) functions, or shape deduction intrinsics.
  • V2: None, indicate that shape is unknown at compile-time and need lookup at runtime.
from tvm.script import relax as R

@R.function
def shape_example(x: R.Tensor[(n, 2, 2), "f32"]):
    with R.dataflow():
        # V0: symbolic and static shape deduction
	lv0: R.Tensor[(n, 4), "f32"] = R.reshape(x, (n, 4)) 
	lv1: R.Tensor[(n * 4,), "f32"] = R.flatten(lv0)
	lv2: R.Shape = (n * 4,)
	# V1: external opaque shape function
	lv3: R.Shape = R.call_packed("myshape_func", lv2)
	lv4: R.Tensor[lv3, "f32"] = R.call_tir(lv3, "custom_func", [lv1]) 
	# V2: data dependent case
	lv5: R.Tensor[_, "f32"] = R.unique(lv4)
	# re-match shape
	lv6: R.Tensor[(m,), "f32"] = R.match_shape(lv5, (m,))
	gv0: R.Tensor[(m,), "f32"] = R.exp(lv6)
	R.outputs(gv0)
    return gv0

The above code block shows examples of the three scenarios. V0 is the most common case that we want to bring first class support. V1 is used to handle cases that may not be trivially representable by symbolic expressions(e.g. dynamic rank case). V2 is our safety net(G1). When a shape deducer cannot deduce shape, either due to lack of analysis capabilities or due to unsolvable obstacles(in case of data-dependent), we can fall back to V2.

Shape refinement

After a data-dependent computation or external calls, we may need to be able to recover/refine the shape information to enable more optimizations. match_shape construct is used to perform such refinements.

value = match_shape(lhs, pattern)

The match shape construct takes a lhs value and a pattern(of symbolic integer expressions). It has two overloaded semantics:

  • When lhs is a Tensor, it will match lhs.shape to the pattern, populate the corresponding symbolic integer variable if it occurs in the pattern for the first time, and then return a new Tensor that is the same as lhs but the shape field is updated to the pattern.
  • lhs can also be a Shape that directly matches the pattern. This is useful when we want to isolate out shape functions that do not correspond to any Tensor value.

Shape Propagation

This section discusses the strategies for shape propagation in Relax.

D0: Safety net — no shape computation

The first thing we need to build is a safety net computation when no shape computation is carried out. The support for no shape computation establishes a safety net that we can always fall back into.

from tvm.script import relax as R

@R.function
def shape_example(x: R.Tensor[_, _]):
    with R.dataflow():
        # computation without shape
	lv0: R.Tensor[_, _] = R.log(x) 
	lv1: R.Tensor[_, _] = R.flatten(lv0)
	# bring things back to the shape computatuion land
	lv2: R.Tensor[(m,), "f32"] = R.match_shape(lv1, (m,))
	gv0: R.Tensor[(m,), "f32"] = R.exp(lv2)
	R.outputs(gv0)
    return gv0

The above code snippet example shows an AST without shape. During lowering, these calls won't get translated into destination passing style, because it is impossible to obtain the shape information and pre-allocate the memory. Instead, they are directly translated to calls that allocate and return the result tensor.

  • R.log can be mapped to a runtime PackedFunc calls that takes in an NDArray x and perform an elementwise log operation.
    • We can even dispatch to common runtime libraries such as torch.log

These features can be readily supported by VM already as PackedFunc calls that return object. We can bring the tensors from no shape computation land to the shape-aware land, by using match shape and followup propagations.

The no shape computation is by no means the most effective way to handle things. It is necessary for cases like data-dependent calculation and interfaces with external libs that have weaker shape information.

Importantly, it establishes an important safety net that allows us to explore advanced shape deduction mechanisms in an incremental way without worrying about 100% coverage of all possible cases(coming back to G2).

D1: Compile-time shape propagation for primitive ops

To handle compile-time shape propagation for primitives ops, such as R.add, which have a binary broadcasting relation. We register the following deduction functions.

@register_op_attr("relax.add", "FInferShape")
def finfer_shape(call: Call) -> Shape:
    # note: if x.shape_ is None
    # x.shape() will return shape_of(x) that obtains shape in runtime.
    lhs = call.args[0].shape();
    rhs = call.args[1].shape();
    # pattern match to symbolic deduction case
    if isinstance(lhs, ShapeExpr) and isinstance(rhs, ShapeExpr):
	# construct a symbolic shape according to lhs and rhs
	# Examples:
	# (n, m) + (m) => (n, m)
	# (n, 1, m) + (2, m) => (n, 2, m)
	...
	return deduced_shape_expr
    return call(op.get("binary_broadcast"), lhs,rhs)

finfer_shape can then be used for shape propagations. Specifically:

  • It handles the most commonly needed case of symbolic shape(G0).
  • Allows fallbacks to opaque shape op binary broadcast.
  • In the worst case, we can still return None, in that case, we fall onto our safety net(G1).

Note that there might be a need to introduce other signatures for advanced deductions. These deduction hooks can be introduced as separate registrations without impacting the overall design.

D2: Extra refinement and shape preservation across transformations

Shape deduction is not perfect. In cases like low-level call into TIR. It is sometimes a must to explicitly specify the output shape itself.

To complement that, we need to be able to make use of extra shape refinement(via match_shape) and preserve shape information across transformations when possible. Specifically:

  • Enable passes to annotate additional shape information after analysis.
  • Preserve the output shape information explicitly from high level after changing to low-level functions.

Right now we need to make sure we have enough D2 so they are compatible with future improvements (G2). See A0 for more discussions.

D3: Shape propagation across function boundaries

Calls into sub-functions can be generated either due to construction, or fusion passes. This section discusses the shape propagation across function boundaries.

The default calling convention of functions are as follows:

  • tir.PrimFunc: the shapes are checked at runtime in the callee, caller need to explicitly specify the output shape.
  • PackedFunc compiled by relax: the input shapes are checked by the callee, the return value's shape needs to be wrapped by match_shape if its shape is unknown.

Note that the current calling convention isolates shape constraints at the function boundary, so it establishes a safety net on runtime.

Importantly, it is desirable for callers to propagate as much shape information as possible(although it is not strictly necessary due to the safety net). There are a few ways to do:

  • W0: in passes like fusion, the shape information is originally available before fusion, generate sub-functions that contain these shape information, preserve the shape information from before fusion to after fusion, insert match shape in caller's side if necessary.
  • W1: leverage constant propagation of a sub relax function to do shape deduction.
    • Replace the symbolic shape in input to be the shape of input arguments
    • Constant evaluate the program to see if we can get the shape of output arguments independent from the other executions.
  • W2: extract out shape function if it is independent of value
    • Analyze the computation inside a function
    • Extract out a shape computation function if shape computation follows a separate path.

Note that W1/W2 is not always possible. It is fine to do the best effort as long as we have the fallback safety net that either adds more annotation or relies on D0.

NOTE: the plan here is not detailed enough for immediate action, but considering G2, we can proceed with the necessary means(W0), and then add W1/W2 later.

Additional Considerations

These are additional considerations that are not part of immediate scope, but possible future goals that worth aligning per consideration G2.

A0: More advanced shape deduction mechanisms

With a clear safety net, we can leverage advanced rewriting to be able to enhance the shape information if necessary. See the following example

from tvm.script import relax as R

@R.function
def before(x: R.Tensor[_, _], y: R.Tensor[(m, n), "f32"]):
    with R.dataflow():
	# computation without shape
	lv0: R.Tensor[_, _] = R.log(x) 
	gv0: R.Tensor[(m, n), "f32"] = R.ewise_add(lv0, y)
	R.outputs(gv0)
    return gv0

@R.function
def after(x: R.Tensor[(m, n), "f32"], y: R.Tensor[(m, n), "f32"]):
    with R.dataflow():
	# computation without shape
	lv0: R.Tensor[(m, n), "f32"] = R.log(x) 
	gv0: R.Tensor[(m, n), "f32"] = R.ewise_add(lv0, y)
	R.outputs(gv0)
    return gv0

The above code shows a possible transformation that leverages more advanced shape analysis. In this case, use ewise_add to figure out the shape and type of the input x. These improvements can be added in an incremental fashion without changing the overall architecture.

A1: Bringing symbolic shape closer to Expr

Right now symbolic shape sits on its own type and requires wrapping(ShapeExpr) to be used as relax.Expr. One question is whether or not do we need to open up the mix further. While it is worthwhile to consider such a move, doing so would bring more complexity(such as differentiating Tuple[Int] vs ShapeExpr) as well as bringing in complexities in the developer side (the ability to pattern match ShapeExpr for most common cases). As a result, we make the symbolic shape computation constrained at ShapeExpr for now. This is a point that can be revisited later.