diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f7e991ff..6acfc880 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,8 @@ repos: exclude: | (?x)^( src/core_codemods/docs/.*| - integration_tests/.* + integration_tests/.*| + tests/test_codemodder.py )$ - id: check-added-large-files - repo: https://github.com/psf/black diff --git a/src/codemodder/codemodder.py b/src/codemodder/codemodder.py index 9c843831..774f32b2 100644 --- a/src/codemodder/codemodder.py +++ b/src/codemodder/codemodder.py @@ -48,6 +48,21 @@ def find_semgrep_results( return {rule_id for file_changes in results.values() for rule_id in file_changes} +def create_diff(original_tree: cst.Module, new_tree: cst.Module) -> str: + diff_lines = list( + difflib.unified_diff( + original_tree.code.splitlines(keepends=True), + new_tree.code.splitlines(keepends=True), + ) + ) + # All but the last diff line should end with a newline + # The last diff line should be preserved as-is (with or without a newline) + diff_lines = [ + line if line.endswith("\n") else line + "\n" for line in diff_lines[:-1] + ] + [diff_lines[-1]] + return "".join(diff_lines) + + def apply_codemod_to_file( base_directory: Path, file_context, @@ -68,12 +83,7 @@ def apply_codemod_to_file( if output_tree.deep_equals(source_tree): return False - diff = "".join( - difflib.unified_diff( - source_tree.code.splitlines(1), output_tree.code.splitlines(1) - ) - ) - + diff = create_diff(source_tree, output_tree) change_set = ChangeSet( str(file_context.file_path.relative_to(base_directory)), diff, diff --git a/tests/test_codemodder.py b/tests/test_codemodder.py index 775ea84c..02f5149f 100644 --- a/tests/test_codemodder.py +++ b/tests/test_codemodder.py @@ -1,6 +1,9 @@ import mock import pytest -from codemodder.codemodder import run, find_semgrep_results + +import libcst as cst + +from codemodder.codemodder import create_diff, run, find_semgrep_results from codemodder.semgrep import run as semgrep_run from codemodder.registry import load_registered_codemods @@ -189,3 +192,30 @@ def test_find_semgrep_results_no_yaml(self, mocker): result = find_semgrep_results(mocker.MagicMock(), codemods) assert result == set() assert run_semgrep.call_count == 0 + + def test_diff_newline_edge_case(self): + source = """ +SECRET_COOKIE_KEY = "PYGOAT" +CSRF_TRUSTED_ORIGINS = ["http://127.0.0.1:8000","http://0.0.0.0:8000","http://172.16.189.10"]""" # no newline here + + result = """ +SECRET_COOKIE_KEY = "PYGOAT" +CSRF_TRUSTED_ORIGINS = ["http://127.0.0.1:8000","http://0.0.0.0:8000","http://172.16.189.10"] +SESSION_COOKIE_SECURE = True""" + + source_tree = cst.parse_module(source) + result_tree = cst.parse_module(result) + + diff = create_diff(source_tree, result_tree) + assert ( + diff + == """\ +--- ++++ +@@ -1,3 +1,4 @@ + + SECRET_COOKIE_KEY = "PYGOAT" +-CSRF_TRUSTED_ORIGINS = ["http://127.0.0.1:8000","http://0.0.0.0:8000","http://172.16.189.10"] ++CSRF_TRUSTED_ORIGINS = ["http://127.0.0.1:8000","http://0.0.0.0:8000","http://172.16.189.10"] ++SESSION_COOKIE_SECURE = True""" + )