Skip to content

Commit

Permalink
Add support for bfloat16 in segment_matmul operation (#272)
Browse files Browse the repository at this point in the history
This PR covers CPU implementation.

---------

Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
DamianSzwichtenberg and rusty1s authored Nov 2, 2023
1 parent 2a0d558 commit 44760ec
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.4.0] - 2023-MM-DD
### Added
- Added support for `bfloat16` data type in `segment_matmul` and `grouped_matmul` (CPU only) ([#272](https://github.com/pyg-team/pyg-lib/pull/272))
### Changed
- Added `--biased` parameter to run benchmarks for biased sampling ([#267](https://github.com/pyg-team/pyg-lib/pull/267))
- Improved speed of biased sampling ([#270](https://github.com/pyg-team/pyg-lib/pull/270))
Expand Down
6 changes: 4 additions & 2 deletions pyg_lib/csrc/ops/cpu/matmul_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ std::vector<at::Tensor> grouped_matmul_kernel(const at::TensorList input,
{input_contig[i].size(0), other_contig[i].size(-1)}));
}

AT_DISPATCH_ALL_TYPES(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
input_contig.front().scalar_type(), "grouped_matmul_kernel", [&] {
if (mkl_path_available<scalar_t>() &&
mkl_path_possible(input_contig, other_contig)) {
Expand Down Expand Up @@ -413,7 +414,8 @@ at::Tensor segment_matmul_kernel(const at::Tensor& input,
const auto other_contig = other.contiguous();
auto out = input_contig.new_empty({input.size(0), other.size(-1)});

AT_DISPATCH_ALL_TYPES(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
input_contig.scalar_type(), "segment_matmul_kernel", [&] {
const auto n = other_contig.size(-1);
const auto k = input_contig.size(-1);
Expand Down
11 changes: 6 additions & 5 deletions pyg_lib/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ def onlyTriton(func: Callable) -> Callable:


def withCUDA(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
func(*args, device=torch.device('cpu'), **kwargs)
if torch.cuda.is_available():
func(*args, device=torch.device('cuda:0'), **kwargs)
import pytest

return wrapper
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices.append(torch.device('cuda:0'))

return pytest.mark.parametrize('device', devices)(func)


def withDataset(group: str, name: str) -> Callable:
Expand Down
21 changes: 16 additions & 5 deletions test/ops/test_matmul.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import pytest
import torch

import pyg_lib
Expand All @@ -11,11 +12,17 @@


@withCUDA
def test_segment_matmul_autograd(device):
inputs = torch.randn((8, 16), requires_grad=True, device=device)
@pytest.mark.parametrize('dtype', [torch.float, torch.bfloat16])
def test_segment_matmul_autograd(dtype, device):
if device.type == 'cuda' and dtype == torch.bfloat16:
pytest.skip('CUDA does not support bfloat16')

inputs = torch.randn((8, 16), requires_grad=True, device=device,
dtype=dtype)
ptr = torch.tensor([0, 5, 8]).to(torch.device(device))
other = torch.randn((2, 16, 32), requires_grad=True, device=device)
bias = torch.randn((2, 32), requires_grad=True, device=device)
other = torch.randn((2, 16, 32), requires_grad=True, device=device,
dtype=dtype)
bias = torch.randn((2, 32), requires_grad=True, device=device, dtype=dtype)
out = pyg_lib.ops.segment_matmul(inputs, ptr, other, bias)
assert out.size() == (8, 32)

Expand All @@ -31,7 +38,11 @@ def test_segment_matmul_autograd(device):


@withCUDA
def test_grouped_matmul_autograd(device):
@pytest.mark.parametrize('dtype', [torch.float, torch.bfloat16])
def test_grouped_matmul_autograd(dtype, device):
if device.type == 'cuda' and dtype == torch.bfloat16:
pytest.skip('CUDA does not support bfloat16')

inputs = [
torch.randn(5, 16, device=device, requires_grad=True),
torch.randn(6, 9, device=device, requires_grad=True),
Expand Down

0 comments on commit 44760ec

Please sign in to comment.