diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py index e53fbea..dede73c 100644 --- a/src/latexify/codegen/algorithmic_codegen.py +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -61,6 +61,24 @@ def visit_Expr(self, node: ast.Expr) -> str: rf"\State ${self._expression_codegen.visit(node.value)}$" ) + def visit_For(self, node: ast.For) -> str: + """Visit a For node.""" + if len(node.orelse) != 0: + raise exceptions.LatexifyNotSupportedError( + "For statement with the else clause is not supported" + ) + + target_latex = self._expression_codegen.visit(node.target) + iter_latex = self._expression_codegen.visit(node.iter) + with self._increment_level(): + body_latex = "\n".join(self.visit(stmt) for stmt in node.body) + + return ( + self._add_indent(f"\\For{{${target_latex} \\in {iter_latex}$}}\n") + + f"{body_latex}\n" + + self._add_indent("\\EndFor") + ) + # TODO(ZibingZhang): support nested functions def visit_FunctionDef(self, node: ast.FunctionDef) -> str: """Visit a FunctionDef node.""" @@ -197,25 +215,44 @@ def visit_Expr(self, node: ast.Expr) -> str: """Visit an Expr node.""" return self._add_indent(self._expression_codegen.visit(node.value)) + def visit_For(self, node: ast.For) -> str: + """Visit a For node.""" + if len(node.orelse) != 0: + raise exceptions.LatexifyNotSupportedError( + "For statement with the else clause is not supported" + ) + + target_latex = self._expression_codegen.visit(node.target) + iter_latex = self._expression_codegen.visit(node.iter) + with self._increment_level(): + body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body) + + return ( + self._add_indent(r"\mathbf{for}") + + rf" \ {target_latex} \in {iter_latex} \ \mathbf{{do}}{self._LINE_BREAK}" + + f"{body_latex}{self._LINE_BREAK}" + + self._add_indent(r"\mathbf{end \ for}") + ) + # TODO(ZibingZhang): support nested functions def visit_FunctionDef(self, node: ast.FunctionDef) -> str: """Visit a FunctionDef node.""" name_latex = self._identifier_converter.convert(node.name)[0] # Arguments - arg_strs = [ + args_latex = [ self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args ] # Body with self._increment_level(): - body_strs: list[str] = [self.visit(stmt) for stmt in node.body] - body = self._LINE_BREAK.join(body_strs) + body_stmts_latex: list[str] = [self.visit(stmt) for stmt in node.body] + body_latex = self._LINE_BREAK.join(body_stmts_latex) return ( r"\begin{array}{l} " + self._add_indent(r"\mathbf{function}") - + rf" \ {name_latex}({', '.join(arg_strs)})" - + f"{self._LINE_BREAK}{body}{self._LINE_BREAK}" + + rf" \ {name_latex}({', '.join(args_latex)})" + + f"{self._LINE_BREAK}{body_latex}{self._LINE_BREAK}" + self._add_indent(r"\mathbf{end \ function}") + r" \end{array}" ) diff --git a/src/latexify/codegen/algorithmic_codegen_test.py b/src/latexify/codegen/algorithmic_codegen_test.py index e7b3236..80b86d3 100644 --- a/src/latexify/codegen/algorithmic_codegen_test.py +++ b/src/latexify/codegen/algorithmic_codegen_test.py @@ -41,6 +41,28 @@ def test_visit_assign(code: str, latex: str) -> None: assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex +@pytest.mark.parametrize( + "code,latex", + [ + ( + "for i in {1}: x = i", + r""" + \For{$i \in \mathopen{}\left\{ 1 \mathclose{}\right\}$} + \State $x \gets i$ + \EndFor + """, + ), + ], +) +def test_visit_for(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.For) + assert ( + algorithmic_codegen.AlgorithmicCodegen().visit(node) + == textwrap.dedent(latex).strip() + ) + + @pytest.mark.parametrize( "code,latex", [ @@ -180,6 +202,29 @@ def test_visit_assign_ipython(code: str, latex: str) -> None: assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex +@pytest.mark.parametrize( + "code,latex", + [ + ( + "for i in {1}: x = i", + ( + r"\mathbf{for} \ i \in \mathopen{}\left\{ 1 \mathclose{}\right\}" + r" \ \mathbf{do} \\" + r" \hspace{1em} x \gets i \\" + r" \mathbf{end \ for}" + ), + ), + ], +) +def test_visit_for_ipython(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.For) + assert ( + algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) + == textwrap.dedent(latex).strip() + ) + + @pytest.mark.parametrize( "code,latex", [