Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Algorithmic decorator (implement output for _repr_latex_) #163

Merged
merged 31 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions src/integration_tests/algorithmic_style_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,8 @@
from __future__ import annotations

import textwrap
from typing import Any, Callable

from latexify import generate_latex


def check_algorithm(
fn: Callable[..., Any],
latex: str,
**kwargs,
) -> None:
"""Helper to check if the obtained function has the expected LaTeX form.

Args:
fn: Function to check.
latex: LaTeX form of `fn`.
**kwargs: Arguments passed to `frontend.get_latex`.
"""
# Checks the syntax:
# def fn(...):
# ...
# latexified = get_latex(fn, style=ALGORITHM, **kwargs)
latexified = generate_latex.get_latex(
fn, style=generate_latex.Style.ALGORITHMIC, **kwargs
)
assert latexified == latex
from integration_tests import integration_utils


def test_factorial() -> None:
Expand All @@ -50,7 +27,20 @@ def fact(n):
\end{algorithmic}
""" # noqa: E501
).strip()
check_algorithm(fact, latex)
ipython_latex = (
r"\begin{array}{l}"
r" \mathbf{function} \ \mathrm{fact}(n) \\"
r" \hspace{1em} \mathbf{if} \ n = 0 \\"
r" \hspace{2em} \mathbf{return} \ 1 \\"
r" \hspace{1em} \mathbf{else} \\"
r" \hspace{2em}"
r" \mathbf{return} \ n \cdot"
r" \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right) \\"
r" \hspace{1em} \mathbf{end \ if} \\"
r" \mathbf{end \ function}"
r" \end{array}"
)
integration_utils.check_algorithm(fact, latex, ipython_latex)


def test_collatz() -> None:
Expand Down Expand Up @@ -82,4 +72,21 @@ def collatz(n):
\end{algorithmic}
"""
).strip()
check_algorithm(collatz, latex)
ipython_latex = (
r"\begin{array}{l}"
r" \mathbf{function} \ \mathrm{collatz}(n) \\"
r" \hspace{1em} \mathrm{iterations} \gets 0 \\"
r" \hspace{1em} \mathbf{while} \ n > 1 \\"
r" \hspace{2em} \mathbf{if} \ n \mathbin{\%} 2 = 0 \\"
r" \hspace{3em} n \gets \left\lfloor\frac{n}{2}\right\rfloor \\"
r" \hspace{2em} \mathbf{else} \\"
r" \hspace{3em} n \gets 3 \cdot n + 1 \\"
r" \hspace{2em} \mathbf{end \ if} \\"
r" \hspace{2em}"
r" \mathrm{iterations} \gets \mathrm{iterations} + 1 \\"
r" \hspace{1em} \mathbf{end \ while} \\"
r" \hspace{1em} \mathbf{return} \ \mathrm{iterations} \\"
r" \mathbf{end \ function}"
r" \end{array}"
)
integration_utils.check_algorithm(collatz, latex, ipython_latex)
40 changes: 40 additions & 0 deletions src/integration_tests/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,43 @@ def check_function(
latexified = frontend.function(fn, **kwargs)
assert str(latexified) == latex
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex} $$"


def check_algorithm(
fn: Callable[..., Any],
latex: str,
ipython_latex: str,
**kwargs,
) -> None:
"""Helper to check if the obtained function has the expected LaTeX form.

Args:
fn: Function to check.
latex: LaTeX form of `fn`.
ipython_latex: IPython LaTeX form of `fn`
**kwargs: Arguments passed to `frontend.get_latex`.
"""
# Checks the syntax:
# @algorithmic
# def fn(...):
# ...
if not kwargs:
latexified = frontend.algorithmic(fn)
assert str(latexified) == latex
assert latexified._repr_latex_() == f"$ {ipython_latex} $"

# Checks the syntax:
# @algorithmic(**kwargs)
# def fn(...):
# ...
latexified = frontend.algorithmic(**kwargs)(fn)
assert str(latexified) == latex
assert latexified._repr_latex_() == f"$ {ipython_latex} $"

# Checks the syntax:
# def fn(...):
# ...
# latexified = algorithmic(fn, **kwargs)
latexified = frontend.algorithmic(fn, **kwargs)
assert str(latexified) == latex
assert latexified._repr_latex_() == f"$ {ipython_latex} $"
1 change: 1 addition & 0 deletions src/latexify/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
AlgorithmicCodegen = algorithmic_codegen.AlgorithmicCodegen
ExpressionCodegen = expression_codegen.ExpressionCodegen
FunctionCodegen = function_codegen.FunctionCodegen
IPythonAlgorithmicCodegen = algorithmic_codegen.IPythonAlgorithmicCodegen
147 changes: 142 additions & 5 deletions src/latexify/codegen/algorithmic_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def visit_Expr(self, node: ast.Expr) -> str:
rf"\State ${self._expression_codegen.visit(node.value)}$"
)

# TODO(ZibingZhang): support nested functions
def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
"""Visit a FunctionDef node."""
# Arguments
Expand Down Expand Up @@ -89,14 +90,14 @@ def visit_If(self, node: ast.If) -> str:
with self._increment_level():
body_latex = "\n".join(self.visit(stmt) for stmt in node.body)

latex = self._add_indent(f"\\If{{${cond_latex}$}}\n{body_latex}")
latex = self._add_indent(f"\\If{{${cond_latex}$}}\n" + body_latex)

if node.orelse:
latex += "\n" + self._add_indent(r"\Else") + "\n"
latex += "\n" + self._add_indent("\\Else\n")
with self._increment_level():
latex += "\n".join(self.visit(stmt) for stmt in node.orelse)

return latex + "\n" + self._add_indent(r"\EndIf")
return f"{latex}\n" + self._add_indent(r"\EndIf")

def visit_Module(self, node: ast.Module) -> str:
"""Visit a Module node."""
Expand Down Expand Up @@ -136,9 +137,145 @@ def _increment_level(self) -> Generator[None, None, None]:
self._indent_level -= 1

def _add_indent(self, line: str) -> str:
"""Adds whitespace before the line.
"""Adds an indent before the line.

Args:
line: The line to add whitespace to.
line: The line to add an indent to.
"""
return self._indent_level * self._SPACES_PER_INDENT * " " + line


class IPythonAlgorithmicCodegen(ast.NodeVisitor):
"""Codegen for single algorithms targeting IPython.

This codegen works for Module with single FunctionDef node to generate a single
LaTeX expression of the given algorithm.
"""

_EM_PER_INDENT = 1
_LINE_BREAK = r" \\ "

_identifier_converter: identifier_converter.IdentifierConverter
_indent_level: int

def __init__(
self, *, use_math_symbols: bool = False, use_set_symbols: bool = False
) -> None:
"""Initializer.

Args:
use_math_symbols: Whether to convert identifiers with a math symbol surface
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_set_symbols: Whether to use set symbols or not.
"""
self._expression_codegen = expression_codegen.ExpressionCodegen(
use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols
)
self._identifier_converter = identifier_converter.IdentifierConverter(
use_math_symbols=use_math_symbols
)
self._indent_level = 0

def generic_visit(self, node: ast.AST) -> str:
raise exceptions.LatexifyNotSupportedError(
f"Unsupported AST: {type(node).__name__}"
)

def visit_Assign(self, node: ast.Assign) -> str:
"""Visit an Assign node."""
operands: list[str] = [
self._expression_codegen.visit(target) for target in node.targets
]
operands.append(self._expression_codegen.visit(node.value))
operands_latex = r" \gets ".join(operands)
return self._add_indent(operands_latex)

def visit_Expr(self, node: ast.Expr) -> str:
"""Visit an Expr node."""
return self._add_indent(self._expression_codegen.visit(node.value))

# TODO(ZibingZhang): support nested functions
def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
"""Visit a FunctionDef node."""
# Arguments
arg_strs = [
self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args
]
# Body
with self._increment_level():
body_strs: list[str] = [self.visit(stmt) for stmt in node.body]
body = self._LINE_BREAK.join(body_strs)

return (
r"\begin{array}{l} "
+ self._add_indent(r"\mathbf{function}")
+ rf" \ \mathrm{{{node.name}}}({', '.join(arg_strs)})"
+ f"{self._LINE_BREAK}{body}{self._LINE_BREAK}"
+ self._add_indent(r"\mathbf{end \ function}")
+ r" \end{array}"
)

# TODO(ZibingZhang): support \ELSIF
def visit_If(self, node: ast.If) -> str:
"""Visit an If node."""
cond_latex = self._expression_codegen.visit(node.test)
with self._increment_level():
body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body)
latex = self._add_indent(
rf"\mathbf{{if}} \ {cond_latex}{self._LINE_BREAK}{body_latex}"
)

if node.orelse:
latex += self._LINE_BREAK + self._add_indent(r"\mathbf{else} \\ ")
with self._increment_level():
latex += self._LINE_BREAK.join(self.visit(stmt) for stmt in node.orelse)

return latex + self._LINE_BREAK + self._add_indent(r"\mathbf{end \ if}")

def visit_Module(self, node: ast.Module) -> str:
"""Visit a Module node."""
return self.visit(node.body[0])

def visit_Return(self, node: ast.Return) -> str:
"""Visit a Return node."""
return (
self._add_indent(r"\mathbf{return} \ ")
+ self._expression_codegen.visit(node.value)
if node.value is not None
else self._add_indent(r"\mathbf{return}")
)

def visit_While(self, node: ast.While) -> str:
"""Visit a While node."""
if node.orelse:
raise exceptions.LatexifyNotSupportedError(
"While statement with the else clause is not supported"
)

cond_latex = self._expression_codegen.visit(node.test)
with self._increment_level():
body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body)
return (
self._add_indent(r"\mathbf{while} \ ")
+ f"{cond_latex}{self._LINE_BREAK}{body_latex}{self._LINE_BREAK}"
+ self._add_indent(r"\mathbf{end \ while}")
)

@contextlib.contextmanager
def _increment_level(self) -> Generator[None, None, None]:
"""Context manager controlling indent level."""
self._indent_level += 1
yield
self._indent_level -= 1

def _add_indent(self, line: str) -> str:
"""Adds an indent before the line.

Args:
line: The line to add an indent to.
"""
return (
rf"\hspace{{{self._indent_level * self._EM_PER_INDENT}em}} {line}"
if self._indent_level > 0
else line
)
Loading