-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SM68] Initial proposal for Wave Matrix #61
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,396 @@ | ||
<!-- {% raw %} --> | ||
|
||
# Wave Matrix | ||
|
||
* Proposal: [NNNN](NNNN-wave-matrix.md) | ||
* Author(s): [Chris Bieneman](https://github.com/llvm-beanz) | ||
* Sponsor: [Greg Roth](https://github.com/pow2clk), [Tex Riddell](https://github.com/tex3d) | ||
* Status: **Under Consideration** | ||
* Planned Version: Shader Model 6.8 | ||
|
||
|
||
## Introduction | ||
|
||
This proposal adds HLSL support for a new DXIL feature WaveMatrix. The new data | ||
types and operations add support for wave cooperative matrix multiplication and | ||
accumulation. The underlying hardware and driver interfaces for this feature are | ||
introduced in Shader Model 6.8. | ||
|
||
## Motivation | ||
|
||
GPUs have added dedicated hardware to support cooperative matrix multiplication | ||
across SIMD lanes. This feature proposes adding new data types and built-in | ||
functions to enable higher throughput matrix operations that fully utilize GPU | ||
SIMD hardware. | ||
|
||
These higher throughput matrix operations are required for optimal performance | ||
of many machine learning and image processing workloads. Adding native support | ||
to HLSL will enable high-performance matrix operations across all supported | ||
hardware with Shader Model 6.8 drivers. | ||
|
||
## Proposed solution | ||
|
||
WaveMatrix introduces new matrix templates to facilitate wave cooperative | ||
operations: | ||
|
||
```c++ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't seem to make a lick of difference in this case, but ```hlsl is an option in github: WaveMatrixLeft <TYPE_IN, M, N> ; // M x K WaveMatrixLeft <TYPE_IN, M, N> ; // M x K Maybe someday we'll get proper syntax highlighting? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI if you check the 3rd party grammars supported by Github's Linguist integration here, HLSL is listed as being supported by Tim's textmate grammar (so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this code block the HLSL highlighting is probably okay, but GitHub’s HLSL syntax highlighting doesn’t handle HLSL 2021 particularly well. The C++ mode syntax highlighting does much better IMO. As an example here’s HLSL highlighted: namespace detail {
template <typename ElTy, int NRows, int NCols>
class WaveMatrixBase {
};
} // namespace detail And the same code C++ highlighted: namespace detail {
template <typename ElTy, int NRows, int NCols>
class WaveMatrixBase {
};
} // namespace detail I’m fine changing the simpler code samples to be HLSL highlighted, but I’d greatly prefer to keep the ones that have templates C++-mode. I had used C++ everywhere for consistency, but that isn’t strictly necessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, I think the thing to do is to fork the C++ tmLanguage file as a base, adding HLSL specific constructs/keywords, and then swap it as the official highlighting grammar for github (instructions for doing that are here. It'd be a shame for this to be fixed later and then have a bunch of HLSL snippets throughout the github ecosystem be tagged as C++ snippets into perpetuity I think! |
||
// Matrix Depth (K) dimension is hardware dependent | ||
// With TYPE_IN one of {float32_t, float16_t, uint8_t4_packed, int8_t4_packed} | ||
WaveMatrixLeft <TYPE_IN, M, N> ; // M x K | ||
WaveMatrixRight <TYPE_IN, M, N> ; // K x N | ||
|
||
// With TYPE_ACC one of {float32_t, float16_t, int32_t} | ||
WaveMatrixAccumulator <TYPE_ACC, M, N> ; // M x N | ||
// WaveMatrixLeftColAcc and WaveMatrixRightRowAcc are provided support for | ||
// quantization algorithms. See Zero Point section | ||
|
||
// For accumulating columns from WaveMatrixLeft into a single column of sums | ||
WaveMatrixLeftColAcc <TYPE_ACC, M, N> ; // M x 1 | ||
|
||
// For accumulating rows from WaveMatrixRight into a single row of sums | ||
WaveMatrixRightRowAcc <TYPE_ACC, M, N> ; // 1 x N | ||
``` | ||
|
||
WaveMatrix accumulator object methods support operating on corresponding Left | ||
and Right operands. Results of operations can be stored or accumulated back into | ||
the accumulator. A simple example of multiplication is: | ||
|
||
```c++ | ||
[numthreads(64,1,1)] | ||
void main(uint3 GTID : SV_GroupThreadID, uint GIDX : SV_GroupIndex) | ||
{ | ||
WaveMatrixLeft<float, 16, 16> Left; | ||
WaveMatrixRight<float, 16, 16> Right; | ||
WaveMatrixAccumulator<float, 16, 16> Acc; | ||
|
||
Acc.Multiply(Left, Right); // Stores Left * Right into Acc | ||
Acc.MultiplyAccumulate(Left, Right); // Adds Left * Right to Acc | ||
} | ||
``` | ||
|
||
## Detailed design | ||
|
||
### Reading This Spec | ||
|
||
The next few sections include HLSL object definitions written in HLSL 2021 | ||
syntax and using inheritance to represent interface composition. The objects in | ||
the `detail` namespace are not exposed in the HLSL runtime. They are provided | ||
here to make the specification more concise to consume. Objects in no namespace | ||
are exposed as public interfaces. | ||
|
||
### WaveMatrix Fill | ||
|
||
All WaveMatrix objects have a `Fill` method of the form `void Fill(ElTy Value)` | ||
where `ElTy` is the element type. | ||
|
||
The `Fill` method fills the matrix or matrix fragment with the provided value. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a more technical description would be "Assigns the given |
||
All wave threads must provide the same value or the result is undefined. All | ||
WaveMatrix objects have the same `Fill` method with the same behavior. | ||
|
||
### WaveMatrix Matrix Objects | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like the explanation of |
||
|
||
The code below approximately declares the base interface that WaveMatrix matrix | ||
objects implement. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add "the following sections will explain in detail what these methods do and what the parameters represent." |
||
|
||
```c++ | ||
namespace detail { | ||
template <typename ElTy, int NRows, int NCols> | ||
class WaveMatrixBase { | ||
void Fill(ElTy Val); | ||
void Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, bool ColMajor, | ||
uint Align = 0); | ||
void Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride, | ||
bool ColMajor, uint Align = sizeof(ElTy)); | ||
|
||
void Load(groupshared ElTy Arr[], uint StartIdx, uint Stride, bool ColMajor); | ||
|
||
void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride, | ||
bool ColMajor, uint Align = sizeof(ElTy)); | ||
|
||
void Store(groupshared ElTy Arr[], uint StartIdx, uint Stride, bool ColMajor); | ||
}; | ||
} // namespace detail | ||
``` | ||
|
||
#### Loading and Storing WaveMatrix Matrix Objects | ||
|
||
Values for WaveMatrix matrix objects can be loaded from and stored to | ||
`[RW]ByteAddressBuffer` or `groupshared` array of type `ElTy`. | ||
|
||
Loading from or storing to a `[RW]ByteAddressBuffer` takes a start offset in | ||
bytes, row stride in bytes, and optional row alignment in bytes. Rows begin at | ||
algined offsets from the `StartOffset` based on the `Stride` and `Alignment` | ||
(`StartOffset + (RowIdx * [(Stride % Alignment) + Stride])`). | ||
|
||
The `Alignment` must be at least `sizeof(ElTy)` otherwise the load behavior is | ||
undefined. | ||
|
||
Loading from or storing to a `groupshared` array takes the starting index, and | ||
the row stride as a number of elements. | ||
|
||
##### Orientation | ||
|
||
When loading and storing WaveMatrix matrices a boolean parameter is provided to | ||
indicate if the matrix being loaded or stored is column major. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe a brief explanation of what column major means is in order. I think we should say that, when false, row-major orientation is indicated |
||
|
||
Matrices may be stored in row or column layout. Matrices are always loaded into | ||
row-major orientation in memory for WaveMatrix objects. The `Load` and `Store` | ||
methods perform matrix transposition when loading or storing column major | ||
matrices. | ||
|
||
##### Stride | ||
|
||
When loading and storing from `groupshared` arrays, the stride is expressed in | ||
number of elements. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above, perhaps a quick definition of what the stride means in this context in addition to the units for the matrix types |
||
|
||
When loading and storing from `[RW]ByteAddressBuffer` types, the stride is | ||
expressed in bytes. The row stride must be a multiple of the size of the | ||
element, and greater than or equal to the size of the element multiplied by the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think a comma is needed here since there are just two operands to "and" |
||
number of elements per row. Any value below the minimum legal values are | ||
ignored. The behavior of row stride values that are not a multiple of element | ||
stride is undefined. | ||
|
||
#### WaveMatrix Left & Right | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Text similar to the above "The code below approximately declares the base interface that WaveMatrix matrix |
||
|
||
```c++ | ||
template <typename ElTy, int NRows, int NCols> | ||
class WaveMatrixLeft : detail::WaveMatrixBase<ElTy, NRows, NCols> { | ||
uint MatrixDepth(); | ||
}; | ||
|
||
template <typename ElTy, int NRows, int NCols> | ||
class WaveMatrixRight : detail::WaveMatrixBase<ElTy, NRows, NCols> { | ||
uint MatrixDepth(); | ||
}; | ||
``` | ||
|
||
`ElTy` must be either a 32 or 16 bit floating point type or an 8 bit packed | ||
signed or unsigned integer type. `NRows` and `NCols` must be compile-time | ||
constant expressions. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is fairly obvious, but perhaps stating that NRows is the number of rows and NCols the number of columns. I think we should explain what Left and Right matrices are and how they can be used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to be a little careful about how deep this document goes in terms of usage. This isn't intended to be user documentation, this is a specification of the language feature. |
||
|
||
##### WaveMatrix(Left|Right) MatrixDepth | ||
|
||
The `MatrixDepth` method returns the hardware-dependent depth for the matrix | ||
multiplication unit. The resulting value must be an even multiple of 16. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you're using "even" here to indicate that there is no remainder. I'm not sure it's necessary and it might be interpreted as the number itself will not be odd. Obviously a multiple of 16 won't be, but still it's a bit ambiguous initially. |
||
|
||
### WaveMatrix Matrix Fragment Objects | ||
|
||
The code below approximately declares the base interface that all WaveMatrix | ||
fragment objects implement. A WaveMatrix fragment object stores a single row or | ||
column of a WaveMatrix matrix. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does what it represents depend on whether the matrix or row or column major? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it depends on which specific fragment type you use, either |
||
|
||
```c++ | ||
namespace detail { | ||
template <typename ElTy, int NRows, int NCols> | ||
class WaveMatrixFragmentBase : WaveMatrixBase<ElTy, NRows, NCols> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The inheritance here would seem to suggest that redeclaring these methods isn't strictly necessary. I agree that repeating it makes things clearer, but then I don't know if the inheritance adds anything and it makes the reader check above to see if there are any differences. I did exactly that briefly and saw none. Are there any? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The inheritance should not be there. The fragment classes don't support orientation specification. |
||
void Fill(ElTy Val); | ||
|
||
void Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, | ||
uint Align = 0); | ||
void Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride, | ||
uint Align = 0); | ||
|
||
void Load(groupshared ElTy Arr[], uint StartIdx, uint Stride); | ||
|
||
void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride, | ||
uint Align = 0); | ||
|
||
void Store(groupshared ElTy Arr[], uint StartIdx, uint Stride); | ||
}; | ||
} // namespace detail | ||
``` | ||
|
||
#### Loading and Storing WaveMatrix Fragment Objects | ||
|
||
Values for WaveMatrix fragment objects can be loaded from and stored to | ||
`[RW]ByteAddressBuffer` or `groupshared` array of type `ElTy`. | ||
|
||
Loading from or storing to a `[RW]ByteAddressBuffer` takes a start offset in | ||
bytes, element stride in bytes, and optional alignment in bytes. The alignment | ||
must be at least `sizeof(ElTy)` otherwise the load behavior is undefined. | ||
|
||
Loading from or storing to a `groupshared` array takes the starting index, and | ||
the element stride as a number of elements. | ||
|
||
##### Stride | ||
|
||
When loading and storing from `groupshared` arrays, the stride is expressed in | ||
number of elements. | ||
|
||
When loading and storing from `[RW]ByteAddressBuffer` types the stride must be | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This lacks the clarification that these are specified in bytes |
||
greater than or equal to the size of the element type. Any value below the | ||
minimum legal values are ignored. The behavior of row stride values that are not | ||
a multiple of element stride is undefined. | ||
|
||
### WaveMatrix Accumulator Objects | ||
|
||
The WaveMatrix Accumulator objects come in three forms which are represented by | ||
two categories: matrix accumulators and fragment accumulators. All accumulators | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found "which are represented by" to be confusing because there are 3 of one and 2 of another. Would "fall into two categories" be equally accurate? |
||
implement the interface below: | ||
|
||
```c++ | ||
namespace detail { | ||
template <typename ElTy, int NRows, int NCols, typename BaseT> | ||
class WaveMatrixAccumulatorBase : BaseT<ElTy, NRows, NCols> { | ||
|
||
void ScalarMultiply(ElTy Value); | ||
void ScalarDivide(ElTy Value); | ||
void ScalarAdd(ElTy Value); | ||
void ScalarSubtract(ElTy Value); | ||
}; | ||
} // namespace detail | ||
``` | ||
|
||
Each of these operations performs the corresponding element-wise arithmetic | ||
operation on the accumulator and the provided scalar value, storing the result | ||
back into the accumulator. | ||
|
||
### WaveMatrixAccumulator | ||
|
||
```c++ | ||
namespace detail { | ||
template <typename T> | ||
struct is_8bit_packed_int_type = | ||
std::enable_if_t<(std::is_same<T, int8_t4_packed>::value || | ||
std::is_same<T, uint8_t4_packed>::value), | ||
std::true_type>; | ||
|
||
template <typename T> struct is_8bit_packed_int_type = std::false_type; | ||
|
||
template <typename T> | ||
struct is_32bit_int_type = std::enable_if_t<(std::is_same<T, int32_t>::value || | ||
std::is_same<T, uint32_t>::value), | ||
std::true_type>; | ||
|
||
template <typename T> struct is_32bit_int_type = std::false_type; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm afraid I'm not familiar with this feature of C++ and it looks like you're redeclaring variables. Am I exceptionally ignorant of this feature (possible) or is it going to confuse others? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's because I mixed syntaxes and made this invalid. That should be a type alias: using is_32bit_int_type = std::false_type; |
||
} // namespace detail | ||
|
||
template <typename ElTy, int NRows, int NCols> | ||
class WaveMatrixAccumulator | ||
: WaveMatrixAccumulatorBase<ElTy, NRows, NCols, detail::WaveMatrixBase> { | ||
void Add(WaveMatrixLeftColAcc<ElTy, NRows, NCols> LeftMatrix); | ||
void Add(WaveMatrixRightRowAcc<ElTy, NRows, NCols> RightMatrix); | ||
void Add(WaveMatrixAccumulator<ElTy, NRows, NCols> Matrix); | ||
|
||
void Multiply(WaveMatrixLeft<ElTy, NRows, NCols> LHS, | ||
WaveMatrixRight<ElTy, NRows, NCols> RHS); | ||
void MultiplyAccumulate(WaveMatrixLeft<ElTy, NRows, NCols> LHS, | ||
WaveMatrixRight<ElTy, NRows, NCols> RHS); | ||
|
||
template <typename MatElTy1, typename MatElTy2> | ||
std::enable_if_t<detail::is_32bit_int_type<ElTy> && | ||
detail::is_8bit_packed_int_type<MatElTy1> && | ||
detail::is_8bit_packed_int_type<MatElTy2>, | ||
void>::type | ||
Multiply(WaveMatrixLeft<MatElTy1, NRows, NCols> LHS, | ||
WaveMatrixRight<MatElTy2, NRows, NCols> RHS); | ||
|
||
template <typename MatElTy1, typename MatElTy2> | ||
std::enable_if_t<detail::is_32bit_int_type<ElTy> && | ||
detail::is_8bit_packed_int_type<MatElTy1> && | ||
detail::is_8bit_packed_int_type<MatElTy2>, | ||
void>::type | ||
MultiplyAccumulate(WaveMatrixLeft<MatElTy1, NRows, NCols> LHS, | ||
WaveMatrixRight<MatElTy2, NRows, NCols> RHS); | ||
}; | ||
|
||
``` | ||
|
||
#### Broadcast Add | ||
|
||
The `WaveMatrixAccumulator::Add` methods perform an element-wise addition of an | ||
accumulator matrix and the provided matrix or fragment accumulator. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The last "accumulator" seems out of place to me. Is the intent to say that it adds an accumulator matrix (which I think I know what is) to a matrix accumulator or a [matrix] fragment accumulator? I'm not sure it's clear what those are and the terms are kind of ambiguous with the accumulator matrix term There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it adds one accumulator to another. It doesn't work with matrix types, just the accumulator types on both sides. |
||
|
||
If the argument is a fragment object, the `Add` method broadcasts the argument | ||
accumulator up from an Mx1 or 1xN matrix to an MxN matrix then performs | ||
element-wise addition. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is what I tend to call a "splat". Is it a conversion to a full matrix that copies the values of the fragment into a temporary full matrix that gets accumulated? How do the values get selected? I think I can guess, but I think it should be explicit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is explicitly not specified. The hardware can decide if it needs to copy the fragment to a full matrix or if it can more optimally perform a row/column-wise operation. |
||
|
||
#### WaveMatrixAccumulator::Multiply | ||
|
||
The `WaveMatrixAccumulator::Multiply` method performs multiplication of the left | ||
and right arguments and stores the result back into the `WaveMatrixAccumulator`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know you specify that it "stores" the result, but given that this is the only one that doesn't accumulate (right?) maybe we should call out explicitly that it overwrites the contents of the accumulator |
||
This is a wave-level operation and cannot be used inside divergent control flow. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it is, is the result undefined? Longer term goal, but perhaps we should crib from other specs' formal definitions of "should", "must" etc. |
||
|
||
For `Multiply` operations the matrix element types must match the accumulator | ||
type unless the accumulator is a signed or unsigned 32-bit integer type. For | ||
signed or unsigned 32-bit integer accumulators, signed or unsigned 8-bit packed | ||
integers are also supported and can be mixed interchangeably. | ||
|
||
#### WaveMatrixAccumulator::MultiplyAccumulate | ||
|
||
The `WaveMatrixAccumulator::MultiplyAccumulate` method performs multiplication | ||
of the left and right arguments and adds the result back into the | ||
`WaveMatrixAccumulator`. This is a wave-level operation and cannot be used | ||
inside divergent control flow. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as above, what if I do? |
||
|
||
For `MultiplyAccumulate` operations the matrix element types must match the | ||
accumulator type unless the accumulator is a signed or unsigned 32-bit integer | ||
type. For signed or unsigned 32-bit integer accumulators, signed or unsigned | ||
8-bit packed integers are also supported and can be mixed interchangeably. | ||
|
||
### WaveMatrix Fragment Accumulators | ||
|
||
WaveMatrix intrinsics are defined to support quantization calculations. | ||
Including calculating a sum for the rows of the left matrix and a sum of the | ||
columns of the right matrix. The `WaveMatrixRightRowAcc` and | ||
`WaveMatrixLeftColAcc` fragment accumulators perform this operation. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should they have "Fragment" in the name? I was initially confused why they took full matrices as parameters, but I see through the inheritance that they are fragment accumulators There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't love the names of those types, but I think if we add "Fragment" to the name we're going to be getting dangerously close to a type name that can't fit on one line of code with reasonable line wrapping rules. |
||
|
||
```c++ | ||
template <typename ElTy, int NRows, int NCols> | ||
class WaveMatrixLeftColAcc | ||
: WaveMatrixAccumulatorBase<ElTy, NRows, NCols, | ||
detail::WaveMatrixFragmentBase> { | ||
template <typename MatElTy> | ||
void SumAccumulate(WaveMatrixLeft<MatElTy, NRows, NCols> LeftMatrix); | ||
}; | ||
|
||
template <typename ElTy, int NRows, int NCols> | ||
class WaveMatrixRightRowAcc | ||
: WaveMatrixAccumulatorBase<ElTy, NRows, NCols, | ||
detail::WaveMatrixFragmentBase> { | ||
void SumAccumulate(WaveMatrixRight<MatElTy, NRows, NCols> RightMatrix); | ||
}; | ||
``` | ||
|
||
#### Zero Point | ||
|
||
The following is the equation for matrix multiplication with zero point | ||
adjustment included: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This section feels out of place in between the definition of WaveMatrix Fragment Acuumulators and the description of their SumAccumulate method. I recognize that it is referenced in the section just after it, but if this is defining the multiply operations, perhaps it can go after the multiply method description and the next section can link back to it? |
||
|
||
$C_{[x,y]} = (\sum_{i=0}^{K} A_{[x,i]} * B_{[i,y]}) - Z_a * (\sum_{i=0}^{K} B_{[i,y]}) - Z_b * (\sum_{i=0}^{K} A_{[x,i]}) + Z_a * Z_b * K$ | ||
|
||
$(\sum_{i=0}^{K} A_{[x,i]} * B_{[i,y]})$ is basic matrix multiplication. | ||
|
||
$- Z_a * (\sum_{i=0}^{K} B_{[i,y]})$ is the zero point adjustment for matrix $A$ | ||
|
||
$- Z_b * (\sum_{i=0}^{K} A_{[x,i]})$ is the zero point adjustment for matrix $B$ | ||
|
||
$+ Z_a * Z_b * K$ is the static zero point adjustment for both matrix $A$ and $B$ | ||
|
||
$Z_*$ are constant zero points values | ||
|
||
#### Wave Matrix SumAccumulate | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe "WaveMatrix Fragment SumAccumulate" to be consistent with the title of that section? |
||
|
||
The `SumAccumulate` methods accumulate the values of the argument matrix into | ||
the WaveMatrix fragment accumulator. The fragment WaveMatrix must have the same | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reading this and the description above the class, it's not clear to me what actually gets placed into the accumulator. Is it the sum of all elements in each row of a row-major matrix and the sum of all elements in each column of a column-major matrix? |
||
data type as the fragment accumulator. | ||
|
||
This intrinsics is used to calculated | ||
$(\sum_{i=0}^{K} A_{[x,i]})$ and $(\sum_{i=0}^{K} B_{[i,y]})$ from our above | ||
equation. | ||
|
||
|
||
## Acknowledgments | ||
|
||
This spec was developed as an extensive collaboration between the Microsoft HLSL | ||
and Direct3D teams and IHV partners. | ||
|
||
Special thanks to: | ||
|
||
Claire Andrews | ||
Nick Feeney | ||
Amar Patel | ||
Tex Riddell | ||
Greg Roth | ||
|
||
<!-- {% endraw %} --> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's minor, but I fear the notion of "supported" is vague here. It may be interpreted as all hardware that supports SM 6.8, which is probably not the case. My feeble attempt to recraft the sentence:
Adding support to HLSL will enable the high-performance of native matrix operations across all hardware with such support through Shader Model 6.8 drivers.