Skip to content

Commit

Permalink
Added OpInfo for addmm (pytorch#55920)
Browse files Browse the repository at this point in the history
Summary:
Added an OpInfo for `addmm` & ported its `method_tests`

Skipping `test_variant_consistency_eager` on CPU, as it's blocked by pytorch#56233

Pull Request resolved: pytorch#55920

Reviewed By: agolynski

Differential Revision: D27800325

Pulled By: heitorschueroff

fbshipit-source-id: 311cd26c6b491b486f652cf64275c6901fea03c5
  • Loading branch information
imaginary-person authored and facebook-github-bot committed Apr 26, 2021
1 parent b3f56ec commit ed9c7e1
Showing 1 changed file with 38 additions and 33 deletions.
71 changes: 38 additions & 33 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,21 +650,25 @@ def sample_inputs_mm(op_info, device, dtype, requires_grad, **kwargs):
return inputs

def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs):
input = SampleInput(
make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(
make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=requires_grad),
make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=False)))
if dtype.is_complex:
another_input = SampleInput(
make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(
make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=requires_grad),
make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=False)),
kwargs=dict(beta=1 + 2j, alpha=2 + 3j))
return (input, another_input)
else:
return (input, )
alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6)
beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2)
tests_list = [
((2, 3), (2, 2), (2, 3), False)
]
tests_with_lhs_broadcasting = [
((1,), (2, 2), (2, 3), True),
((), (2, 2), (2, 3), True)
]
test_cases = tests_list + tests_with_lhs_broadcasting # type: ignore[operator]
inputs = tuple(SampleInput(make_tensor(shape_a, device, dtype, requires_grad=requires_grad),
args=(make_tensor(shape_b, device, dtype,
requires_grad=requires_grad),
make_tensor(shape_c, device, dtype,
requires_grad=requires_grad)),
kwargs={'alpha': alpha_val, 'beta': beta_val},
broadcasts_input=broadcasts_input)
for shape_a, shape_b, shape_c, broadcasts_input in test_cases)
return inputs

def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs):
return (
Expand Down Expand Up @@ -3239,22 +3243,29 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
sample_inputs_func=sample_inputs_add,
supports_inplace_autograd=False),
OpInfo('addmm',
dtypes=floating_types(),
# This addmm OpInfo is for when alpha and beta are not both equal to 1.
# alpha=beta=1 is tested in the following opinfo, because that special case will
# trigger addmm being decomposed by a jit pass.
dtypes=floating_and_complex_types_and(torch.float16),
dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16),
# BFloat16 support on CUDA requires CUDA 11 and SM53
dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
*[torch.bfloat16] if CUDA11OrLater else []),
dtypesIfROCM=floating_types_and(torch.float16, torch.complex64, torch.complex128, torch.bfloat16),
dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
assert_autodiffed=True,
supports_inplace_autograd=False,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
autodiff_nonfusible_nodes=['aten::add', 'aten::mm'],
skips=(
# Skips unsupported bfloat16 check because above support check
# doesn't work on all platforms
SkipInfo('TestOpInfo', 'test_unsupported_dtypes', dtypes=(torch.bfloat16,)),
# TODO: remove redundant method_tests() entries
SkipInfo('TestOpInfo', 'test_duplicate_method_tests')),
sample_inputs_func=sample_inputs_addmm),
OpInfo('addmm',
# When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add.
variant_test_name='decomposed',
dtypes=floating_and_complex_types_and(torch.float16),
dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16),
dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
assert_autodiffed=True,
supports_inplace_autograd=False,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
autodiff_nonfusible_nodes=['aten::add', 'aten::mm'],
sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1)),
OpInfo('addmv',
dtypes=floating_types(),
dtypesIfCPU=all_types_and_complex_and(torch.bfloat16),
Expand Down Expand Up @@ -5581,12 +5592,6 @@ def method_tests():
('renorm', (S, S, S), (1, 2, 3), 'norm_1'),
('renorm', (S, S, S), (inf, 2, 0.5), 'norm_inf'),
('log_softmax', (S, S, S), (1, torch.float64,), 'kwarg_dtype_would_break_jit_loader', (True,)),
('addmm', (S, M), ((S, S), (S, M)), '', (True, ['aten::add', 'aten::mm'])),
('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs', (True, ['aten::add', 'aten::mm'])),
('addmm', (S, M), ((S, S), (S, M)), 'coef', (True,), (), (), ident, {'beta': 0.2, 'alpha': 0.6}),
('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs_coef', (True,), (), (), ident, {'beta': 0.2, 'alpha': 0.6}),
('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs', (True, ['aten::add', 'aten::mm'])),
('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs_coef', (True,), (), (), ident, {'beta': 0.2, 'alpha': 0.6}),
('mvlgamma', torch.empty(S,).uniform_(0.5, 1), [1], "p=1"),
('mvlgamma', torch.empty(S,).uniform_(1, 2), [2], "p=2"),
('mvlgamma', torch.empty(S, S).uniform_(1.5, 3), [3], "p=3"),
Expand Down

0 comments on commit ed9c7e1

Please sign in to comment.