diff --git a/src/core_codemods/remove_debug_breakpoint.py b/src/core_codemods/remove_debug_breakpoint.py index a12e1ef1..8879912b 100644 --- a/src/core_codemods/remove_debug_breakpoint.py +++ b/src/core_codemods/remove_debug_breakpoint.py @@ -21,4 +21,9 @@ def leave_Expr( ) == "breakpoint" and self.is_builtin_function(call_node): self.report_change(original_node) return cst.RemovalSentinel.REMOVE + if self.find_base_name(call_node) == "pdb.set_trace": + self.remove_unused_import(call_node) + self.report_change(original_node) + return cst.RemovalSentinel.REMOVE + return original_node diff --git a/tests/codemods/test_remove_debug_breakpoint.py b/tests/codemods/test_remove_debug_breakpoint.py index 344a48c8..b3d82540 100644 --- a/tests/codemods/test_remove_debug_breakpoint.py +++ b/tests/codemods/test_remove_debug_breakpoint.py @@ -24,6 +24,22 @@ def something(): self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 + def test_builtin_breakpoint_multiple_statements(self, tmpdir): + input_code = """\ + def something(): + var = 1 + print(var); breakpoint() + something() + """ + expected = """\ + def something(): + var = 1 + print(var); + something() + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + def test_inline_pdb(self, tmpdir): input_code = """\ def something(): @@ -57,7 +73,7 @@ def something(): def test_pdb_from_import(self, tmpdir): input_code = """\ - from pdb import set_trace() + from pdb import set_trace def something(): var = 1 set_trace() @@ -70,5 +86,3 @@ def something(): """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) assert len(self.file_context.codemod_changes) == 1 - - # what about line line print(1); breakpoint