Skip to content

Commit

Permalink
Refactored code.
Browse files Browse the repository at this point in the history
  • Loading branch information
aritrakar committed Nov 17, 2023
1 parent f9c379c commit 9334fee
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 88 deletions.
47 changes: 24 additions & 23 deletions src/latexify/codegen/expression_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,11 @@ def _generate_determinant(self, node: ast.Call) -> str | None:

func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
return rf"\det \left( \mathbf{{{func_arg.id}}} \right)"
arg_id = rf"\mathbf{{{func_arg.id}}}"
return rf"\det \mathopen{{}}\left( {arg_id} \mathclose{{}}\right)"
elif isinstance(func_arg, ast.List):
return rf"\det \left( {self._generate_matrix(node)} \right)"
matrix = self._generate_matrix(node)
return rf"\det \mathopen{{}}\left( {matrix} \mathclose{{}}\right)"

return None

Expand All @@ -280,9 +282,15 @@ def _generate_matrix_rank(self, node: ast.Call) -> str | None:

func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
return rf"\mathrm{{rank}} \left( \mathbf{{{func_arg.id}}} \right)"
arg_id = rf"\mathbf{{{func_arg.id}}}"
return (
rf"\mathrm{{rank}} \mathopen{{}}\left( {arg_id} \mathclose{{}}\right)"
)
elif isinstance(func_arg, ast.List):
return rf"\mathrm{{rank}} \left( {self._generate_matrix(node)} \right)"
matrix = self._generate_matrix(node)
return (
rf"\mathrm{{rank}} \mathopen{{}}\left( {matrix} \mathclose{{}}\right)"
)

return None

Expand Down Expand Up @@ -312,8 +320,8 @@ def _generate_matrix_power(self, node: ast.Call) -> str | None:
return rf"{matrix}^{{{power_arg.n}}}"
return None

def _generate_qr_and_svd(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.qr and numpy.linalg.svd.
def _generate_inv(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.inv.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Expand All @@ -322,23 +330,20 @@ def _generate_qr_and_svd(self, node: ast.Call) -> str | None:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "QR" or name == "SVD"
assert name == "inv"

if len(node.args) != 1:
return None

func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
func_arg_str = rf"\mathbf{{{func_arg.id}}}"
return rf"\mathrm{{{name.upper()}}} \left( {func_arg_str} \right)"

return rf"\mathbf{{{func_arg.id}}}^{{-1}}"
elif isinstance(func_arg, ast.List):
matrix_str = self._generate_matrix(node)
return rf"\mathrm{{{name.upper()}}} \left( {matrix_str} \right)"
return rf"{self._generate_matrix(node)}^{{-1}}"
return None

def _generate_inverses(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.inv.
def _generate_pinv(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.pinv.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Expand All @@ -347,19 +352,15 @@ def _generate_inverses(self, node: ast.Call) -> str | None:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "inv" or name == "pinv"
assert name == "pinv"

if len(node.args) != 1:
return None

func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
if name == "inv":
return rf"\mathbf{{{func_arg.id}}}^{{-1}}"
return rf"\mathbf{{{func_arg.id}}}^{{+}}"
elif isinstance(func_arg, ast.List):
if name == "inv":
return rf"{self._generate_matrix(node)}^{{-1}}"
return rf"{self._generate_matrix(node)}^{{+}}"
return None

Expand All @@ -385,10 +386,10 @@ def visit_Call(self, node: ast.Call) -> str:
special_latex = self._generate_matrix_rank(node)
elif func_name == "matrix_power":
special_latex = self._generate_matrix_power(node)
elif func_name in ("QR", "SVD"):
special_latex = self._generate_qr_and_svd(node)
elif func_name in ("inv", "pinv"):
special_latex = self._generate_inverses(node)
elif func_name == "inv":
special_latex = self._generate_inv(node)
elif func_name == "pinv":
special_latex = self._generate_pinv(node)
else:
special_latex = None

Expand Down
94 changes: 29 additions & 65 deletions src/latexify/codegen/expression_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,16 +996,17 @@ def test_transpose(code: str, latex: str) -> None:
@pytest.mark.parametrize(
"code,latex",
[
("det(A)", r"\det \left( \mathbf{A} \right)"),
("det(b)", r"\det \left( \mathbf{b} \right)"),
("det(A)", r"\det \mathopen{}\left( \mathbf{A} \mathclose{}\right)"),
("det(b)", r"\det \mathopen{}\left( \mathbf{b} \mathclose{}\right)"),
(
"det([[1, 2], [3, 4]])",
r"\det \left( \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \right)",
r"\det \mathopen{}\left( \begin{bmatrix} 1 & 2 \\"
r" 3 & 4 \end{bmatrix} \mathclose{}\right)",
),
(
"det([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\det \left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\"
r" 7 & 8 & 9 \end{bmatrix} \right)",
r"\det \mathopen{}\left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\"
r" 7 & 8 & 9 \end{bmatrix} \mathclose{}\right)",
),
# Unsupported
("det()", r"\mathrm{det} \mathopen{}\left( \mathclose{}\right)"),
Expand All @@ -1026,17 +1027,23 @@ def test_determinant(code: str, latex: str) -> None:
@pytest.mark.parametrize(
"code,latex",
[
("matrix_rank(A)", r"\mathrm{rank} \left( \mathbf{A} \right)"),
("matrix_rank(b)", r"\mathrm{rank} \left( \mathbf{b} \right)"),
(
"matrix_rank(A)",
r"\mathrm{rank} \mathopen{}\left( \mathbf{A} \mathclose{}\right)",
),
(
"matrix_rank(b)",
r"\mathrm{rank} \mathopen{}\left( \mathbf{b} \mathclose{}\right)",
),
(
"matrix_rank([[1, 2], [3, 4]])",
r"\mathrm{rank} \left( \begin{bmatrix} 1 & 2 \\"
r" 3 & 4 \end{bmatrix} \right)",
r"\mathrm{rank} \mathopen{}\left( \begin{bmatrix} 1 & 2 \\"
r" 3 & 4 \end{bmatrix} \mathclose{}\right)",
),
(
"matrix_rank([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\mathrm{rank} \left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\"
r" 7 & 8 & 9 \end{bmatrix} \right)",
r"\mathrm{rank} \mathopen{}\left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\"
r" 7 & 8 & 9 \end{bmatrix} \mathclose{}\right)",
),
# Unsupported
(
Expand Down Expand Up @@ -1098,75 +1105,32 @@ def test_matrix_power(code: str, latex: str) -> None:
@pytest.mark.parametrize(
"code,latex",
[
# Test QR
("QR(A)", r"\mathrm{QR} \left( \mathbf{A} \right)"),
("QR(b)", r"\mathrm{QR} \left( \mathbf{b} \right)"),
(
"QR([[1, 2], [3, 4]])",
r"\mathrm{QR} \left( \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \right)",
),
(
"QR([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\mathrm{QR} \left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\"
r" 7 & 8 & 9 \end{bmatrix} \right)",
),
# Unsupported
("QR()", r"\mathrm{QR} \mathopen{}\left( \mathclose{}\right)"),
("QR(2)", r"\mathrm{QR} \mathopen{}\left( 2 \mathclose{}\right)"),
(
"QR(a, (1, 0))",
r"\mathrm{QR} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
# Test SVD
("SVD(A)", r"\mathrm{SVD} \left( \mathbf{A} \right)"),
("SVD(b)", r"\mathrm{SVD} \left( \mathbf{b} \right)"),
(
"SVD([[1, 2], [3, 4]])",
r"\mathrm{SVD} \left( \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \right)",
),
("inv(A)", r"\mathbf{A}^{-1}"),
("inv(b)", r"\mathbf{b}^{-1}"),
("inv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{-1}"),
(
"SVD([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\mathrm{SVD} \left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\"
r" 7 & 8 & 9 \end{bmatrix} \right)",
"inv([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{-1}",
),
# Unsupported
("SVD()", r"\mathrm{SVD} \mathopen{}\left( \mathclose{}\right)"),
("SVD(2)", r"\mathrm{SVD} \mathopen{}\left( 2 \mathclose{}\right)"),
("inv()", r"\mathrm{inv} \mathopen{}\left( \mathclose{}\right)"),
("inv(2)", r"\mathrm{inv} \mathopen{}\left( 2 \mathclose{}\right)"),
(
"SVD(a, (1, 0))",
r"\mathrm{SVD} \mathopen{}\left( a, "
"inv(a, (1, 0))",
r"\mathrm{inv} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_qr_and_svd(code: str, latex: str) -> None:
def test_inv(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


# tests for inv and pinv
@pytest.mark.parametrize(
"code,latex",
[
# Test inv
("inv(A)", r"\mathbf{A}^{-1}"),
("inv(b)", r"\mathbf{b}^{-1}"),
("inv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{-1}"),
(
"inv([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{-1}",
),
# Unsupported
("inv()", r"\mathrm{inv} \mathopen{}\left( \mathclose{}\right)"),
("inv(2)", r"\mathrm{inv} \mathopen{}\left( 2 \mathclose{}\right)"),
(
"inv(a, (1, 0))",
r"\mathrm{inv} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
# Test pinv
("pinv(A)", r"\mathbf{A}^{+}"),
("pinv(b)", r"\mathbf{b}^{+}"),
("pinv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{+}"),
Expand All @@ -1184,7 +1148,7 @@ def test_qr_and_svd(code: str, latex: str) -> None:
),
],
)
def test_inv_and_pinv(code: str, latex: str) -> None:
def test_pinv(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
Expand Down

0 comments on commit 9334fee

Please sign in to comment.