Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Algorithmic decorator (implement output for _repr_latex_) #163

Merged
merged 31 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 29 additions & 26 deletions src/integration_tests/algorithmic_style_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,8 @@
from __future__ import annotations

import textwrap
from typing import Any, Callable

from latexify import generate_latex


def check_algorithm(
fn: Callable[..., Any],
latex: str,
**kwargs,
) -> None:
"""Helper to check if the obtained function has the expected LaTeX form.

Args:
fn: Function to check.
latex: LaTeX form of `fn`.
**kwargs: Arguments passed to `frontend.get_latex`.
"""
# Checks the syntax:
# def fn(...):
# ...
# latexified = get_latex(fn, style=ALGORITHM, **kwargs)
latexified = generate_latex.get_latex(
fn, style=generate_latex.Style.ALGORITHMIC, **kwargs
)
assert latexified == latex
from integration_tests import integration_utils


def test_factorial() -> None:
Expand All @@ -50,7 +27,18 @@ def fact(n):
\end{algorithmic}
""" # noqa: E501
).strip()
check_algorithm(fact, latex)
ipython_latex = (
r"\mathbf{function} \ \mathrm{FACT}(n) \\"
r" \hspace{1em} \mathbf{if} \ n = 0 \\"
r" \hspace{2em} \mathbf{return} \ 1 \\"
r" \hspace{1em} \mathbf{else} \\"
r" \hspace{2em}"
r" \mathbf{return} \ n \cdot"
r" \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right) \\"
r" \hspace{1em} \mathbf{end \ if} \\"
r" \mathbf{end \ function}"
)
integration_utils.check_algorithm(fact, latex, ipython_latex)


def test_collatz() -> None:
Expand Down Expand Up @@ -82,4 +70,19 @@ def collatz(n):
\end{algorithmic}
"""
).strip()
check_algorithm(collatz, latex)
ipython_latex = (
r"\mathbf{function} \ \mathrm{COLLATZ}(n) \\"
r" \hspace{1em} \mathrm{iterations} \gets 0 \\"
r" \hspace{1em} \mathbf{while} \ n > 1 \\"
r" \hspace{2em} \mathbf{if} \ n \mathbin{\%} 2 = 0 \\"
r" \hspace{3em} n \gets \left\lfloor\frac{n}{2}\right\rfloor \\"
r" \hspace{2em} \mathbf{else} \\"
r" \hspace{3em} n \gets 3 \cdot n + 1 \\"
r" \hspace{2em} \mathbf{end \ if} \\"
r" \hspace{2em}"
r" \mathrm{iterations} \gets \mathrm{iterations} + 1 \\"
r" \hspace{1em} \mathbf{end \ while} \\"
r" \hspace{1em} \mathbf{return} \ \mathrm{iterations} \\"
r" \mathbf{end \ function}"
)
integration_utils.check_algorithm(collatz, latex, ipython_latex)
40 changes: 40 additions & 0 deletions src/integration_tests/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,43 @@ def check_function(
latexified = frontend.function(fn, **kwargs)
assert str(latexified) == latex
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex} $$"


def check_algorithm(
fn: Callable[..., Any],
latex: str,
ipython_latex: str,
**kwargs,
) -> None:
"""Helper to check if the obtained function has the expected LaTeX form.

Args:
fn: Function to check.
latex: LaTeX form of `fn`.
ipython_latex: IPython LaTeX form of `fn`
**kwargs: Arguments passed to `frontend.get_latex`.
"""
# Checks the syntax:
# @algorithmic
# def fn(...):
# ...
if not kwargs:
latexified = frontend.algorithmic(fn)
assert str(latexified) == latex
assert latexified._repr_latex_() == f"$ {ipython_latex} $"

# Checks the syntax:
# @algorithmic(**kwargs)
# def fn(...):
# ...
latexified = frontend.algorithmic(**kwargs)(fn)
assert str(latexified) == latex
assert latexified._repr_latex_() == f"$ {ipython_latex} $"

# Checks the syntax:
# def fn(...):
# ...
# latexified = algorithmic(fn, **kwargs)
latexified = frontend.algorithmic(fn, **kwargs)
assert str(latexified) == latex
assert latexified._repr_latex_() == f"$ {ipython_latex} $"
1 change: 1 addition & 0 deletions src/latexify/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
AlgorithmicCodegen = algorithmic_codegen.AlgorithmicCodegen
ExpressionCodegen = expression_codegen.ExpressionCodegen
FunctionCodegen = function_codegen.FunctionCodegen
IPythonAlgorithmicCodegen = algorithmic_codegen.IPythonAlgorithmicCodegen
125 changes: 125 additions & 0 deletions src/latexify/codegen/algorithmic_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,128 @@ def _add_indent(self, line: str) -> str:
line: The line to add whitespace to.
"""
return self._indent_level * self._SPACES_PER_INDENT * " " + line


class IPythonAlgorithmicCodegen(ast.NodeVisitor):
"""Codegen for single algorithms targeting IPython.

This codegen works for Module with single FunctionDef node to generate a single
LaTeX expression of the given algorithm.
"""

_EM_PER_INDENT = 1

_identifier_converter: identifier_converter.IdentifierConverter
_indent_level: int

def __init__(
self, *, use_math_symbols: bool = False, use_set_symbols: bool = False
) -> None:
"""Initializer.

Args:
use_math_symbols: Whether to convert identifiers with a math symbol surface
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_set_symbols: Whether to use set symbols or not.
"""
self._expression_codegen = expression_codegen.ExpressionCodegen(
use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols
)
self._identifier_converter = identifier_converter.IdentifierConverter(
use_math_symbols=use_math_symbols
)
self._indent_level = 0

def generic_visit(self, node: ast.AST) -> str:
raise exceptions.LatexifyNotSupportedError(
f"Unsupported AST: {type(node).__name__}"
)

def visit_Assign(self, node: ast.Assign) -> str:
"""Visit an Assign node."""
operands: list[str] = [
self._expression_codegen.visit(target) for target in node.targets
]
operands.append(self._expression_codegen.visit(node.value))
operands_latex = r" \gets ".join(operands)
return self._add_prefix() + operands_latex

def visit_Expr(self, node: ast.Expr) -> str:
"""Visit an Expr node."""
return self._add_prefix() + self._expression_codegen.visit(node.value)

def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
"""Visit a FunctionDef node."""
# Arguments
arg_strs = [
self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args
]
# Body
with self._increment_level():
body_strs: list[str] = [self.visit(stmt) for stmt in node.body]
body = r" \\ ".join(body_strs)

return (
rf"{self._add_prefix()}\mathbf{{function}}"
rf" \ \mathrm{{{node.name.upper()}}}({', '.join(arg_strs)}) \\"
rf" {body} \\"
rf" {self._add_prefix()}\mathbf{{end \ function}}"
)

# TODO(ZibingZhang): support \ELSIF
def visit_If(self, node: ast.If) -> str:
"""Visit an If node."""
cond_latex = self._expression_codegen.visit(node.test)
with self._increment_level():
body_latex = r" \\ ".join(self.visit(stmt) for stmt in node.body)
latex = rf"{self._add_prefix()}\mathbf{{if}} \ {cond_latex} \\ {body_latex}"

if node.orelse:
latex += rf" \\ {self._add_prefix()}\mathbf{{else}} \\ "
with self._increment_level():
latex += r" \\ ".join(self.visit(stmt) for stmt in node.orelse)

return latex + rf" \\ {self._add_prefix()}\mathbf{{end \ if}}"

def visit_Module(self, node: ast.Module) -> str:
"""Visit a Module node."""
return self.visit(node.body[0])

def visit_Return(self, node: ast.Return) -> str:
"""Visit a Return node."""
return (
rf"{self._add_prefix()}\mathbf{{return}}"
rf" \ {self._expression_codegen.visit(node.value)}"
if node.value is not None
else rf"{self._add_prefix()}\mathbf{{return}}"
)

def visit_While(self, node: ast.While) -> str:
"""Visit a While node."""
if node.orelse:
raise exceptions.LatexifyNotSupportedError(
"While statement with the else clause is not supported"
)

cond_latex = self._expression_codegen.visit(node.test)
with self._increment_level():
body_latex = r" \\ ".join(self.visit(stmt) for stmt in node.body)
return (
rf"{self._add_prefix()}\mathbf{{while}} \ {cond_latex} \\ "
rf"{body_latex} \\ "
rf"{self._add_prefix()}\mathbf{{end \ while}}"
)

@contextlib.contextmanager
def _increment_level(self) -> Generator[None, None, None]:
"""Context manager controlling indent level."""
self._indent_level += 1
yield
self._indent_level -= 1

def _add_prefix(self) -> str:
return (
rf"\hspace{{{self._indent_level * self._EM_PER_INDENT}em}} "
if self._indent_level > 0
else ""
)
128 changes: 128 additions & 0 deletions src/latexify/codegen/algorithmic_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,131 @@ def test_visit_while_with_else() -> None:
match="^While statement with the else clause is not supported$",
):
algorithmic_codegen.AlgorithmicCodegen().visit(node)


@pytest.mark.parametrize(
"code,latex",
[
("x = 3", r"x \gets 3"),
("a = b = 0", r"a \gets b \gets 0"),
],
)
def test_visit_assign_jupyter(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.Assign)
assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex


@pytest.mark.parametrize(
"code,latex",
[
(
"def f(x): return x",
(
r"\mathbf{function}"
r" \ \mathrm{F}(x) \\"
r" \hspace{1em} \mathbf{return} \ x \\"
r" \mathbf{end \ function}"
),
),
(
"def f(a, b, c): return 3",
(
r"\mathbf{function}"
r" \ \mathrm{F}(a, b, c) \\"
r" \hspace{1em} \mathbf{return} \ 3 \\"
r" \mathbf{end \ function}"
),
),
],
)
def test_visit_functiondef_ipython(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.FunctionDef)
assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex


@pytest.mark.parametrize(
"code,latex",
[
(
"if x < y: return x",
(
r"\mathbf{if} \ x < y \\"
r" \hspace{1em} \mathbf{return} \ x \\"
r" \mathbf{end \ if}"
),
),
(
"if True: x\nelse: y",
(
r"\mathbf{if} \ \mathrm{True} \\"
r" \hspace{1em} x \\"
r" \mathbf{else} \\"
r" \hspace{1em} y \\"
r" \mathbf{end \ if}"
),
),
],
)
def test_visit_if_ipython(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.If)
assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex


@pytest.mark.parametrize(
"code,latex",
[
(
"return x + y",
r"\mathbf{return} \ x + y",
),
(
"return",
r"\mathbf{return}",
),
],
)
def test_visit_return_ipython(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.Return)
assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex


@pytest.mark.parametrize(
"code,latex",
[
(
"while x < y: x = x + 1",
(
r"\mathbf{while} \ x < y \\"
r" \hspace{1em} x \gets x + 1 \\"
r" \mathbf{end \ while}"
),
)
],
)
def test_visit_while_ipython(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.While)
assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex


def test_visit_while_with_else_ipython() -> None:
node = ast.parse(
textwrap.dedent(
"""
while True:
x = x
else:
x = y
"""
)
).body[0]
assert isinstance(node, ast.While)
with pytest.raises(
exceptions.LatexifyNotSupportedError,
match="^While statement with the else clause is not supported$",
):
algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node)
Loading