From 0ab50335c7b56b5257a18389cad61b5d25fd2886 Mon Sep 17 00:00:00 2001 From: Daniel D'Avella Date: Wed, 20 Sep 2023 10:07:29 -0400 Subject: [PATCH] Preserve comments from removed lines --- src/codemodder/codemods/use_walrus_if.py | 15 +++++++++++++-- tests/codemods/test_walrus_if.py | 13 +++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/codemodder/codemods/use_walrus_if.py b/src/codemodder/codemods/use_walrus_if.py index 930006679..f625fe64f 100644 --- a/src/codemodder/codemods/use_walrus_if.py +++ b/src/codemodder/codemods/use_walrus_if.py @@ -109,11 +109,22 @@ def leave_Assign(self, original_node, updated_node): def leave_SimpleStatementLine(self, original_node, updated_node): """ - This is a workaround for the fact that libcst doesn't preserve the whitespace in the parent node when all children are removed. + Preserves the whitespace and comments in the line when all children are removed. This feels like a bug in libCST but we'll work around it for now. """ if not updated_node.body: - return cst.FlattenSentinel(original_node.leading_lines) + trailing_whitespace = ( + ( + original_node.trailing_whitespace.with_changes( + whitespace=cst.SimpleWhitespace(""), + ), + ) + if original_node.trailing_whitespace.comment + else () + ) + return cst.FlattenSentinel( + original_node.leading_lines + trailing_whitespace + ) return updated_node diff --git a/tests/codemods/test_walrus_if.py b/tests/codemods/test_walrus_if.py index 2bfe519c6..65367c59e 100644 --- a/tests/codemods/test_walrus_if.py +++ b/tests/codemods/test_walrus_if.py @@ -40,6 +40,19 @@ def test_walrus_if_name_only(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected_output) + def test_walrus_if_preserve_comments(self, tmpdir): + input_code = """ +val = do_something() # comment +if val is not None: # another comment + do_something_else(val) +""" + expected_output = """ +# comment +if (val := do_something()) is not None: # another comment + do_something_else(val) +""" + self.run_and_assert(tmpdir, input_code, expected_output) + def test_walrus_if_multiple(self, tmpdir): input_code = """ val = do_something()