From 06ff5b832dcc6924aa05132a1bb3f0e7d528f488 Mon Sep 17 00:00:00 2001 From: Jack Douglas Date: Tue, 10 Oct 2023 23:59:42 -0400 Subject: [PATCH] Numpy transpose --- src/latexify/codegen/expression_codegen.py | 23 +++++++++++++++++++ .../codegen/expression_codegen_test.py | 21 +++++++++++++++++ src/latexify/frontend.py | 21 +++++++++-------- 3 files changed, 56 insertions(+), 9 deletions(-) diff --git a/src/latexify/codegen/expression_codegen.py b/src/latexify/codegen/expression_codegen.py index 5a1ae1b..9706e74 100644 --- a/src/latexify/codegen/expression_codegen.py +++ b/src/latexify/codegen/expression_codegen.py @@ -218,6 +218,27 @@ def _generate_identity(self, node: ast.Call) -> str | None: return rf"\mathbf{{I}}_{{{ndims}}}" + def _generate_transpose(self, node: ast.Call) -> str | None: + """Generates LaTeX for numpy.transpose. + Args: + node: ast.Call node containing the appropriate method invocation. + Returns: + Generated LaTeX, or None if the node has unsupported syntax. + Raises: + LatexifyError: Unsupported argument type given. + """ + name = ast_utils.extract_function_name_or_none(node) + assert name == "transpose" + + if len(node.args) != 1: + return None + + func_arg = node.args[0] + if isinstance(func_arg, ast.Name): + return rf"\mathbf{{{func_arg.id}}}^\intercal" + else: + return None + def visit_Call(self, node: ast.Call) -> str: """Visit a Call node.""" func_name = ast_utils.extract_function_name_or_none(node) @@ -232,6 +253,8 @@ def visit_Call(self, node: ast.Call) -> str: special_latex = self._generate_zeros(node) elif func_name == "identity": special_latex = self._generate_identity(node) + elif func_name == "transpose": + special_latex = self._generate_transpose(node) else: special_latex = None diff --git a/src/latexify/codegen/expression_codegen_test.py b/src/latexify/codegen/expression_codegen_test.py index 86d7358..0b90915 100644 --- a/src/latexify/codegen/expression_codegen_test.py +++ b/src/latexify/codegen/expression_codegen_test.py @@ -970,3 +970,24 @@ def test_identity(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Call) assert expression_codegen.ExpressionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("transpose(A)", r"\mathbf{A}^\intercal"), + ("transpose(b)", r"\mathbf{b}^\intercal"), + # Unsupported + ("transpose()", r"\mathrm{transpose} \mathopen{}\left( \mathclose{}\right)"), + ("transpose(2)", r"\mathrm{transpose} \mathopen{}\left( 2 \mathclose{}\right)"), + ( + "transpose(a, (1, 0))", + r"\mathrm{transpose} \mathopen{}\left( a, " + r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)", + ), + ], +) +def test_transpose(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Call) + assert expression_codegen.ExpressionCodegen().visit(tree) == latex diff --git a/src/latexify/frontend.py b/src/latexify/frontend.py index 6f38bc7..c2a7875 100644 --- a/src/latexify/frontend.py +++ b/src/latexify/frontend.py @@ -24,9 +24,10 @@ def algorithmic( def algorithmic( fn: Callable[..., Any] | None = None, **kwargs: Any -) -> ipython_wrappers.LatexifiedAlgorithm | Callable[ - [Callable[..., Any]], ipython_wrappers.LatexifiedAlgorithm -]: +) -> ( + ipython_wrappers.LatexifiedAlgorithm + | Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedAlgorithm] +): """Attach LaTeX pretty-printing to the given function. This function works with or without specifying the target function as the @@ -67,9 +68,10 @@ def function( def function( fn: Callable[..., Any] | None = None, **kwargs: Any -) -> ipython_wrappers.LatexifiedFunction | Callable[ - [Callable[..., Any]], ipython_wrappers.LatexifiedFunction -]: +) -> ( + ipython_wrappers.LatexifiedFunction + | Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction] +): """Attach LaTeX pretty-printing to the given function. This function works with or without specifying the target function as the positional @@ -110,9 +112,10 @@ def expression( def expression( fn: Callable[..., Any] | None = None, **kwargs: Any -) -> ipython_wrappers.LatexifiedFunction | Callable[ - [Callable[..., Any]], ipython_wrappers.LatexifiedFunction -]: +) -> ( + ipython_wrappers.LatexifiedFunction + | Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction] +): """Attach LaTeX pretty-printing to the given function. This function is a shortcut for `latexify.function` with the default parameter