diff --git a/opt_einsum/parser.py b/opt_einsum/parser.py index 47567ae..5c6ac53 100644 --- a/opt_einsum/parser.py +++ b/opt_einsum/parser.py @@ -309,8 +309,8 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L else: subscripts, operands = convert_interleaved_input(operands) - if shapes: - operand_shapes = operands + if shapes: + operand_shapes = [list(s) for s in operands] else: operand_shapes = [o.shape for o in operands] diff --git a/opt_einsum/tests/test_parser.py b/opt_einsum/tests/test_parser.py index d582ca4..ed1a30b 100644 --- a/opt_einsum/tests/test_parser.py +++ b/opt_einsum/tests/test_parser.py @@ -41,3 +41,12 @@ def test_parse_einsum_input_shapes() -> None: assert input_subscripts == eq assert output_subscript == "ad" assert np.allclose([possibly_convert_to_numpy(shp) for shp in shps], operands) + + +def test_parse_with_ellisis(): + eq = "...a,ab" + shps = [(2, 3), (3, 4)] + input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shps], shapes=True) + assert input_subscripts == "da,ab" + assert output_subscript == "db" + assert np.allclose([possibly_convert_to_numpy(shp) for shp in shps], operands)