Skip to content

Commit

Permalink
Numpy transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Douglas committed Oct 11, 2023
1 parent b3ae7fa commit 06ff5b8
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 9 deletions.
23 changes: 23 additions & 0 deletions src/latexify/codegen/expression_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,27 @@ def _generate_identity(self, node: ast.Call) -> str | None:

return rf"\mathbf{{I}}_{{{ndims}}}"

def _generate_transpose(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.transpose.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
Raises:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "transpose"

if len(node.args) != 1:
return None

func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
return rf"\mathbf{{{func_arg.id}}}^\intercal"
else:
return None

def visit_Call(self, node: ast.Call) -> str:
"""Visit a Call node."""
func_name = ast_utils.extract_function_name_or_none(node)
Expand All @@ -232,6 +253,8 @@ def visit_Call(self, node: ast.Call) -> str:
special_latex = self._generate_zeros(node)
elif func_name == "identity":
special_latex = self._generate_identity(node)
elif func_name == "transpose":
special_latex = self._generate_transpose(node)
else:
special_latex = None

Expand Down
21 changes: 21 additions & 0 deletions src/latexify/codegen/expression_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,3 +970,24 @@ def test_identity(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
"code,latex",
[
("transpose(A)", r"\mathbf{A}^\intercal"),
("transpose(b)", r"\mathbf{b}^\intercal"),
# Unsupported
("transpose()", r"\mathrm{transpose} \mathopen{}\left( \mathclose{}\right)"),
("transpose(2)", r"\mathrm{transpose} \mathopen{}\left( 2 \mathclose{}\right)"),
(
"transpose(a, (1, 0))",
r"\mathrm{transpose} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_transpose(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
21 changes: 12 additions & 9 deletions src/latexify/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ def algorithmic(

def algorithmic(
fn: Callable[..., Any] | None = None, **kwargs: Any
) -> ipython_wrappers.LatexifiedAlgorithm | Callable[
[Callable[..., Any]], ipython_wrappers.LatexifiedAlgorithm
]:
) -> (
ipython_wrappers.LatexifiedAlgorithm
| Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedAlgorithm]
):
"""Attach LaTeX pretty-printing to the given function.
This function works with or without specifying the target function as the
Expand Down Expand Up @@ -67,9 +68,10 @@ def function(

def function(
fn: Callable[..., Any] | None = None, **kwargs: Any
) -> ipython_wrappers.LatexifiedFunction | Callable[
[Callable[..., Any]], ipython_wrappers.LatexifiedFunction
]:
) -> (
ipython_wrappers.LatexifiedFunction
| Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction]
):
"""Attach LaTeX pretty-printing to the given function.
This function works with or without specifying the target function as the positional
Expand Down Expand Up @@ -110,9 +112,10 @@ def expression(

def expression(
fn: Callable[..., Any] | None = None, **kwargs: Any
) -> ipython_wrappers.LatexifiedFunction | Callable[
[Callable[..., Any]], ipython_wrappers.LatexifiedFunction
]:
) -> (
ipython_wrappers.LatexifiedFunction
| Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction]
):
"""Attach LaTeX pretty-printing to the given function.
This function is a shortcut for `latexify.function` with the default parameter
Expand Down

0 comments on commit 06ff5b8

Please sign in to comment.