diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b4d53cc9144ed..4b8a3d3c37839 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -650,21 +650,25 @@ def sample_inputs_mm(op_info, device, dtype, requires_grad, **kwargs): return inputs def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): - input = SampleInput( - make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=requires_grad), - args=( - make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=requires_grad), - make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=False))) - if dtype.is_complex: - another_input = SampleInput( - make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=requires_grad), - args=( - make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=requires_grad), - make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=False)), - kwargs=dict(beta=1 + 2j, alpha=2 + 3j)) - return (input, another_input) - else: - return (input, ) + alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6) + beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2) + tests_list = [ + ((2, 3), (2, 2), (2, 3), False) + ] + tests_with_lhs_broadcasting = [ + ((1,), (2, 2), (2, 3), True), + ((), (2, 2), (2, 3), True) + ] + test_cases = tests_list + tests_with_lhs_broadcasting # type: ignore[operator] + inputs = tuple(SampleInput(make_tensor(shape_a, device, dtype, requires_grad=requires_grad), + args=(make_tensor(shape_b, device, dtype, + requires_grad=requires_grad), + make_tensor(shape_c, device, dtype, + requires_grad=requires_grad)), + kwargs={'alpha': alpha_val, 'beta': beta_val}, + broadcasts_input=broadcasts_input) + for shape_a, shape_b, shape_c, broadcasts_input in test_cases) + return inputs def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs): return ( @@ -3239,22 +3243,29 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): sample_inputs_func=sample_inputs_add, supports_inplace_autograd=False), OpInfo('addmm', - dtypes=floating_types(), + # This addmm OpInfo is for when alpha and beta are not both equal to 1. + # alpha=beta=1 is tested in the following opinfo, because that special case will + # trigger addmm being decomposed by a jit pass. + dtypes=floating_and_complex_types_and(torch.float16), dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16), - # BFloat16 support on CUDA requires CUDA 11 and SM53 - dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128, - *[torch.bfloat16] if CUDA11OrLater else []), - dtypesIfROCM=floating_types_and(torch.float16, torch.complex64, torch.complex128, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []), assert_autodiffed=True, + supports_inplace_autograd=False, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, - autodiff_nonfusible_nodes=['aten::add', 'aten::mm'], - skips=( - # Skips unsupported bfloat16 check because above support check - # doesn't work on all platforms - SkipInfo('TestOpInfo', 'test_unsupported_dtypes', dtypes=(torch.bfloat16,)), - # TODO: remove redundant method_tests() entries - SkipInfo('TestOpInfo', 'test_duplicate_method_tests')), sample_inputs_func=sample_inputs_addmm), + OpInfo('addmm', + # When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add. + variant_test_name='decomposed', + dtypes=floating_and_complex_types_and(torch.float16), + dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []), + assert_autodiffed=True, + supports_inplace_autograd=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + autodiff_nonfusible_nodes=['aten::add', 'aten::mm'], + sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1)), OpInfo('addmv', dtypes=floating_types(), dtypesIfCPU=all_types_and_complex_and(torch.bfloat16), @@ -5581,12 +5592,6 @@ def method_tests(): ('renorm', (S, S, S), (1, 2, 3), 'norm_1'), ('renorm', (S, S, S), (inf, 2, 0.5), 'norm_inf'), ('log_softmax', (S, S, S), (1, torch.float64,), 'kwarg_dtype_would_break_jit_loader', (True,)), - ('addmm', (S, M), ((S, S), (S, M)), '', (True, ['aten::add', 'aten::mm'])), - ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs', (True, ['aten::add', 'aten::mm'])), - ('addmm', (S, M), ((S, S), (S, M)), 'coef', (True,), (), (), ident, {'beta': 0.2, 'alpha': 0.6}), - ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs_coef', (True,), (), (), ident, {'beta': 0.2, 'alpha': 0.6}), - ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs', (True, ['aten::add', 'aten::mm'])), - ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs_coef', (True,), (), (), ident, {'beta': 0.2, 'alpha': 0.6}), ('mvlgamma', torch.empty(S,).uniform_(0.5, 1), [1], "p=1"), ('mvlgamma', torch.empty(S,).uniform_(1, 2), [2], "p=2"), ('mvlgamma', torch.empty(S, S).uniform_(1.5, 3), [3], "p=3"),