-
Notifications
You must be signed in to change notification settings - Fork 58
Relax Shape Computation Design
Authors(alphabetical): @altanh, @electriclilies, @jroesch, @junrushao1994, @mbs-octoml, @mikepapadim, @tkonolige, @tqchen, @YuchenJin, @ZihengJiang
This document outlines the design of shape deduction and shape computation in Relax. For a broad background of Relax, please refer to Relax Architecture Overview.
The following key goals are considered when designing shape computation in relax.
from tvm.script import relax as R
@R.function
def shape_example(x: R.Tensor[(n, m, 2), "f32"]):
with R.dataflow():
# symbolic and static shape deduction
lv0: R.Tensor[(n, m * 2), "f32"] = R.reshape(x, (n, m*2))
lv1: R.Tensor[(n * m * 2,), "f32"] = R.flatten(lv0)
...
Symbolic integer shape allows us to effectively represent the shape computation and their relations in symbolic expressions. The above example gives an example of such a case. The compiler needs to know that the same n
get threaded through multiple computational blocks. The ability to deduce, constant fold, and analyze(prove equality and other relations) will enable more optimizations.
We also need to make it easy from the API level, for developers(and users) to be able to directly look up, prove properties, run symbolic computations during AST construction, rewriting, and inspection time.
While fixed rank, dynamic symbolic shape relation covers most of the use cases. Inevitably we also need to be able to cover general cases that may not fall into the category:
- C0: Dynamic shape relations where output shape is data dependent on the input (e.g.
unique
operator). - C1: Rank of a tensor is not known (can happen in rare cases of loops).
- C2: dtype of a tensor is not known.
- C3: Other cases, opaque runtime objects for low-level libraries(e.g. PRNG handle, cuDNN context).
As a result, it is important to have a "safety net" solution so that we cover the general cases.
As always, there are ways to continuously improve the ability to do shape support in general. For example, building capabilities to effectively codegen for dynamic rank operators. One thing we need to make sure is that the proposed design has a compatible path for future, advanced improvements. This perspective also is closely related to G1, as we can always fall back to the safety net for now before we add advanced shape deduction mechanisms.
A shape (constraint) of a tensor is represented by two fields of the relax.Expr(RelayExpr).
-
checked_type_: Type
, stores the generic rank and dtype constraints. -
shape_: Expr
, stores ways to compute shape of the expression at runtime.
checked_type_
stores the compile time deduced type of an expression. Tensor Expr contains the following two fields:
class DynTensorType:
# rank can be unknown
rank: int
# dtype can be unknown
dtype: DataType
In most common cases, the rank and dtype are known for a given Tensor expr. But we also allow types with unknown rank and dtype (due to consideration of G1 and G2).
DynTensorType
does not contain the shape information. Instead, the shape of a Tensor is stored in an optional shape_
field in an Expr.
For an Expr x
, x.shape_
can contain the following values:
- V0: ShapeExpr, which contains an
Array<PrimExpr>
, indicates thatx
has a known symbolic shape that can be deduced from previous symbolic values. - V1: Generic relax.Expr, which can call into opaque (shape) functions, or shape deduction intrinsics.
- V2: None, indicate that shape is unknown at compile-time and need lookup at runtime.
from tvm.script import relax as R
@R.function
def shape_example(x: R.Tensor[(n, 2, 2), "f32"]):
with R.dataflow():
# V0: symbolic and static shape deduction
lv0: R.Tensor[(n, 4), "f32"] = R.reshape(x, (n, 4))
lv1: R.Tensor[(n * 4,), "f32"] = R.flatten(lv0)
lv2: R.Shape = (n * 4,)
# V1: external opaque shape function
lv3: R.Shape = R.call_packed("myshape_func", lv2)
lv4: R.Tensor[lv3, "f32"] = R.call_tir(lv3, "custom_func", [lv1])
# V2: data dependent case
lv5: R.Tensor[_, "f32"] = R.unique(lv4)
# re-match shape
lv6: R.Tensor[(m,), "f32"] = R.match_shape(lv5, (m,))
gv0: R.Tensor[(m,), "f32"] = R.exp(lv6)
R.outputs(gv0)
return gv0
The above code block shows examples of the three scenarios. V0 is the most common case that we want to bring first class support. V1 is used to handle cases that may not be trivially representable by symbolic expressions(e.g. dynamic rank case). V2 is our safety net(G1). When a shape deducer cannot deduce shape, either due to lack of analysis capabilities or due to unsolvable obstacles(in case of data-dependent), we can fall back to V2.
After a data-dependent computation or external calls, we may need to be able to recover/refine the shape information to enable more optimizations. match_shape construct is used to perform such refinements.
value = match_shape(lhs, pattern)
The match shape construct takes a lhs value and a pattern(of symbolic integer expressions). It has two overloaded semantics:
- When lhs is a Tensor, it will match
lhs.shape
to the pattern, populate the corresponding symbolic integer variable if it occurs in the pattern for the first time, and then return a new Tensor that is the same as lhs but the shape field is updated to the pattern. - lhs can also be a Shape that directly matches the pattern. This is useful when we want to isolate out shape functions that do not correspond to any Tensor value.
This section discusses the strategies for shape propagation in Relax.
The first thing we need to build is a safety net computation when no shape computation is carried out. The support for no shape computation establishes a safety net that we can always fall back into.
from tvm.script import relax as R
@R.function
def shape_example(x: R.Tensor[_, _]):
with R.dataflow():
# computation without shape
lv0: R.Tensor[_, _] = R.log(x)
lv1: R.Tensor[_, _] = R.flatten(lv0)
# bring things back to the shape computatuion land
lv2: R.Tensor[(m,), "f32"] = R.match_shape(lv1, (m,))
gv0: R.Tensor[(m,), "f32"] = R.exp(lv2)
R.outputs(gv0)
return gv0
The above code snippet example shows an AST without shape. During lowering, these calls won't get translated into destination passing style, because it is impossible to obtain the shape information and pre-allocate the memory. Instead, they are directly translated to calls that allocate and return the result tensor.
-
R.log
can be mapped to a runtime PackedFunc calls that takes in an NDArray x and perform an elementwise log operation.- We can even dispatch to common runtime libraries such as
torch.log
- We can even dispatch to common runtime libraries such as
These features can be readily supported by VM already as PackedFunc calls that return object. We can bring the tensors from no shape computation land to the shape-aware land, by using match shape and followup propagations.
The no shape computation is by no means the most effective way to handle things. It is necessary for cases like data-dependent calculation and interfaces with external libs that have weaker shape information.
Importantly, it establishes an important safety net that allows us to explore advanced shape deduction mechanisms in an incremental way without worrying about 100% coverage of all possible cases(coming back to G2).
To handle compile-time shape propagation for primitives ops, such as R.add
, which have a binary broadcasting relation. We register the following deduction functions.
@register_op_attr("relax.add", "FInferShape")
def finfer_shape(call: Call) -> Shape:
# note: if x.shape_ is None
# x.shape() will return shape_of(x) that obtains shape in runtime.
lhs = call.args[0].shape();
rhs = call.args[1].shape();
# pattern match to symbolic deduction case
if isinstance(lhs, ShapeExpr) and isinstance(rhs, ShapeExpr):
# construct a symbolic shape according to lhs and rhs
# Examples:
# (n, m) + (m) => (n, m)
# (n, 1, m) + (2, m) => (n, 2, m)
...
return deduced_shape_expr
return call(op.get("binary_broadcast"), lhs,rhs)
finfer_shape
can then be used for shape propagations. Specifically:
- It handles the most commonly needed case of symbolic shape(G0).
- Allows fallbacks to opaque shape op binary broadcast.
- In the worst case, we can still return
None
, in that case, we fall onto our safety net(G1).
Note that there might be a need to introduce other signatures for advanced deductions. These deduction hooks can be introduced as separate registrations without impacting the overall design.
Shape deduction is not perfect. In cases like low-level call into TIR. It is sometimes a must to explicitly specify the output shape itself.
To complement that, we need to be able to make use of extra shape refinement(via match_shape
) and preserve shape information across transformations when possible. Specifically:
- Enable passes to annotate additional shape information after analysis.
- Preserve the output shape information explicitly from high level after changing to low-level functions.
Right now we need to make sure we have enough D2 so they are compatible with future improvements (G2). See A0 for more discussions.
Calls into sub-functions can be generated either due to construction, or fusion passes. This section discusses the shape propagation across function boundaries.
The default calling convention of functions are as follows:
- tir.PrimFunc: the shapes are checked at runtime in the callee, caller need to explicitly specify the output shape.
- PackedFunc compiled by relax: the input shapes are checked by the callee, the return value's shape needs to be wrapped by
match_shape
if its shape is unknown.
Note that the current calling convention isolates shape constraints at the function boundary, so it establishes a safety net on runtime.
Importantly, it is desirable for callers to propagate as much shape information as possible(although it is not strictly necessary due to the safety net). There are a few ways to do:
- W0: in passes like fusion, the shape information is originally available before fusion, generate sub-functions that contain these shape information, preserve the shape information from before fusion to after fusion, insert match shape in caller's side if necessary.
- W1: leverage constant propagation of a sub relax function to do shape deduction.
- Replace the symbolic shape in input to be the shape of input arguments
- Constant evaluate the program to see if we can get the shape of output arguments independent from the other executions.
- W2: extract out shape function if it is independent of value
- Analyze the computation inside a function
- Extract out a shape computation function if shape computation follows a separate path.
Note that W1/W2 is not always possible. It is fine to do the best effort as long as we have the fallback safety net that either adds more annotation or relies on D0.
NOTE: the plan here is not detailed enough for immediate action, but considering G2, we can proceed with the necessary means(W0), and then add W1/W2 later.
These are additional considerations that are not part of immediate scope, but possible future goals that worth aligning per consideration G2.
With a clear safety net, we can leverage advanced rewriting to be able to enhance the shape information if necessary. See the following example
from tvm.script import relax as R
@R.function
def before(x: R.Tensor[_, _], y: R.Tensor[(m, n), "f32"]):
with R.dataflow():
# computation without shape
lv0: R.Tensor[_, _] = R.log(x)
gv0: R.Tensor[(m, n), "f32"] = R.ewise_add(lv0, y)
R.outputs(gv0)
return gv0
@R.function
def after(x: R.Tensor[(m, n), "f32"], y: R.Tensor[(m, n), "f32"]):
with R.dataflow():
# computation without shape
lv0: R.Tensor[(m, n), "f32"] = R.log(x)
gv0: R.Tensor[(m, n), "f32"] = R.ewise_add(lv0, y)
R.outputs(gv0)
return gv0
The above code shows a possible transformation that leverages more advanced shape analysis. In this case, use ewise_add to figure out the shape and type of the input x
. These improvements can be added in an incremental fashion without changing the overall architecture.
Right now symbolic shape sits on its own type and requires wrapping(ShapeExpr
) to be used as relax.Expr. One question is whether or not do we need to open up the mix further. While it is worthwhile to consider such a move, doing so would bring more complexity(such as differentiating Tuple[Int] vs ShapeExpr) as well as bringing in complexities in the developer side (the ability to pattern match ShapeExpr for most common cases). As a result, we make the symbolic shape computation constrained at ShapeExpr for now. This is a point that can be revisited later.