Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update nvFuser matmul #419

Merged
merged 6 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2236,17 +2236,11 @@ def _matmul_check(
a: TensorProxy,
b: TensorProxy,
) -> bool:
if nv_version < LooseVersion("0.2.2"):
if nv_version < LooseVersion("0.2.4"):
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
return False

enable_matmul: None | bool = get_compile_option("nv_enable_matmul", "Enable nvFuser matmul.")
if not enable_matmul:
return False
if not are_supported_tensors(a, b):
return False
if not (a.ndim == b.ndim and a.ndim == 2):
return False
return True
return enable_matmul and are_supported_tensors(a, b)


def matmul(
Expand Down
35 changes: 15 additions & 20 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
TorchExecutor,
)
from thunder.tests.make_tensor import make_tensor, make_tensor_like
from thunder.tests.opinfos import opinfos, push_away_from_singularities, tensor_creation_ops, get_opinfo
from thunder.tests.opinfos import opinfos, push_away_from_singularities, tensor_creation_ops, get_opinfo, matmul_opinfo
from looseversion import LooseVersion


Expand Down Expand Up @@ -887,29 +887,24 @@ def fn(a, b, bias=None):


@instantiate(
dtypes=(thunder.float16, thunder.bfloat16), devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,)
dtypes=(thunder.float16, thunder.bfloat16),
devicetypes=(devices.DeviceType.CUDA,),
executors=(nvFuserExecutor,),
decorators=(
pytest.mark.skipif(nvfuser_version() < LooseVersion("0.2.4"), reason="Requires nvFuser version 0.2.4 or later"),
),
)
def test_matmul(executor, device: str, dtype: dtypes.dtype):
m, n, k = 128, 64, 32
torch_dtype = ltorch.to_torch_dtype(dtype)
a = torch.randn((m, k), dtype=torch_dtype, device=device)
b = torch.randn((k, n), dtype=torch_dtype, device=device)

def fn(a, b):
return a.matmul(b)
return torch.matmul(a, b)

compiled_func = thunder.jit(
fn,
executors_list=executor.executors_list(),
nv_enable_matmul=True,
)
for sample in matmul_opinfo.sample_inputs(device, dtype):
compiled_func = thunder.jit(fn, executors_list=executor.executors_list(), nv_enable_matmul=True)

out = compiled_func(a, b)
traces = thunder.last_traces(compiled_func)
fusions = examine.get_fusions(traces[-1])
nv_version = nvfuser_version()

expected_fusions = 1 if nv_version >= "0.2.2" else 0
out = compiled_func(*sample.args)
traces = thunder.last_traces(compiled_func)
fusions = examine.get_fusions(traces[-1])

assert len(fusions) == expected_fusions
assert torch.allclose(out, torch.matmul(a, b))
assert len(fusions) == 1
torch.testing.assert_close(out, torch.matmul(*sample.args))
Loading