Skip to content

Commit

Permalink
Fixes operands could not be broadcast error when the subscripts con…
Browse files Browse the repository at this point in the history
…tain ellipsis and the operands are shapes.
  • Loading branch information
nova77 committed Jun 14, 2024
1 parent eede8fa commit 19de34e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
4 changes: 2 additions & 2 deletions opt_einsum/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L
else:
subscripts, operands = convert_interleaved_input(operands)

if shapes:
operand_shapes = operands
if shapes:
operand_shapes = [list(s) for s in operands]
else:
operand_shapes = [o.shape for o in operands]

Expand Down
9 changes: 9 additions & 0 deletions opt_einsum/tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,12 @@ def test_parse_einsum_input_shapes() -> None:
assert input_subscripts == eq
assert output_subscript == "ad"
assert np.allclose([possibly_convert_to_numpy(shp) for shp in shps], operands)


def test_parse_with_ellisis():
eq = "...a,ab"
shps = [(2, 3), (3, 4)]
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shps], shapes=True)
assert input_subscripts == "da,ab"
assert output_subscript == "db"
assert np.allclose([possibly_convert_to_numpy(shp) for shp in shps], operands)

0 comments on commit 19de34e

Please sign in to comment.