From 1cfb8e78527cb8e13dc647e9c8d91f789a427402 Mon Sep 17 00:00:00 2001 From: Aritra Kar Date: Fri, 17 Nov 2023 01:47:28 -0500 Subject: [PATCH] Added remaining Numpy NDArray single function expressions (#183) * Added support for single expressions involving the following functions: numpy.linalg.{matrix_power, qr, svd, det, matrix_rank, inv, pinv}. * Fixes for CI tests. * Fixed issues with line lengths and import order. * Refactored code. --- src/latexify/codegen/expression_codegen.py | 138 ++++++++++++++- .../codegen/expression_codegen_test.py | 161 ++++++++++++++++++ 2 files changed, 297 insertions(+), 2 deletions(-) diff --git a/src/latexify/codegen/expression_codegen.py b/src/latexify/codegen/expression_codegen.py index c400dac..f88869c 100644 --- a/src/latexify/codegen/expression_codegen.py +++ b/src/latexify/codegen/expression_codegen.py @@ -23,8 +23,8 @@ def __init__( """Initializer. Args: - use_math_symbols: Whether to convert identifiers with a math symbol surface - (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). + use_math_symbols: Whether to convert identifiers with a math symbol + surface (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). use_set_symbols: Whether to use set symbols or not. """ self._identifier_converter = identifier_converter.IdentifierConverter( @@ -240,6 +240,130 @@ def _generate_transpose(self, node: ast.Call) -> str | None: else: return None + def _generate_determinant(self, node: ast.Call) -> str | None: + """Generates LaTeX for numpy.linalg.det. + Args: + node: ast.Call node containing the appropriate method invocation. + Returns: + Generated LaTeX, or None if the node has unsupported syntax. + Raises: + LatexifyError: Unsupported argument type given. + """ + name = ast_utils.extract_function_name_or_none(node) + assert name == "det" + + if len(node.args) != 1: + return None + + func_arg = node.args[0] + if isinstance(func_arg, ast.Name): + arg_id = rf"\mathbf{{{func_arg.id}}}" + return rf"\det \mathopen{{}}\left( {arg_id} \mathclose{{}}\right)" + elif isinstance(func_arg, ast.List): + matrix = self._generate_matrix(node) + return rf"\det \mathopen{{}}\left( {matrix} \mathclose{{}}\right)" + + return None + + def _generate_matrix_rank(self, node: ast.Call) -> str | None: + """Generates LaTeX for numpy.linalg.matrix_rank. + Args: + node: ast.Call node containing the appropriate method invocation. + Returns: + Generated LaTeX, or None if the node has unsupported syntax. + Raises: + LatexifyError: Unsupported argument type given. + """ + name = ast_utils.extract_function_name_or_none(node) + assert name == "matrix_rank" + + if len(node.args) != 1: + return None + + func_arg = node.args[0] + if isinstance(func_arg, ast.Name): + arg_id = rf"\mathbf{{{func_arg.id}}}" + return ( + rf"\mathrm{{rank}} \mathopen{{}}\left( {arg_id} \mathclose{{}}\right)" + ) + elif isinstance(func_arg, ast.List): + matrix = self._generate_matrix(node) + return ( + rf"\mathrm{{rank}} \mathopen{{}}\left( {matrix} \mathclose{{}}\right)" + ) + + return None + + def _generate_matrix_power(self, node: ast.Call) -> str | None: + """Generates LaTeX for numpy.linalg.matrix_power. + Args: + node: ast.Call node containing the appropriate method invocation. + Returns: + Generated LaTeX, or None if the node has unsupported syntax. + Raises: + LatexifyError: Unsupported argument type given. + """ + name = ast_utils.extract_function_name_or_none(node) + assert name == "matrix_power" + + if len(node.args) != 2: + return None + + func_arg = node.args[0] + power_arg = node.args[1] + if isinstance(power_arg, ast.Num): + if isinstance(func_arg, ast.Name): + return rf"\mathbf{{{func_arg.id}}}^{{{power_arg.n}}}" + elif isinstance(func_arg, ast.List): + matrix = self._generate_matrix(node) + if matrix is not None: + return rf"{matrix}^{{{power_arg.n}}}" + return None + + def _generate_inv(self, node: ast.Call) -> str | None: + """Generates LaTeX for numpy.linalg.inv. + Args: + node: ast.Call node containing the appropriate method invocation. + Returns: + Generated LaTeX, or None if the node has unsupported syntax. + Raises: + LatexifyError: Unsupported argument type given. + """ + name = ast_utils.extract_function_name_or_none(node) + assert name == "inv" + + if len(node.args) != 1: + return None + + func_arg = node.args[0] + if isinstance(func_arg, ast.Name): + return rf"\mathbf{{{func_arg.id}}}^{{-1}}" + elif isinstance(func_arg, ast.List): + return rf"{self._generate_matrix(node)}^{{-1}}" + return None + + def _generate_pinv(self, node: ast.Call) -> str | None: + """Generates LaTeX for numpy.linalg.pinv. + Args: + node: ast.Call node containing the appropriate method invocation. + Returns: + Generated LaTeX, or None if the node has unsupported syntax. + Raises: + LatexifyError: Unsupported argument type given. + """ + name = ast_utils.extract_function_name_or_none(node) + assert name == "pinv" + + if len(node.args) != 1: + return None + + func_arg = node.args[0] + if isinstance(func_arg, ast.Name): + return rf"\mathbf{{{func_arg.id}}}^{{+}}" + elif isinstance(func_arg, ast.List): + return rf"{self._generate_matrix(node)}^{{+}}" + return None + def visit_Call(self, node: ast.Call) -> str: """Visit a Call node.""" func_name = ast_utils.extract_function_name_or_none(node) @@ -256,6 +380,16 @@ def visit_Call(self, node: ast.Call) -> str: special_latex = self._generate_identity(node) elif func_name == "transpose": special_latex = self._generate_transpose(node) + elif func_name == "det": + special_latex = self._generate_determinant(node) + elif func_name == "matrix_rank": + special_latex = self._generate_matrix_rank(node) + elif func_name == "matrix_power": + special_latex = self._generate_matrix_power(node) + elif func_name == "inv": + special_latex = self._generate_inv(node) + elif func_name == "pinv": + special_latex = self._generate_pinv(node) else: special_latex = None diff --git a/src/latexify/codegen/expression_codegen_test.py b/src/latexify/codegen/expression_codegen_test.py index 41510fa..8ed960b 100644 --- a/src/latexify/codegen/expression_codegen_test.py +++ b/src/latexify/codegen/expression_codegen_test.py @@ -995,6 +995,167 @@ def test_transpose(code: str, latex: str) -> None: assert expression_codegen.ExpressionCodegen().visit(tree) == latex +@pytest.mark.parametrize( + "code,latex", + [ + ("det(A)", r"\det \mathopen{}\left( \mathbf{A} \mathclose{}\right)"), + ("det(b)", r"\det \mathopen{}\left( \mathbf{b} \mathclose{}\right)"), + ( + "det([[1, 2], [3, 4]])", + r"\det \mathopen{}\left( \begin{bmatrix} 1 & 2 \\" + r" 3 & 4 \end{bmatrix} \mathclose{}\right)", + ), + ( + "det([[1, 2, 3], [4, 5, 6], [7, 8, 9]])", + r"\det \mathopen{}\left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\" + r" 7 & 8 & 9 \end{bmatrix} \mathclose{}\right)", + ), + # Unsupported + ("det()", r"\mathrm{det} \mathopen{}\left( \mathclose{}\right)"), + ("det(2)", r"\mathrm{det} \mathopen{}\left( 2 \mathclose{}\right)"), + ( + "det(a, (1, 0))", + r"\mathrm{det} \mathopen{}\left( a, " + r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)", + ), + ], +) +def test_determinant(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Call) + assert expression_codegen.ExpressionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ( + "matrix_rank(A)", + r"\mathrm{rank} \mathopen{}\left( \mathbf{A} \mathclose{}\right)", + ), + ( + "matrix_rank(b)", + r"\mathrm{rank} \mathopen{}\left( \mathbf{b} \mathclose{}\right)", + ), + ( + "matrix_rank([[1, 2], [3, 4]])", + r"\mathrm{rank} \mathopen{}\left( \begin{bmatrix} 1 & 2 \\" + r" 3 & 4 \end{bmatrix} \mathclose{}\right)", + ), + ( + "matrix_rank([[1, 2, 3], [4, 5, 6], [7, 8, 9]])", + r"\mathrm{rank} \mathopen{}\left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\" + r" 7 & 8 & 9 \end{bmatrix} \mathclose{}\right)", + ), + # Unsupported + ( + "matrix_rank()", + r"\mathrm{matrix\_rank} \mathopen{}\left( \mathclose{}\right)", + ), + ( + "matrix_rank(2)", + r"\mathrm{matrix\_rank} \mathopen{}\left( 2 \mathclose{}\right)", + ), + ( + "matrix_rank(a, (1, 0))", + r"\mathrm{matrix\_rank} \mathopen{}\left( a, " + r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)", + ), + ], +) +def test_matrix_rank(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Call) + assert expression_codegen.ExpressionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("matrix_power(A, 2)", r"\mathbf{A}^{2}"), + ("matrix_power(b, 2)", r"\mathbf{b}^{2}"), + ( + "matrix_power([[1, 2], [3, 4]], 2)", + r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{2}", + ), + ( + "matrix_power([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 42)", + r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{42}", + ), + # Unsupported + ( + "matrix_power()", + r"\mathrm{matrix\_power} \mathopen{}\left( \mathclose{}\right)", + ), + ( + "matrix_power(2)", + r"\mathrm{matrix\_power} \mathopen{}\left( 2 \mathclose{}\right)", + ), + ( + "matrix_power(a, (1, 0))", + r"\mathrm{matrix\_power} \mathopen{}\left( a, " + r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)", + ), + ], +) +def test_matrix_power(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Call) + assert expression_codegen.ExpressionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("inv(A)", r"\mathbf{A}^{-1}"), + ("inv(b)", r"\mathbf{b}^{-1}"), + ("inv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{-1}"), + ( + "inv([[1, 2, 3], [4, 5, 6], [7, 8, 9]])", + r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{-1}", + ), + # Unsupported + ("inv()", r"\mathrm{inv} \mathopen{}\left( \mathclose{}\right)"), + ("inv(2)", r"\mathrm{inv} \mathopen{}\left( 2 \mathclose{}\right)"), + ( + "inv(a, (1, 0))", + r"\mathrm{inv} \mathopen{}\left( a, " + r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)", + ), + ], +) +def test_inv(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Call) + assert expression_codegen.ExpressionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("pinv(A)", r"\mathbf{A}^{+}"), + ("pinv(b)", r"\mathbf{b}^{+}"), + ("pinv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{+}"), + ( + "pinv([[1, 2, 3], [4, 5, 6], [7, 8, 9]])", + r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{+}", + ), + # Unsupported + ("pinv()", r"\mathrm{pinv} \mathopen{}\left( \mathclose{}\right)"), + ("pinv(2)", r"\mathrm{pinv} \mathopen{}\left( 2 \mathclose{}\right)"), + ( + "pinv(a, (1, 0))", + r"\mathrm{pinv} \mathopen{}\left( a, " + r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)", + ), + ], +) +def test_pinv(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Call) + assert expression_codegen.ExpressionCodegen().visit(tree) == latex + + # Check list for #89. # https://github.com/google/latexify_py/issues/89#issuecomment-1344967636 @pytest.mark.parametrize(