From f6619d1e547e28520673f3d02086e3c84fc3d702 Mon Sep 17 00:00:00 2001 From: Daniel D'Avella Date: Wed, 3 Jan 2024 09:35:09 -0500 Subject: [PATCH] Implement new codemod API --- .coveragerc | 1 + integration_tests/base_test.py | 20 +- .../test_add_requests_timeout.py | 7 +- .../test_combine_startswith_endswith.py | 2 +- .../test_django_debug_flag_on.py | 2 +- .../test_django_json_response_type.py | 2 +- .../test_django_receiver_on_top.py | 2 +- .../test_django_session_cookie_secure_off.py | 2 +- .../test_exception_without_raise.py | 2 +- integration_tests/test_file_resource_leak.py | 2 +- .../test_fix_deprecated_logging_warn.py | 2 +- integration_tests/test_fix_mutable_params.py | 2 +- .../test_flask_json_response_type.py | 2 +- integration_tests/test_harden_pyyaml.py | 2 +- integration_tests/test_harden_ruamel.py | 2 +- integration_tests/test_https_connection.py | 2 +- integration_tests/test_jinja2_autoescape.py | 2 +- integration_tests/test_jwt_decode_verify.py | 2 +- integration_tests/test_limit_readline.py | 2 +- .../test_literal_or_new_object_identity.py | 2 +- .../test_lxml_safe_parser_defaults.py | 2 +- integration_tests/test_lxml_safe_parsing.py | 2 +- integration_tests/test_numpy_nan_equality.py | 2 +- integration_tests/test_order_imports.py | 2 +- integration_tests/test_process_sandbox.py | 2 +- .../test_remove_debug_breakpoint.py | 2 +- .../test_remove_future_imports.py | 2 +- .../test_remove_module_global.py | 2 +- .../test_remove_unused_imports.py | 2 +- integration_tests/test_request_verify.py | 2 +- integration_tests/test_secure_flask_cookie.py | 2 +- .../test_secure_flask_session_config.py | 2 +- integration_tests/test_secure_random.py | 2 +- .../test_sql_parameterization.py | 2 +- .../test_subprocess_shell_false.py | 2 +- integration_tests/test_tempfile_mktemp.py | 2 +- integration_tests/test_unnecessary_f_str.py | 2 +- ...test_upgrade_sslcontext_minimum_version.py | 2 +- .../test_upgrade_sslcontext_tls.py | 2 +- integration_tests/test_url_sandbox.py | 2 +- integration_tests/test_use_defusedxml.py | 2 +- integration_tests/test_use_generator.py | 2 +- integration_tests/test_use_set_literal.py | 2 +- integration_tests/test_use_walrus_if.py | 2 +- integration_tests/test_with_threading_lock.py | 2 +- pylintrc | 3 +- src/codemodder/codemodder.py | 176 ++-------- src/codemodder/codemods/api.py | 52 +++ src/codemodder/codemods/api/__init__.py | 142 -------- src/codemodder/codemods/api/helpers.py | 130 -------- src/codemodder/codemods/base_codemod.py | 247 ++++++++------ src/codemodder/codemods/base_detector.py | 16 + src/codemodder/codemods/base_transformer.py | 43 +++ src/codemodder/codemods/base_visitor.py | 1 + src/codemodder/codemods/libcst_transformer.py | 305 ++++++++++++++++++ src/codemodder/codemods/semgrep.py | 49 +++ src/codemodder/context.py | 23 +- .../dependency_management/setup_py_writer.py | 19 +- src/codemodder/executor.py | 109 ------- src/codemodder/registry.py | 48 +-- src/core_codemods/add_requests_timeouts.py | 89 ++--- src/core_codemods/api/__init__.py | 6 + src/core_codemods/api/core_codemod.py | 23 ++ .../combine_startswith_endswith.py | 15 +- src/core_codemods/django_debug_flag_on.py | 43 +-- .../django_json_response_type.py | 44 +-- src/core_codemods/django_receiver_on_top.py | 30 +- .../django_session_cookie_secure_off.py | 51 +-- src/core_codemods/enable_jinja2_autoescape.py | 42 +-- src/core_codemods/exception_without_raise.py | 32 +- src/core_codemods/file_resource_leak.py | 44 +-- .../fix_deprecated_abstractproperty.py | 33 +- .../fix_deprecated_logging_warn.py | 35 +- src/core_codemods/fix_mutable_params.py | 19 +- src/core_codemods/flask_json_response_type.py | 43 +-- src/core_codemods/harden_pyyaml.py | 121 +++---- src/core_codemods/harden_ruamel.py | 40 +-- src/core_codemods/https_connection.py | 39 +-- src/core_codemods/jwt_decode_verify.py | 40 +-- src/core_codemods/limit_readline.py | 31 +- .../literal_or_new_object_identity.py | 32 +- .../lxml_safe_parser_defaults.py | 52 +-- src/core_codemods/lxml_safe_parsing.py | 52 +-- src/core_codemods/numpy_nan_equality.py | 32 +- src/core_codemods/order_imports.py | 31 +- src/core_codemods/process_creation_sandbox.py | 43 +-- .../{semgrep => refactor}/__init__.py | 0 .../refactor/refactor_new_api.py | 297 +++++++++++++++++ src/core_codemods/remove_debug_breakpoint.py | 15 +- src/core_codemods/remove_future_imports.py | 33 +- src/core_codemods/remove_module_global.py | 15 +- src/core_codemods/remove_unnecessary_f_str.py | 44 +-- src/core_codemods/remove_unused_imports.py | 36 +-- src/core_codemods/requests_verify.py | 40 +-- src/core_codemods/secure_flask_cookie.py | 45 +-- .../secure_flask_session_config.py | 44 +-- src/core_codemods/secure_random.py | 55 ++-- .../semgrep/sandbox_url_creation.yaml | 13 - src/core_codemods/sql_parameterization.py | 45 ++- src/core_codemods/subprocess_shell_false.py | 44 +-- src/core_codemods/tempfile_mktemp.py | 36 ++- .../upgrade_sslcontext_minimum_version.py | 43 ++- src/core_codemods/upgrade_sslcontext_tls.py | 47 ++- src/core_codemods/url_sandbox.py | 81 ++--- src/core_codemods/use_defused_xml.py | 54 ++-- src/core_codemods/use_generator.py | 47 +-- src/core_codemods/use_set_literal.py | 16 +- src/core_codemods/use_walrus_if.py | 40 +-- src/core_codemods/with_threading_lock.py | 47 +-- tests/codemods/base_codemod_test.py | 123 +++---- tests/codemods/conftest.py | 15 + tests/codemods/test_add_requests_timeouts.py | 7 +- tests/codemods/test_base_codemod.py | 26 +- .../test_combine_startswith_endswith.py | 4 +- tests/codemods/test_django_debug_flag_on.py | 5 +- .../test_django_json_response_type.py | 7 +- tests/codemods/test_django_receiver_on_top.py | 7 +- .../test_django_session_cookie_secure_off.py | 7 +- .../codemods/test_enable_jinja2_autoescape.py | 2 +- .../codemods/test_exception_without_raise.py | 7 +- tests/codemods/test_file_resource_leak.py | 17 +- .../test_fix_deprecated_logging_warn.py | 5 - .../codemods/test_flask_json_response_type.py | 17 +- tests/codemods/test_harden_pyyaml.py | 14 +- tests/codemods/test_harden_ruamel.py | 2 +- tests/codemods/test_https_connection.py | 7 - tests/codemods/test_jwt_decode_verify.py | 2 +- tests/codemods/test_limit_readline.py | 2 +- .../test_literal_or_new_object_identity.py | 21 +- .../test_lxml_safe_parameter_defaults.py | 2 +- tests/codemods/test_lxml_safe_parsing.py | 2 +- tests/codemods/test_numpy_nan_equality.py | 13 +- tests/codemods/test_order_imports.py | 12 - .../codemods/test_process_creation_sandbox.py | 41 +-- .../codemods/test_remove_debug_breakpoint.py | 7 +- tests/codemods/test_remove_module_global.py | 6 +- .../codemods/test_remove_unnecessary_f_str.py | 2 - tests/codemods/test_remove_unused_imports.py | 17 +- tests/codemods/test_request_verify.py | 2 +- tests/codemods/test_secure_flask_cookie.py | 2 +- .../test_secure_flask_session_config.py | 13 +- tests/codemods/test_secure_random.py | 207 ++++++------ tests/codemods/test_sql_parameterization.py | 23 +- tests/codemods/test_subprocess_shell_false.py | 6 +- tests/codemods/test_tempfile_mktemp.py | 2 +- tests/codemods/test_url_sandbox.py | 50 +-- tests/codemods/test_use_defused_xml.py | 34 +- tests/codemods/test_use_set_literal.py | 5 - tests/codemods/test_with_threading_lock.py | 2 +- tests/conftest.py | 42 +-- .../test_setup_py_writer.py | 2 +- tests/test_codemod_docs.py | 10 +- tests/test_codemodder.py | 9 +- tests/test_context.py | 4 +- 154 files changed, 2386 insertions(+), 2100 deletions(-) create mode 100644 src/codemodder/codemods/api.py delete mode 100644 src/codemodder/codemods/api/__init__.py delete mode 100644 src/codemodder/codemods/api/helpers.py create mode 100644 src/codemodder/codemods/base_detector.py create mode 100644 src/codemodder/codemods/base_transformer.py create mode 100644 src/codemodder/codemods/libcst_transformer.py create mode 100644 src/codemodder/codemods/semgrep.py delete mode 100644 src/codemodder/executor.py create mode 100644 src/core_codemods/api/__init__.py create mode 100644 src/core_codemods/api/core_codemod.py rename src/core_codemods/{semgrep => refactor}/__init__.py (100%) create mode 100644 src/core_codemods/refactor/refactor_new_api.py delete mode 100644 src/core_codemods/semgrep/sandbox_url_creation.yaml create mode 100644 tests/codemods/conftest.py diff --git a/.coveragerc b/.coveragerc index d3e029f22..61b9c6245 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,6 +3,7 @@ source = codemodder omit = */codemodder/scripts/* */codemodder/_version.py + */core_codemods/refactor/* [paths] codemodder = diff --git a/integration_tests/base_test.py b/integration_tests/base_test.py index dd715c9cd..485c23960 100644 --- a/integration_tests/base_test.py +++ b/integration_tests/base_test.py @@ -62,8 +62,14 @@ def setup_class(cls): def setup_method(self): try: - self.codemod_wrapper = self.codemod_registry.match_codemods( - codemod_include=[self.codemod.name()] + name = ( + self.codemod().name + if isinstance(self.codemod, type) + else self.codemod.name + ) + # This is how we ensure that the codemod is actually in the registry + self.codemod_instance = self.codemod_registry.match_codemods( + codemod_include=[name] )[0] except IndexError as exc: raise IndexError( @@ -77,7 +83,7 @@ def _assert_run_fields(self, run, output_path): assert run["elapsed"] != "" assert ( run["commandLine"] - == f"codemodder {SAMPLES_DIR} --output {output_path} --codemod-include={self.codemod_wrapper.name} --path-include={self.code_path}" + == f"codemodder {SAMPLES_DIR} --output {output_path} --codemod-include={self.codemod_instance.name} --path-include={self.code_path}" ) assert run["directory"] == os.path.abspath(SAMPLES_DIR) assert run["sarifs"] == [] @@ -85,8 +91,10 @@ def _assert_run_fields(self, run, output_path): def _assert_results_fields(self, results, output_path): assert len(results) == 1 result = results[0] - assert result["codemod"] == self.codemod_wrapper.id - assert result["references"] == self.codemod_wrapper.references + assert result["codemod"] == self.codemod_instance.id + assert result["references"] == [ + ref.to_json() for ref in self.codemod_instance.references + ] # TODO: once we add description for each url. for reference in result["references"]: @@ -147,7 +155,7 @@ def test_file_rewritten(self): SAMPLES_DIR, "--output", self.output_path, - f"--codemod-include={self.codemod_wrapper.name}", + f"--codemod-include={self.codemod_instance.name}", f"--path-include={self.code_path}", ] diff --git a/integration_tests/test_add_requests_timeout.py b/integration_tests/test_add_requests_timeout.py index df7407c40..556dbe4d1 100644 --- a/integration_tests/test_add_requests_timeout.py +++ b/integration_tests/test_add_requests_timeout.py @@ -1,4 +1,7 @@ -from core_codemods.add_requests_timeouts import AddRequestsTimeouts +from core_codemods.add_requests_timeouts import ( + AddRequestsTimeouts, + TransformAddRequestsTimeouts, +) from integration_tests.base_test import ( BaseIntegrationTest, original_and_expected_from_code_path, @@ -33,4 +36,4 @@ class TestAddRequestsTimeouts(BaseIntegrationTest): num_changes = 2 expected_line_change = "3" - change_description = AddRequestsTimeouts.CHANGE_DESCRIPTION + change_description = TransformAddRequestsTimeouts.change_description diff --git a/integration_tests/test_combine_startswith_endswith.py b/integration_tests/test_combine_startswith_endswith.py index 9ad639f61..6426e1faa 100644 --- a/integration_tests/test_combine_startswith_endswith.py +++ b/integration_tests/test_combine_startswith_endswith.py @@ -13,4 +13,4 @@ class TestCombineStartswithEndswith(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,3 +1,3 @@\n x = \'foo\'\n-if x.startswith("foo") or x.startswith("bar"):\n+if x.startswith(("foo", "bar")):\n print("Yes")\n' expected_line_change = "2" - change_description = CombineStartswithEndswith.CHANGE_DESCRIPTION + change_description = CombineStartswithEndswith.change_description diff --git a/integration_tests/test_django_debug_flag_on.py b/integration_tests/test_django_debug_flag_on.py index bb4a2c1cb..f3abf56cb 100644 --- a/integration_tests/test_django_debug_flag_on.py +++ b/integration_tests/test_django_debug_flag_on.py @@ -14,4 +14,4 @@ class TestDjangoDebugFlagFlip(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -23,7 +23,7 @@\n SECRET_KEY = "django-insecure-t*rrda&qd4^#q+50^%q^rrsp-t$##&u5_#=9)&@ei^ppl6$*c*"\n \n # SECURITY WARNING: don\'t run with debug turned on in production!\n-DEBUG = True\n+DEBUG = False\n \n ALLOWED_HOSTS = []\n \n' expected_line_change = "26" - change_description = DjangoDebugFlagOn.CHANGE_DESCRIPTION + change_description = DjangoDebugFlagOn.change_description diff --git a/integration_tests/test_django_json_response_type.py b/integration_tests/test_django_json_response_type.py index 6b2c28df7..59f751d7b 100644 --- a/integration_tests/test_django_json_response_type.py +++ b/integration_tests/test_django_json_response_type.py @@ -32,5 +32,5 @@ class TestDjangoJsonResponseType(BaseIntegrationTest): # fmt: on expected_line_change = "6" - change_description = DjangoJsonResponseType.CHANGE_DESCRIPTION + change_description = DjangoJsonResponseType.change_description num_changed_files = 1 diff --git a/integration_tests/test_django_receiver_on_top.py b/integration_tests/test_django_receiver_on_top.py index 4fcf56520..833f3c8d9 100644 --- a/integration_tests/test_django_receiver_on_top.py +++ b/integration_tests/test_django_receiver_on_top.py @@ -33,5 +33,5 @@ class TestDjangoReceiverOnTop(BaseIntegrationTest): # fmt: on expected_line_change = "7" - change_description = DjangoReceiverOnTop.CHANGE_DESCRIPTION + change_description = DjangoReceiverOnTop.change_description num_changed_files = 1 diff --git a/integration_tests/test_django_session_cookie_secure_off.py b/integration_tests/test_django_session_cookie_secure_off.py index ca193390f..bf726e487 100644 --- a/integration_tests/test_django_session_cookie_secure_off.py +++ b/integration_tests/test_django_session_cookie_secure_off.py @@ -16,4 +16,4 @@ class TestDjangoSessionCookieSecureOff(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -121,3 +121,4 @@\n # https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field\n \n DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"\n+SESSION_COOKIE_SECURE = True\n' expected_line_change = "124" - change_description = DjangoSessionCookieSecureOff.CHANGE_DESCRIPTION + change_description = DjangoSessionCookieSecureOff.change_description diff --git a/integration_tests/test_exception_without_raise.py b/integration_tests/test_exception_without_raise.py index 565bdf7b2..0b5f9b0bd 100644 --- a/integration_tests/test_exception_without_raise.py +++ b/integration_tests/test_exception_without_raise.py @@ -29,5 +29,5 @@ class TestExceptionWithoutRaise(BaseIntegrationTest): # fmt: on expected_line_change = "2" - change_description = ExceptionWithoutRaise.CHANGE_DESCRIPTION + change_description = ExceptionWithoutRaise.change_description num_changed_files = 1 diff --git a/integration_tests/test_file_resource_leak.py b/integration_tests/test_file_resource_leak.py index bbf877857..6c518ba39 100644 --- a/integration_tests/test_file_resource_leak.py +++ b/integration_tests/test_file_resource_leak.py @@ -33,5 +33,5 @@ class TestFileResourceLeak(BaseIntegrationTest): # fmt: on expected_line_change = "3" - change_description = FileResourceLeak.CHANGE_DESCRIPTION + change_description = FileResourceLeak.change_description num_changed_files = 1 diff --git a/integration_tests/test_fix_deprecated_logging_warn.py b/integration_tests/test_fix_deprecated_logging_warn.py index caa08ff75..89dd65fbb 100644 --- a/integration_tests/test_fix_deprecated_logging_warn.py +++ b/integration_tests/test_fix_deprecated_logging_warn.py @@ -13,4 +13,4 @@ class TestFixDeprecatedLoggingWarn(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n import logging\n \n log = logging.getLogger("my logger")\n-log.warn("hello")\n+log.warning("hello")\n' expected_line_change = "4" - change_description = FixDeprecatedLoggingWarn.CHANGE_DESCRIPTION + change_description = FixDeprecatedLoggingWarn.change_description diff --git a/integration_tests/test_fix_mutable_params.py b/integration_tests/test_fix_mutable_params.py index 110e9ac67..b84470b01 100644 --- a/integration_tests/test_fix_mutable_params.py +++ b/integration_tests/test_fix_mutable_params.py @@ -30,4 +30,4 @@ def baz(x=None, y=None): expected_diff = '--- \n+++ \n@@ -1,4 +1,5 @@\n-def foo(x, y=[]):\n+def foo(x, y=None):\n+ y = [] if y is None else y\n y.append(x)\n print(y)\n \n@@ -7,6 +8,8 @@\n print(x)\n \n \n-def baz(x={"foo": 42}, y=set()):\n+def baz(x=None, y=None):\n+ x = {"foo": 42} if x is None else x\n+ y = set() if y is None else y\n print(x)\n print(y)\n' expected_line_change = 1 num_changes = 2 - change_description = FixMutableParams.CHANGE_DESCRIPTION + change_description = FixMutableParams.change_description diff --git a/integration_tests/test_flask_json_response_type.py b/integration_tests/test_flask_json_response_type.py index d7511a2ec..f283e7423 100644 --- a/integration_tests/test_flask_json_response_type.py +++ b/integration_tests/test_flask_json_response_type.py @@ -32,5 +32,5 @@ class TestFlaskJsonResponseType(BaseIntegrationTest): # fmt: on expected_line_change = "9" - change_description = FlaskJsonResponseType.CHANGE_DESCRIPTION + change_description = FlaskJsonResponseType.change_description num_changed_files = 1 diff --git a/integration_tests/test_harden_pyyaml.py b/integration_tests/test_harden_pyyaml.py index 1a7caca47..407ffa2f0 100644 --- a/integration_tests/test_harden_pyyaml.py +++ b/integration_tests/test_harden_pyyaml.py @@ -15,6 +15,6 @@ class TestHardenPyyaml(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n import yaml\n \n data = b"!!python/object/apply:subprocess.Popen \\\\n- ls"\n-deserialized_data = yaml.load(data, Loader=yaml.Loader)\n+deserialized_data = yaml.load(data, Loader=yaml.SafeLoader)\n' expected_line_change = "4" - change_description = HardenPyyaml.CHANGE_DESCRIPTION + change_description = HardenPyyaml.change_description # expected exception because the yaml.SafeLoader protects against unsafe code allowed_exceptions = (yaml.constructor.ConstructorError,) diff --git a/integration_tests/test_harden_ruamel.py b/integration_tests/test_harden_ruamel.py index 4a8b06bca..dd5d62447 100644 --- a/integration_tests/test_harden_ruamel.py +++ b/integration_tests/test_harden_ruamel.py @@ -18,4 +18,4 @@ class TestHardenRuamel(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n from ruamel.yaml import YAML\n \n-serializer = YAML(typ="unsafe")\n-serializer = YAML(typ="base")\n+serializer = YAML(typ="safe")\n+serializer = YAML(typ="safe")\n' expected_line_change = "3" num_changes = 2 - change_description = HardenRuamel.CHANGE_DESCRIPTION + change_description = HardenRuamel.change_description diff --git a/integration_tests/test_https_connection.py b/integration_tests/test_https_connection.py index 31f560133..79ac59a4f 100644 --- a/integration_tests/test_https_connection.py +++ b/integration_tests/test_https_connection.py @@ -22,4 +22,4 @@ class TestHTTPSConnection(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,8 +1,7 @@\n import urllib3\n import urllib3.connectionpool as pool\n-from urllib3 import HTTPConnectionPool as something\n \n-urllib3.HTTPConnectionPool("localhost", "80")\n-urllib3.connectionpool.HTTPConnectionPool("localhost", "80")\n-something("localhost", "80")\n-pool.HTTPConnectionPool("localhost", "80")\n+urllib3.HTTPSConnectionPool("localhost", "80")\n+urllib3.connectionpool.HTTPSConnectionPool("localhost", "80")\n+urllib3.HTTPSConnectionPool("localhost", "80")\n+pool.HTTPSConnectionPool("localhost", "80")\n' expected_line_change = "5" num_changes = 4 - change_description = HTTPSConnection.CHANGE_DESCRIPTION + change_description = HTTPSConnection.change_description diff --git a/integration_tests/test_jinja2_autoescape.py b/integration_tests/test_jinja2_autoescape.py index 7dbf0fa05..83a0b8b96 100644 --- a/integration_tests/test_jinja2_autoescape.py +++ b/integration_tests/test_jinja2_autoescape.py @@ -18,4 +18,4 @@ class TestEnableJinja2Autoescape(BaseIntegrationTest): expected_diff = "--- \n+++ \n@@ -1,4 +1,4 @@\n from jinja2 import Environment\n \n-env = Environment()\n-env = Environment(autoescape=False)\n+env = Environment(autoescape=True)\n+env = Environment(autoescape=True)\n" expected_line_change = "3" num_changes = 2 - change_description = EnableJinja2Autoescape.CHANGE_DESCRIPTION + change_description = EnableJinja2Autoescape.change_description diff --git a/integration_tests/test_jwt_decode_verify.py b/integration_tests/test_jwt_decode_verify.py index c6ae99aae..6c1360fd6 100644 --- a/integration_tests/test_jwt_decode_verify.py +++ b/integration_tests/test_jwt_decode_verify.py @@ -24,4 +24,4 @@ class TestJwtDecodeVerify(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -8,7 +8,7 @@\n \n encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm="HS256")\n \n-decoded_payload = jwt.decode(encoded_jwt, SECRET_KEY, algorithms=["HS256"], verify=False)\n-decoded_payload = jwt.decode(encoded_jwt, SECRET_KEY, algorithms=["HS256"], options={"verify_signature": False})\n+decoded_payload = jwt.decode(encoded_jwt, SECRET_KEY, algorithms=["HS256"], verify=True)\n+decoded_payload = jwt.decode(encoded_jwt, SECRET_KEY, algorithms=["HS256"], options={"verify_signature": True})\n \n var = "something"\n' expected_line_change = "11" num_changes = 2 - change_description = JwtDecodeVerify.CHANGE_DESCRIPTION + change_description = JwtDecodeVerify.change_description diff --git a/integration_tests/test_limit_readline.py b/integration_tests/test_limit_readline.py index b228b24ca..a91517f36 100644 --- a/integration_tests/test_limit_readline.py +++ b/integration_tests/test_limit_readline.py @@ -13,6 +13,6 @@ class TestLimitReadline(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,2 +1,2 @@\n file = open("some_file.txt")\n-file.readline()\n+file.readline(5_000_000)\n' expected_line_change = "2" - change_description = LimitReadline.CHANGE_DESCRIPTION + change_description = LimitReadline.change_description # expected because output code points to fake file allowed_exceptions = (FileNotFoundError,) diff --git a/integration_tests/test_literal_or_new_object_identity.py b/integration_tests/test_literal_or_new_object_identity.py index 7e5bccbc7..03f37af3f 100644 --- a/integration_tests/test_literal_or_new_object_identity.py +++ b/integration_tests/test_literal_or_new_object_identity.py @@ -28,5 +28,5 @@ class TestLiteralOrNewObjectIdentity(BaseIntegrationTest): # fmt: on expected_line_change = "2" - change_description = LiteralOrNewObjectIdentity.CHANGE_DESCRIPTION + change_description = LiteralOrNewObjectIdentity.change_description num_changed_files = 1 diff --git a/integration_tests/test_lxml_safe_parser_defaults.py b/integration_tests/test_lxml_safe_parser_defaults.py index 25292b9cb..5c1a5cba0 100644 --- a/integration_tests/test_lxml_safe_parser_defaults.py +++ b/integration_tests/test_lxml_safe_parser_defaults.py @@ -13,4 +13,4 @@ class TestLxmlSafeParserDefaults(BaseIntegrationTest): ) expected_diff = "--- \n+++ \n@@ -1,2 +1,2 @@\n import lxml.etree\n-parser = lxml.etree.XMLParser()\n+parser = lxml.etree.XMLParser(resolve_entities=False)\n" expected_line_change = "2" - change_description = LxmlSafeParserDefaults.CHANGE_DESCRIPTION + change_description = LxmlSafeParserDefaults.change_description diff --git a/integration_tests/test_lxml_safe_parsing.py b/integration_tests/test_lxml_safe_parsing.py index 3a8d65ea3..b121dce42 100644 --- a/integration_tests/test_lxml_safe_parsing.py +++ b/integration_tests/test_lxml_safe_parsing.py @@ -24,5 +24,5 @@ class TestLxmlSafeParsing(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,3 +1,3 @@\n import lxml.etree\n-lxml.etree.parse("path_to_file")\n-lxml.etree.fromstring("xml_str")\n+lxml.etree.parse("path_to_file", parser=lxml.etree.XMLParser(resolve_entities=False))\n+lxml.etree.fromstring("xml_str", parser=lxml.etree.XMLParser(resolve_entities=False))\n' expected_line_change = "2" num_changes = 2 - change_description = LxmlSafeParsing.CHANGE_DESCRIPTION + change_description = LxmlSafeParsing.change_description allowed_exceptions = (OSError,) diff --git a/integration_tests/test_numpy_nan_equality.py b/integration_tests/test_numpy_nan_equality.py index f5b7626a4..85ebd36a7 100644 --- a/integration_tests/test_numpy_nan_equality.py +++ b/integration_tests/test_numpy_nan_equality.py @@ -30,5 +30,5 @@ class TestNumpyNanEquality(BaseIntegrationTest): # fmt: on expected_line_change = "4" - change_description = NumpyNanEquality.CHANGE_DESCRIPTION + change_description = NumpyNanEquality.change_description num_changed_files = 1 diff --git a/integration_tests/test_order_imports.py b/integration_tests/test_order_imports.py index 3d174c877..0b515178c 100644 --- a/integration_tests/test_order_imports.py +++ b/integration_tests/test_order_imports.py @@ -37,4 +37,4 @@ class TestOrderImports(BaseIntegrationTest): expected_diff = "--- \n+++ \n@@ -1,20 +1,14 @@\n #!/bin/env python\n-from abc import ABCMeta\n-\n+# comment builtins4\n+# comment builtins5\n+# comment builtins3\n # comment builtins1\n # comment builtins2\n import builtins\n-\n+import collections\n+import datetime\n # comment a\n-from abc import ABC\n-\n-# comment builtins3\n-import builtins, datetime\n-\n-# comment builtins4\n-# comment builtins5\n-import builtins\n-import collections\n+from abc import ABC, ABCMeta\n \n ABC\n ABCMeta\n" expected_line_change = "2" - change_description = OrderImports.CHANGE_DESCRIPTION + change_description = OrderImports.change_description diff --git a/integration_tests/test_process_sandbox.py b/integration_tests/test_process_sandbox.py index a54fb9269..ba6aa5e32 100644 --- a/integration_tests/test_process_sandbox.py +++ b/integration_tests/test_process_sandbox.py @@ -22,7 +22,7 @@ class TestProcessSandbox(BaseIntegrationTest): expected_line_change = "3" num_changes = 4 num_changed_files = 2 - change_description = ProcessSandbox.CHANGE_DESCRIPTION + change_description = ProcessSandbox.change_description requirements_path = "tests/samples/requirements.txt" original_requirements = "# file used to test dependency management\nrequests==2.31.0\nblack==23.7.*\nmypy~=1.4\npylint>1\n" diff --git a/integration_tests/test_remove_debug_breakpoint.py b/integration_tests/test_remove_debug_breakpoint.py index 2bfe01202..14c6b294c 100644 --- a/integration_tests/test_remove_debug_breakpoint.py +++ b/integration_tests/test_remove_debug_breakpoint.py @@ -17,4 +17,4 @@ class TestRemoveDebugBreakpoint(BaseIntegrationTest): '--- \n+++ \n@@ -1,3 +1,2 @@\n print("hello")\n-breakpoint()\n print("world")\n' ) expected_line_change = "2" - change_description = RemoveDebugBreakpoint.CHANGE_DESCRIPTION + change_description = RemoveDebugBreakpoint.change_description diff --git a/integration_tests/test_remove_future_imports.py b/integration_tests/test_remove_future_imports.py index 8e7790743..9f7bae29e 100644 --- a/integration_tests/test_remove_future_imports.py +++ b/integration_tests/test_remove_future_imports.py @@ -33,4 +33,4 @@ class TestRemoveFutureImports(BaseIntegrationTest): num_changes = 2 expected_line_change = "1" - change_description = RemoveFutureImports.CHANGE_DESCRIPTION + change_description = RemoveFutureImports.change_description diff --git a/integration_tests/test_remove_module_global.py b/integration_tests/test_remove_module_global.py index 5daf5fca5..355c0b5e3 100644 --- a/integration_tests/test_remove_module_global.py +++ b/integration_tests/test_remove_module_global.py @@ -16,4 +16,4 @@ class TestRemoveModuleGlobal(BaseIntegrationTest): """.lstrip() expected_diff = '--- \n+++ \n@@ -1,4 +1,3 @@\n price = 25\n print("hello")\n-global price\n price = 30\n' expected_line_change = "3" - change_description = RemoveModuleGlobal.CHANGE_DESCRIPTION + change_description = RemoveModuleGlobal.change_description diff --git a/integration_tests/test_remove_unused_imports.py b/integration_tests/test_remove_unused_imports.py index 366756ece..1c3d38601 100644 --- a/integration_tests/test_remove_unused_imports.py +++ b/integration_tests/test_remove_unused_imports.py @@ -17,4 +17,4 @@ class TestRemoveUnusedImports(BaseIntegrationTest): expected_diff = "--- \n+++ \n@@ -1,5 +1,5 @@\n import abc\n-from builtins import complex, dict\n+from builtins import complex\n \n abc\n complex\n" expected_line_change = 2 - change_description = RemoveUnusedImports.CHANGE_DESCRIPTION + change_description = RemoveUnusedImports.change_description diff --git a/integration_tests/test_request_verify.py b/integration_tests/test_request_verify.py index 5478abbb7..dd2cf5ec3 100644 --- a/integration_tests/test_request_verify.py +++ b/integration_tests/test_request_verify.py @@ -22,6 +22,6 @@ class TestRequestsVerify(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,5 +1,5 @@\n import requests\n \n-requests.get("https://www.google.com", verify=False)\n-requests.post("https://some-api/", json={"id": 1234, "price": 18}, verify=False)\n+requests.get("https://www.google.com", verify=True)\n+requests.post("https://some-api/", json={"id": 1234, "price": 18}, verify=True)\n var = "hello"\n' expected_line_change = "3" num_changes = 2 - change_description = RequestsVerify.CHANGE_DESCRIPTION + change_description = RequestsVerify.change_description # expected because when executing the output code it will make a request which fails, which is OK. allowed_exceptions = (exceptions.ConnectionError,) diff --git a/integration_tests/test_secure_flask_cookie.py b/integration_tests/test_secure_flask_cookie.py index e8394b672..f1a8cdc9c 100644 --- a/integration_tests/test_secure_flask_cookie.py +++ b/integration_tests/test_secure_flask_cookie.py @@ -19,4 +19,4 @@ class TestSecureFlaskCookie(BaseIntegrationTest): ) expected_diff = "--- \n+++ \n@@ -5,5 +5,5 @@\n @app.route('/')\n def index():\n resp = make_response('Custom Cookie Set')\n- resp.set_cookie('custom_cookie', 'value')\n+ resp.set_cookie('custom_cookie', 'value', secure=True, httponly=True, samesite='Lax')\n return resp\n" expected_line_change = "8" - change_description = SecureFlaskCookie.CHANGE_DESCRIPTION + change_description = SecureFlaskCookie.change_description diff --git a/integration_tests/test_secure_flask_session_config.py b/integration_tests/test_secure_flask_session_config.py index f138ba184..b07ccd1d4 100644 --- a/integration_tests/test_secure_flask_session_config.py +++ b/integration_tests/test_secure_flask_session_config.py @@ -13,4 +13,4 @@ class TestSecureFlaskSessionConfig(BaseIntegrationTest): ) expected_diff = "--- \n+++ \n@@ -1,6 +1,6 @@\n from flask import Flask\n app = Flask(__name__)\n-app.config['SESSION_COOKIE_HTTPONLY'] = False\n+app.config['SESSION_COOKIE_HTTPONLY'] = True\n @app.route('/')\n def hello_world():\n return 'Hello World!'\n" expected_line_change = "3" - change_description = SecureFlaskSessionConfig.CHANGE_DESCRIPTION + change_description = SecureFlaskSessionConfig.change_description diff --git a/integration_tests/test_secure_random.py b/integration_tests/test_secure_random.py index 458cf785a..97863fa9b 100644 --- a/integration_tests/test_secure_random.py +++ b/integration_tests/test_secure_random.py @@ -19,4 +19,4 @@ class TestSecureRandom(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n-import random\n+import secrets\n \n-random.random()\n+secrets.SystemRandom().random()\n var = "hello"\n' expected_line_change = "3" - change_description = SecureRandom.CHANGE_DESCRIPTION + change_description = SecureRandom.change_description diff --git a/integration_tests/test_sql_parameterization.py b/integration_tests/test_sql_parameterization.py index f27063fa9..99242d33f 100644 --- a/integration_tests/test_sql_parameterization.py +++ b/integration_tests/test_sql_parameterization.py @@ -37,5 +37,5 @@ class TestSQLQueryParameterization(BaseIntegrationTest): # fmt: on expected_line_change = "12" - change_description = SQLQueryParameterization.CHANGE_DESCRIPTION + change_description = SQLQueryParameterization.change_description num_changed_files = 1 diff --git a/integration_tests/test_subprocess_shell_false.py b/integration_tests/test_subprocess_shell_false.py index dd520269f..078b88988 100644 --- a/integration_tests/test_subprocess_shell_false.py +++ b/integration_tests/test_subprocess_shell_false.py @@ -13,6 +13,6 @@ class TestSubprocessShellFalse(BaseIntegrationTest): ) expected_diff = "--- \n+++ \n@@ -1,2 +1,2 @@\n import subprocess\n-subprocess.run(\"echo 'hi'\", shell=True)\n+subprocess.run(\"echo 'hi'\", shell=False)\n" expected_line_change = "2" - change_description = SubprocessShellFalse.CHANGE_DESCRIPTION + change_description = SubprocessShellFalse.change_description # expected because output code points to fake file allowed_exceptions = (FileNotFoundError,) diff --git a/integration_tests/test_tempfile_mktemp.py b/integration_tests/test_tempfile_mktemp.py index f8e87e70a..7acca7efb 100644 --- a/integration_tests/test_tempfile_mktemp.py +++ b/integration_tests/test_tempfile_mktemp.py @@ -13,4 +13,4 @@ class TestTempfileMktemp(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n import tempfile\n \n-tempfile.mktemp()\n+tempfile.mkstemp()\n var = "hello"\n' expected_line_change = "3" - change_description = TempfileMktemp.CHANGE_DESCRIPTION + change_description = TempfileMktemp.change_description diff --git a/integration_tests/test_unnecessary_f_str.py b/integration_tests/test_unnecessary_f_str.py index 009cb2687..a29bfb291 100644 --- a/integration_tests/test_unnecessary_f_str.py +++ b/integration_tests/test_unnecessary_f_str.py @@ -13,4 +13,4 @@ class TestFStr(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,2 +1,2 @@\n-bad = f"hello"\n+bad = "hello"\n good = f"{2+3}"\n' expected_line_change = "1" - change_description = RemoveUnnecessaryFStr.CHANGE_DESCRIPTION + change_description = RemoveUnnecessaryFStr.change_description diff --git a/integration_tests/test_upgrade_sslcontext_minimum_version.py b/integration_tests/test_upgrade_sslcontext_minimum_version.py index 2b5b7a11e..4da3f45ea 100644 --- a/integration_tests/test_upgrade_sslcontext_minimum_version.py +++ b/integration_tests/test_upgrade_sslcontext_minimum_version.py @@ -20,4 +20,4 @@ class TestUpgradeSSLContextMininumVersion(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,8 +1,9 @@\n from ssl import PROTOCOL_TLS_CLIENT, SSLContext, TLSVersion\n+import ssl\n \n my_ctx = SSLContext(protocol=PROTOCOL_TLS_CLIENT)\n \n print("FOO")\n \n my_ctx.maximum_version = TLSVersion.MAXIMUM_SUPPORTED\n-my_ctx.minimum_version = TLSVersion.TLSv1_1\n+my_ctx.minimum_version = ssl.TLSVersion.TLSv1_2\n' expected_line_change = "8" - change_description = UpgradeSSLContextMinimumVersion.CHANGE_DESCRIPTION + change_description = UpgradeSSLContextMinimumVersion.change_description diff --git a/integration_tests/test_upgrade_sslcontext_tls.py b/integration_tests/test_upgrade_sslcontext_tls.py index ea788a624..34d9b3eac 100644 --- a/integration_tests/test_upgrade_sslcontext_tls.py +++ b/integration_tests/test_upgrade_sslcontext_tls.py @@ -9,4 +9,4 @@ class TestUpgradeWeakTLS(BaseIntegrationTest): expected_new_code = "import ssl\n\nssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)\n" expected_diff = "--- \n+++ \n@@ -1,3 +1,3 @@\n import ssl\n \n-ssl.SSLContext(ssl.PROTOCOL_SSLv2)\n+ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)\n" expected_line_change = "3" - change_description = UpgradeSSLContextTLS.CHANGE_DESCRIPTION + change_description = UpgradeSSLContextTLS.change_description diff --git a/integration_tests/test_url_sandbox.py b/integration_tests/test_url_sandbox.py index 1fd9688ac..4fba2507f 100644 --- a/integration_tests/test_url_sandbox.py +++ b/integration_tests/test_url_sandbox.py @@ -31,7 +31,7 @@ class TestUrlSandbox(BaseIntegrationTest): """ expected_line_change = "5" - change_description = UrlSandbox.CHANGE_DESCRIPTION + change_description = UrlSandbox.change_description num_changed_files = 2 requirements_path = "tests/samples/requirements.txt" diff --git a/integration_tests/test_use_defusedxml.py b/integration_tests/test_use_defusedxml.py index f9734208d..681371eeb 100644 --- a/integration_tests/test_use_defusedxml.py +++ b/integration_tests/test_use_defusedxml.py @@ -35,7 +35,7 @@ class TestUseDefusedXml(BaseIntegrationTest): """ expected_line_change = "5" - change_description = UseDefusedXml.CHANGE_DESCRIPTION + change_description = UseDefusedXml.change_description requirements_path = "tests/samples/requirements.txt" original_requirements = "# file used to test dependency management\nrequests==2.31.0\nblack==23.7.*\nmypy~=1.4\npylint>1\n" diff --git a/integration_tests/test_use_generator.py b/integration_tests/test_use_generator.py index a7d6b38aa..fe432617b 100644 --- a/integration_tests/test_use_generator.py +++ b/integration_tests/test_use_generator.py @@ -27,4 +27,4 @@ class TestUseGenerator(BaseIntegrationTest): """ expected_line_change = "6" - change_description = UseGenerator.CHANGE_DESCRIPTION + change_description = UseGenerator.change_description diff --git a/integration_tests/test_use_set_literal.py b/integration_tests/test_use_set_literal.py index 031ed0355..07e691d27 100644 --- a/integration_tests/test_use_set_literal.py +++ b/integration_tests/test_use_set_literal.py @@ -26,4 +26,4 @@ class TestUseSetLiteral(BaseIntegrationTest): expected_line_change = "1" num_changes = 2 - change_description = UseSetLiteral.CHANGE_DESCRIPTION + change_description = UseSetLiteral.change_description diff --git a/integration_tests/test_use_walrus_if.py b/integration_tests/test_use_walrus_if.py index 5b091071c..6d3abedde 100644 --- a/integration_tests/test_use_walrus_if.py +++ b/integration_tests/test_use_walrus_if.py @@ -29,4 +29,4 @@ def whatever(): num_changes = 3 expected_line_change = 1 - change_description = UseWalrusIf.CHANGE_DESCRIPTION + change_description = UseWalrusIf.change_description diff --git a/integration_tests/test_with_threading_lock.py b/integration_tests/test_with_threading_lock.py index 72add390d..d696f7793 100644 --- a/integration_tests/test_with_threading_lock.py +++ b/integration_tests/test_with_threading_lock.py @@ -18,4 +18,4 @@ class TestWithThreadingLock(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,3 +1,4 @@\n import threading\n-with threading.Lock():\n+lock = threading.Lock()\n+with lock:\n print("Hello")\n' expected_line_change = "2" - change_description = WithThreadingLock.CHANGE_DESCRIPTION + change_description = WithThreadingLock.change_description diff --git a/pylintrc b/pylintrc index 74e31145e..e5b7d0a43 100644 --- a/pylintrc +++ b/pylintrc @@ -7,7 +7,8 @@ jobs=0 # Add files or directories matching the regex patterns to the ignore-list. The # regex matches against paths and can be in Posix or Windows format. ignore-paths= - tests/samples, + tests/samples/, + .*/core_codemods/refactor/.*py$ # Specify a score threshold under which the program will exit with error. fail-under=10.0 diff --git a/src/codemodder/codemodder.py b/src/codemodder/codemodder.py index 0983e04c2..cf9215fc9 100644 --- a/src/codemodder/codemodder.py +++ b/src/codemodder/codemodder.py @@ -1,24 +1,20 @@ -from concurrent.futures import ThreadPoolExecutor import datetime import itertools import logging import os import sys +from typing import Sequence from pathlib import Path -import libcst as cst -from libcst.codemod import CodemodContext from codemodder.dependency import Dependency -from codemodder.file_context import FileContext from codemodder import registry, __version__ from codemodder.logging import configure_logger, logger, log_section, log_list from codemodder.cli import parse_args -from codemodder.change import ChangeSet -from codemodder.code_directory import file_line_patterns, match_files +from codemodder.code_directory import match_files +from codemodder.codemods.semgrep import SemgrepRuleDetector from codemodder.context import CodemodExecutionContext -from codemodder.diff import create_diff_from_tree -from codemodder.executor import CodemodExecutorWrapper +from codemodder.codemods.api import BaseCodemod from codemodder.project_analysis.file_parsers.package_store import PackageStore from codemodder.project_analysis.python_repo_manager import PythonRepoManager from codemodder.report.codetf_reporter import report_default @@ -36,128 +32,24 @@ def update_code(file_path, new_code): def find_semgrep_results( context: CodemodExecutionContext, - codemods: list[CodemodExecutorWrapper], + codemods: Sequence[BaseCodemod], + files_to_analyze: list[Path] | None = None, ) -> ResultSet: """Run semgrep once with all configuration files from all codemods and return a set of applicable rule IDs""" yaml_files = list( itertools.chain.from_iterable( - [codemod.yaml_files for codemod in codemods if codemod.yaml_files] + [ + codemod.detector.get_yaml_files(codemod.name) + for codemod in codemods + if codemod.detector + and isinstance(codemod.detector, SemgrepRuleDetector) + ] ) ) if not yaml_files: return ResultSet() - return run_semgrep(context, yaml_files) - - -def apply_codemod_to_file( - base_directory: Path, - file_context, - codemod_kls: CodemodExecutorWrapper, - source_tree, - dry_run: bool = False, -): - wrapper = cst.MetadataWrapper(source_tree) - codemod = codemod_kls(CodemodContext(wrapper=wrapper), file_context) - if not codemod.should_transform: - return False - - with file_context.timer.measure("transform"): - output_tree = codemod.transform_module(source_tree) - - # TODO: we can probably just use the presence of recorded changes instead of - # comparing the trees to gain some efficiency - if output_tree.deep_equals(source_tree): - return False - - diff = create_diff_from_tree(source_tree, output_tree) - change_set = ChangeSet( - str(file_context.file_path.relative_to(base_directory)), - diff, - changes=file_context.codemod_changes, - ) - file_context.add_result(change_set) - - if not dry_run: - with file_context.timer.measure("write"): - update_code(file_context.file_path, output_tree.code) - - return True - - -# pylint: disable-next=too-many-arguments -def process_file( - idx: int, - file_path: Path, - base_directory: Path, - codemod, - results: ResultSet, - cli_args, -) -> FileContext: - logger.debug("scanning file %s", file_path) - if idx and idx % 100 == 0: - logger.info("scanned %s files...", idx) # pragma: no cover - - line_exclude = file_line_patterns(file_path, cli_args.path_exclude) - line_include = file_line_patterns(file_path, cli_args.path_include) - findings_for_rule = results.results_for_rule_and_file( - codemod.name, # TODO: should be full ID - file_path, - ) - - file_context = FileContext( - base_directory, - file_path, - line_exclude, - line_include, - findings_for_rule, - ) - - try: - with file_context.timer.measure("parse"): - with open(file_path, "r", encoding="utf-8") as f: - source_tree = cst.parse_module(f.read()) - except Exception: - file_context.add_failure(file_path) - logger.exception("error parsing file %s", file_path) - return file_context - - apply_codemod_to_file( - base_directory, - file_context, - codemod, - source_tree, - cli_args.dry_run, - ) - - return file_context - - -def analyze_files( - execution_context: CodemodExecutionContext, - files_to_analyze, - codemod, - results: ResultSet, - cli_args, -): - with ThreadPoolExecutor(max_workers=cli_args.max_workers) as executor: - logger.debug( - "using executor with %s threads", - cli_args.max_workers, - ) - analysis_results = executor.map( - lambda args: process_file( - args[0], - args[1], - execution_context.directory, - codemod, - results, - cli_args, - ), - enumerate(files_to_analyze), - ) - executor.shutdown(wait=True) - execution_context.process_results(codemod.id, analysis_results) + return run_semgrep(context, yaml_files, files_to_analyze) def log_report(context, argv, elapsed_ms, files_to_analyze): @@ -185,10 +77,9 @@ def log_report(context, argv, elapsed_ms, files_to_analyze): def apply_codemods( context: CodemodExecutionContext, - codemods_to_run: list[CodemodExecutorWrapper], + codemods_to_run: Sequence[BaseCodemod], semgrep_results: ResultSet, files_to_analyze: list[Path], - argv, ): log_section("scanning") @@ -207,25 +98,20 @@ def apply_codemods( # NOTE: this may be used as a progress indicator by upstream tools logger.info("running codemod %s", codemod.id) - # Unfortunately the IDs from semgrep are not fully specified - # TODO: eventually we need to be able to use fully specified IDs here - if codemod.is_semgrep and codemod.name not in semgrep_finding_ids: - logger.debug( - "no results from semgrep for %s, skipping analysis", - codemod.id, - ) - continue + if isinstance(codemod.detector, SemgrepRuleDetector): + # Unfortunately the IDs from semgrep are not fully specified + # TODO: eventually we need to be able to use fully specified IDs here + if codemod.name not in semgrep_finding_ids: + logger.debug( + "no results from semgrep for %s, skipping analysis", + codemod.id, + ) + continue + + files_to_analyze = semgrep_results.files_for_rule(codemod.name) - semgrep_files = semgrep_results.files_for_rule(codemod.name) # Non-semgrep codemods ignore the semgrep results - results = codemod.apply(context, semgrep_files) - analyze_files( - context, - files_to_analyze, - codemod, - results, - argv, - ) + codemod.apply(context, files_to_analyze) record_dependency_update(context.process_dependencies(codemod.id)) context.log_changes(codemod.id) @@ -276,6 +162,9 @@ def run(original_args) -> int: argv.verbose, codemod_registry, repo_manager, + argv.path_include, + argv.path_exclude, + argv.max_workers, ) repo_manager.parse_project() @@ -297,14 +186,17 @@ def run(original_args) -> int: full_names = [str(path) for path in files_to_analyze] log_list(logging.DEBUG, "matched files", full_names) - semgrep_results: ResultSet = find_semgrep_results(context, codemods_to_run) + semgrep_results: ResultSet = find_semgrep_results( + context, + codemods_to_run, + files_to_analyze, + ) apply_codemods( context, codemods_to_run, semgrep_results, files_to_analyze, - argv, ) results = context.compile_results(codemods_to_run) diff --git a/src/codemodder/codemods/api.py b/src/codemodder/codemods/api.py new file mode 100644 index 000000000..4c9c3c0b6 --- /dev/null +++ b/src/codemodder/codemods/api.py @@ -0,0 +1,52 @@ +from abc import ABCMeta +from typing import Callable + +import libcst as cst + +from codemodder.codemods.base_codemod import ( # pylint: disable=unused-import + BaseCodemod, + Metadata, + Reference, + ReviewGuidance, +) +from codemodder.codemods.libcst_transformer import ( + LibcstResultTransformer, + LibcstTransformerPipeline, +) +from codemodder.file_context import FileContext # pylint: disable=unused-import +from codemodder.codemods.semgrep import SemgrepRuleDetector + + +class SimpleCodemod(LibcstResultTransformer, metaclass=ABCMeta): + """ + Base class for codemods with a single detector and transformer + + Child classes must implement the following attributes: + - metadata: Metadata + - codemod_base: type[BaseCodemod] + """ + + metadata: Metadata + detector_pattern: str + on_result_found: Callable[[cst.CSTNode, cst.CSTNode], cst.CSTNode] + + codemod_base: type[BaseCodemod] + + def __init__(self, *args, **kwargs): + """Obfuscates the type of the constructor to make the type checker happy""" + super().__init__(*args, **kwargs) + + def __new__(cls, *args, **kwargs): + del args + + if kwargs.get("_transformer", False): + return super().__new__(cls) + + return cls.codemod_base( + metadata=cls.metadata, + detector=SemgrepRuleDetector(cls.detector_pattern) + if getattr(cls, "detector_pattern", None) + else None, + # This allows the transformer to inherit all the methods of the class itself + transformer=LibcstTransformerPipeline(cls), + ) diff --git a/src/codemodder/codemods/api/__init__.py b/src/codemodder/codemods/api/__init__.py deleted file mode 100644 index 0a05313bc..000000000 --- a/src/codemodder/codemods/api/__init__.py +++ /dev/null @@ -1,142 +0,0 @@ -import io -import os -import tempfile - -import libcst as cst -from libcst.codemod import ( - CodemodContext, -) -import yaml - -from codemodder.codemods.base_codemod import ( - CodemodMetadata, - BaseCodemod as _BaseCodemod, - SemgrepCodemod as _SemgrepCodemod, - # Make this available via the simplified API - ReviewGuidance, # pylint: disable=unused-import -) - -from codemodder.codemods.base_visitor import BaseTransformer -from codemodder.change import Change -from codemodder.file_context import FileContext -from .helpers import Helpers - - -def _populate_yaml(rule: str, metadata: CodemodMetadata) -> str: - config = yaml.safe_load(io.StringIO(rule)) - # TODO: handle more than rule per config? - assert len(config["rules"]) == 1 - config["rules"][0].setdefault("id", metadata.NAME) - config["rules"][0].setdefault("message", "Semgrep found a match") - config["rules"][0].setdefault("severity", "WARNING") - config["rules"][0].setdefault("languages", ["python"]) - return yaml.safe_dump(config) - - -def _create_temp_yaml_file(orig_cls, metadata: CodemodMetadata): - fd, path = tempfile.mkstemp() - with os.fdopen(fd, "w") as ff: - ff.write(_populate_yaml(orig_cls.rule(), metadata)) - - return [path] - - -class _CodemodSubclassWithMetadata: - def __init_subclass__(cls): - # This is a pretty yucky workaround. - # But it is necessary to get around the fact that these fields are - # checked by __init_subclass__ of the other parents of SemgrepCodemod - # first. - if cls.__name__ not in ("BaseCodemod", "SemgrepCodemod"): - # TODO: if we intend to continue to check class-level attributes - # using this mechanism, we should add checks (or defaults) for - # NAME, DESCRIPTION, and REVIEW_GUIDANCE here. - missing_fields = [] - for field in ["SUMMARY", "DESCRIPTION", "REVIEW_GUIDANCE"]: - try: - assert ( - hasattr(cls, field) - and getattr(cls, field) is not NotImplemented - ) - except AssertionError: - missing_fields.append(field) - - if missing_fields: - raise AssertionError( - f"{cls.__name__} is missing the following fields: {missing_fields}" - ) - - cls.METADATA = CodemodMetadata( - cls.DESCRIPTION, # pylint: disable=no-member - cls.NAME, # pylint: disable=no-member - cls.REVIEW_GUIDANCE, # pylint: disable=no-member - cls.REFERENCES, # pylint: disable=no-member - ) - - # This is a little bit hacky, but it also feels like the right solution? - cls.CHANGE_DESCRIPTION = cls.DESCRIPTION # pylint: disable=no-member - - return cls - - -class BaseCodemod( - _CodemodSubclassWithMetadata, - _BaseCodemod, - BaseTransformer, - Helpers, -): - def __init__(self, codemod_context: CodemodContext, file_context: FileContext): - _BaseCodemod.__init__(self, file_context) - BaseTransformer.__init__(self, codemod_context, []) - - def report_change(self, original_node): - line_number = self.lineno_for_node(original_node) - self.file_context.codemod_changes.append( - Change(line_number, self.CHANGE_DESCRIPTION) - ) - - -# NOTE: this shadows base_codemod.SemgrepCodemod but I can't think of a better name right now -# At least it is namespaced but we might want to deconflict these things in the long term -class SemgrepCodemod( - BaseCodemod, - _CodemodSubclassWithMetadata, - _SemgrepCodemod, - BaseTransformer, -): - def __init_subclass__(cls): - super().__init_subclass__() - cls.YAML_FILES = _create_temp_yaml_file(cls, cls.METADATA) - - def __init__(self, codemod_context: CodemodContext, file_context: FileContext): - BaseCodemod.__init__(self, codemod_context, file_context) - _SemgrepCodemod.__init__(self, file_context) - BaseTransformer.__init__(self, codemod_context, file_context.findings) - - def _new_or_updated_node(self, original_node, updated_node): - if self.node_is_selected(original_node): - self.report_change(original_node) - if (attr := getattr(self, "on_result_found", None)) is not None: - # pylint: disable=not-callable - new_node = attr(original_node, updated_node) - return new_node - return updated_node - - # TODO: there needs to be a way to generalize this so that it applies - # more broadly than to just a specific kind of node. There's probably a - # decent way to do this with metaprogramming. We could either apply it - # broadly to every known method (which would probably have a big - # performance impact). Or we could allow users to register the handler - # for a specific node or nodes by means of a decorator or something - # similar when they define their `on_result_found` method. - # Right now this is just to demonstrate a particular use case. - def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): - return self._new_or_updated_node(original_node, updated_node) - - def leave_Assign(self, original_node, updated_node): - return self._new_or_updated_node(original_node, updated_node) - - def leave_ClassDef( - self, original_node: cst.ClassDef, updated_node: cst.ClassDef - ) -> cst.ClassDef: - return self._new_or_updated_node(original_node, updated_node) diff --git a/src/codemodder/codemods/api/helpers.py b/src/codemodder/codemods/api/helpers.py deleted file mode 100644 index 93da9506f..000000000 --- a/src/codemodder/codemods/api/helpers.py +++ /dev/null @@ -1,130 +0,0 @@ -from collections import namedtuple -import libcst as cst -from libcst import matchers -from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor -from codemodder.codemods.utils import get_call_name - -NewArg = namedtuple("NewArg", ["name", "value", "add_if_missing"]) - - -class Helpers: - def remove_unused_import(self, original_node): - # pylint: disable=no-member - RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) - - def add_needed_import(self, module, obj=None): - # TODO: do we need to check if this import already exists? - AddImportsVisitor.add_needed_import( - self.context, module, obj # pylint: disable=no-member - ) - - def update_call_target( - self, original_node, new_target, new_func=None, replacement_args=None - ): - # TODO: is an assertion the best way to handle this? - # Or should we just return the original node if it's not a Call? - assert isinstance(original_node, cst.Call) - - attr = ( - cst.parse_expression(new_func) - if new_func - else cst.Name(value=get_call_name(original_node)) - ) - return cst.Call( - func=cst.Attribute( - value=cst.parse_expression(new_target), - attr=attr, - ), - args=replacement_args if replacement_args else original_node.args, - ) - - def update_arg_target(self, updated_node, new_args: list): - return updated_node.with_changes( - args=[new if isinstance(new, cst.Arg) else cst.Arg(new) for new in new_args] - ) - - def update_assign_rhs(self, updated_node: cst.Assign, rhs: str): - value = cst.parse_expression(rhs) - return updated_node.with_changes(value=value) - - def parse_expression(self, expression: str): - return cst.parse_expression(expression) - - def replace_args(self, original_node, args_info): - """ - Iterate over the args in original_node and replace each arg - with any matching arg in `args_info`. - - :param original_node: libcst node with args attribute. - :param list args_info: List of NewArg - """ - assert hasattr(original_node, "args") - assert all( - isinstance(arg, NewArg) for arg in args_info - ), "`args_info` must contain `NewArg` types." - new_args = [] - - for arg in original_node.args: - arg_name, replacement_val, idx = _match_with_existing_arg(arg, args_info) - if arg_name is not None: - new = self.make_new_arg(replacement_val, arg_name, arg) - del args_info[idx] - else: - new = arg - new_args.append(new) - - for arg_name, replacement_val, add_if_missing in args_info: - if add_if_missing: - new = self.make_new_arg(replacement_val, arg_name) - new_args.append(new) - - return new_args - - def make_new_arg(self, value, name=None, existing_arg=None): - if name is None: - # Make a positional argument - return cst.Arg( - value=cst.parse_expression(value), - ) - - # make a keyword argument - equal = ( - existing_arg.equal - if existing_arg - else cst.AssignEqual( - whitespace_before=cst.SimpleWhitespace(""), - whitespace_after=cst.SimpleWhitespace(""), - ) - ) - return cst.Arg( - keyword=cst.Name(value=name), - value=cst.parse_expression(value), - equal=equal, - ) - - def add_arg_to_call(self, node: cst.Call, name: str, value): - """ - Add a new arg to the end of the args list. - """ - new_args = list(node.args) + [ - cst.Arg( - keyword=cst.Name(value=name), - value=cst.parse_expression(str(value)), - equal=cst.AssignEqual( - whitespace_before=cst.SimpleWhitespace(""), - whitespace_after=cst.SimpleWhitespace(""), - ), - ) - ] - return node.with_changes(args=new_args) - - -def _match_with_existing_arg(arg, args_info): - """ - Given an `arg` and a list of arg info, determine if - any of the names in arg_info match the arg. - """ - for idx, (arg_name, replacement_val, _) in enumerate(args_info): - if matchers.matches(arg.keyword, matchers.Name(arg_name)): - return arg_name, replacement_val, idx - return None, None, None diff --git a/src/codemodder/codemods/base_codemod.py b/src/codemodder/codemods/base_codemod.py index 6b1b44480..b2edfeae0 100644 --- a/src/codemodder/codemods/base_codemod.py +++ b/src/codemodder/codemods/base_codemod.py @@ -1,14 +1,19 @@ +from abc import ABCMeta, abstractmethod +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum -from typing import List, ClassVar - -from libcst._position import CodeRange - -from codemodder.change import Change -from codemodder.dependency import Dependency +from functools import cached_property +import importlib.resources +from importlib.abc import Traversable +from pathlib import Path + +from codemodder.codemods.base_detector import BaseDetector +from codemodder.codemods.base_transformer import BaseTransformerPipeline +from codemodder.context import CodemodExecutionContext +from codemodder.code_directory import file_line_patterns from codemodder.file_context import FileContext +from codemodder.logging import logger from codemodder.result import ResultSet -from codemodder.semgrep import run as semgrep_run class ReviewGuidance(Enum): @@ -17,109 +22,163 @@ class ReviewGuidance(Enum): MERGE_WITHOUT_REVIEW = 3 -@dataclass(frozen=True) -class CodemodMetadata: - DESCRIPTION: str # TODO: this field should be optional - NAME: str - REVIEW_GUIDANCE: ReviewGuidance - REFERENCES: list = field(default_factory=list) - - # TODO: remove post_init update_references once we add description for each url. - def __post_init__(self): - object.__setattr__(self, "REFERENCES", self.update_references(self.REFERENCES)) - - @staticmethod - def update_references(references): - updated_references = [] - for reference in references: - updated_reference = dict( - reference - ) # Create a copy to avoid modifying the original dict - updated_reference["description"] = updated_reference["url"] - updated_references.append(updated_reference) - return updated_references - - -class BaseCodemod: - # Implementation borrowed from https://stackoverflow.com/a/45250114 - METADATA: ClassVar[CodemodMetadata] = NotImplemented - SUMMARY: ClassVar[str] = NotImplemented - is_semgrep: bool = False - adds_dependency: bool = False - file_context: FileContext - - def __init__(self, file_context: FileContext): - self.file_context = file_context - - @classmethod - def apply_rule(cls, context, *args, **kwargs) -> ResultSet: - """ - Apply rule associated with this codemod and gather results +@dataclass +class Reference: + url: str + description: str = "" - Does nothing by default. Subclasses may override for custom rule logic. - """ - del context, args, kwargs - return ResultSet() + def to_json(self): + return { + "url": self.url, + "description": self.description or self.url, + } - @classmethod - def name(cls): - # pylint: disable=no-member - return cls.METADATA.NAME - @property - def should_transform(self): - return True +@dataclass +class Metadata: + name: str + summary: str + review_guidance: ReviewGuidance + references: list[Reference] = field(default_factory=list) + has_description: bool = True + + +class BaseCodemod(metaclass=ABCMeta): + """ + Base class for all codemods + + Conceptually a codemod is composed of the following attributes: + * Metadata: contains information about the codemod including its name, summary, and review guidance + * Detector (optional): the source of results indicating which code locations the codemod should be applied + * Transformer: a transformer pipeline that will be applied to each applicable file and perform the actual modifications + + A detector may parse result files generated by other tools or it may + perform its own analysis at runtime, potentially by calling another tool + (e.g. Semgrep). - def node_position(self, node): - # pylint: disable=no-member - # See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112 - return self.get_metadata(self.METADATA_DEPENDENCIES[0], node) + Some codemods may not require a detector if the transformation pipeline + itself is capable of determining locations to modify. - def add_change(self, node, description: str, start: bool = True): - position = self.node_position(node) - self.add_change_from_position(position, description, start) + Codemods that apply the same transformation but use different detectors + should be implemented as distinct codemod classes. + """ - def add_change_from_position( - self, position: CodeRange, description: str, start: bool = True + metadata: Metadata + detector: BaseDetector | None + transformer: BaseTransformerPipeline + + def __init__( + self, + *, + metadata: Metadata, + detector: BaseDetector | None = None, + transformer: BaseTransformerPipeline, ): - lineno = position.start.line if start else position.end.line - self.file_context.codemod_changes.append( - Change( - lineNumber=lineno, - description=description, - ) - ) + self.metadata = metadata + self.detector = detector + self.transformer = transformer + + @property + @abstractmethod + def origin(self) -> str: + ... + + @property + @abstractmethod + def docs_module_path(self) -> str: + ... - def lineno_for_node(self, node): - return self.node_position(node).start.line + @property + def name(self) -> str: + return self.metadata.name @property - def line_exclude(self): - return self.file_context.line_exclude + def id(self) -> str: + return f"{self.origin}:python/{self.name}" @property - def line_include(self): - return self.file_context.line_include + def summary(self): + return self.metadata.summary - def add_dependency(self, dependency: Dependency): - self.file_context.add_dependency(dependency) + @cached_property + def docs_module(self) -> Traversable: + return importlib.resources.files(self.docs_module_path) + @cached_property + def description(self) -> str: + if not self.metadata.has_description: + return "" -class SemgrepCodemod(BaseCodemod): - YAML_FILES: ClassVar[List[str]] = NotImplemented - is_semgrep = True + doc_path = self.docs_module / f"{self.origin}_python_{self.name}.md" + return doc_path.read_text() - @classmethod - def apply_rule(cls, context, *args, **kwargs) -> ResultSet: + @property + def review_guidance(self): + return self.metadata.review_guidance.name.replace("_", " ").title() + + @property + def references(self) -> list[Reference]: + return self.metadata.references + + def describe(self): + return { + "codemod": self.id, + "summary": self.summary, + "description": self.description, + "references": [ref.to_json() for ref in self.references], + } + + def apply( + self, + context: CodemodExecutionContext, + files_to_analyze: list[Path], + ) -> None: """ - Apply semgrep to gather rule results + Apply the codemod to the given list of files + + This method is responsible for orchestrating the application of the codemod to a given list of files. + + It will first apply the detector (if any) to the files to determine which files should be modified. + + It then applies the transformer pipeline to each file applicable file, potentially generating a change set. + + All results are then processed and reported to the context. + + Per-file processing can be parallelized based on the `max_workers` setting. + + :param context: The codemod execution context + :param files_to_analyze: The list of files to analyze """ - yaml_files = kwargs.get("yaml_files") or args[0] - files_to_analyze = kwargs.get("files_to_analyze") or args[1] - with context.timer.measure("semgrep"): - return semgrep_run(context, yaml_files, files_to_analyze) + results = ( + # It seems like semgrep doesn't like our fully-specified id format + self.detector.apply(self.name, context, files_to_analyze) + if self.detector + else ResultSet() + ) - @property - def should_transform(self): - """Semgrep codemods should attempt transform only if there are semgrep results""" - return bool(self.file_context.findings) + def process_file(filename: Path): + line_exclude = file_line_patterns(filename, context.path_exclude) + line_include = file_line_patterns(filename, context.path_include) + findings_for_rule = results.results_for_rule_and_file(self.name, filename) + + file_context = FileContext( + context.directory, + filename, + line_exclude, + line_include, + findings_for_rule, + ) + + if change_set := self.transformer.apply( + context, file_context, findings_for_rule + ): + file_context.add_result(change_set) + + return file_context + + with ThreadPoolExecutor() as executor: + logger.debug("using executor with %s workers", context.max_workers) + contexts = executor.map(process_file, files_to_analyze) + executor.shutdown(wait=True) + + context.process_results(self.id, contexts) diff --git a/src/codemodder/codemods/base_detector.py b/src/codemodder/codemods/base_detector.py new file mode 100644 index 000000000..c1bae4d61 --- /dev/null +++ b/src/codemodder/codemods/base_detector.py @@ -0,0 +1,16 @@ +from abc import ABCMeta, abstractmethod +from pathlib import Path + +from codemodder.context import CodemodExecutionContext +from codemodder.result import ResultSet + + +class BaseDetector(metaclass=ABCMeta): + @abstractmethod + def apply( + self, + codemod_id: str, + context: CodemodExecutionContext, + files_to_analyze: list[Path], + ) -> ResultSet: + ... diff --git a/src/codemodder/codemods/base_transformer.py b/src/codemodder/codemods/base_transformer.py new file mode 100644 index 000000000..998b26328 --- /dev/null +++ b/src/codemodder/codemods/base_transformer.py @@ -0,0 +1,43 @@ +from abc import ABCMeta, abstractmethod + +from codemodder.change import ChangeSet +from codemodder.context import CodemodExecutionContext +from codemodder.codemods.base_visitor import BaseTransformer +from codemodder.file_context import FileContext +from codemodder.result import Result + + +class BaseTransformerPipeline(metaclass=ABCMeta): + """ + Base class for a pipeline of transformers + + A pipeline is a list of one or more transformers that are applied in sequence. + + The transformers in a given pipeline can either be homogeneous or heterogeneous in terms of inputs and output formats accepted by each transformer. For a heterogeneous pipeline it may be necessary to implement adapter classes to convert between formats. + + Each transformer pipeline is responsible for writing results to the output files if `dry_run` is `False`. + """ + + transformers: list[type[BaseTransformer]] + + def __init__(self, *transformers: type[BaseTransformer]): + self.transformers = list(transformers) + + @abstractmethod + def apply( + self, + context: CodemodExecutionContext, + file_context: FileContext, + results: list[Result], + ) -> ChangeSet | None: + """ + Apply the pipeline to the given file context + + :param context: The codemod execution context + :param file_context: The file context representing the file to transform + :param results: The (optional) results of the detector phase + + :return: The `ChangeSet` to apply to the file, or `None` if no changes are applied + + This method is responsible for writing the results to the output files if `dry_run` is False. + """ diff --git a/src/codemodder/codemods/base_visitor.py b/src/codemodder/codemods/base_visitor.py index 65d30d1d0..e50c91784 100644 --- a/src/codemodder/codemods/base_visitor.py +++ b/src/codemodder/codemods/base_visitor.py @@ -5,6 +5,7 @@ from codemodder.result import Result +# TODO: this should just be part of BaseTransformer and BaseVisitor? class UtilsMixin: results: list[Result] diff --git a/src/codemodder/codemods/libcst_transformer.py b/src/codemodder/codemods/libcst_transformer.py new file mode 100644 index 000000000..ef752b8b3 --- /dev/null +++ b/src/codemodder/codemods/libcst_transformer.py @@ -0,0 +1,305 @@ +from collections import namedtuple + +import libcst as cst +from libcst import matchers +from libcst.codemod import CodemodContext +from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor +from libcst._position import CodeRange + +from codemodder.codemods.base_visitor import BaseTransformer +from codemodder.codemods.base_transformer import BaseTransformerPipeline +from codemodder.codemods.utils import get_call_name +from codemodder.context import CodemodExecutionContext +from codemodder.diff import create_diff_from_tree +from codemodder.change import ChangeSet, Change +from codemodder.dependency import Dependency +from codemodder.logging import logger +from codemodder.file_context import FileContext +from codemodder.result import Result + + +NewArg = namedtuple("NewArg", ["name", "value", "add_if_missing"]) + + +def update_code(file_path, new_code): + """ + Write the `new_code` to the `file_path` + """ + with open(file_path, "w", encoding="utf-8") as f: + f.write(new_code) + + +class LibcstResultTransformer( + BaseTransformer +): # pylint: disable=too-many-public-methods + """ + Transformer class that performs libcst-based transformations on a given file + + :param context: libcst CodemodContext + :param results: list of `Result` generated by the detector phase (may be empty) + :param file_context: `FileContext` for the file to be transformed + """ + + change_description: str = "" + + def __init__( + self, + context: CodemodContext, + results: list[Result], + file_context: FileContext, + _transformer: bool = False, + ): + del _transformer + + super().__init__(context, results) + self.file_context = file_context + + @classmethod + def transform( + cls, module: cst.Module, results: list[Result], file_context: FileContext + ) -> cst.Module: + wrapper = cst.MetadataWrapper(module) + codemod = cls( + CodemodContext(wrapper=wrapper), + results, + file_context, + _transformer=True, + ) + + return codemod.transform_module(module) + + def _new_or_updated_node(self, original_node, updated_node): + if self.node_is_selected(original_node): + self.report_change(original_node) + if (attr := getattr(self, "on_result_found", None)) is not None: + # pylint: disable=not-callable + new_node = attr(original_node, updated_node) + return new_node + return updated_node + + # TODO: there needs to be a way to generalize this so that it applies + # more broadly than to just a specific kind of node. There's probably a + # decent way to do this with metaprogramming. We could either apply it + # broadly to every known method (which would probably have a big + # performance impact). Or we could allow users to register the handler + # for a specific node or nodes by means of a decorator or something + # similar when they define their `on_result_found` method. + # Right now this is just to demonstrate a particular use case. + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): + return self._new_or_updated_node(original_node, updated_node) + + def leave_Assign(self, original_node, updated_node): + return self._new_or_updated_node(original_node, updated_node) + + def leave_ClassDef( + self, original_node: cst.ClassDef, updated_node: cst.ClassDef + ) -> cst.ClassDef: + return self._new_or_updated_node(original_node, updated_node) + + def node_position(self, node): + # pylint: disable=no-member + # See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112 + return self.get_metadata(self.METADATA_DEPENDENCIES[0], node) + + def add_change(self, node, description: str, start: bool = True): + position = self.node_position(node) + self.add_change_from_position(position, description, start) + + def add_change_from_position( + self, position: CodeRange, description: str, start: bool = True + ): + lineno = position.start.line if start else position.end.line + self.file_context.codemod_changes.append( + Change( + lineNumber=lineno, + description=description, + ) + ) + + def lineno_for_node(self, node): + return self.node_position(node).start.line + + @property + def line_exclude(self): + return self.file_context.line_exclude + + @property + def line_include(self): + return self.file_context.line_include + + def add_dependency(self, dependency: Dependency): + self.file_context.add_dependency(dependency) + + def report_change(self, original_node): + line_number = self.lineno_for_node(original_node) + self.file_context.codemod_changes.append( + Change(line_number, self.change_description) + ) + + def remove_unused_import(self, original_node): + # pylint: disable=no-member + RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) + + def add_needed_import(self, module, obj=None): + # TODO: do we need to check if this import already exists? + AddImportsVisitor.add_needed_import( + self.context, module, obj # pylint: disable=no-member + ) + + def update_call_target( + self, + original_node, + new_target, + new_func: str | None = None, + replacement_args=None, + ): + # TODO: is an assertion the best way to handle this? + # Or should we just return the original node if it's not a Call? + assert isinstance(original_node, cst.Call) + + func_name = new_func if new_func else get_call_name(original_node) + return cst.Call( + func=cst.Attribute( + value=cst.parse_expression(new_target), + attr=cst.Name(value=func_name), + ), + args=replacement_args if replacement_args else original_node.args, + ) + + def update_arg_target(self, updated_node, new_args: list): + return updated_node.with_changes( + args=[new if isinstance(new, cst.Arg) else cst.Arg(new) for new in new_args] + ) + + def update_assign_rhs(self, updated_node: cst.Assign, rhs: str): + value = cst.parse_expression(rhs) + return updated_node.with_changes(value=value) + + def parse_expression(self, expression: str): + return cst.parse_expression(expression) + + def replace_args(self, original_node, args_info): + """ + Iterate over the args in original_node and replace each arg + with any matching arg in `args_info`. + + :param original_node: libcst node with args attribute. + :param list args_info: List of NewArg + """ + assert hasattr(original_node, "args") + assert all( + isinstance(arg, NewArg) for arg in args_info + ), "`args_info` must contain `NewArg` types." + new_args = [] + + for arg in original_node.args: + arg_name, replacement_val, idx = _match_with_existing_arg(arg, args_info) + if arg_name is not None: + new = self.make_new_arg(replacement_val, arg_name, arg) + del args_info[idx] + else: + new = arg + new_args.append(new) + + for arg_name, replacement_val, add_if_missing in args_info: + if add_if_missing: + new = self.make_new_arg(replacement_val, arg_name) + new_args.append(new) + + return new_args + + def make_new_arg(self, value, name=None, existing_arg=None): + if name is None: + # Make a positional argument + return cst.Arg( + value=cst.parse_expression(value), + ) + + # make a keyword argument + equal = ( + existing_arg.equal + if existing_arg + else cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ) + ) + return cst.Arg( + keyword=cst.Name(value=name), + value=cst.parse_expression(value), + equal=equal, + ) + + def add_arg_to_call(self, node: cst.Call, name: str, value): + """ + Add a new arg to the end of the args list. + """ + new_args = list(node.args) + [ + cst.Arg( + keyword=cst.Name(value=name), + value=cst.parse_expression(str(value)), + equal=cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ), + ) + ] + return node.with_changes(args=new_args) + + +class LibcstTransformerPipeline(BaseTransformerPipeline): + """ + Transformer pipeline class that applies one or more `LibcstResultTransformer` to a given file + + This pipeline expects that all transformers accept a libcst `Module` as input and return a libcst `Module` as output. + """ + + transformers: list[type[LibcstResultTransformer]] + + def apply( + self, + context: CodemodExecutionContext, + file_context: FileContext, + results: list[Result], + ) -> ChangeSet | None: + file_path = file_context.file_path + + try: + with file_context.timer.measure("parse"): + with open(file_path, "r", encoding="utf-8") as f: + source_tree = cst.parse_module(f.read()) + except Exception: + file_context.add_failure(file_path) + logger.exception("error parsing file %s", file_path) + return None + + tree = source_tree + with file_context.timer.measure("transform"): + for transformer in self.transformers: + tree = transformer.transform(tree, results, file_context) + + if tree.deep_equals(source_tree): + return None + + diff = create_diff_from_tree(source_tree, tree) + change_set = ChangeSet( + str(file_context.file_path.relative_to(context.directory)), + diff, + changes=file_context.codemod_changes, + ) + + if not context.dry_run: + with file_context.timer.measure("write"): + update_code(file_context.file_path, tree.code) + + return change_set + + +def _match_with_existing_arg(arg, args_info): + """ + Given an `arg` and a list of arg info, determine if any of the names in arg_info match the arg. + """ + for idx, (arg_name, replacement_val, _) in enumerate(args_info): + if matchers.matches(arg.keyword, matchers.Name(arg_name)): + return arg_name, replacement_val, idx + return None, None, None diff --git a/src/codemodder/codemods/semgrep.py b/src/codemodder/codemods/semgrep.py new file mode 100644 index 000000000..c3e073527 --- /dev/null +++ b/src/codemodder/codemods/semgrep.py @@ -0,0 +1,49 @@ +import io +import os +from pathlib import Path +import tempfile + +import yaml + +from codemodder.codemods.base_detector import BaseDetector +from codemodder.context import CodemodExecutionContext +from codemodder.result import ResultSet +from codemodder.semgrep import run as semgrep_run + + +def _populate_yaml(rule: str, codemod_id: str) -> str: + rule_yaml = yaml.safe_load(io.StringIO(rule)) + config = {"rules": rule_yaml} if "rules" not in rule_yaml else rule_yaml + config["rules"][0].setdefault("id", codemod_id) + config["rules"][0].setdefault("message", "Semgrep found a match") + config["rules"][0].setdefault("severity", "WARNING") + config["rules"][0].setdefault("languages", ["python"]) + return yaml.safe_dump(config) + + +def _create_temp_yaml_file(rule: str, codemod_id: str): + fd, path = tempfile.mkstemp() + with os.fdopen(fd, "w") as ff: + ff.write(_populate_yaml(rule, codemod_id)) + + return [Path(path)] + + +class SemgrepRuleDetector(BaseDetector): + rule: str + + def __init__(self, rule: str): + self.rule = rule + + def get_yaml_files(self, codemod_id: str) -> list[Path]: + return _create_temp_yaml_file(self.rule, codemod_id) + + def apply( + self, + codemod_id: str, + context: CodemodExecutionContext, + files_to_analyze: list[Path], + ) -> ResultSet: + yaml_files = self.get_yaml_files(codemod_id) + with context.timer.measure("semgrep"): + return semgrep_run(context, yaml_files, files_to_analyze) diff --git a/src/codemodder/context.py b/src/codemodder/context.py index 6f8b5e260..84811959d 100644 --- a/src/codemodder/context.py +++ b/src/codemodder/context.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import logging from pathlib import Path import itertools from textwrap import indent -from typing import List, Iterator +from typing import TYPE_CHECKING, List, Iterator from codemodder.change import ChangeSet from codemodder.dependency import ( @@ -10,7 +12,6 @@ build_dependency_notification, build_failed_dependency_notification, ) -from codemodder.executor import CodemodExecutorWrapper from codemodder.file_context import FileContext from codemodder.logging import logger, log_list from codemodder.project_analysis.file_parsers.package_store import PackageStore @@ -18,6 +19,9 @@ from codemodder.project_analysis.python_repo_manager import PythonRepoManager from codemodder.utils.timer import Timer +if TYPE_CHECKING: + from codemodder.codemods.base_codemod import BaseCodemod + class CodemodExecutionContext: # pylint: disable=too-many-instance-attributes _results_by_codemod: dict[str, list[ChangeSet]] = {} @@ -30,6 +34,9 @@ class CodemodExecutionContext: # pylint: disable=too-many-instance-attributes registry: CodemodRegistry repo_manager: PythonRepoManager timer: Timer + path_include: list[str] + path_exclude: list[str] + max_workers: int = 1 def __init__( self, @@ -38,6 +45,9 @@ def __init__( verbose: bool, registry: CodemodRegistry, repo_manager: PythonRepoManager, + path_include: list[str], + path_exclude: list[str], + max_workers: int = 1, ): # pylint: disable=too-many-arguments self.directory = directory self.dry_run = dry_run @@ -48,6 +58,9 @@ def __init__( self.registry = registry self.repo_manager = repo_manager self.timer = Timer() + self.path_include = path_include + self.path_exclude = path_exclude + self.max_workers = max_workers def add_results(self, codemod_name: str, change_sets: List[ChangeSet]): self._results_by_codemod.setdefault(codemod_name, []).extend(change_sets) @@ -116,7 +129,7 @@ def process_dependencies( return record - def add_description(self, codemod: CodemodExecutorWrapper): + def add_description(self, codemod: BaseCodemod): description = codemod.description if dependencies := list(self.dependencies.get(codemod.id, [])): if pkg_store := self._dependency_update_by_codemod.get(codemod.id): @@ -135,14 +148,14 @@ def process_results(self, codemod_id: str, results: Iterator[FileContext]): self.add_dependencies(codemod_id, file_context.dependencies) self.timer.aggregate(file_context.timer) - def compile_results(self, codemods: list[CodemodExecutorWrapper]): + def compile_results(self, codemods: list[BaseCodemod]): results = [] for codemod in codemods: data = { "codemod": codemod.id, "summary": codemod.summary, "description": self.add_description(codemod), - "references": codemod.references, + "references": [ref.to_json() for ref in codemod.references], "properties": {}, "failedFiles": [str(file) for file in self.get_failures(codemod.id)], "changeset": [ diff --git a/src/codemodder/dependency_management/setup_py_writer.py b/src/codemodder/dependency_management/setup_py_writer.py index 50eab85a8..fcba52880 100644 --- a/src/codemodder/dependency_management/setup_py_writer.py +++ b/src/codemodder/dependency_management/setup_py_writer.py @@ -2,8 +2,7 @@ from libcst.codemod import CodemodContext from libcst import matchers from typing import Optional -from codemodder.codemods.api import BaseCodemod -from codemodder.codemods.base_codemod import ReviewGuidance +from codemodder.codemods.api import SimpleCodemod, Metadata, ReviewGuidance from codemodder.codemods.utils import is_setup_py_file from codemodder.codemods.utils_mixin import NameResolutionMixin from codemodder.file_context import FileContext @@ -30,6 +29,7 @@ def add_to_file( CodemodContext(wrapper=wrapper), file_context, dependencies=[dep.requirement for dep in dependencies], + _transformer=True, ) output_tree = codemod.transform_module(input_tree) @@ -56,20 +56,21 @@ def _parse_file(self): return cst.parse_module(f.read()) -class SetupPyAddDependencies(BaseCodemod, NameResolutionMixin): - NAME = "setup-py-add-dependencies" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Add Dependencies to `setup.py` `install_requires`" - DESCRIPTION = SUMMARY - REFERENCES: list = [] +class SetupPyAddDependencies(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="setup-py-add-dependencies", + summary="Add Dependencies to `setup.py` `install_requires`", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + ) def __init__( self, codemod_context: CodemodContext, file_context: FileContext, dependencies: list[Requirement], + **kwargs, ): - BaseCodemod.__init__(self, codemod_context, file_context) + SimpleCodemod.__init__(self, codemod_context, [], file_context, **kwargs) NameResolutionMixin.__init__(self) self.filename = self.file_context.file_path self.dependencies = dependencies diff --git a/src/codemodder/executor.py b/src/codemodder/executor.py deleted file mode 100644 index 7e46ed3f5..000000000 --- a/src/codemodder/executor.py +++ /dev/null @@ -1,109 +0,0 @@ -from importlib.abc import Traversable -from pathlib import Path - -from wrapt import CallableObjectProxy - - -class CodemodExecutorWrapper(CallableObjectProxy): - """A wrapper around a codemod that provides additional metadata.""" - - origin: str - docs_module: Traversable - semgrep_config_module: Traversable - - def __init__( - self, - codemod, - origin: str, - docs_module: Traversable, - semgrep_config_module: Traversable, - ): - super().__init__(codemod) - self.origin = origin - self.docs_module = docs_module - self.semgrep_config_module = semgrep_config_module - - def apply(self, context, files: list[Path]): - """ - Wraps the codemod's apply method to inject additional arguments. - - Not all codemods will need these arguments. - """ - return self.apply_rule( - context, - yaml_files=self.yaml_files, - files_to_analyze=files, - ) - - @property - def name(self): - return self.__wrapped__.name() - - @property - def id(self): - return f"{self.origin}:python/{self.name}" - - @property - def is_semgrep(self): - return self.__wrapped__.is_semgrep - - @property - def summary(self): - return self.SUMMARY - - def _get_description(self): - doc_path = self.docs_module / f"{self.origin}_python_{self.name}.md" - return doc_path.read_text() - - @property - def description(self): - try: - return self._get_description() - except FileNotFoundError: - # TODO: temporary workaround - return self.METADATA.DESCRIPTION - - @property - def review_guidance(self): - return self.METADATA.REVIEW_GUIDANCE.name.replace("_", " ").title() - - @property - def references(self): - return self.METADATA.REFERENCES - - @property - def yaml_files(self): - return [ - self.semgrep_config_module / yaml_file - for yaml_file in getattr(self, "YAML_FILES", []) - ] - - def describe(self): - return { - "codemod": self.id, - "summary": self.summary, - "description": self.description, - "references": self.references, - } - - def __repr__(self): - return "<{} at 0x{:x} for {}.{}>".format( - type(self).__name__, - id(self), - self.__wrapped__.__module__, - self.__wrapped__.__name__, - ) - - # The following methods are all abstract in the ObjectProxy class, so - # we just implement them as simple pass-throughs to the wrapped object. - def __copy__(self): - return self.__wrapped__.__copy__() - - def __deepcopy__(self, memo): - return self.__wrapped__.__deepcopy__(memo) - - def __reduce__(self): - return self.__wrapped__.__reduce__() - - def __reduce_ex__(self, protocol): - return self.__wrapped__.__reduce_ex__(protocol) diff --git a/src/codemodder/registry.py b/src/codemodder/registry.py index b58a70e94..293efd9f7 100644 --- a/src/codemodder/registry.py +++ b/src/codemodder/registry.py @@ -1,11 +1,14 @@ -from dataclasses import dataclass, asdict -from importlib.resources import files +from __future__ import annotations + +from dataclasses import dataclass from importlib.metadata import entry_points -from typing import Optional +from typing import Optional, TYPE_CHECKING -from codemodder.executor import CodemodExecutorWrapper from codemodder.logging import logger +if TYPE_CHECKING: + from codemodder.codemods.base_codemod import BaseCodemod + # These are generally not intended to be applied directly so they are excluded by default. DEFAULT_EXCLUDED_CODEMODS = [ @@ -25,8 +28,8 @@ class CodemodCollection: class CodemodRegistry: - _codemods_by_name: dict[str, CodemodExecutorWrapper] - _codemods_by_id: dict[str, CodemodExecutorWrapper] + _codemods_by_name: dict[str, BaseCodemod] + _codemods_by_id: dict[str, BaseCodemod] def __init__(self): self._codemods_by_name = {} @@ -45,45 +48,16 @@ def codemods(self): return list(self._codemods_by_name.values()) def add_codemod_collection(self, collection: CodemodCollection): - docs_module = files(collection.docs_module) - semgrep_module = files(collection.semgrep_config_module) for codemod in collection.codemods: - self._validate_codemod(codemod) - wrapper = CodemodExecutorWrapper( - codemod, - collection.origin, - docs_module, - semgrep_module, - ) + wrapper = codemod() if isinstance(codemod, type) else codemod self._codemods_by_name[wrapper.name] = wrapper self._codemods_by_id[wrapper.id] = wrapper - def _validate_codemod(self, codemod): - for name in ["SUMMARY", "METADATA"]: - if not (attr := getattr(codemod, name)) or attr is NotImplemented: - raise ValueError( - f'Missing required attribute "{name}" on codemod {codemod}' - ) - - for k, v in asdict(codemod.METADATA).items(): - if v is NotImplemented: - raise NotImplementedError(f"METADATA.{k} not defined for {codemod}") - if k != "REFERENCES" and not v: - raise NotImplementedError( - f"METADATA.{k} should not be None or empty for {codemod}" - ) - - # TODO: eventually we will represent IS_SEMGREP using the class hierarchy - if codemod.is_semgrep and not codemod.YAML_FILES: - raise ValueError( - f"Missing required attribute YAML_FILES on semgrep codemod {codemod}" - ) - def match_codemods( self, codemod_include: Optional[list] = None, codemod_exclude: Optional[list] = None, - ) -> list[CodemodExecutorWrapper]: + ) -> list[BaseCodemod]: codemod_include = codemod_include or [] codemod_exclude = codemod_exclude or DEFAULT_EXCLUDED_CODEMODS diff --git a/src/core_codemods/add_requests_timeouts.py b/src/core_codemods/add_requests_timeouts.py index 1f48d6d36..cadbf3494 100644 --- a/src/core_codemods/add_requests_timeouts.py +++ b/src/core_codemods/add_requests_timeouts.py @@ -1,45 +1,60 @@ -from codemodder.codemods.api import SemgrepCodemod, ReviewGuidance +from core_codemods.api import ( + CoreCodemod, + Metadata, + Reference, + ReviewGuidance, +) +from codemodder.codemods.libcst_transformer import ( + LibcstTransformerPipeline, + LibcstResultTransformer, +) +from codemodder.codemods.semgrep import SemgrepRuleDetector -class AddRequestsTimeouts(SemgrepCodemod): - NAME = "add-requests-timeouts" - SUMMARY = "Add timeout to `requests` calls" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - DESCRIPTION = "Add timeout to `requests` call" - REFERENCES = [ - { - "url": "https://docs.python-requests.org/en/master/user/quickstart/#timeouts", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ - rules: - - patterns: - - pattern-inside: | - import requests - ... - - pattern: $CALL(...) - - pattern-not: $CALL(..., timeout=$TIMEOUT, ...) - - metavariable-pattern: - metavariable: $CALL - patterns: - - pattern-either: - - pattern: requests.get - - pattern: requests.post - - pattern: requests.put - - pattern: requests.delete - - pattern: requests.head - - pattern: requests.options - - pattern: requests.patch - - pattern: requests.request - """ - +class TransformAddRequestsTimeouts(LibcstResultTransformer): # Sets an arbitrary default timeout for all requests DEFAULT_TIMEOUT = 60 + change_description = "Add timeout to `requests` call" + def on_result_found(self, original_node, updated_node): del original_node return self.add_arg_to_call(updated_node, "timeout", self.DEFAULT_TIMEOUT) + + +# This codemod uses the lower level codemod and transformer APIs for the sake of example. +AddRequestsTimeouts = CoreCodemod( + metadata=Metadata( + name="add-requests-timeouts", + summary="Add timeout to `requests` calls", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://docs.python-requests.org/en/master/user/quickstart/#timeouts" + ), + ], + ), + detector=SemgrepRuleDetector( + """ + - patterns: + - pattern-inside: | + import requests + ... + - pattern: $CALL(...) + - pattern-not: $CALL(..., timeout=$TIMEOUT, ...) + - metavariable-pattern: + metavariable: $CALL + patterns: + - pattern-either: + - pattern: requests.get + - pattern: requests.post + - pattern: requests.put + - pattern: requests.delete + - pattern: requests.head + - pattern: requests.options + - pattern: requests.patch + - pattern: requests.request + """ + ), + transformer=LibcstTransformerPipeline(TransformAddRequestsTimeouts), +) diff --git a/src/core_codemods/api/__init__.py b/src/core_codemods/api/__init__.py new file mode 100644 index 000000000..904e6b82d --- /dev/null +++ b/src/core_codemods/api/__init__.py @@ -0,0 +1,6 @@ +from codemodder.codemods.api import ( + Metadata, + Reference, + ReviewGuidance, +) +from .core_codemod import SimpleCodemod, CoreCodemod diff --git a/src/core_codemods/api/core_codemod.py b/src/core_codemods/api/core_codemod.py new file mode 100644 index 000000000..deefb4e4c --- /dev/null +++ b/src/core_codemods/api/core_codemod.py @@ -0,0 +1,23 @@ +from codemodder.codemods.api import BaseCodemod, SimpleCodemod as _SimpleCodemod + + +class CoreCodemod(BaseCodemod): + """ + Base class for all core codemods provided by this package. + """ + + @property + def origin(self): + return "pixee" + + @property + def docs_module_path(self): + return "core_codemods.docs" + + +class SimpleCodemod(_SimpleCodemod): + """ + Base class for all core codemods with a single detector and transformer. + """ + + codemod_base = CoreCodemod diff --git a/src/core_codemods/combine_startswith_endswith.py b/src/core_codemods/combine_startswith_endswith.py index 070497bc1..3c2e88370 100644 --- a/src/core_codemods/combine_startswith_endswith.py +++ b/src/core_codemods/combine_startswith_endswith.py @@ -1,14 +1,17 @@ import libcst as cst from libcst import matchers as m -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod -class CombineStartswithEndswith(BaseCodemod, NameResolutionMixin): - NAME = "combine-startswith-endswith" - SUMMARY = "Simplify Boolean Expressions Using `startswith` and `endswith`" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Use tuple of matches instead of boolean expression" +class CombineStartswithEndswith(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="combine-startswith-endswith", + summary="Simplify Boolean Expressions Using `startswith` and `endswith`", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[], + ) + change_description = "Use tuple of matches instead of boolean expression" REFERENCES: list = [] def leave_BooleanOperation( diff --git a/src/core_codemods/django_debug_flag_on.py b/src/core_codemods/django_debug_flag_on.py index 38b23c400..3a1a26438 100644 --- a/src/core_codemods/django_debug_flag_on.py +++ b/src/core_codemods/django_debug_flag_on.py @@ -1,28 +1,29 @@ import libcst as cst -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils import is_django_settings_file +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class DjangoDebugFlagOn(SemgrepCodemod): - NAME = "django-debug-flag-on" - DESCRIPTION = "Flip `Django` debug flag to off." - SUMMARY = "Disable Django Debug Mode" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - REFERENCES = [ - { - "url": "https://owasp.org/www-project-top-ten/2017/A3_2017-Sensitive_Data_Exposure", - "description": "", - }, - { - "url": "https://docs.djangoproject.com/en/4.2/ref/settings/#std-setting-DEBUG", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ +class DjangoDebugFlagOn(SimpleCodemod): + metadata = Metadata( + name="django-debug-flag-on", + summary="Disable Django Debug Mode", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://owasp.org/www-project-top-ten/2017/A3_2017-Sensitive_Data_Exposure" + ), + Reference( + url="https://docs.djangoproject.com/en/4.2/ref/settings/#std-setting-DEBUG" + ), + ], + ) + change_description = "Flip `Django` debug flag to off." + detector_pattern = """ rules: - id: django-debug-flag-on pattern: DEBUG = True diff --git a/src/core_codemods/django_json_response_type.py b/src/core_codemods/django_json_response_type.py index bc1a963bd..a3337ed08 100644 --- a/src/core_codemods/django_json_response_type.py +++ b/src/core_codemods/django_json_response_type.py @@ -1,28 +1,28 @@ import libcst as cst +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod - -class DjangoJsonResponseType(SemgrepCodemod): - NAME = "django-json-response-type" - SUMMARY = "Set content type to `application/json` for `django.http.HttpResponse` with JSON data" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Sets `content_type` to `application/json`." - REFERENCES = [ - { - "url": "https://docs.djangoproject.com/en/4.0/ref/request-response/#django.http.HttpResponse.__init__", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/Cross_Site_Scripting_Prevention_Cheat_Sheet.html#output-encoding-for-javascript-contexts", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ +class DjangoJsonResponseType(SimpleCodemod): + metadata = Metadata( + name="django-json-response-type", + summary="Set content type to `application/json` for `django.http.HttpResponse` with JSON data", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.djangoproject.com/en/4.0/ref/request-response/#django.http.HttpResponse.__init__" + ), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/Cross_Site_Scripting_Prevention_Cheat_Sheet.html#output-encoding-for-javascript-contexts" + ), + ], + ) + change_description = "Sets `content_type` to `application/json`." + detector_pattern = """ rules: - id: django-json-response-type mode: taint diff --git a/src/core_codemods/django_receiver_on_top.py b/src/core_codemods/django_receiver_on_top.py index 6ebeab0ac..331b73eff 100644 --- a/src/core_codemods/django_receiver_on_top.py +++ b/src/core_codemods/django_receiver_on_top.py @@ -1,22 +1,24 @@ from typing import Union import libcst as cst -from codemodder.codemods.api import BaseCodemod -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class DjangoReceiverOnTop(BaseCodemod, NameResolutionMixin): - NAME = "django-receiver-on-top" - SUMMARY = "Ensure Django @receiver is the first decorator" - DESCRIPTION = SUMMARY - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - REFERENCES = [ - { - "url": "https://docs.djangoproject.com/en/4.1/topics/signals/", - "description": "", - }, - ] - CHANGE_DESCRIPTION = "Moved @receiver to the top." +class DjangoReceiverOnTop(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="django-receiver-on-top", + summary="Ensure Django @receiver is the first decorator", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference(url="https://docs.djangoproject.com/en/4.1/topics/signals/"), + ], + ) + change_description = "Moved @receiver to the top." def leave_FunctionDef( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef diff --git a/src/core_codemods/django_session_cookie_secure_off.py b/src/core_codemods/django_session_cookie_secure_off.py index bf1bf4b72..2244b5801 100644 --- a/src/core_codemods/django_session_cookie_secure_off.py +++ b/src/core_codemods/django_session_cookie_secure_off.py @@ -1,28 +1,29 @@ import libcst as cst -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils import is_django_settings_file, is_assigned_to_True +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class DjangoSessionCookieSecureOff(SemgrepCodemod): - NAME = "django-session-cookie-secure-off" - DESCRIPTION = "Sets Django's `SESSION_COOKIE_SECURE` flag if off or missing." - SUMMARY = "Secure Setting for Django `SESSION_COOKIE_SECURE` flag" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - REFERENCES = [ - { - "url": "https://owasp.org/www-community/controls/SecureCookieAttribute", - "description": "", - }, - { - "url": "https://docs.djangoproject.com/en/4.2/ref/settings/#session-cookie-secure", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ +class DjangoSessionCookieSecureOff(SimpleCodemod): + metadata = Metadata( + name="django-session-cookie-secure-off", + summary="Secure Setting for Django `SESSION_COOKIE_SECURE` flag", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://owasp.org/www-community/controls/SecureCookieAttribute" + ), + Reference( + url="https://docs.djangoproject.com/en/4.2/ref/settings/#session-cookie-secure" + ), + ], + ) + change_description = "Sets Django's `SESSION_COOKIE_SECURE` flag if off or missing." + detector_pattern = """ rules: - id: django-session-cookie-secure-off # This pattern creates one finding with no text for settings.py file. @@ -32,8 +33,8 @@ def rule(cls): - settings.py """ - def __init__(self, *args): - super().__init__(*args) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.is_django_settings_file = is_django_settings_file( self.file_context.file_path ) @@ -60,7 +61,7 @@ def leave_Module( # something else and we changed it in `leave_Assign`. return updated_node - self.add_change(original_node, self.CHANGE_DESCRIPTION, start=False) + self.add_change(original_node, self.change_description, start=False) final_line = cst.parse_statement("SESSION_COOKIE_SECURE = True") new_body = updated_node.body + (final_line,) return updated_node.with_changes(body=new_body) @@ -80,7 +81,7 @@ def leave_Assign( return updated_node # SESSION_COOKIE_SECURE = anything other than True - self.add_change(original_node, self.CHANGE_DESCRIPTION) + self.add_change(original_node, self.change_description) return updated_node.with_changes(value=cst.Name("True")) return updated_node diff --git a/src/core_codemods/enable_jinja2_autoescape.py b/src/core_codemods/enable_jinja2_autoescape.py index 470aac6e8..d4426adef 100644 --- a/src/core_codemods/enable_jinja2_autoescape.py +++ b/src/core_codemods/enable_jinja2_autoescape.py @@ -1,24 +1,28 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class EnableJinja2Autoescape(SemgrepCodemod): - NAME = "enable-jinja2-autoescape" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Enable Jinja2 Autoescape" - DESCRIPTION = "Sets the `autoescape` parameter in jinja2.Environment to `True`." - REFERENCES = [ - {"url": "https://owasp.org/www-community/attacks/xss/", "description": ""}, - { - "url": "https://jinja.palletsprojects.com/en/3.1.x/api/#autoescaping", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ +class EnableJinja2Autoescape(SimpleCodemod): + metadata = Metadata( + name="enable-jinja2-autoescape", + summary="Enable Jinja2 Autoescape", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference(url="https://owasp.org/www-community/attacks/xss/"), + Reference( + url="https://jinja.palletsprojects.com/en/3.1.x/api/#autoescaping" + ), + ], + ) + change_description = ( + "Sets the `autoescape` parameter in jinja2.Environment to `True`." + ) + detector_pattern = """ rules: - pattern-either: - patterns: diff --git a/src/core_codemods/exception_without_raise.py b/src/core_codemods/exception_without_raise.py index 3a796a321..c848a41fb 100644 --- a/src/core_codemods/exception_without_raise.py +++ b/src/core_codemods/exception_without_raise.py @@ -1,24 +1,28 @@ from typing import Union import libcst as cst -from codemodder.codemods.api import BaseCodemod -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin from codemodder.utils.utils import full_qualified_name_from_class, list_subclasses +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class ExceptionWithoutRaise(BaseCodemod, NameResolutionMixin): - NAME = "exception-without-raise" - SUMMARY = "Ensure bare exception statements are raised" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = SUMMARY - REFERENCES = [ - { - "url": "https://docs.python.org/3/tutorial/errors.html#raising-exceptions", - "description": "", - }, - ] - CHANGE_DESCRIPTION = "Raised bare exception statement" +class ExceptionWithoutRaise(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="exception-without-raise", + summary="Ensure bare exception statements are raised", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/tutorial/errors.html#raising-exceptions" + ), + ], + ) + change_description = "Raised bare exception statement" def leave_SimpleStatementLine( self, diff --git a/src/core_codemods/file_resource_leak.py b/src/core_codemods/file_resource_leak.py index 24992190c..ee4bb4dae 100644 --- a/src/core_codemods/file_resource_leak.py +++ b/src/core_codemods/file_resource_leak.py @@ -13,32 +13,29 @@ ScopeProvider, ) from codemodder.change import Change -from codemodder.codemods.base_codemod import ( - ReviewGuidance, -) -from codemodder.codemods.api import BaseCodemod from codemodder.codemods.utils import MetadataPreservingTransformer from codemodder.codemods.utils_mixin import AncestorPatternsMixin, NameResolutionMixin from codemodder.file_context import FileContext from functools import partial +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class FileResourceLeak(BaseCodemod): - NAME = "fix-file-resource-leak" - SUMMARY = "Automatically Close Resources" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = SUMMARY - REFERENCES = [ - { - "url": "https://cwe.mitre.org/data/definitions/772.html", - "description": "", - }, - { - "url": "https://cwe.mitre.org/data/definitions/404.html", - "description": "", - }, - ] - CHANGE_DESCRIPTION = "Wrapped opened resource in a with statement." +class FileResourceLeak(SimpleCodemod): + metadata = Metadata( + name="fix-file-resource-leak", + summary="Automatically Close Resources", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference(url="https://cwe.mitre.org/data/definitions/772.html"), + Reference(url="https://cwe.mitre.org/data/definitions/404.html"), + ], + ) + change_description = "Wrapped opened resource in a with statement." METADATA_DEPENDENCIES = ( PositionProvider, @@ -51,11 +48,14 @@ def __init__( context: CodemodContext, file_context: FileContext, *codemod_args, + **codemod_kwargs, ) -> None: self.changed_nodes: dict[ cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel ] = {} - BaseCodemod.__init__(self, context, file_context, *codemod_args) + SimpleCodemod.__init__( + self, context, file_context, *codemod_args, **codemod_kwargs + ) def transform_module_impl(self, tree: cst.Module) -> cst.Module: fr = FindResources(self.context) @@ -200,7 +200,7 @@ def leave_Module(self, original_node: cst.Module, updated_node) -> cst.Module: if all(name_condition): line_number = self.get_metadata(PositionProvider, resource).start.line self.changes.append( - Change(line_number, FileResourceLeak.CHANGE_DESCRIPTION) + Change(line_number, FileResourceLeak.change_description) ) last_index = self._find_last_index_with_access( named_targets, block, index diff --git a/src/core_codemods/fix_deprecated_abstractproperty.py b/src/core_codemods/fix_deprecated_abstractproperty.py index 3ed64a112..43593ef17 100644 --- a/src/core_codemods/fix_deprecated_abstractproperty.py +++ b/src/core_codemods/fix_deprecated_abstractproperty.py @@ -1,20 +1,27 @@ import libcst as cst - -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class FixDeprecatedAbstractproperty(BaseCodemod, NameResolutionMixin): - NAME = "fix-deprecated-abstractproperty" - SUMMARY = "Replace deprecated abstractproperty" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Replace deprecated abstractproperty with property and abstractmethod" - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/abc.html#abc.abstractproperty", - "description": "", - }, - ] +class FixDeprecatedAbstractproperty(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="fix-deprecated-abstractproperty", + summary="Replace deprecated abstractproperty", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/abc.html#abc.abstractproperty" + ), + ], + ) + change_description = ( + "Replace deprecated abstractproperty with property and abstractmethod" + ) def leave_Decorator( self, original_node: cst.Decorator, updated_node: cst.Decorator diff --git a/src/core_codemods/fix_deprecated_logging_warn.py b/src/core_codemods/fix_deprecated_logging_warn.py index 6025d2475..291d6fcce 100644 --- a/src/core_codemods/fix_deprecated_logging_warn.py +++ b/src/core_codemods/fix_deprecated_logging_warn.py @@ -1,24 +1,27 @@ import libcst as cst -from codemodder.codemods.api import SemgrepCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class FixDeprecatedLoggingWarn(SemgrepCodemod, NameResolutionMixin): - NAME = "fix-deprecated-logging-warn" - SUMMARY = "Replace Deprecated `logging.warn`" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Replace deprecated `logging.warn` with `logging.warning`" - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/logging.html#logging.Logger.warning", - "description": "", - }, - ] +class FixDeprecatedLoggingWarn(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="fix-deprecated-logging-warn", + summary="Replace Deprecated `logging.warn`", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/logging.html#logging.Logger.warning" + ), + ], + ) + change_description = "Replace deprecated `logging.warn` with `logging.warning`" _module_name = "logging" - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - pattern-either: - patterns: diff --git a/src/core_codemods/fix_mutable_params.py b/src/core_codemods/fix_mutable_params.py index a1ebf8985..159aa52dc 100644 --- a/src/core_codemods/fix_mutable_params.py +++ b/src/core_codemods/fix_mutable_params.py @@ -1,15 +1,16 @@ import libcst as cst from libcst import matchers as m +from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import BaseCodemod - -class FixMutableParams(BaseCodemod): - NAME = "fix-mutable-params" - SUMMARY = "Replace Mutable Default Parameters" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Replace mutable parameter with `None`." +class FixMutableParams(SimpleCodemod): + metadata = Metadata( + name="fix-mutable-params", + summary="Replace Mutable Default Parameters", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[], + ) + change_description = "Replace mutable parameter with `None`." REFERENCES: list = [] _BUILTIN_TO_LITERAL = { "list": cst.List(elements=[]), @@ -166,7 +167,7 @@ def leave_FunctionDef( ) if new_var_decls: # If we're adding statements to the body, we know a change took place - self.add_change(original_node, self.CHANGE_DESCRIPTION) + self.add_change(original_node, self.change_description) if add_annotation: self.add_needed_import("typing", "Optional") diff --git a/src/core_codemods/flask_json_response_type.py b/src/core_codemods/flask_json_response_type.py index 959e5674c..273ab7c1d 100644 --- a/src/core_codemods/flask_json_response_type.py +++ b/src/core_codemods/flask_json_response_type.py @@ -1,28 +1,31 @@ from typing import Optional, Tuple import libcst as cst from libcst.codemod import CodemodContext, ContextAwareVisitor -from codemodder.codemods.api import BaseCodemod - -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin - - -class FlaskJsonResponseType(BaseCodemod, NameAndAncestorResolutionMixin): - NAME = "flask-json-response-type" - SUMMARY = "Set content type to `application/json` for `flask.make_response` with JSON data" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Sets `mimetype` to `application/json`." - REFERENCES = [ - { - "url": "https://flask.palletsprojects.com/en/2.3.x/patterns/javascript/#return-json-from-views", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/Cross_Site_Scripting_Prevention_Cheat_Sheet.html#output-encoding-for-javascript-contexts", - "description": "", - }, - ] +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) + + +class FlaskJsonResponseType(SimpleCodemod, NameAndAncestorResolutionMixin): + metadata = Metadata( + name="flask-json-response-type", + summary="Set content type to `application/json` for `flask.make_response` with JSON data", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://flask.palletsprojects.com/en/2.3.x/patterns/javascript/#return-json-from-views" + ), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/Cross_Site_Scripting_Prevention_Cheat_Sheet.html#output-encoding-for-javascript-contexts" + ), + ], + ) + change_description = "Sets `mimetype` to `application/json`." content_type_key = "Content-Type" json_content_type = "application/json" diff --git a/src/core_codemods/harden_pyyaml.py b/src/core_codemods/harden_pyyaml.py index 1387ea78f..2dac634a6 100644 --- a/src/core_codemods/harden_pyyaml.py +++ b/src/core_codemods/harden_pyyaml.py @@ -1,70 +1,71 @@ from typing import Union import libcst as cst -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class HardenPyyaml(SemgrepCodemod, NameResolutionMixin): - NAME = "harden-pyyaml" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Replace unsafe `pyyaml` loader with `SafeLoader`" - DESCRIPTION = "Replace unsafe `pyyaml` loader with `SafeLoader` in calls to `yaml.load` or custom loader classes." - REFERENCES = [ - { - "url": "https://owasp.org/www-community/vulnerabilities/Deserialization_of_untrusted_data", - "description": "", - } - ] +class HardenPyyaml(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="harden-pyyaml", + summary="Replace unsafe `pyyaml` loader with `SafeLoader`", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://owasp.org/www-community/vulnerabilities/Deserialization_of_untrusted_data" + ), + ], + ) + change_description = "Replace unsafe `pyyaml` loader with `SafeLoader` in calls to `yaml.load` or custom loader classes." _module_name = "yaml" - - @classmethod - def rule(cls): - return """ - rules: - - pattern-either: - - patterns: - - pattern: yaml.load(...) - - pattern-inside: | - import yaml - ... - yaml.load(...,$ARG) - - metavariable-pattern: - metavariable: $ARG - patterns: - - pattern-either: - - pattern: yaml.Loader - - pattern: yaml.BaseLoader - - pattern: yaml.FullLoader - - pattern: yaml.UnsafeLoader - - patterns: - - pattern: yaml.load(...) - - pattern-inside: | - import yaml - ... - yaml.load(...,Loader=$ARG) - - metavariable-pattern: - metavariable: $ARG - patterns: - - pattern-either: - - pattern: yaml.Loader - - pattern: yaml.BaseLoader - - pattern: yaml.FullLoader - - pattern: yaml.UnsafeLoader - - patterns: - - pattern: | - class $X(...,$LOADER, ...): - ... - - metavariable-pattern: - metavariable: $LOADER - patterns: - - pattern-either: - - pattern: yaml.Loader - - pattern: yaml.BaseLoader - - pattern: yaml.FullLoader - - pattern: yaml.UnsafeLoader - + detector_pattern = """ + rules: + - pattern-either: + - patterns: + - pattern: yaml.load(...) + - pattern-inside: | + import yaml + ... + yaml.load(...,$ARG) + - metavariable-pattern: + metavariable: $ARG + patterns: + - pattern-either: + - pattern: yaml.Loader + - pattern: yaml.BaseLoader + - pattern: yaml.FullLoader + - pattern: yaml.UnsafeLoader + - patterns: + - pattern: yaml.load(...) + - pattern-inside: | + import yaml + ... + yaml.load(...,Loader=$ARG) + - metavariable-pattern: + metavariable: $ARG + patterns: + - pattern-either: + - pattern: yaml.Loader + - pattern: yaml.BaseLoader + - pattern: yaml.FullLoader + - pattern: yaml.UnsafeLoader + - patterns: + - pattern: | + class $X(...,$LOADER, ...): + ... + - metavariable-pattern: + metavariable: $LOADER + patterns: + - pattern-either: + - pattern: yaml.Loader + - pattern: yaml.BaseLoader + - pattern: yaml.FullLoader + - pattern: yaml.UnsafeLoader """ def on_result_found( diff --git a/src/core_codemods/harden_ruamel.py b/src/core_codemods/harden_ruamel.py index 8eb0a12b1..44fd73425 100644 --- a/src/core_codemods/harden_ruamel.py +++ b/src/core_codemods/harden_ruamel.py @@ -1,23 +1,27 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class HardenRuamel(SemgrepCodemod): - NAME = "harden-ruamel" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Use `typ='safe'` in ruamel.yaml() Calls" - DESCRIPTION = "Ensures all unsafe calls to ruamel.yaml.YAML use `typ='safe'`." - REFERENCES = [ - { - "url": "https://owasp.org/www-community/vulnerabilities/Deserialization_of_untrusted_data", - "description": "", - } - ] - - @classmethod - def rule(cls): - return """ +class HardenRuamel(SimpleCodemod): + metadata = Metadata( + name="harden-ruamel", + summary="Use `typ='safe'` in ruamel.yaml() Calls", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://owasp.org/www-community/vulnerabilities/Deserialization_of_untrusted_data" + ), + ], + ) + change_description = ( + "Ensures all unsafe calls to ruamel.yaml.YAML use `typ='safe'`." + ) + detector_pattern = """ rules: - pattern-either: - patterns: diff --git a/src/core_codemods/https_connection.py b/src/core_codemods/https_connection.py index 1a011a169..a1b2eeb1a 100644 --- a/src/core_codemods/https_connection.py +++ b/src/core_codemods/https_connection.py @@ -3,10 +3,13 @@ import libcst as cst from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor from libcst.metadata import PositionProvider - -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import BaseCodemod from codemodder.codemods.imported_call_modifier import ImportedCallModifier +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) class HTTPSConnectionModifier(ImportedCallModifier[Set[str]]): @@ -48,20 +51,20 @@ def count_positional_args(self, arglist: Sequence[cst.Arg]) -> int: return len(arglist) -class HTTPSConnection(BaseCodemod): - SUMMARY = "Enforce HTTPS Connection for `urllib3`" - NAME = "https-connection" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - REFERENCES = [ - { - "url": "https://owasp.org/www-community/vulnerabilities/Insecure_Transport", - "description": "", - }, - { - "url": "https://urllib3.readthedocs.io/en/stable/reference/urllib3.connectionpool.html#urllib3.HTTPConnectionPool", - "description": "", - }, - ] +class HTTPSConnection(SimpleCodemod): + metadata = Metadata( + name="https-connection", + summary="Enforce HTTPS Connection for `urllib3`", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://owasp.org/www-community/vulnerabilities/Insecure_Transport" + ), + Reference( + url="https://urllib3.readthedocs.io/en/stable/reference/urllib3.connectionpool.html#urllib3.HTTPConnectionPool" + ), + ], + ) METADATA_DEPENDENCIES = (PositionProvider,) @@ -75,7 +78,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.context, self.file_context, self.matching_functions, - self.CHANGE_DESCRIPTION, + self.change_description, ) result_tree = visitor.transform_module(tree) self.file_context.codemod_changes.extend(visitor.changes_in_file) diff --git a/src/core_codemods/jwt_decode_verify.py b/src/core_codemods/jwt_decode_verify.py index 343ac94a9..604a5c82d 100644 --- a/src/core_codemods/jwt_decode_verify.py +++ b/src/core_codemods/jwt_decode_verify.py @@ -1,26 +1,28 @@ import libcst as cst from libcst import matchers -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class JwtDecodeVerify(SemgrepCodemod): - NAME = "jwt-decode-verify" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Verify JWT Decode" - DESCRIPTION = "Enable all verifications in `jwt.decode` call." - REFERENCES = [ - {"url": "https://pyjwt.readthedocs.io/en/stable/api.html", "description": ""}, - { - "url": "https://owasp.org/www-project-web-security-testing-guide/latest/4-Web_Application_Security_Testing/06-Session_Management_Testing/10-Testing_JSON_Web_Tokens", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return r""" +class JwtDecodeVerify(SimpleCodemod): + metadata = Metadata( + name="jwt-decode-verify", + summary="Verify JWT Decode", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference(url="https://pyjwt.readthedocs.io/en/stable/api.html"), + Reference( + url="https://owasp.org/www-project-web-security-testing-guide/latest/4-Web_Application_Security_Testing/06-Session_Management_Testing/10-Testing_JSON_Web_Tokens" + ), + ], + ) + change_description = "Enable all verifications in `jwt.decode` call." + detector_pattern = r""" rules: - pattern-either: - patterns: diff --git a/src/core_codemods/limit_readline.py b/src/core_codemods/limit_readline.py index 694780b04..81f527341 100644 --- a/src/core_codemods/limit_readline.py +++ b/src/core_codemods/limit_readline.py @@ -1,23 +1,26 @@ import libcst as cst -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) default_limit = "5_000_000" -class LimitReadline(SemgrepCodemod): - NAME = "limit-readline" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - SUMMARY = "Limit readline()" - DESCRIPTION = "Adds a size limit argument to readline() calls." - REFERENCES = [ - {"url": "https://cwe.mitre.org/data/definitions/400.html", "description": ""} - ] - - @classmethod - def rule(cls): - return """ +class LimitReadline(SimpleCodemod): + metadata = Metadata( + name="limit-readline", + summary="Limit readline()", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference(url="https://cwe.mitre.org/data/definitions/400.html"), + ], + ) + change_description = "Adds a size limit argument to readline() calls." + detector_pattern = """ rules: - id: limit-readline mode: taint diff --git a/src/core_codemods/literal_or_new_object_identity.py b/src/core_codemods/literal_or_new_object_identity.py index e6d9f73ba..32f907aa9 100644 --- a/src/core_codemods/literal_or_new_object_identity.py +++ b/src/core_codemods/literal_or_new_object_identity.py @@ -1,22 +1,26 @@ import libcst as cst -from codemodder.codemods.api import BaseCodemod -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class LiteralOrNewObjectIdentity(BaseCodemod, NameAndAncestorResolutionMixin): - NAME = "literal-or-new-object-identity" - SUMMARY = "Replaces is operator with == for literal or new object comparisons" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = SUMMARY - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/stdtypes.html#comparisons", - "description": "", - }, - ] - CHANGE_DESCRIPTION = "Replaces is operator with ==" +class LiteralOrNewObjectIdentity(SimpleCodemod, NameAndAncestorResolutionMixin): + metadata = Metadata( + name="literal-or-new-object-identity", + summary="Replaces is operator with == for literal or new object comparisons", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/stdtypes.html#comparisons" + ), + ], + ) + change_description = "Replaces is operator with ==" def _is_object_creation_or_literal(self, node: cst.BaseExpression): match node: diff --git a/src/core_codemods/lxml_safe_parser_defaults.py b/src/core_codemods/lxml_safe_parser_defaults.py index 20b47b844..d94450404 100644 --- a/src/core_codemods/lxml_safe_parser_defaults.py +++ b/src/core_codemods/lxml_safe_parser_defaults.py @@ -1,31 +1,31 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class LxmlSafeParserDefaults(SemgrepCodemod): - NAME = "safe-lxml-parser-defaults" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Use Safe Defaults for `lxml` Parsers" - DESCRIPTION = "Replace `lxml` parser parameters with safe defaults." - REFERENCES = [ - { - "url": "https://lxml.de/apidoc/lxml.etree.html#lxml.etree.XMLParser", - "description": "", - }, - { - "url": "https://owasp.org/www-community/vulnerabilities/XML_External_Entity_(XXE)_Processing", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ +class LxmlSafeParserDefaults(SimpleCodemod): + metadata = Metadata( + name="safe-lxml-parser-defaults", + summary="Use Safe Defaults for `lxml` Parsers", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://lxml.de/apidoc/lxml.etree.html#lxml.etree.XMLParser" + ), + Reference( + url="https://owasp.org/www-community/vulnerabilities/XML_External_Entity_(XXE)_Processing" + ), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html" + ), + ], + ) + change_description = "Replace `lxml` parser parameters with safe defaults." + detector_pattern = """ rules: - patterns: - pattern: lxml.etree.$CLASS(...) diff --git a/src/core_codemods/lxml_safe_parsing.py b/src/core_codemods/lxml_safe_parsing.py index a7a981016..aa5c9b26b 100644 --- a/src/core_codemods/lxml_safe_parsing.py +++ b/src/core_codemods/lxml_safe_parsing.py @@ -1,33 +1,33 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class LxmlSafeParsing(SemgrepCodemod): - NAME = "safe-lxml-parsing" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Use Safe Parsers in `lxml` Parsing Functions" - DESCRIPTION = ( +class LxmlSafeParsing(SimpleCodemod): + metadata = Metadata( + name="safe-lxml-parsing", + summary="Use Safe Parsers in `lxml` Parsing Functions", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://lxml.de/apidoc/lxml.etree.html#lxml.etree.XMLParser" + ), + Reference( + url="https://owasp.org/www-community/vulnerabilities/XML_External_Entity_(XXE)_Processing" + ), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html" + ), + ], + ) + change_description = ( "Call `lxml.etree.parse` and `lxml.etree.fromstring` with a safe parser." ) - REFERENCES = [ - { - "url": "https://lxml.de/apidoc/lxml.etree.html#lxml.etree.XMLParser", - "description": "", - }, - { - "url": "https://owasp.org/www-community/vulnerabilities/XML_External_Entity_(XXE)_Processing", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - pattern-either: - patterns: diff --git a/src/core_codemods/numpy_nan_equality.py b/src/core_codemods/numpy_nan_equality.py index 912439ec8..e63016c66 100644 --- a/src/core_codemods/numpy_nan_equality.py +++ b/src/core_codemods/numpy_nan_equality.py @@ -1,23 +1,27 @@ import libcst as cst from libcst import UnaryOperation -from codemodder.codemods.api import BaseCodemod -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class NumpyNanEquality(BaseCodemod, NameResolutionMixin): - NAME = "numpy-nan-equality" - SUMMARY = "Replace == comparison with numpy.isnan()" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = SUMMARY - REFERENCES = [ - { - "url": "https://numpy.org/doc/stable/reference/constants.html#numpy.nan", - "description": "", - }, - ] - CHANGE_DESCRIPTION = "Replaces == check with numpy.isnan()." +class NumpyNanEquality(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="numpy-nan-equality", + summary="Replace == comparison with numpy.isnan()", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://numpy.org/doc/stable/reference/constants.html#numpy.nan" + ), + ], + ) + change_description = "Replaces == check with numpy.isnan()." np_nan = "numpy.nan" diff --git a/src/core_codemods/order_imports.py b/src/core_codemods/order_imports.py index 17ecd2321..f6c497186 100644 --- a/src/core_codemods/order_imports.py +++ b/src/core_codemods/order_imports.py @@ -1,34 +1,25 @@ +import libcst as cst from libcst.metadata import PositionProvider -from codemodder.codemods.base_codemod import ( - BaseCodemod, - CodemodMetadata, - ReviewGuidance, -) + +from core_codemods.api import SimpleCodemod, Metadata, ReviewGuidance from codemodder.change import Change from codemodder.codemods.transformations.clean_imports import ( GatherTopLevelImportBlocks, OrderImportsBlocksTransform, ) -import libcst as cst -from libcst.codemod import Codemod, CodemodContext -class OrderImports(BaseCodemod, Codemod): - METADATA = CodemodMetadata( - DESCRIPTION=("Formats and orders imports by categories."), - NAME="order-imports", - REVIEW_GUIDANCE=ReviewGuidance.MERGE_WITHOUT_REVIEW, - REFERENCES=[], +class OrderImports(SimpleCodemod): + metadata = Metadata( + name="order-imports", + summary="Order Imports", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + has_description=False, ) - SUMMARY = "Order Imports" - CHANGE_DESCRIPTION = "Ordered and formatted import block below this line" + change_description = "Ordered and formatted import block below this line" METADATA_DEPENDENCIES = (PositionProvider,) - def __init__(self, codemod_context: CodemodContext, *codemod_args): - Codemod.__init__(self, codemod_context) - BaseCodemod.__init__(self, *codemod_args) - def transform_module_impl(self, tree: cst.Module) -> cst.Module: top_imports_visitor = GatherTopLevelImportBlocks() tree.visit(top_imports_visitor) @@ -54,7 +45,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: top_imports_visitor.top_imports_blocks[i][0] ).start.line self.file_context.codemod_changes.append( - Change(line_number, self.CHANGE_DESCRIPTION) + Change(line_number, self.change_description) ) return result_tree return tree diff --git a/src/core_codemods/process_creation_sandbox.py b/src/core_codemods/process_creation_sandbox.py index e3a9339ad..f9b6068ac 100644 --- a/src/core_codemods/process_creation_sandbox.py +++ b/src/core_codemods/process_creation_sandbox.py @@ -1,32 +1,33 @@ import libcst as cst -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod from codemodder.dependency import Security +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class ProcessSandbox(SemgrepCodemod): - NAME = "sandbox-process-creation" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - SUMMARY = "Sandbox Process Creation" - DESCRIPTION = ( +class ProcessSandbox(SimpleCodemod): + metadata = Metadata( + name="sandbox-process-creation", + summary="Sandbox Process Creation", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://github.com/pixee/python-security/blob/main/src/security/safe_command/api.py" + ), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/OS_Command_Injection_Defense_Cheat_Sheet.html" + ), + ], + ) + change_description = ( "Replaces subprocess.{func} with more secure safe_command library functions." ) - REFERENCES = [ - { - "url": "https://github.com/pixee/python-security/blob/main/src/security/safe_command/api.py", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/OS_Command_Injection_Defense_Cheat_Sheet.html", - "description": "", - }, - ] adds_dependency = True - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - pattern-either: - patterns: diff --git a/src/core_codemods/semgrep/__init__.py b/src/core_codemods/refactor/__init__.py similarity index 100% rename from src/core_codemods/semgrep/__init__.py rename to src/core_codemods/refactor/__init__.py diff --git a/src/core_codemods/refactor/refactor_new_api.py b/src/core_codemods/refactor/refactor_new_api.py new file mode 100644 index 000000000..02453bbf7 --- /dev/null +++ b/src/core_codemods/refactor/refactor_new_api.py @@ -0,0 +1,297 @@ +import libcst as cst + +from core_codemods.api import SimpleCodemod, Metadata, ReviewGuidance +from codemodder.codemods.utils_mixin import NameResolutionMixin + + +class RefactorNewApi(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="refactor-new-api", + summary="Refactor to use thew new simplified API", + review_guidance=ReviewGuidance.MERGE_AFTER_REVIEW, + has_description=False, + ) + + new_api_module = "codemodder.codemods.new_api" + new_api_class = "SimpleCodemod" + + no_whitespace_assign = cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ) + + def _build_metadata(self, metadata: dict) -> tuple[cst.SimpleStatementLine, bool]: + refs = ( + build_references(metadata["references"]) + if "references" in metadata + else cst.List(elements=[]) + ) + + return ( + cst.SimpleStatementLine( + body=[ + cst.Assign( + targets=[cst.AssignTarget(target=cst.Name(value="metadata"))], + value=cst.Call( + func=cst.Name(value="Metadata"), + args=[ + make_metadata_arg(metadata, "name"), + make_metadata_arg(metadata, "summary"), + make_metadata_arg(metadata, "review_guidance"), + make_metadata_arg( + metadata, "references", value=refs, last=True + ), + ], + whitespace_before_args=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(" "), + ), + ), + ) + ] + ), + bool(refs.elements), + ) + + def find_metadata( + self, assign: cst.Assign + ) -> tuple[str, cst.BaseExpression] | None: + match assign: + case cst.Assign( + targets=[cst.AssignTarget(target=cst.Name(value=name))], + value=value, + ): + match name: + case "NAME" as name: + return name, value + case "SUMMARY" as name: + return name, value + case "DESCRIPTION" as name: + return name, value + case "REVIEW_GUIDANCE" as name: + return name, value + case "REFERENCES" as name: + return name, value + case "CHANGE_DESCRIPTION" | "change_description" as name: + return name, value + + return None + + def create_rule(self, body: cst.BaseSuite) -> cst.SimpleStatementLine: + match body: + case cst.IndentedBlock( + body=[ + cst.SimpleStatementLine( + body=[cst.Return(value=cst.SimpleString() as value)] + ) + ] + ): + return cst.SimpleStatementLine( + body=[ + cst.Assign( + targets=[ + cst.AssignTarget( + target=cst.Name(value="detector_pattern") + ) + ], + value=value, + ) + ] + ) + + raise ValueError("Could not find detector pattern for codemod") + + def create_change_description( + self, metadata: dict + ) -> cst.SimpleStatementLine | None: + if ( + "description" in metadata + and "change_description" not in metadata + and "CHANGE_DESCRIPTION" not in metadata + ): + match metadata["description"]: + case cst.SimpleString() as description: + return cst.SimpleStatementLine( + body=[ + cst.Assign( + targets=[ + cst.AssignTarget( + target=cst.Name(value="change_description") + ) + ], + value=description, + ) + ] + ) + + return None + + def leave_Assert(self, original: cst.Assert, updated: cst.Assert): + match original: + case cst.Assert( + test=cst.Comparison( + left=cst.Call( + func=cst.Name(value="len"), + args=[ + cst.Arg( + value=cst.Attribute( + value=cst.Attribute( + value=cst.Name(value="self"), + attr=cst.Name(value="file_context"), + ), + attr=cst.Name(value="codemod_changes"), + ) + ) + ], + ) + ) + ): + return cst.RemoveFromParent() + return updated + + def leave_Name(self, original: cst.Name, updated: cst.Name) -> cst.Name: + if original.value == "CHANGE_DESCRIPTION": + return updated.with_changes(value="change_description") + + return updated + + def leave_ImportFrom(self, original: cst.ImportFrom, updated: cst.ImportFrom): + match original: + case cst.ImportFrom( + module=cst.Attribute( + value=cst.Attribute( + value=cst.Name(value="codemodder"), + attr=cst.Name(value="codemods"), + ), + attr=cst.Name(value="api"), + ) + ): + return cst.RemoveFromParent() + + return updated + + def leave_ClassDef(self, original: cst.ClassDef, new: cst.ClassDef) -> cst.ClassDef: + new_bases: list[cst.Arg] = [ + base.with_changes(value=cst.Name(self.new_api_class)) + if self.find_base_name(base.value) + in ( + "codemodder.codemods.api.BaseCodemod", + "codemodder.codemods.api.SemgrepCodemod", + ) + else base + for base in original.bases + ] + + if all(base.value.value != self.new_api_class for base in new_bases): + return new + + self.add_needed_import(self.new_api_module, obj=self.new_api_class) + + metadata = {} + new_body = [] + for stmt in new.body.body: + match stmt: + case cst.SimpleStatementLine(body=(cst.Assign() as assign,)): + if result := self.find_metadata(assign): + key, value = result + if key == "change_description": + new_body.append(stmt) + continue + + metadata[key.lower()] = value + continue + case cst.FunctionDef( + name=cst.Name(value="rule"), + body=body, + ): + new_body.append(self.create_rule(body)) + continue + + new_body.append(stmt) + + if not metadata or "name" not in metadata: + return new + + new_metadata, has_references = self._build_metadata(metadata) + new_body.insert(0, new_metadata) + + if change_description := self.create_change_description(metadata): + new_body.insert(1, change_description) + + self.add_needed_import(self.new_api_module, obj="SimpleCodemod") + self.add_needed_import(self.new_api_module, obj="Metadata") + if has_references: + self.add_needed_import(self.new_api_module, obj="Reference") + self.add_needed_import(self.new_api_module, obj="ReviewGuidance") + + self.remove_unused_import(original) + + new_body = new.body.with_changes(body=new_body) + return new.with_changes(bases=new_bases, body=new_body) + + +def make_metadata_arg( + metadata: dict, + metadata_key: str, + value: cst.BaseExpression | None = None, + last: bool = False, +) -> cst.Arg: + return cst.Arg( + value=value or metadata[metadata_key], + keyword=cst.Name(value=metadata_key), + equal=RefactorNewApi.no_whitespace_assign, + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace("" if last else " "), + ), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace("")), + ) + + +def build_references(old_refs: cst.List) -> cst.List: + new_refs: list[cst.Call] = [] + for ref in old_refs.elements: + match ref: + case cst.Element(value=cst.Dict(elements=elements)): + args = [ + cst.Arg( + keyword=cst.Name(value=elm.key.raw_value), + value=elm.value, + equal=RefactorNewApi.no_whitespace_assign, + ) + for elm in elements + if elm.value.raw_value + ] + new_refs.append( + cst.Call( + func=cst.Name(value="Reference"), + args=args, + ) + ) + + return cst.List( + elements=[ + cst.Element( + value=ref, + comma=cst.Comma( + whitespace_after=cst.ParenthesizedWhitespace( + indent=True, + first_line=cst.TrailingWhitespace(newline=cst.Newline()), + ) + ), + ) + for ref in new_refs + ], + lbracket=cst.LeftSquareBracket( + whitespace_after=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(" " * 8), + ), + ), + rbracket=cst.RightSquareBracket( + whitespace_before=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(" " * 4), + ), + ), + ) diff --git a/src/core_codemods/remove_debug_breakpoint.py b/src/core_codemods/remove_debug_breakpoint.py index 7a35846f7..798652d13 100644 --- a/src/core_codemods/remove_debug_breakpoint.py +++ b/src/core_codemods/remove_debug_breakpoint.py @@ -1,14 +1,17 @@ import libcst as cst from typing import Union -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin, AncestorPatternsMixin +from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod -class RemoveDebugBreakpoint(BaseCodemod, NameResolutionMixin, AncestorPatternsMixin): - NAME = "remove-debug-breakpoint" - SUMMARY = "Remove Calls to `builtin` `breakpoint` and `pdb.set_trace" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Remove breakpoint call" +class RemoveDebugBreakpoint(SimpleCodemod, NameResolutionMixin, AncestorPatternsMixin): + metadata = Metadata( + name="remove-debug-breakpoint", + summary="Remove Calls to `builtin` `breakpoint` and `pdb.set_trace", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[], + ) + change_description = "Remove breakpoint call" REFERENCES: list = [] def leave_Expr( diff --git a/src/core_codemods/remove_future_imports.py b/src/core_codemods/remove_future_imports.py index 1823bcf17..d126c30e2 100644 --- a/src/core_codemods/remove_future_imports.py +++ b/src/core_codemods/remove_future_imports.py @@ -1,6 +1,10 @@ import libcst as cst - -from codemodder.codemods.api import BaseCodemod, ReviewGuidance +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) DEPRECATED_NAMES = [ @@ -18,17 +22,16 @@ ] -class RemoveFutureImports(BaseCodemod): - NAME = "remove-future-imports" - SUMMARY = "Remove deprecated `__future__` imports" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Remove deprecated `__future__` imports" - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/__future__.html", - "description": "", - }, - ] +class RemoveFutureImports(SimpleCodemod): + metadata = Metadata( + name="remove-future-imports", + summary="Remove deprecated `__future__` imports", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference(url="https://docs.python.org/3/library/__future__.html"), + ], + ) + change_description = "Remove deprecated `__future__` imports" def leave_ImportFrom( self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom @@ -41,7 +44,7 @@ def leave_ImportFrom( cst.ImportAlias(name=cst.Name(value=name)) for name in CURRENT_NAMES ] - self.add_change(original_node, self.CHANGE_DESCRIPTION) + self.add_change(original_node, self.change_description) return original_node.with_changes(names=names) updated_names: list[cst.ImportAlias] = [ @@ -49,7 +52,7 @@ def leave_ImportFrom( for name in original_node.names if name.name.value not in DEPRECATED_NAMES ] - self.add_change(original_node, self.CHANGE_DESCRIPTION) + self.add_change(original_node, self.change_description) return ( updated_node.with_changes(names=updated_names) if updated_names diff --git a/src/core_codemods/remove_module_global.py b/src/core_codemods/remove_module_global.py index 1b404093c..54b6ec67c 100644 --- a/src/core_codemods/remove_module_global.py +++ b/src/core_codemods/remove_module_global.py @@ -1,15 +1,18 @@ import libcst as cst from libcst.metadata import GlobalScope, ScopeProvider from typing import Union -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod -class RemoveModuleGlobal(BaseCodemod, NameResolutionMixin): - NAME = "remove-module-global" - SUMMARY = "Remove `global` Usage at Module Level" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Remove `global` usage at module level." +class RemoveModuleGlobal(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="remove-module-global", + summary="Remove `global` Usage at Module Level", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[], + ) + change_description = "Remove `global` usage at module level." REFERENCES: list = [] def leave_Global( diff --git a/src/core_codemods/remove_unnecessary_f_str.py b/src/core_codemods/remove_unnecessary_f_str.py index c07e8246b..d53220507 100644 --- a/src/core_codemods/remove_unnecessary_f_str.py +++ b/src/core_codemods/remove_unnecessary_f_str.py @@ -4,29 +4,35 @@ ) from libcst.codemod.commands.unnecessary_format_string import UnnecessaryFormatString import libcst.matchers as m -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import BaseCodemod + +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class RemoveUnnecessaryFStr(BaseCodemod, UnnecessaryFormatString): - NAME = "remove-unnecessary-f-str" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Remove Unnecessary F-strings" - DESCRIPTION = UnnecessaryFormatString.DESCRIPTION - REFERENCES = [ - { - "url": "https://pylint.readthedocs.io/en/latest/user_guide/messages/warning/f-string-without-interpolation.html", - "description": "", - }, - { - "url": "https://github.com/Instagram/LibCST/blob/main/libcst/codemod/commands/unnecessary_format_string.py", - "description": "", - }, - ] +class RemoveUnnecessaryFStr(SimpleCodemod, UnnecessaryFormatString): + metadata = Metadata( + name="remove-unnecessary-f-str", + summary="Remove Unnecessary F-strings", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://pylint.readthedocs.io/en/latest/user_guide/messages/warning/f-string-without-interpolation.html" + ), + Reference( + url="https://github.com/Instagram/LibCST/blob/main/libcst/codemod/commands/unnecessary_format_string.py" + ), + ], + ) - def __init__(self, codemod_context: CodemodContext, *codemod_args): + def __init__( + self, codemod_context: CodemodContext, *codemod_args, **codemod_kwargs + ): UnnecessaryFormatString.__init__(self, codemod_context) - BaseCodemod.__init__(self, codemod_context, *codemod_args) + SimpleCodemod.__init__(self, codemod_context, *codemod_args, **codemod_kwargs) @m.leave(m.FormattedString(parts=(m.FormattedStringText(),))) def _check_formatted_string( diff --git a/src/core_codemods/remove_unused_imports.py b/src/core_codemods/remove_unused_imports.py index 1b54b1a19..f83b28420 100644 --- a/src/core_codemods/remove_unused_imports.py +++ b/src/core_codemods/remove_unused_imports.py @@ -1,3 +1,6 @@ +import re + +import libcst as cst from libcst import CSTVisitor, ensure_type, matchers from libcst.codemod.visitors import GatherUnusedImportsVisitor from libcst.metadata import ( @@ -6,32 +9,25 @@ ScopeProvider, ParentNodeProvider, ) -from codemodder.codemods.base_codemod import ( - BaseCodemod, - CodemodMetadata, - ReviewGuidance, -) + +from pylint.utils.pragma_parser import parse_pragma + +from core_codemods.api import SimpleCodemod, Metadata, ReviewGuidance from codemodder.change import Change from codemodder.codemods.transformations.remove_unused_imports import ( RemoveUnusedImportsTransformer, ) -import libcst as cst -from libcst.codemod import Codemod, CodemodContext -import re -from pylint.utils.pragma_parser import parse_pragma NOQA_PATTERN = re.compile(r"^#\s*noqa", re.IGNORECASE) -class RemoveUnusedImports(BaseCodemod, Codemod): - METADATA = CodemodMetadata( - DESCRIPTION=("Remove unused imports from a module."), - NAME="unused-imports", - REVIEW_GUIDANCE=ReviewGuidance.MERGE_WITHOUT_REVIEW, - REFERENCES=[], +class RemoveUnusedImports(SimpleCodemod): + metadata = Metadata( + name="unused-imports", + summary="Remove Unused Imports", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, ) - SUMMARY = "Remove Unused Imports" - CHANGE_DESCRIPTION = "Unused import." + change_description = "Unused import." METADATA_DEPENDENCIES = ( PositionProvider, @@ -40,10 +36,6 @@ class RemoveUnusedImports(BaseCodemod, Codemod): ParentNodeProvider, ) - def __init__(self, codemod_context: CodemodContext, *codemod_args): - Codemod.__init__(self, codemod_context) - BaseCodemod.__init__(self, *codemod_args) - def transform_module_impl(self, tree: cst.Module) -> cst.Module: # Do nothing in __init__.py files if self.file_context.file_path.name == "__init__.py": @@ -57,7 +49,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: if self.filter_by_path_includes_or_excludes(pos): if not self._is_disabled_by_linter(importt): self.file_context.codemod_changes.append( - Change(pos.start.line, self.CHANGE_DESCRIPTION) + Change(pos.start.line, self.change_description) ) filtered_unused_imports.add((import_alias, importt)) return tree.visit(RemoveUnusedImportsTransformer(filtered_unused_imports)) diff --git a/src/core_codemods/requests_verify.py b/src/core_codemods/requests_verify.py index 086712650..47380efd6 100644 --- a/src/core_codemods/requests_verify.py +++ b/src/core_codemods/requests_verify.py @@ -1,26 +1,28 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class RequestsVerify(SemgrepCodemod): - NAME = "requests-verify" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - SUMMARY = "Verify SSL Certificates for Requests." - DESCRIPTION = ( +class RequestsVerify(SimpleCodemod): + metadata = Metadata( + name="requests-verify", + summary="Verify SSL Certificates for Requests.", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference(url="https://requests.readthedocs.io/en/latest/api/"), + Reference( + url="https://owasp.org/www-community/attacks/Manipulator-in-the-middle_attack" + ), + ], + ) + change_description = ( "Makes any calls to requests.{func} with `verify=False` to `verify=True`." ) - REFERENCES = [ - {"url": "https://requests.readthedocs.io/en/latest/api/", "description": ""}, - { - "url": "https://owasp.org/www-community/attacks/Manipulator-in-the-middle_attack", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - patterns: - pattern: requests.$F(..., verify=False, ...) diff --git a/src/core_codemods/secure_flask_cookie.py b/src/core_codemods/secure_flask_cookie.py index 82fce6de0..c935c11e6 100644 --- a/src/core_codemods/secure_flask_cookie.py +++ b/src/core_codemods/secure_flask_cookie.py @@ -1,28 +1,29 @@ from libcst import matchers -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class SecureFlaskCookie(SemgrepCodemod): - NAME = "secure-flask-cookie" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - SUMMARY = "Use Safe Parameters in `flask` Response `set_cookie` Call" - DESCRIPTION = "Flask response `set_cookie` call should be called with `secure=True`, `httponly=True`, and `samesite='Lax'`." - REFERENCES = [ - { - "url": "https://flask.palletsprojects.com/en/3.0.x/api/#flask.Response.set_cookie", - "description": "", - }, - { - "url": "https://owasp.org/www-community/controls/SecureCookieAttribute", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ +class SecureFlaskCookie(SimpleCodemod): + metadata = Metadata( + name="secure-flask-cookie", + summary="Use Safe Parameters in `flask` Response `set_cookie` Call", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://flask.palletsprojects.com/en/3.0.x/api/#flask.Response.set_cookie" + ), + Reference( + url="https://owasp.org/www-community/controls/SecureCookieAttribute" + ), + ], + ) + change_description = "Flask response `set_cookie` call should be called with `secure=True`, `httponly=True`, and `samesite='Lax'`." + detector_pattern = """ rules: - id: secure-flask-cookie mode: taint diff --git a/src/core_codemods/secure_flask_session_config.py b/src/core_codemods/secure_flask_session_config.py index 3efba51b5..95c2ae5ec 100644 --- a/src/core_codemods/secure_flask_session_config.py +++ b/src/core_codemods/secure_flask_session_config.py @@ -3,30 +3,34 @@ from libcst.metadata import ParentNodeProvider, PositionProvider from libcst import matchers -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import BaseCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin from codemodder.utils.utils import extract_targets_of_assignment, true_value from codemodder.codemods.base_visitor import BaseTransformer from codemodder.change import Change from codemodder.file_context import FileContext - - -class SecureFlaskSessionConfig(BaseCodemod, Codemod): - NAME = "secure-flask-session-configuration" - SUMMARY = "Flip Insecure `Flask` Session Configurations" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_REVIEW - DESCRIPTION = "Flip Flask session configuration if defined as insecure." - REFERENCES = [ - { - "url": "https://owasp.org/www-community/controls/SecureCookieAttribute", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html", - "description": "", - }, - ] +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) + + +class SecureFlaskSessionConfig(SimpleCodemod, Codemod): + metadata = Metadata( + name="secure-flask-session-configuration", + summary="Flip Insecure `Flask` Session Configurations", + review_guidance=ReviewGuidance.MERGE_AFTER_REVIEW, + references=[ + Reference( + url="https://owasp.org/www-community/controls/SecureCookieAttribute" + ), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html" + ), + ], + ) + change_description = "Flip Flask session configuration if defined as insecure." def transform_module_impl(self, tree: cst.Module) -> cst.Module: flask_codemod = FixFlaskConfig(self.context, self.file_context) @@ -191,5 +195,5 @@ def _is_config_subscript(self, original_node: cst.Assign): def report_change(self, original_node): line_number = self.lineno_for_node(original_node) self.file_context.codemod_changes.append( - Change(line_number, SecureFlaskSessionConfig.CHANGE_DESCRIPTION) + Change(line_number, SecureFlaskSessionConfig.change_description) ) diff --git a/src/core_codemods/secure_random.py b/src/core_codemods/secure_random.py index b1b6944c1..2761ae00e 100644 --- a/src/core_codemods/secure_random.py +++ b/src/core_codemods/secure_random.py @@ -1,30 +1,37 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod +from core_codemods.api import ( + SimpleCodemod, + Metadata, + Reference, + ReviewGuidance, +) -class SecureRandom(SemgrepCodemod): - NAME = "secure-random" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - SUMMARY = "Secure Source of Randomness" - DESCRIPTION = "Replaces random.{func} with more secure secrets library functions." - REFERENCES = [ - { - "url": "https://owasp.org/www-community/vulnerabilities/Insecure_Randomness", - "description": "", - }, - {"url": "https://docs.python.org/3/library/random.html", "description": ""}, - ] +class SecureRandom(SimpleCodemod): + metadata = Metadata( + name="secure-random", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + summary="Secure Source of Randomness", + references=[ + Reference( + url="https://owasp.org/www-community/vulnerabilities/Insecure_Randomness", + ), + Reference( + url="https://docs.python.org/3/library/random.html", + ), + ], + ) - @classmethod - def rule(cls): - return """ - rules: - - patterns: - - pattern: random.$F(...) - - pattern-inside: | - import random - ... - """ + detector_pattern = """ + - patterns: + - pattern: random.$F(...) + - pattern-inside: | + import random + ... + """ + + change_description = ( + "Replace random.{func} with more secure secrets library functions." + ) def on_result_found(self, original_node, updated_node): self.remove_unused_import(original_node) diff --git a/src/core_codemods/semgrep/sandbox_url_creation.yaml b/src/core_codemods/semgrep/sandbox_url_creation.yaml deleted file mode 100644 index aa793bc2f..000000000 --- a/src/core_codemods/semgrep/sandbox_url_creation.yaml +++ /dev/null @@ -1,13 +0,0 @@ -rules: - - id: url-sandbox - message: Unbounded URL creation - severity: WARNING - languages: - - python - pattern-either: - - patterns: - - pattern: requests.get(...) - - pattern-not: requests.get("...") - - pattern-inside: | - import requests - ... diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 3619111fb..df115eb14 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -1,6 +1,7 @@ import re from typing import Any, Optional, Tuple import itertools + import libcst as cst from libcst import ( FormattedString, @@ -10,7 +11,6 @@ matchers, ) from libcst.codemod import ( - Codemod, CodemodContext, ContextAwareTransformer, ContextAwareVisitor, @@ -22,13 +22,14 @@ PositionProvider, ScopeProvider, ) -from codemodder.change import Change -from codemodder.codemods.base_codemod import ( - BaseCodemod, - CodemodMetadata, +from core_codemods.api import ( + SimpleCodemod, + Metadata, + Reference, ReviewGuidance, ) +from codemodder.change import Change from codemodder.codemods.base_visitor import UtilsMixin from codemodder.codemods.transformations.remove_empty_string_concatenation import ( RemoveEmptyStringConcatenation, @@ -41,7 +42,6 @@ infer_expression_type, ) from codemodder.codemods.utils_mixin import NameResolutionMixin -from codemodder.file_context import FileContext parameter_token = "?" @@ -49,24 +49,17 @@ raw_quote_pattern = re.compile(r"(? None: self.changed_nodes: dict[ cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel | dict[str, Any], ] = {} - BaseCodemod.__init__(self, file_context, *codemod_args) + SimpleCodemod.__init__(self, *codemod_args, **codemod_kwargs) UtilsMixin.__init__(self, []) - Codemod.__init__(self, context) def _build_param_element(self, prepend, middle, append): new_middle = ( @@ -166,7 +157,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.changed_nodes = {} line_number = self.get_metadata(PositionProvider, call).start.line self.file_context.codemod_changes.append( - Change(line_number, SQLQueryParameterization.CHANGE_DESCRIPTION) + Change(line_number, SQLQueryParameterization.change_description) ) # Normalization and cleanup result = result.visit(RemoveEmptyStringConcatenation()) diff --git a/src/core_codemods/subprocess_shell_false.py b/src/core_codemods/subprocess_shell_false.py index 4286d5abf..312d1eb7f 100644 --- a/src/core_codemods/subprocess_shell_false.py +++ b/src/core_codemods/subprocess_shell_false.py @@ -1,29 +1,31 @@ import libcst as cst from libcst import matchers -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class SubprocessShellFalse(BaseCodemod, NameResolutionMixin): - NAME = "subprocess-shell-false" - SUMMARY = "Use `shell=False` in `subprocess` Function Calls" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - DESCRIPTION = "Set `shell` keyword argument to `False`" - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/subprocess.html#security-considerations", - "description": "", - }, - { - "url": "https://en.wikipedia.org/wiki/Code_injection#Shell_injection", - "description": "", - }, - { - "url": "https://stackoverflow.com/a/3172488", - "description": "", - }, - ] +class SubprocessShellFalse(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="subprocess-shell-false", + summary="Use `shell=False` in `subprocess` Function Calls", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/subprocess.html#security-considerations" + ), + Reference( + url="https://en.wikipedia.org/wiki/Code_injection#Shell_injection" + ), + Reference(url="https://stackoverflow.com/a/3172488"), + ], + ) + change_description = "Set `shell` keyword argument to `False`" SUBPROCESS_FUNCS = [ f"subprocess.{func}" for func in {"run", "call", "check_output", "check_call", "Popen"} diff --git a/src/core_codemods/tempfile_mktemp.py b/src/core_codemods/tempfile_mktemp.py index 8de39ac1f..c5a00fdb5 100644 --- a/src/core_codemods/tempfile_mktemp.py +++ b/src/core_codemods/tempfile_mktemp.py @@ -1,25 +1,27 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class TempfileMktemp(SemgrepCodemod, NameResolutionMixin): - NAME = "secure-tempfile" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Upgrade and Secure Temp File Creation" - DESCRIPTION = "Replaces `tempfile.mktemp` with `tempfile.mkstemp`." - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/tempfile.html#tempfile.mktemp", - "description": "", - } - ] +class TempfileMktemp(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="secure-tempfile", + summary="Upgrade and Secure Temp File Creation", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/tempfile.html#tempfile.mktemp" + ), + ], + ) + change_description = "Replaces `tempfile.mktemp` with `tempfile.mkstemp`." _module_name = "tempfile" - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - patterns: - pattern: tempfile.mktemp(...) diff --git a/src/core_codemods/upgrade_sslcontext_minimum_version.py b/src/core_codemods/upgrade_sslcontext_minimum_version.py index 85ee1a850..cddf1dcdd 100644 --- a/src/core_codemods/upgrade_sslcontext_minimum_version.py +++ b/src/core_codemods/upgrade_sslcontext_minimum_version.py @@ -1,30 +1,29 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class UpgradeSSLContextMinimumVersion(SemgrepCodemod, NameResolutionMixin): - NAME = "upgrade-sslcontext-minimum-version" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Upgrade SSLContext Minimum Version" - DESCRIPTION = "Replaces minimum SSL/TLS version for SSLContext." - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/ssl.html#security-considerations", - "description": "", - }, - {"url": "https://datatracker.ietf.org/doc/rfc8996/", "description": ""}, - { - "url": "https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1", - "description": "", - }, - ] +class UpgradeSSLContextMinimumVersion(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="upgrade-sslcontext-minimum-version", + summary="Upgrade SSLContext Minimum Version", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/ssl.html#security-considerations" + ), + Reference(url="https://datatracker.ietf.org/doc/rfc8996/"), + Reference(url="https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1"), + ], + ) + change_description = "Replaces minimum SSL/TLS version for SSLContext." _module_name = "ssl" - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - mode: taint pattern-sources: diff --git a/src/core_codemods/upgrade_sslcontext_tls.py b/src/core_codemods/upgrade_sslcontext_tls.py index 12187838c..b3f50d332 100644 --- a/src/core_codemods/upgrade_sslcontext_tls.py +++ b/src/core_codemods/upgrade_sslcontext_tls.py @@ -1,25 +1,27 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class UpgradeSSLContextTLS(SemgrepCodemod): - NAME = "upgrade-sslcontext-tls" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - SUMMARY = "Upgrade TLS Version In SSLContext" - DESCRIPTION = "Replaces known insecure TLS/SSL protocol versions in SSLContext with secure ones." - CHANGE_DESCRIPTION = "Upgrade to use a safe version of TLS in SSLContext" - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/ssl.html#security-considerations", - "description": "", - }, - {"url": "https://datatracker.ietf.org/doc/rfc8996/", "description": ""}, - { - "url": "https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1", - "description": "", - }, - ] +class UpgradeSSLContextTLS(SimpleCodemod): + metadata = Metadata( + name="upgrade-sslcontext-tls", + summary="Upgrade TLS Version In SSLContext", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/ssl.html#security-considerations" + ), + Reference(url="https://datatracker.ietf.org/doc/rfc8996/"), + Reference(url="https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1"), + ], + ) + change_description = "Replaces known insecure TLS/SSL protocol versions in SSLContext with secure ones." + change_description = "Upgrade to use a safe version of TLS in SSLContext" # TODO: in the majority of cases, using PROTOCOL_TLS_CLIENT will be the # right fix. However in some cases it will be appropriate to use @@ -27,10 +29,7 @@ class UpgradeSSLContextTLS(SemgrepCodemod): # this. Eventually, when the platform supports parameters, we want to # revisit this to provide PROTOCOL_TLS_SERVER as an alternative fix. SAFE_TLS_PROTOCOL_VERSION = "ssl.PROTOCOL_TLS_CLIENT" - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - patterns: - pattern-inside: | diff --git a/src/core_codemods/url_sandbox.py b/src/core_codemods/url_sandbox.py index 85213384f..2893af4a1 100644 --- a/src/core_codemods/url_sandbox.py +++ b/src/core_codemods/url_sandbox.py @@ -2,14 +2,15 @@ import libcst as cst from libcst import CSTNode, matchers -from libcst.codemod import Codemod, CodemodContext +from libcst.codemod import CodemodContext from libcst.metadata import PositionProvider, ScopeProvider from libcst.codemod.visitors import AddImportsVisitor, ImportItem from codemodder.change import Change -from codemodder.codemods.base_codemod import ( - SemgrepCodemod, - CodemodMetadata, +from core_codemods.api import ( + SimpleCodemod, + Metadata, + Reference, ReviewGuidance, ) from codemodder.codemods.base_visitor import BaseVisitor @@ -24,47 +25,47 @@ replacement_import = "safe_requests" -class UrlSandbox(SemgrepCodemod, Codemod): - METADATA = CodemodMetadata( - DESCRIPTION=( - "Replaces request.{func} with more secure safe_request library functions." - ), - NAME="url-sandbox", - REVIEW_GUIDANCE=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, - REFERENCES=[ - { - "url": "https://github.com/pixee/python-security/blob/main/src/security/safe_requests/api.py", - "description": "", - }, - {"url": "https://portswigger.net/web-security/ssrf", "description": ""}, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/Server_Side_Request_Forgery_Prevention_Cheat_Sheet.html", - "description": "", - }, - { - "url": "https://www.rapid7.com/blog/post/2021/11/23/owasp-top-10-deep-dive-defending-against-server-side-request-forgery/", - "description": "", - }, - { - "url": "https://blog.assetnote.io/2021/01/13/blind-ssrf-chains/", - "description": "", - }, +class UrlSandbox(SimpleCodemod): + metadata = Metadata( + name="url-sandbox", + summary="Sandbox URL Creation", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://github.com/pixee/python-security/blob/main/src/security/safe_requests/api.py" + ), + Reference(url="https://portswigger.net/web-security/ssrf"), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/Server_Side_Request_Forgery_Prevention_Cheat_Sheet.html" + ), + Reference( + url="https://www.rapid7.com/blog/post/2021/11/23/owasp-top-10-deep-dive-defending-against-server-side-request-forgery/" + ), + Reference(url="https://blog.assetnote.io/2021/01/13/blind-ssrf-chains/"), ], ) - SUMMARY = "Sandbox URL Creation" - CHANGE_DESCRIPTION = "Switch use of requests for security.safe_requests" - YAML_FILES = [ - "sandbox_url_creation.yaml", - ] + change_description = "Switch use of requests for security.safe_requests" + + detector_pattern = """ + rules: + - id: url-sandbox + message: Unbounded URL creation + severity: WARNING + languages: + - python + pattern-either: + - patterns: + - pattern: requests.get(...) + - pattern-not: requests.get("...") + - pattern-inside: | + import requests + ... + """ METADATA_DEPENDENCIES = (PositionProvider, ScopeProvider) adds_dependency = True - def __init__(self, codemod_context: CodemodContext, *args): - Codemod.__init__(self, codemod_context) - SemgrepCodemod.__init__(self, *args) - def transform_module_impl(self, tree: cst.Module) -> cst.Module: # we first gather all the nodes we want to change together with their replacements find_requests_visitor = FindRequestCallsAndImports( @@ -142,7 +143,7 @@ def leave_Call(self, original_node: cst.Call): } ) self.changes_in_file.append( - Change(line_number, UrlSandbox.CHANGE_DESCRIPTION) + Change(line_number, UrlSandbox.change_description) ) # case req.get(...) @@ -156,7 +157,7 @@ def leave_Call(self, original_node: cst.Call): } ) self.changes_in_file.append( - Change(line_number, UrlSandbox.CHANGE_DESCRIPTION) + Change(line_number, UrlSandbox.change_description) ) def _find_assignments(self, node: CSTNode): diff --git a/src/core_codemods/use_defused_xml.py b/src/core_codemods/use_defused_xml.py index 498ec3e29..883d7fb28 100644 --- a/src/core_codemods/use_defused_xml.py +++ b/src/core_codemods/use_defused_xml.py @@ -3,9 +3,12 @@ import libcst as cst from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor - -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import BaseCodemod +from core_codemods.api import ( + SimpleCodemod, + Metadata, + Reference, + ReviewGuidance, +) from codemodder.codemods.imported_call_modifier import ImportedCallModifier from codemodder.dependency import DefusedXML @@ -42,31 +45,26 @@ def update_simple_name(self, true_name, original_node, updated_node, new_args): # TODO: add expat methods? -class UseDefusedXml(BaseCodemod): - NAME = "use-defusedxml" - SUMMARY = "Use `defusedxml` for Parsing XML" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_REVIEW - DESCRIPTION = "Replace builtin xml method with safe defusedxml method" - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/xml.html#xml-vulnerabilities", - "description": "", - }, - { - "url": "https://docs.python.org/3/library/xml.html#the-defusedxml-package", - "description": "", - }, - { - "url": "https://pypi.org/project/defusedxml/", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html", - "description": "", - }, - ] +class UseDefusedXml(SimpleCodemod): + metadata = Metadata( + name="use-defusedxml", + summary="Use `defusedxml` for Parsing XML", + review_guidance=ReviewGuidance.MERGE_AFTER_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/xml.html#xml-vulnerabilities" + ), + Reference( + url="https://docs.python.org/3/library/xml.html#the-defusedxml-package" + ), + Reference(url="https://pypi.org/project/defusedxml/"), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html" + ), + ], + ) - adds_dependency = True + change_description = "Replace builtin XML method with safe `defusedxml` method" @cached_property def matching_functions(self) -> dict[str, str]: @@ -89,7 +87,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.context, self.file_context, self.matching_functions, - self.CHANGE_DESCRIPTION, + self.change_description, ) result_tree = visitor.transform_module(tree) self.file_context.codemod_changes.extend(visitor.changes_in_file) diff --git a/src/core_codemods/use_generator.py b/src/core_codemods/use_generator.py index 0a507db42..726db7ac1 100644 --- a/src/core_codemods/use_generator.py +++ b/src/core_codemods/use_generator.py @@ -1,28 +1,31 @@ import libcst as cst - -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class UseGenerator(BaseCodemod, NameResolutionMixin): - NAME = "use-generator" - SUMMARY = "Use Generator Expressions Instead of List Comprehensions" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Replace list comprehension with generator expression" - REFERENCES = [ - { - "url": "https://pylint.readthedocs.io/en/latest/user_guide/messages/refactor/use-a-generator.html", - "description": "", - }, - { - "url": "https://docs.python.org/3/glossary.html#term-generator-expression", - "description": "", - }, - { - "url": "https://docs.python.org/3/glossary.html#term-list-comprehension", - "description": "", - }, - ] +class UseGenerator(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="use-generator", + summary="Use Generator Expressions Instead of List Comprehensions", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://pylint.readthedocs.io/en/latest/user_guide/messages/refactor/use-a-generator.html" + ), + Reference( + url="https://docs.python.org/3/glossary.html#term-generator-expression" + ), + Reference( + url="https://docs.python.org/3/glossary.html#term-list-comprehension" + ), + ], + ) + change_description = "Replace list comprehension with generator expression" def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): match original_node.func: @@ -32,7 +35,7 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): if self.is_builtin_function(original_node): match original_node.args[0].value: case cst.ListComp(elt=elt, for_in=for_in): - self.add_change(original_node, self.CHANGE_DESCRIPTION) + self.add_change(original_node, self.change_description) return updated_node.with_changes( args=[ cst.Arg( diff --git a/src/core_codemods/use_set_literal.py b/src/core_codemods/use_set_literal.py index f4ef023a7..eb1dfbdcd 100644 --- a/src/core_codemods/use_set_literal.py +++ b/src/core_codemods/use_set_literal.py @@ -1,14 +1,16 @@ import libcst as cst - -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod -class UseSetLiteral(BaseCodemod, NameResolutionMixin): - NAME = "use-set-literal" - SUMMARY = "Use Set Literals Instead of Sets from Lists" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Replace sets from lists with set literals" +class UseSetLiteral(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="use-set-literal", + summary="Use Set Literals Instead of Sets from Lists", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[], + ) + change_description = "Replace sets from lists with set literals" REFERENCES: list = [] def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): diff --git a/src/core_codemods/use_walrus_if.py b/src/core_codemods/use_walrus_if.py index a66988fb5..df5eeab52 100644 --- a/src/core_codemods/use_walrus_if.py +++ b/src/core_codemods/use_walrus_if.py @@ -6,9 +6,12 @@ from libcst._position import CodeRange from libcst import matchers as m from libcst.metadata import ParentNodeProvider, ScopeProvider - -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import BaseCodemod +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) FoundAssign = namedtuple("FoundAssign", ["assign", "target", "value"]) @@ -20,23 +23,24 @@ def pairwise(iterable): return zip(a, b) -class UseWalrusIf(BaseCodemod): - METADATA_DEPENDENCIES = BaseCodemod.METADATA_DEPENDENCIES + ( - ParentNodeProvider, - ScopeProvider, +class UseWalrusIf(SimpleCodemod): + metadata = Metadata( + name="use-walrus-if", + summary="Use Assignment Expression (Walrus) In Conditional", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/whatsnew/3.8.html#assignment-expressions" + ), + ], ) - NAME = "use-walrus-if" - SUMMARY = "Use Assignment Expression (Walrus) In Conditional" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - DESCRIPTION = ( + change_description = ( "Replaces multiple expressions involving `if` operator with 'walrus' operator." ) - REFERENCES = [ - { - "url": "https://docs.python.org/3/whatsnew/3.8.html#assignment-expressions", - "description": "", - } - ] + METADATA_DEPENDENCIES = SimpleCodemod.METADATA_DEPENDENCIES + ( + ParentNodeProvider, + ScopeProvider, + ) _modify_next_if: List[Tuple[CodeRange, cst.NamedExpr]] _if_stack: List[Optional[Tuple[CodeRange, cst.NamedExpr]]] @@ -121,7 +125,7 @@ def leave_If(self, original_node, updated_node): if (result := self._if_stack.pop()) is not None: position, named_expr = result is_name = m.matches(updated_node.test, m.Name()) - self.add_change_from_position(position, self.CHANGE_DESCRIPTION) + self.add_change_from_position(position, self.change_description) return ( updated_node.with_changes(test=named_expr) if is_name diff --git a/src/core_codemods/with_threading_lock.py b/src/core_codemods/with_threading_lock.py index e7ad02b48..32a51f159 100644 --- a/src/core_codemods/with_threading_lock.py +++ b/src/core_codemods/with_threading_lock.py @@ -1,30 +1,31 @@ import libcst as cst -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class WithThreadingLock(SemgrepCodemod, NameResolutionMixin): - NAME = "bad-lock-with-statement" - SUMMARY = "Separate Lock Instantiation from `with` Call" - DESCRIPTION = ( +class WithThreadingLock(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="bad-lock-with-statement", + summary="Separate Lock Instantiation from `with` Call", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://pylint.pycqa.org/en/latest/user_guide/messages/warning/useless-with-lock." + ), + Reference( + url="https://docs.python.org/3/library/threading.html#using-locks-conditions-and-semaphores-in-the-with-statement" + ), + ], + ) + change_description = ( "Replace deprecated usage of threading lock classes as context managers." ) - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - REFERENCES = [ - { - "url": "https://pylint.pycqa.org/en/latest/user_guide/messages/warning/useless-with-lock.", - "description": "", - }, - { - "url": "https://docs.python.org/3/library/threading.html#using-locks-conditions-and-semaphores-in-the-with-statement", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - patterns: - pattern: | @@ -45,8 +46,8 @@ def rule(cls): - focus-metavariable: $BODY """ - def __init__(self, *args): - SemgrepCodemod.__init__(self, *args) + def __init__(self, *args, **kwargs): + SimpleCodemod.__init__(self, *args, **kwargs) NameResolutionMixin.__init__(self) self.names_in_module = self.find_used_names_in_module() diff --git a/tests/codemods/base_codemod_test.py b/tests/codemods/base_codemod_test.py index 221aaa473..b402c4e5f 100644 --- a/tests/codemods/base_codemod_test.py +++ b/tests/codemods/base_codemod_test.py @@ -4,13 +4,10 @@ from textwrap import dedent from typing import ClassVar -import libcst as cst -from libcst.codemod import CodemodContext import mock from codemodder.context import CodemodExecutionContext -from codemodder.dependency import Dependency -from codemodder.file_context import FileContext +from codemodder.diff import create_diff from codemodder.registry import CodemodRegistry, CodemodCollection from codemodder.semgrep import run as semgrep_run @@ -19,43 +16,84 @@ class BaseCodemodTest: codemod: ClassVar = NotImplemented def setup_method(self): - self.file_context = None + if isinstance(self.codemod, type): + self.codemod = self.codemod() - def initialize_codemod(self, input_tree): - wrapper = cst.MetadataWrapper(input_tree) - codemod_instance = self.codemod( - CodemodContext(wrapper=wrapper), - self.file_context, - ) - return codemod_instance + self.file_context = None - def run_and_assert(self, tmpdir, input_code, expected): - tmp_file_path = Path(tmpdir / "code.py") - self.run_and_assert_filepath(tmpdir, tmp_file_path, input_code, expected) + def run_and_assert( # pylint: disable=too-many-arguments + self, + tmpdir, + input_code, + expected, + num_changes: int = 1, + root: Path | None = None, + files: list[Path] | None = None, + ): + root = root or tmpdir + tmp_file_path = files[0] if files else Path(tmpdir) / "code.py" + tmp_file_path.write_text(dedent(input_code)) + + files_to_check = files or [tmp_file_path] - def run_and_assert_filepath(self, root, file_path, input_code, expected): - input_tree = cst.parse_module(dedent(input_code)) self.execution_context = CodemodExecutionContext( directory=root, - dry_run=True, + dry_run=False, verbose=False, registry=mock.MagicMock(), repo_manager=mock.MagicMock(), + path_include=[f.name for f in files_to_check], + path_exclude=[], ) - self.file_context = FileContext( - root, - file_path, - [], - [], - [], + + self.codemod.apply(self.execution_context, files_to_check) + changes = self.execution_context.get_results(self.codemod.id) + + if input_code == expected: + assert not changes + return + + assert len(changes) == num_changes + + self.assert_changes( + tmpdir, + tmp_file_path, + input_code, + expected, + changes[0], ) - codemod_instance = self.initialize_codemod(input_tree) - output_tree = codemod_instance.transform_module(input_tree) - assert output_tree.code == dedent(expected) + def assert_changes( # pylint: disable=too-many-arguments + self, root, file_path, input_code, expected, changes + ): + expected_diff = create_diff( + dedent(input_code).splitlines(keepends=True), + dedent(expected).splitlines(keepends=True), + ) - def assert_dependency(self, dependency: Dependency): - assert self.file_context and self.file_context.dependencies == set([dependency]) + assert expected_diff == changes.diff + assert os.path.relpath(file_path, root) == changes.path + + with open(file_path, "r", encoding="utf-8") as tmp_file: + output_code = tmp_file.read() + + assert output_code == dedent(expected) + + def run_and_assert_filepath( # pylint: disable=too-many-arguments + self, + root: Path, + file_path: Path, + input_code: str, + expected: str, + num_changes: int = 1, + ): + self.run_and_assert( + tmpdir=root, + input_code=input_code, + expected=expected, + num_changes=num_changes, + files=[file_path], + ) class BaseSemgrepCodemodTest(BaseCodemodTest): @@ -74,35 +112,12 @@ def results_by_id_filepath(self, input_code, file_path): with open(file_path, "w", encoding="utf-8") as tmp_file: tmp_file.write(dedent(input_code)) - name = self.codemod.name() + name = self.codemod.name results = self.registry.match_codemods(codemod_include=[name]) return semgrep_run(self.execution_context, results[0].yaml_files) - def run_and_assert_filepath(self, root, file_path, input_code, expected): - self.execution_context = CodemodExecutionContext( - directory=root, - dry_run=True, - verbose=False, - registry=mock.MagicMock(), - repo_manager=mock.MagicMock(), - ) - input_tree = cst.parse_module(dedent(input_code)) - all_results = self.results_by_id_filepath(input_code, file_path) - results = all_results.results_for_rule_and_file(self.codemod.name(), file_path) - self.file_context = FileContext( - root, - file_path, - [], - [], - results, - ) - codemod_instance = self.initialize_codemod(input_tree) - output_tree = codemod_instance.transform_module(input_tree) - - assert output_tree.code == dedent(expected) - -class BaseDjangoCodemodTest(BaseSemgrepCodemodTest): +class BaseDjangoCodemodTest(BaseCodemodTest): def create_dir_structure(self, tmpdir): django_root = Path(tmpdir) / "mysite" settings_folder = django_root / "mysite" diff --git a/tests/codemods/conftest.py b/tests/codemods/conftest.py new file mode 100644 index 000000000..6aab38f60 --- /dev/null +++ b/tests/codemods/conftest.py @@ -0,0 +1,15 @@ +import pytest + + +@pytest.fixture(autouse=True) +def disable_semgrep_run(): + """ + Override the fixture defined in conftest.py + """ + + +@pytest.fixture(autouse=True) +def disable_update_code(): + """ + Override the fixture defined in conftest.py + """ diff --git a/tests/codemods/test_add_requests_timeouts.py b/tests/codemods/test_add_requests_timeouts.py index 196070f25..9382de8c6 100644 --- a/tests/codemods/test_add_requests_timeouts.py +++ b/tests/codemods/test_add_requests_timeouts.py @@ -1,10 +1,13 @@ import pytest -from core_codemods.add_requests_timeouts import AddRequestsTimeouts +from core_codemods.add_requests_timeouts import ( + AddRequestsTimeouts, + TransformAddRequestsTimeouts, +) from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest METHODS = ["get", "post", "put", "delete", "head", "options", "patch"] -TIMEOUT = AddRequestsTimeouts.DEFAULT_TIMEOUT +TIMEOUT = TransformAddRequestsTimeouts.DEFAULT_TIMEOUT class TestAddRequestsTimeouts(BaseSemgrepCodemodTest): diff --git a/tests/codemods/test_base_codemod.py b/tests/codemods/test_base_codemod.py index 033e8950d..662f7e0b3 100644 --- a/tests/codemods/test_base_codemod.py +++ b/tests/codemods/test_base_codemod.py @@ -1,26 +1,20 @@ import libcst as cst -from libcst.codemod import Codemod, CodemodContext +from libcst.codemod import CodemodContext import mock -from codemodder.codemods.base_codemod import ( - SemgrepCodemod, - CodemodMetadata, +from codemodder.codemods.api import ( + SimpleCodemod, + Metadata, ReviewGuidance, ) -class DoNothingCodemod(SemgrepCodemod, Codemod): - METADATA = CodemodMetadata( - DESCRIPTION="An identity codemod for testing purposes.", - NAME="do-nothing", - REVIEW_GUIDANCE=ReviewGuidance.MERGE_WITHOUT_REVIEW, +class DoNothingCodemod(SimpleCodemod): + metadata = Metadata( + name="do-nothing", + summary="An identity codemod for testing purposes.", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, ) - SUMMARY = "An identity codemod for testing purposes." - YAML_FILES = [] - - def __init__(self, codemod_context: CodemodContext, *args): - Codemod.__init__(self, codemod_context) - SemgrepCodemod.__init__(self, *args) def transform_module_impl(self, tree: cst.Module) -> cst.Module: return tree @@ -32,6 +26,8 @@ def run_and_assert(self, input_code, expected_output): command_instance = DoNothingCodemod( CodemodContext(), mock.MagicMock(), + mock.MagicMock(), + _transformer=True, ) output_tree = command_instance.transform_module(input_tree) diff --git a/tests/codemods/test_combine_startswith_endswith.py b/tests/codemods/test_combine_startswith_endswith.py index 19305f9d3..6d1baa86e 100644 --- a/tests/codemods/test_combine_startswith_endswith.py +++ b/tests/codemods/test_combine_startswith_endswith.py @@ -9,7 +9,7 @@ class TestCombineStartswithEndswith(BaseCodemodTest): codemod = CombineStartswithEndswith def test_name(self): - assert self.codemod.name() == "combine-startswith-endswith" + assert self.codemod.name == "combine-startswith-endswith" @each_func def test_combine(self, tmpdir, func): @@ -22,7 +22,6 @@ def test_combine(self, tmpdir, func): x.{func}(("foo", "f")) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 @pytest.mark.parametrize( "code", @@ -37,4 +36,3 @@ def test_combine(self, tmpdir, func): ) def test_no_change(self, tmpdir, code): self.run_and_assert(tmpdir, code, code) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_django_debug_flag_on.py b/tests/codemods/test_django_debug_flag_on.py index b66b603f4..c539665d4 100644 --- a/tests/codemods/test_django_debug_flag_on.py +++ b/tests/codemods/test_django_debug_flag_on.py @@ -6,7 +6,7 @@ class TestDjangoDebugFlagOn(BaseDjangoCodemodTest): codemod = DjangoDebugFlagOn def test_name(self): - assert self.codemod.name() == "django-debug-flag-on" + assert self.codemod.name == "django-debug-flag-on" def test_settings_dot_py(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -15,7 +15,6 @@ def test_settings_dot_py(self, tmpdir): input_code = """DEBUG = True""" expected = """DEBUG = False""" self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_not_settings_dot_py(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -24,7 +23,6 @@ def test_not_settings_dot_py(self, tmpdir): input_code = """DEBUG = True""" expected = input_code self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 0 def test_no_manage_dot_py(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -32,4 +30,3 @@ def test_no_manage_dot_py(self, tmpdir): input_code = """DEBUG = True""" expected = input_code self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_django_json_response_type.py b/tests/codemods/test_django_json_response_type.py index 294870009..afb99f5de 100644 --- a/tests/codemods/test_django_json_response_type.py +++ b/tests/codemods/test_django_json_response_type.py @@ -7,7 +7,7 @@ class TestDjangoJsonResponseType(BaseSemgrepCodemodTest): codemod = DjangoJsonResponseType def test_name(self): - assert self.codemod.name() == "django-json-response-type" + assert self.codemod.name == "django-json-response-type" def test_simple(self, tmpdir): input_code = """\ @@ -27,7 +27,6 @@ def foo(request): return HttpResponse(json_response, content_type="application/json") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_alias(self, tmpdir): input_code = """\ @@ -47,7 +46,6 @@ def foo(request): return response(json_response, content_type="application/json") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_direct(self, tmpdir): input_code = """\ @@ -65,7 +63,6 @@ def foo(request): return HttpResponse(json.dumps({ "user_input": request.GET.get("input") }), content_type="application/json") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_content_type_set(self, tmpdir): input_code = """\ @@ -77,7 +74,6 @@ def foo(request): return HttpResponse(json_response, content_type='application/json') """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_no_json_input(self, tmpdir): input_code = """\ @@ -89,4 +85,3 @@ def foo(request): return HttpResponse(dict_response) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_django_receiver_on_top.py b/tests/codemods/test_django_receiver_on_top.py index 95042ce49..e83f6468c 100644 --- a/tests/codemods/test_django_receiver_on_top.py +++ b/tests/codemods/test_django_receiver_on_top.py @@ -7,7 +7,7 @@ class TestDjangoReceiverOnTop(BaseCodemodTest): codemod = DjangoReceiverOnTop def test_name(self): - assert self.codemod.name() == "django-receiver-on-top" + assert self.codemod.name == "django-receiver-on-top" def test_simple(self, tmpdir): input_code = """\ @@ -27,7 +27,6 @@ def foo(): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_alias(self, tmpdir): input_code = """\ @@ -47,7 +46,6 @@ def foo(): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_no_receiver(self, tmpdir): input_code = """\ @@ -56,7 +54,6 @@ def foo(): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_receiver_but_not_djangos(self, tmpdir): input_code = """\ @@ -68,7 +65,6 @@ def foo(): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_receiver_on_top(self, tmpdir): input_code = """\ @@ -80,4 +76,3 @@ def foo(): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_django_session_cookie_secure_off.py b/tests/codemods/test_django_session_cookie_secure_off.py index 821c24685..62b1cf453 100644 --- a/tests/codemods/test_django_session_cookie_secure_off.py +++ b/tests/codemods/test_django_session_cookie_secure_off.py @@ -9,7 +9,7 @@ class TestDjangoSessionSecureCookieOff(BaseDjangoCodemodTest): codemod = DjangoSessionCookieSecureOff def test_name(self): - assert self.codemod.name() == "django-session-cookie-secure-off" + assert self.codemod.name == "django-session-cookie-secure-off" def test_not_settings_dot_py(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -18,7 +18,6 @@ def test_not_settings_dot_py(self, tmpdir): input_code = """SESSION_COOKIE_SECURE = True""" expected = input_code self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 0 def test_no_manage_dot_py(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -26,7 +25,6 @@ def test_no_manage_dot_py(self, tmpdir): input_code = """SESSION_COOKIE_SECURE = True""" expected = input_code self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 0 def test_settings_dot_py_secure_true(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -36,7 +34,6 @@ def test_settings_dot_py_secure_true(self, tmpdir): SESSION_COOKIE_SECURE = True """ self.run_and_assert_filepath(django_root, file_path, input_code, input_code) - assert len(self.file_context.codemod_changes) == 0 @pytest.mark.parametrize("value", ["False", "gibberish"]) def test_settings_dot_py_secure_bad(self, tmpdir, value): @@ -50,7 +47,6 @@ def test_settings_dot_py_secure_bad(self, tmpdir, value): SESSION_COOKIE_SECURE = True """ self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_settings_dot_py_secure_missing(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -62,4 +58,3 @@ def test_settings_dot_py_secure_missing(self, tmpdir): SESSION_COOKIE_SECURE = True """ self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_enable_jinja2_autoescape.py b/tests/codemods/test_enable_jinja2_autoescape.py index e48a36409..9d0eafe4e 100644 --- a/tests/codemods/test_enable_jinja2_autoescape.py +++ b/tests/codemods/test_enable_jinja2_autoescape.py @@ -6,7 +6,7 @@ class TestEnableJinja2Autoescape(BaseSemgrepCodemodTest): codemod = EnableJinja2Autoescape def test_name(self): - assert self.codemod.name() == "enable-jinja2-autoescape" + assert self.codemod.name == "enable-jinja2-autoescape" def test_import(self, tmpdir): input_code = """ diff --git a/tests/codemods/test_exception_without_raise.py b/tests/codemods/test_exception_without_raise.py index 6ec37df94..bc4b727c7 100644 --- a/tests/codemods/test_exception_without_raise.py +++ b/tests/codemods/test_exception_without_raise.py @@ -7,7 +7,7 @@ class TestExceptionWithoutRaise(BaseCodemodTest): codemod = ExceptionWithoutRaise def test_name(self): - assert self.codemod.name() == "exception-without-raise" + assert self.codemod.name == "exception-without-raise" def test_simple(self, tmpdir): input_code = """\ @@ -17,7 +17,6 @@ def test_simple(self, tmpdir): raise ValueError """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_call(self, tmpdir): input_code = """\ @@ -27,7 +26,6 @@ def test_simple_call(self, tmpdir): raise ValueError("Bad value!") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_alias(self, tmpdir): input_code = """\ @@ -39,18 +37,15 @@ def test_alias(self, tmpdir): raise error """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_unknown_exception(self, tmpdir): input_code = """\ Something """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_raised_exception(self, tmpdir): input_code = """\ raise ValueError """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_file_resource_leak.py b/tests/codemods/test_file_resource_leak.py index b67a81f2c..e3c6835e0 100644 --- a/tests/codemods/test_file_resource_leak.py +++ b/tests/codemods/test_file_resource_leak.py @@ -7,7 +7,7 @@ class TestFileResourceLeak(BaseCodemodTest): codemod = FileResourceLeak def test_name(self): - assert self.codemod.name() == "fix-file-resource-leak" + assert self.codemod.name == "fix-file-resource-leak" def test_simple(self, tmpdir): input_code = """\ @@ -19,7 +19,6 @@ def test_simple(self, tmpdir): file.read() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_annotated(self, tmpdir): input_code = """\ @@ -31,7 +30,6 @@ def test_simple_annotated(self, tmpdir): file.read() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_just_open(self, tmpdir): # strange as this change may be, it still leaks if left untouched @@ -43,7 +41,6 @@ def test_just_open(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_multiple_assignments(self, tmpdir): input_code = """\ @@ -55,7 +52,6 @@ def test_multiple_assignments(self, tmpdir): file.read() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_minimal_block(self, tmpdir): input_code = """\ @@ -69,7 +65,6 @@ def test_minimal_block(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 # negative tests below @@ -80,7 +75,6 @@ def test_is_closed(self, tmpdir): file.close() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_is_closed_with_exit(self, tmpdir): input_code = """\ @@ -89,7 +83,6 @@ def test_is_closed_with_exit(self, tmpdir): file.__exit__() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_is_closed_with_statement(self, tmpdir): input_code = """\ @@ -98,7 +91,6 @@ def test_is_closed_with_statement(self, tmpdir): file.read() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_is_closed_with_statement_and_contextlib(self, tmpdir): input_code = """\ @@ -108,7 +100,6 @@ def test_is_closed_with_statement_and_contextlib(self, tmpdir): file.read() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_is_closed_transitivelly(self, tmpdir): input_code = """\ @@ -117,7 +108,6 @@ def test_is_closed_transitivelly(self, tmpdir): same_file.close() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_escapes_with_assignment(self, tmpdir): input_code = """\ @@ -125,7 +115,6 @@ def test_escapes_with_assignment(self, tmpdir): Object.attribute = file """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_escapes_as_function_argument(self, tmpdir): input_code = """\ @@ -133,7 +122,6 @@ def test_escapes_as_function_argument(self, tmpdir): foo(file) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_escapes_returned(self, tmpdir): input_code = """\ @@ -142,7 +130,6 @@ def foo(): return file """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_escapes_yielded(self, tmpdir): input_code = """\ @@ -151,7 +138,6 @@ def foo(): yield file """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_escapes_outside_reference(self, tmpdir): input_code = """\ @@ -163,4 +149,3 @@ def test_escapes_outside_reference(self, tmpdir): out.read() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_fix_deprecated_logging_warn.py b/tests/codemods/test_fix_deprecated_logging_warn.py index 4b05f87dc..41f3a1ea1 100644 --- a/tests/codemods/test_fix_deprecated_logging_warn.py +++ b/tests/codemods/test_fix_deprecated_logging_warn.py @@ -28,7 +28,6 @@ def test_import(self, tmpdir, code): original_code = code.format("warn") new_code = code.format("warning") self.run_and_assert(tmpdir, original_code, new_code) - assert len(self.file_context.codemod_changes) == 1 @pytest.mark.parametrize( "code", @@ -47,7 +46,6 @@ def test_from_import(self, tmpdir, code): original_code = code.format("warn") new_code = code.format("warning") self.run_and_assert(tmpdir, original_code, new_code) - assert len(self.file_context.codemod_changes) == 1 @pytest.mark.parametrize( "input_code,expected_output", @@ -74,7 +72,6 @@ def test_from_import(self, tmpdir, code): ) def test_import_alias(self, tmpdir, input_code, expected_output): self.run_and_assert(tmpdir, input_code, expected_output) - assert len(self.file_context.codemod_changes) == 1 @pytest.mark.parametrize( "code", @@ -92,7 +89,6 @@ def test_import_alias(self, tmpdir, input_code, expected_output): ) def test_different_warn(self, tmpdir, code): self.run_and_assert(tmpdir, code, code) - assert len(self.file_context.codemod_changes) == 0 @pytest.mark.xfail(reason="Not currently supported") def test_log_as_arg(self, tmpdir): @@ -106,4 +102,3 @@ def some_function(logger): original_code = code.format("warn") new_code = code.format("warning") self.run_and_assert(tmpdir, original_code, new_code) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_flask_json_response_type.py b/tests/codemods/test_flask_json_response_type.py index 77e9b1a39..6a6757556 100644 --- a/tests/codemods/test_flask_json_response_type.py +++ b/tests/codemods/test_flask_json_response_type.py @@ -7,7 +7,7 @@ class TestFlaskJsonResponseType(BaseCodemodTest): codemod = FlaskJsonResponseType def test_name(self): - assert self.codemod.name() == "flask-json-response-type" + assert self.codemod.name == "flask-json-response-type" def test_simple(self, tmpdir): input_code = """\ @@ -33,7 +33,6 @@ def foo(request): return make_response(json_response, {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_indirect(self, tmpdir): input_code = """\ @@ -61,7 +60,6 @@ def foo(request): return response """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_tuple_arg(self, tmpdir): input_code = """\ @@ -87,7 +85,6 @@ def foo(request): return make_response((json_response, 404, {'Content-Type': 'application/json'})) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_return_json(self, tmpdir): input_code = """\ @@ -113,7 +110,6 @@ def foo(request): return (json_response, {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_tuple(self, tmpdir): input_code = """\ @@ -139,7 +135,6 @@ def foo(request): return (json_response, 404, {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_alias(self, tmpdir): input_code = """\ @@ -165,7 +160,6 @@ def foo(request): return response(json_response, {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_indirect_dict(self, tmpdir): input_code = """\ @@ -193,7 +187,6 @@ def foo(request): return make_response(json_response, headers) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_direct_return(self, tmpdir): input_code = """\ @@ -217,7 +210,6 @@ def foo(request): return make_response(json.dumps({ "user_input": request.GET.get("input") }), {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_tuple_dict_no_key(self, tmpdir): input_code = """\ @@ -243,7 +235,6 @@ def foo(request): return (make_response(json_response), {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_no_route_decorator(self, tmpdir): input_code = """\ @@ -257,7 +248,6 @@ def foo(request): return make_response(json_response) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_content_type_set(self, tmpdir): input_code = """\ @@ -272,7 +262,6 @@ def foo(request): return (make_response(json_response), {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_content_type_maybe_set_star(self, tmpdir): input_code = """\ @@ -288,7 +277,6 @@ def foo(request): return (make_response(json_response), {**another_dict}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_content_type_maybe_set(self, tmpdir): input_code = """\ @@ -304,7 +292,6 @@ def foo(request): return (make_response(json_response), {key:'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_no_json_dumps_input(self, tmpdir): input_code = """\ @@ -319,7 +306,6 @@ def foo(request): return make_response(dict_response) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_unknown_call_response(self, tmpdir): input_code = """\ @@ -335,4 +321,3 @@ def foo(request): return bar(dict_response) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_harden_pyyaml.py b/tests/codemods/test_harden_pyyaml.py index 1e30f2827..cec378bad 100644 --- a/tests/codemods/test_harden_pyyaml.py +++ b/tests/codemods/test_harden_pyyaml.py @@ -12,7 +12,7 @@ class TestHardenPyyaml(BaseSemgrepCodemodTest): codemod = HardenPyyaml def test_name(self): - assert self.codemod.name() == "harden-pyyaml" + assert self.codemod.name == "harden-pyyaml" def test_safe_loader(self, tmpdir): input_code = """import yaml @@ -20,7 +20,6 @@ def test_safe_loader(self, tmpdir): deserialized_data = yaml.load(data, Loader=yaml.SafeLoader) """ self.run_and_assert(tmpdir, input_code, input_code) - assert len(self.file_context.codemod_changes) == 0 @loaders def test_all_unsafe_loaders_arg(self, tmpdir, loader): @@ -34,7 +33,6 @@ def test_all_unsafe_loaders_arg(self, tmpdir, loader): deserialized_data = yaml.load(data, yaml.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 @loaders def test_all_unsafe_loaders_kwarg(self, tmpdir, loader): @@ -48,7 +46,6 @@ def test_all_unsafe_loaders_kwarg(self, tmpdir, loader): deserialized_data = yaml.load(data, Loader=yaml.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_import_alias(self, tmpdir): input_code = """import yaml as yam @@ -64,7 +61,6 @@ def test_import_alias(self, tmpdir): deserialized_data = yam.load(data, Loader=yam.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_preserve_custom_loader(self, tmpdir): expected = input_code = """ @@ -75,7 +71,6 @@ def test_preserve_custom_loader(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 0 def test_preserve_custom_loader_kwarg(self, tmpdir): expected = input_code = """ @@ -86,7 +81,6 @@ def test_preserve_custom_loader_kwarg(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 0 class TestHardenPyyamlClassInherit(BaseSemgrepCodemodTest): @@ -103,7 +97,6 @@ def __init__(self, *args, **kwargs): """ self.run_and_assert(tmpdir, input_code, input_code) - assert len(self.file_context.codemod_changes) == 0 @loaders def test_unsafe_loaders(self, tmpdir, loader): @@ -122,7 +115,6 @@ def __init__(self, *args, **kwargs): super(MyCustomLoader, self).__init__(*args, **kwargs) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_from_import(self, tmpdir): input_code = """\ @@ -140,7 +132,6 @@ def __init__(self, *args, **kwargs): super(MyCustomLoader, self).__init__(*args, **kwargs) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_import_alias(self, tmpdir): input_code = """\ @@ -158,7 +149,6 @@ def __init__(self, *args, **kwargs): super(MyCustomLoader, self).__init__(*args, **kwargs) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_multiple_bases(self, tmpdir): input_code = """\ @@ -180,7 +170,6 @@ def __init__(self, *args, **kwargs): super(MyCustomLoader, self).__init__(*args, **kwargs) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_different_yaml(self, tmpdir): input_code = """\ @@ -198,4 +187,3 @@ class MyLoader(SafeLoader, yaml.Loader): ... """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_harden_ruamel.py b/tests/codemods/test_harden_ruamel.py index 2a388a4d3..3f997c9e4 100644 --- a/tests/codemods/test_harden_ruamel.py +++ b/tests/codemods/test_harden_ruamel.py @@ -7,7 +7,7 @@ class TestHardenRuamel(BaseSemgrepCodemodTest): codemod = HardenRuamel def test_name(self): - assert self.codemod.name() == "harden-ruamel" + assert self.codemod.name == "harden-ruamel" @pytest.mark.parametrize("loader", ["YAML()", "YAML(typ='rt')", "YAML(typ='safe')"]) def test_safe(self, tmpdir, loader): diff --git a/tests/codemods/test_https_connection.py b/tests/codemods/test_https_connection.py index 005586589..5fbf20b83 100644 --- a/tests/codemods/test_https_connection.py +++ b/tests/codemods/test_https_connection.py @@ -11,7 +11,6 @@ def test_no_change(self, tmpdir): urllib3.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_simple(self, tmpdir): before = r"""import urllib3 @@ -23,7 +22,6 @@ def test_simple(self, tmpdir): urllib3.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_module_alias(self, tmpdir): before = r"""import urllib3 as module @@ -35,7 +33,6 @@ def test_module_alias(self, tmpdir): module.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_alias(self, tmpdir): before = r"""from urllib3 import HTTPConnectionPool as something @@ -47,7 +44,6 @@ def test_alias(self, tmpdir): urllib3.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_connectionpool(self, tmpdir): before = r"""import urllib3 @@ -59,7 +55,6 @@ def test_connectionpool(self, tmpdir): urllib3.connectionpool.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_connectionpool_alias(self, tmpdir): before = r"""import urllib3.connectionpool as pool @@ -71,7 +66,6 @@ def test_connectionpool_alias(self, tmpdir): pool.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_last_arg(self, tmpdir): before = r"""import urllib3 @@ -83,4 +77,3 @@ def test_last_arg(self, tmpdir): urllib3.HTTPSConnectionPool(None, None, None, None, None, None, None, None, None, _proxy_config = None) """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_jwt_decode_verify.py b/tests/codemods/test_jwt_decode_verify.py index 35dc1a1ab..677b4bfa9 100644 --- a/tests/codemods/test_jwt_decode_verify.py +++ b/tests/codemods/test_jwt_decode_verify.py @@ -7,7 +7,7 @@ class TestJwtDecodeVerify(BaseSemgrepCodemodTest): codemod = JwtDecodeVerify def test_name(self): - assert self.codemod.name() == "jwt-decode-verify" + assert self.codemod.name == "jwt-decode-verify" def test_import(self, tmpdir): input_code = """import jwt diff --git a/tests/codemods/test_limit_readline.py b/tests/codemods/test_limit_readline.py index c7d943639..bf6a7be24 100644 --- a/tests/codemods/test_limit_readline.py +++ b/tests/codemods/test_limit_readline.py @@ -6,7 +6,7 @@ class TestLimitReadline(BaseSemgrepCodemodTest): codemod = LimitReadline def test_name(self): - assert self.codemod.name() == "limit-readline" + assert self.codemod.name == "limit-readline" def test_file_readline(self, tmpdir): input_code = """file = open('some_file.txt') diff --git a/tests/codemods/test_literal_or_new_object_identity.py b/tests/codemods/test_literal_or_new_object_identity.py index fd165f15a..91e2251fe 100644 --- a/tests/codemods/test_literal_or_new_object_identity.py +++ b/tests/codemods/test_literal_or_new_object_identity.py @@ -7,7 +7,7 @@ class TestLiteralOrNewObjectIdentity(BaseCodemodTest): codemod = LiteralOrNewObjectIdentity def test_name(self): - assert self.codemod.name() == "literal-or-new-object-identity" + assert self.codemod.name == "literal-or-new-object-identity" def test_list(self, tmpdir): input_code = """\ @@ -17,7 +17,6 @@ def test_list(self, tmpdir): l == [1,2,3] """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_list_indirect(self, tmpdir): input_code = """\ @@ -29,7 +28,6 @@ def test_list_indirect(self, tmpdir): l == some_list """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_list_lhs(self, tmpdir): input_code = """\ @@ -39,7 +37,6 @@ def test_list_lhs(self, tmpdir): [1,2,3] == l """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_list_function(self, tmpdir): input_code = """\ @@ -49,7 +46,6 @@ def test_list_function(self, tmpdir): l == list({1,2,3}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_dict(self, tmpdir): input_code = """\ @@ -59,7 +55,6 @@ def test_dict(self, tmpdir): l == {1:2} """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_dict_function(self, tmpdir): input_code = """\ @@ -69,7 +64,6 @@ def test_dict_function(self, tmpdir): l == dict({1,2,3}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_tuple(self, tmpdir): input_code = """\ @@ -79,7 +73,6 @@ def test_tuple(self, tmpdir): l == (1,2,3) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_tuple_function(self, tmpdir): input_code = """\ @@ -89,7 +82,6 @@ def test_tuple_function(self, tmpdir): l == tuple({1,2,3}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_set(self, tmpdir): input_code = """\ @@ -99,7 +91,6 @@ def test_set(self, tmpdir): l == {1,2,3} """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_set_function(self, tmpdir): input_code = """\ @@ -109,7 +100,6 @@ def test_set_function(self, tmpdir): l == set([1,2,3]) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_int(self, tmpdir): input_code = """\ @@ -119,7 +109,6 @@ def test_int(self, tmpdir): l == 1 """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_float(self, tmpdir): input_code = """\ @@ -129,7 +118,6 @@ def test_float(self, tmpdir): l == 1.0 """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_imaginary(self, tmpdir): input_code = """\ @@ -139,7 +127,6 @@ def test_imaginary(self, tmpdir): l == 1j """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_str(self, tmpdir): input_code = """\ @@ -149,7 +136,6 @@ def test_str(self, tmpdir): l == '1' """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_fstr(self, tmpdir): input_code = """\ @@ -159,7 +145,6 @@ def test_fstr(self, tmpdir): l == f'1' """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_concatenated_str(self, tmpdir): input_code = """\ @@ -169,7 +154,6 @@ def test_concatenated_str(self, tmpdir): l == '1' ',2' """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_negative(self, tmpdir): input_code = """\ @@ -179,18 +163,15 @@ def test_negative(self, tmpdir): l != [1,2,3] """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_do_nothing(self, tmpdir): input_code = """\ l == [1,2,3] """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_do_nothing_negative(self, tmpdir): input_code = """\ l != [1,2,3] """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_lxml_safe_parameter_defaults.py b/tests/codemods/test_lxml_safe_parameter_defaults.py index 067539041..e28842801 100644 --- a/tests/codemods/test_lxml_safe_parameter_defaults.py +++ b/tests/codemods/test_lxml_safe_parameter_defaults.py @@ -11,7 +11,7 @@ class TestLxmlSafeParserDefaults(BaseSemgrepCodemodTest): codemod = LxmlSafeParserDefaults def test_name(self): - assert self.codemod.name() == "safe-lxml-parser-defaults" + assert self.codemod.name == "safe-lxml-parser-defaults" @each_class def test_import(self, tmpdir, klass): diff --git a/tests/codemods/test_lxml_safe_parsing.py b/tests/codemods/test_lxml_safe_parsing.py index c23a26d2e..3400b1668 100644 --- a/tests/codemods/test_lxml_safe_parsing.py +++ b/tests/codemods/test_lxml_safe_parsing.py @@ -9,7 +9,7 @@ class TestLxmlSafeParsing(BaseSemgrepCodemodTest): codemod = LxmlSafeParsing def test_name(self): - assert self.codemod.name() == "safe-lxml-parsing" + assert self.codemod.name == "safe-lxml-parsing" @each_func def test_import(self, tmpdir, func): diff --git a/tests/codemods/test_numpy_nan_equality.py b/tests/codemods/test_numpy_nan_equality.py index f1440788e..21ead13f0 100644 --- a/tests/codemods/test_numpy_nan_equality.py +++ b/tests/codemods/test_numpy_nan_equality.py @@ -7,7 +7,7 @@ class TestNumpyNanEquality(BaseCodemodTest): codemod = NumpyNanEquality def test_name(self): - assert self.codemod.name() == "numpy-nan-equality" + assert self.codemod.name == "numpy-nan-equality" def test_simple(self, tmpdir): input_code = """\ @@ -21,7 +21,6 @@ def test_simple(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_inequality(self, tmpdir): input_code = """\ @@ -35,7 +34,6 @@ def test_simple_inequality(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_inequality_2(self, tmpdir): input_code = """\ @@ -49,7 +47,6 @@ def test_simple_inequality_2(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_parenthesis(self, tmpdir): input_code = """\ @@ -63,7 +60,6 @@ def test_simple_parenthesis(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_conjunction(self, tmpdir): input_code = """\ @@ -77,7 +73,6 @@ def test_conjunction(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 2 def test_from_numpy(self, tmpdir): input_code = """\ @@ -92,7 +87,6 @@ def test_from_numpy(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_left(self, tmpdir): input_code = """\ @@ -106,7 +100,6 @@ def test_simple_left(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_alias(self, tmpdir): input_code = """\ @@ -120,7 +113,6 @@ def test_alias(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_multiple_comparisons(self, tmpdir): input_code = """\ @@ -129,7 +121,6 @@ def test_multiple_comparisons(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_not_numpy(self, tmpdir): input_code = """\ @@ -138,7 +129,6 @@ def test_not_numpy(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_numpy_other_operator(self, tmpdir): input_code = """\ @@ -147,4 +137,3 @@ def test_numpy_other_operator(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_order_imports.py b/tests/codemods/test_order_imports.py index 7afcdbb28..f7a4c3655 100644 --- a/tests/codemods/test_order_imports.py +++ b/tests/codemods/test_order_imports.py @@ -10,7 +10,6 @@ def test_no_change(self, tmpdir): from b import c """ self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_separate_from_imports_and_regular(self, tmpdir): before = r"""import y @@ -21,7 +20,6 @@ def test_separate_from_imports_and_regular(self, tmpdir): import y from a import c""" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_consolidate_from_imports(self, tmpdir): before = r"""from a import a1 @@ -30,7 +28,6 @@ def test_consolidate_from_imports(self, tmpdir): after = r"""from a import a1, a2, a3""" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_order_blocks_separately(self, tmpdir): before = r"""import x @@ -46,7 +43,6 @@ def test_order_blocks_separately(self, tmpdir): import y""" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 2 def test_preserve_comments(self, tmpdir): before = r"""# do not move @@ -69,7 +65,6 @@ def test_preserve_comments(self, tmpdir): """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_handle_star_imports(self, tmpdir): before = r"""from a import x @@ -82,7 +77,6 @@ def test_handle_star_imports(self, tmpdir): from a import b, x""" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_handle_composite_and_relative_imports(self, tmpdir): before = r"""from . import a @@ -93,7 +87,6 @@ def test_handle_composite_and_relative_imports(self, tmpdir): from . import a""" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_natural_order(self, tmpdir): before = """from a import Object11 @@ -104,7 +97,6 @@ def test_natural_order(self, tmpdir): """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_wont_remove_unused_future(self, tmpdir): before = """from __future__ import absolute_import @@ -113,7 +105,6 @@ def test_wont_remove_unused_future(self, tmpdir): after = before self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 0 def test_organize_by_sections(self, tmpdir): before = """from codemodder.codemods.transformations.clean_imports import CleanImports @@ -134,7 +125,6 @@ def test_organize_by_sections(self, tmpdir): """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_will_ignore_non_top_level(self, tmpdir): before = """import global2 @@ -159,7 +149,6 @@ def f(): import global3 """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_it_can_change_behavior(self, tmpdir): # note that c will change from b to e due to the sort @@ -175,4 +164,3 @@ def test_it_can_change_behavior(self, tmpdir): c() """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_process_creation_sandbox.py b/tests/codemods/test_process_creation_sandbox.py index 51309755d..bafcad044 100644 --- a/tests/codemods/test_process_creation_sandbox.py +++ b/tests/codemods/test_process_creation_sandbox.py @@ -1,16 +1,19 @@ import pytest +import mock + from codemodder.dependency import Security from core_codemods.process_creation_sandbox import ProcessSandbox from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest +@mock.patch("codemodder.codemods.api.FileContext.add_dependency") class TestProcessCreationSandbox(BaseSemgrepCodemodTest): codemod = ProcessSandbox - def test_name(self): - assert self.codemod.name() == "sandbox-process-creation" + def test_name(self, _): + assert self.codemod.name == "sandbox-process-creation" - def test_import_subprocess(self, tmpdir): + def test_import_subprocess(self, adds_dependency, tmpdir): input_code = """ import subprocess @@ -25,9 +28,9 @@ def test_import_subprocess(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) - def test_import_alias(self, tmpdir): + def test_import_alias(self, adds_dependency, tmpdir): input_code = """ import subprocess as sub @@ -42,9 +45,9 @@ def test_import_alias(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) - def test_from_subprocess(self, tmpdir): + def test_from_subprocess(self, adds_dependency, tmpdir): input_code = """ from subprocess import run @@ -59,9 +62,9 @@ def test_from_subprocess(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) - def test_subprocess_nameerror(self, tmpdir): + def test_subprocess_nameerror(self, _, tmpdir): input_code = """ subprocess.run("echo 'hi'", shell=True) @@ -107,11 +110,13 @@ def test_subprocess_nameerror(self, tmpdir): ), ], ) - def test_other_import_untouched(self, tmpdir, input_code, expected): + def test_other_import_untouched( + self, adds_dependency, tmpdir, input_code, expected + ): self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) - def test_multifunctions(self, tmpdir): + def test_multifunctions(self, adds_dependency, tmpdir): # Test that subprocess methods that aren't part of the codemod are not changed. # If we add the function as one of # our codemods, this test would change. @@ -129,9 +134,9 @@ def test_multifunctions(self, tmpdir): subprocess.check_output(["ls", "-l"])""" self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) - def test_custom_run(self, tmpdir): + def test_custom_run(self, _, tmpdir): input_code = """ from app_funcs import run @@ -139,7 +144,7 @@ def test_custom_run(self, tmpdir): expected = input_code self.run_and_assert(tmpdir, input_code, expected) - def test_subprocess_call(self, tmpdir): + def test_subprocess_call(self, adds_dependency, tmpdir): input_code = """ import subprocess @@ -152,9 +157,9 @@ def test_subprocess_call(self, tmpdir): safe_command.run(subprocess.call, ["ls", "-l"]) """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) - def test_subprocess_popen(self, tmpdir): + def test_subprocess_popen(self, adds_dependency, tmpdir): input_code = """ import subprocess @@ -167,4 +172,4 @@ def test_subprocess_popen(self, tmpdir): safe_command.run(subprocess.Popen, ["ls", "-l"]) """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) diff --git a/tests/codemods/test_remove_debug_breakpoint.py b/tests/codemods/test_remove_debug_breakpoint.py index 27e10b834..cd37e1ea0 100644 --- a/tests/codemods/test_remove_debug_breakpoint.py +++ b/tests/codemods/test_remove_debug_breakpoint.py @@ -6,7 +6,7 @@ class TestRemoveDebugBreakpoint(BaseCodemodTest): codemod = RemoveDebugBreakpoint def test_name(self): - assert self.codemod.name() == "remove-debug-breakpoint" + assert self.codemod.name == "remove-debug-breakpoint" def test_builtin_breakpoint(self, tmpdir): input_code = """\ @@ -21,7 +21,6 @@ def something(): something() """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_builtin_breakpoint_multiple_statements(self, tmpdir): input_code = """\ @@ -37,7 +36,6 @@ def something(): something() """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_inline_pdb(self, tmpdir): input_code = """\ @@ -52,7 +50,6 @@ def something(): something() """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_pdb_import(self, tmpdir): input_code = """\ @@ -68,7 +65,6 @@ def something(): something() """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_pdb_from_import(self, tmpdir): input_code = """\ @@ -84,4 +80,3 @@ def something(): something() """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_remove_module_global.py b/tests/codemods/test_remove_module_global.py index 12e95f794..2e9f3e022 100644 --- a/tests/codemods/test_remove_module_global.py +++ b/tests/codemods/test_remove_module_global.py @@ -7,7 +7,7 @@ class TestJwtDecodeVerify(BaseCodemodTest): codemod = RemoveModuleGlobal def test_name(self): - assert self.codemod.name() == "remove-module-global" + assert self.codemod.name == "remove-module-global" def test_simple(self, tmpdir): input_code = """\ @@ -18,7 +18,6 @@ def test_simple(self, tmpdir): price = 25 """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_reassigned(self, tmpdir): input_code = """\ @@ -33,7 +32,6 @@ def test_reassigned(self, tmpdir): price = 30 """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_attr_call(self, tmpdir): input_code = """\ @@ -50,7 +48,6 @@ class Price: PRICE.__repr__ """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_correct_scope(self, tmpdir): input_code = """\ @@ -60,4 +57,3 @@ def change_price(): price = 30 """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_remove_unnecessary_f_str.py b/tests/codemods/test_remove_unnecessary_f_str.py index 35c49c32e..042ac37e7 100644 --- a/tests/codemods/test_remove_unnecessary_f_str.py +++ b/tests/codemods/test_remove_unnecessary_f_str.py @@ -17,7 +17,6 @@ def test_no_change(self, tmpdir): good = f"wow i don't have args but don't mess my braces {{ up }}" """ self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_change(self, tmpdir): before = r""" @@ -31,4 +30,3 @@ def test_change(self, tmpdir): bad: str = r'bad\d+' """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 3 diff --git a/tests/codemods/test_remove_unused_imports.py b/tests/codemods/test_remove_unused_imports.py index 9e675b5c3..9f2e65fc5 100644 --- a/tests/codemods/test_remove_unused_imports.py +++ b/tests/codemods/test_remove_unused_imports.py @@ -14,7 +14,6 @@ def test_no_change(self, tmpdir): c """ self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_change(self, tmpdir): before = r"""import a @@ -22,7 +21,6 @@ def test_change(self, tmpdir): after = r""" """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_remove_import(self, tmpdir): before = r"""import a @@ -34,7 +32,6 @@ def test_remove_import(self, tmpdir): """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_remove_single_from_import(self, tmpdir): before = r"""from b import c, d @@ -45,7 +42,6 @@ def test_remove_single_from_import(self, tmpdir): c """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_remove_from_import(self, tmpdir): before = r"""from b import c @@ -53,7 +49,6 @@ def test_remove_from_import(self, tmpdir): after = "\n" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_remove_inner_import(self, tmpdir): before = r"""import a @@ -67,7 +62,6 @@ def something(): """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_no_import_star_removal(self, tmpdir): before = r"""import a @@ -80,17 +74,14 @@ def test_keep_format(self, tmpdir): before = "from a import b,c,d \nprint(b)\nprint(d)" after = "from a import b,d \nprint(b)\nprint(d)" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_dont_remove_if_noqa_before(self, tmpdir): before = "import a\n# noqa\nimport b\na()" self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_dont_remove_if_noqa_trailing(self, tmpdir): before = "import a\nimport b # noqa\na()" self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_dont_remove_if_noqa_trailing_multiline(self, tmpdir): before = dedent( @@ -101,34 +92,28 @@ def test_dont_remove_if_noqa_trailing_multiline(self, tmpdir): ) self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_dont_remove_if_pylint_disable(self, tmpdir): before = "import a\nimport b # pylint: disable=W0611\na()" self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_dont_remove_if_pylint_disable_next(self, tmpdir): before = ( "import a\n# pylint: disable-next=no-member, unused-import\nimport b\na()" ) self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_ignore_init_files(self, tmpdir): before = "import a" tmp_file_path = Path(tmpdir / "__init__.py") - self.run_and_assert_filepath(tmpdir, tmp_file_path, before, before) - assert len(self.file_context.codemod_changes) == 0 + self.run_and_assert(tmpdir, before, before, files=[tmp_file_path]) def test_no_pyling_pragma_in_comment_trailing(self, tmpdir): before = "import a # bogus: no-pragma" after = "" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_no_pyling_pragma_in_comment_before(self, tmpdir): before = "#header\nprint('hello')\n# bogus: no-pragma\nimport a " after = "#header\nprint('hello')" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_request_verify.py b/tests/codemods/test_request_verify.py index 28f00b1b8..503c4d4e7 100644 --- a/tests/codemods/test_request_verify.py +++ b/tests/codemods/test_request_verify.py @@ -9,7 +9,7 @@ class TestRequestsVerify(BaseSemgrepCodemodTest): codemod = RequestsVerify def test_name(self): - assert self.codemod.name() == "requests-verify" + assert self.codemod.name == "requests-verify" @pytest.mark.parametrize("func", REQUEST_FUNCS) def test_default_verify(self, tmpdir, func): diff --git a/tests/codemods/test_secure_flask_cookie.py b/tests/codemods/test_secure_flask_cookie.py index 923a829a1..b7304aaa6 100644 --- a/tests/codemods/test_secure_flask_cookie.py +++ b/tests/codemods/test_secure_flask_cookie.py @@ -10,7 +10,7 @@ class TestSecureFlaskCookie(BaseSemgrepCodemodTest): codemod = SecureFlaskCookie def test_name(self): - assert self.codemod.name() == "secure-flask-cookie" + assert self.codemod.name == "secure-flask-cookie" @each_func def test_import(self, tmpdir, func): diff --git a/tests/codemods/test_secure_flask_session_config.py b/tests/codemods/test_secure_flask_session_config.py index 60c834ad6..353c94600 100644 --- a/tests/codemods/test_secure_flask_session_config.py +++ b/tests/codemods/test_secure_flask_session_config.py @@ -8,7 +8,7 @@ class TestSecureFlaskSessionConfig(BaseCodemodTest): codemod = SecureFlaskSessionConfig def test_name(self): - assert self.codemod.name() == "secure-flask-session-configuration" + assert self.codemod.name == "secure-flask-session-configuration" def test_no_flask_app(self, tmpdir): input_code = """\ @@ -19,7 +19,6 @@ def test_no_flask_app(self, tmpdir): response.set_cookie("name", "value") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_app_defined_separate_module(self, tmpdir): # TODO: test this as an integration test with two real modules @@ -29,7 +28,6 @@ def test_app_defined_separate_module(self, tmpdir): app.config["SESSION_COOKIE_SECURE"] = False """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_app_not_assigned(self, tmpdir): input_code = """\ @@ -39,7 +37,6 @@ def test_app_not_assigned(self, tmpdir): print(1) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_app_accessed_config_not_called(self, tmpdir): input_code = """\ @@ -50,7 +47,6 @@ def test_app_accessed_config_not_called(self, tmpdir): # more code """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_app_update_no_keyword(self, tmpdir): input_code = """\ @@ -63,8 +59,6 @@ def foo(test_config=None): app.config.update(test_config) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - print(self.file_context.codemod_changes) - assert len(self.file_context.codemod_changes) == 0 def test_from_import(self, tmpdir): input_code = """\ @@ -82,7 +76,6 @@ def test_from_import(self, tmpdir): app.config.update(SESSION_COOKIE_SECURE=True) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - assert len(self.file_context.codemod_changes) == 1 def test_import_alias(self, tmpdir): input_code = """\ @@ -100,7 +93,6 @@ def test_import_alias(self, tmpdir): app.config.update(SESSION_COOKIE_SECURE=True) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - assert len(self.file_context.codemod_changes) == 1 def test_annotated_assign(self, tmpdir): input_code = """\ @@ -118,7 +110,6 @@ def test_annotated_assign(self, tmpdir): app.config.update(SESSION_COOKIE_SECURE=True) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - assert len(self.file_context.codemod_changes) == 1 def test_other_assignment_type(self, tmpdir): input_code = """\ @@ -132,7 +123,6 @@ class AppStore: store.app.config.update(SESSION_COOKIE_SECURE=False) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 @pytest.mark.parametrize( "config_lines,expected_config_lines", @@ -246,4 +236,3 @@ def configure(): # either within configure() call or after it's called """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_secure_random.py b/tests/codemods/test_secure_random.py index fc1d3d2d1..305c14220 100644 --- a/tests/codemods/test_secure_random.py +++ b/tests/codemods/test_secure_random.py @@ -7,74 +7,84 @@ class TestSecureRandom(BaseSemgrepCodemodTest): codemod = SecureRandom def test_name(self): - assert self.codemod.name() == "secure-random" + assert self.codemod.name == "secure-random" def test_import_random(self, tmpdir): - input_code = """import random + input_code = """ + import random -random.random() -var = "hello" -""" - expected_output = """import secrets + random.random() + var = "hello" + """ + expected_output = """ + import secrets -secrets.SystemRandom().random() -var = "hello" -""" + secrets.SystemRandom().random() + var = "hello" + """ self.run_and_assert(tmpdir, input_code, expected_output) def test_from_random(self, tmpdir): - input_code = """from random import random - -random() -var = "hello" -""" - expected_output = """import secrets - -secrets.SystemRandom().random() -var = "hello" -""" + input_code = """ + from random import random + + random() + var = "hello" + """ + expected_output = """ + import secrets + + secrets.SystemRandom().random() + var = "hello" + """ self.run_and_assert(tmpdir, input_code, expected_output) def test_random_alias(self, tmpdir): - input_code = """import random as alleatory - -alleatory.random() -var = "hello" -""" - expected_output = """import secrets - -secrets.SystemRandom().random() -var = "hello" -""" + input_code = """ + import random as alleatory + + alleatory.random() + var = "hello" + """ + expected_output = """ + import secrets + + secrets.SystemRandom().random() + var = "hello" + """ self.run_and_assert(tmpdir, input_code, expected_output) @pytest.mark.parametrize( "input_code,expected_output", [ ( - """import random - -random.randint(0, 10) -var = "hello" -""", - """import secrets - -secrets.SystemRandom().randint(0, 10) -var = "hello" -""", + """ + import random + + random.randint(0, 10) + var = "hello" + """, + """ + import secrets + + secrets.SystemRandom().randint(0, 10) + var = "hello" + """, ), ( - """from random import randint - -randint(0, 10) -var = "hello" -""", - """import secrets - -secrets.SystemRandom().randint(0, 10) -var = "hello" -""", + """ + from random import randint + + randint(0, 10) + var = "hello" + """, + """ + import secrets + + secrets.SystemRandom().randint(0, 10) + var = "hello" + """, ), ], ) @@ -82,48 +92,54 @@ def test_random_randint(self, tmpdir, input_code, expected_output): self.run_and_assert(tmpdir, input_code, expected_output) def test_multiple_calls(self, tmpdir): - input_code = """import random - -random.random() -random.randint() -var = "hello" -""" - expected_output = """import secrets - -secrets.SystemRandom().random() -secrets.SystemRandom().randint() -var = "hello" -""" + input_code = """ + import random + + random.random() + random.randint() + var = "hello" + """ + expected_output = """ + import secrets + + secrets.SystemRandom().random() + secrets.SystemRandom().randint() + var = "hello" + """ self.run_and_assert(tmpdir, input_code, expected_output) @pytest.mark.parametrize( "input_code,expected_output", [ ( - """import random -import csv -random.random() -csv.excel -""", - """import csv -import secrets - -secrets.SystemRandom().random() -csv.excel -""", + """ + import random + import csv + random.random() + csv.excel + """, + """ + import csv + import secrets + + secrets.SystemRandom().random() + csv.excel + """, ), ( - """import random -from csv import excel -random.random() -excel -""", - """from csv import excel -import secrets - -secrets.SystemRandom().random() -excel -""", + """ + import random + from csv import excel + random.random() + excel + """, + """ + from csv import excel + import secrets + + secrets.SystemRandom().random() + excel + """, ), ], ) @@ -131,9 +147,10 @@ def test_random_other_import_untouched(self, tmpdir, input_code, expected_output self.run_and_assert(tmpdir, input_code, expected_output) def test_random_nameerror(self, tmpdir): - input_code = """random.random() + input_code = """ + random.random() -import random""" + import random""" expected_output = input_code self.run_and_assert(tmpdir, input_code, expected_output) @@ -141,17 +158,19 @@ def test_random_multifunctions(self, tmpdir): # Test that `random` import isn't removed if code uses part of the random # library that isn't part of our codemods. If we add the function as one of # our codemods, this test would change. - input_code = """import random + input_code = """ + import random -random.random() -random.__all__ -""" + random.random() + random.__all__ + """ - expected_output = """import random -import secrets + expected_output = """ + import random + import secrets -secrets.SystemRandom().random() -random.__all__ -""" + secrets.SystemRandom().random() + random.__all__ + """ self.run_and_assert(tmpdir, input_code, expected_output) diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py index e7c721849..5256d3bad 100644 --- a/tests/codemods/test_sql_parameterization.py +++ b/tests/codemods/test_sql_parameterization.py @@ -7,7 +7,7 @@ class TestSQLQueryParameterization(BaseCodemodTest): codemod = SQLQueryParameterization def test_name(self): - assert self.codemod.name() == "sql-parameterization" + assert self.codemod.name == "sql-parameterization" def test_simple(self, tmpdir): input_code = """\ @@ -27,7 +27,6 @@ def test_simple(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name =?", (name, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_multiple(self, tmpdir): input_code = """\ @@ -49,7 +48,6 @@ def test_multiple(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name =?" + r" AND phone =?", (name, phone, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_with_quotes_in_middle(self, tmpdir): input_code = """\ @@ -69,7 +67,6 @@ def test_simple_with_quotes_in_middle(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name =?", ('user_{0}{1}'.format(name, r"_system"), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_can_deal_with_multiple_variables(self, tmpdir): input_code = """\ @@ -94,7 +91,6 @@ def foo(self, cursor, name, phone): return cursor.execute(a + b + c, (name, phone, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_if(self, tmpdir): input_code = """\ @@ -114,7 +110,6 @@ def test_simple_if(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name =?", (('Jenny' if True else name), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_multiple_escaped_quote(self, tmpdir): input_code = """\ @@ -136,7 +131,6 @@ def test_multiple_escaped_quote(self, tmpdir): cursor.execute('SELECT * from USERS WHERE name =?' + ' AND phone =?', (name, phone, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_concatenated_strings(self, tmpdir): input_code = """\ @@ -156,7 +150,6 @@ def test_simple_concatenated_strings(self, tmpdir): cursor.execute("SELECT * from USERS" "WHERE name =?", (name, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 class TestSQLQueryParameterizationFormattedString(BaseCodemodTest): @@ -180,7 +173,6 @@ def test_formatted_string_simple(self, tmpdir): cursor.execute(f"SELECT * from USERS WHERE name=?", (name, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_formatted_string_quote_in_middle(self, tmpdir): input_code = """\ @@ -200,7 +192,6 @@ def test_formatted_string_quote_in_middle(self, tmpdir): cursor.execute(f"SELECT * from USERS WHERE name=?", ('user_{0}_admin'.format(name), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_formatted_string_with_literal(self, tmpdir): input_code = """\ @@ -220,7 +211,6 @@ def test_formatted_string_with_literal(self, tmpdir): cursor.execute(f"SELECT * from USERS WHERE name=?", ('{0}_{1}'.format(name, 1+2), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_formatted_string_nested(self, tmpdir): input_code = """\ @@ -240,7 +230,6 @@ def test_formatted_string_nested(self, tmpdir): cursor.execute(f"SELECT * from USERS WHERE name={f"?"}", (name, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_formatted_string_concat_mixed(self, tmpdir): input_code = """\ @@ -260,7 +249,6 @@ def test_formatted_string_concat_mixed(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name=?", ('{0}_{1}'.format(name, b'123'), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_multiple_expressions_injection(self, tmpdir): input_code = """\ @@ -280,7 +268,6 @@ def test_multiple_expressions_injection(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name =?", ('{0}_username'.format(name), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 class TestSQLQueryParameterizationNegative(BaseCodemodTest): @@ -299,7 +286,6 @@ def foo(self, cursor, name, phone): return cursor.execute(a + b + c) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_wont_mess_with_byte_strings(self, tmpdir): input_code = """\ @@ -310,7 +296,6 @@ def test_wont_mess_with_byte_strings(self, tmpdir): cursor.execute("SELECT * from USERS WHERE " + b"name ='" + str(1234) + b"'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_wont_parameterize_literals(self, tmpdir): input_code = """\ @@ -321,7 +306,6 @@ def test_wont_parameterize_literals(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name ='" + str(1234) + "'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_wont_parameterize_literals_if(self, tmpdir): input_code = """\ @@ -332,7 +316,6 @@ def test_wont_parameterize_literals_if(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name ='" + ('Jenny' if True else 'Lorelei') + "'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_will_ignore_escaped_quote(self, tmpdir): input_code = """\ @@ -343,7 +326,6 @@ def test_will_ignore_escaped_quote(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name ='Jenny\'s username'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_already_has_parameters(self, tmpdir): input_code = """\ @@ -357,7 +339,6 @@ def foo(self, cursor, name, phone): return cursor.execute(a + b + c, (phone,)) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_wont_change_class_attribute(self, tmpdir): # query may be accesed from outside the module by importing A @@ -373,7 +354,6 @@ def foo(self, name, cursor): return cursor.execute(query + name + "'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_wont_change_module_variable(self, tmpdir): # query may be accesed from outside the module by importing it @@ -386,4 +366,3 @@ def foo(name, cursor): return cursor.execute(query + name + "'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_subprocess_shell_false.py b/tests/codemods/test_subprocess_shell_false.py index b19c95134..35a971ffe 100644 --- a/tests/codemods/test_subprocess_shell_false.py +++ b/tests/codemods/test_subprocess_shell_false.py @@ -11,7 +11,7 @@ class TestSubprocessShellFalse(BaseCodemodTest): codemod = SubprocessShellFalse def test_name(self): - assert self.codemod.name() == "subprocess-shell-false" + assert self.codemod.name == "subprocess-shell-false" @each_func def test_import(self, tmpdir, func): @@ -24,7 +24,6 @@ def test_import(self, tmpdir, func): subprocess.{func}(args, shell=False) """ self.run_and_assert(tmpdir, input_code, expexted_output) - assert len(self.file_context.codemod_changes) == 1 @each_func def test_from_import(self, tmpdir, func): @@ -37,7 +36,6 @@ def test_from_import(self, tmpdir, func): {func}(args, shell=False) """ self.run_and_assert(tmpdir, input_code, expexted_output) - assert len(self.file_context.codemod_changes) == 1 @each_func def test_no_shell(self, tmpdir, func): @@ -46,7 +44,6 @@ def test_no_shell(self, tmpdir, func): subprocess.{func}(args, timeout=1) """ self.run_and_assert(tmpdir, input_code, input_code) - assert len(self.file_context.codemod_changes) == 0 @each_func def test_shell_False(self, tmpdir, func): @@ -55,4 +52,3 @@ def test_shell_False(self, tmpdir, func): subprocess.{func}(args, shell=False) """ self.run_and_assert(tmpdir, input_code, input_code) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_tempfile_mktemp.py b/tests/codemods/test_tempfile_mktemp.py index b4bb69e55..3d28019be 100644 --- a/tests/codemods/test_tempfile_mktemp.py +++ b/tests/codemods/test_tempfile_mktemp.py @@ -6,7 +6,7 @@ class TestTempfileMktemp(BaseSemgrepCodemodTest): codemod = TempfileMktemp def test_name(self): - assert self.codemod.name() == "secure-tempfile" + assert self.codemod.name == "secure-tempfile" def test_import(self, tmpdir): input_code = """import tempfile diff --git a/tests/codemods/test_url_sandbox.py b/tests/codemods/test_url_sandbox.py index 9978dd9a3..800649ce9 100644 --- a/tests/codemods/test_url_sandbox.py +++ b/tests/codemods/test_url_sandbox.py @@ -1,17 +1,19 @@ import pytest +import mock from codemodder.dependency import Security from core_codemods.url_sandbox import UrlSandbox from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest +@mock.patch("codemodder.codemods.api.FileContext.add_dependency") class TestUrlSandbox(BaseSemgrepCodemodTest): codemod = UrlSandbox - def test_name(self): - assert self.codemod.name() == "url-sandbox" + def test_name(self, _): + assert self.codemod.name == "url-sandbox" - def test_import_requests(self, tmpdir): + def test_import_requests(self, add_dependency, tmpdir): input_code = """ import requests @@ -27,9 +29,9 @@ def test_import_requests(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + add_dependency.assert_called_once_with(Security) - def test_from_requests(self, tmpdir): + def test_from_requests(self, add_dependency, tmpdir): input_code = """ from requests import get @@ -45,9 +47,9 @@ def test_from_requests(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + add_dependency.assert_called_once_with(Security) - def test_requests_nameerror(self, tmpdir): + def test_requests_nameerror(self, _, tmpdir): input_code = """ url = input() requests.get(url) @@ -96,11 +98,13 @@ def test_requests_nameerror(self, tmpdir): ), ], ) - def test_requests_other_import_untouched(self, tmpdir, input_code, expected): + def test_requests_other_import_untouched( + self, add_dependency, tmpdir, input_code, expected + ): self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + add_dependency.assert_called_once_with(Security) - def test_requests_multifunctions(self, tmpdir): + def test_requests_multifunctions(self, add_dependency, tmpdir): # Test that `requests` import isn't removed if code uses part of the requests # library that isn't part of our codemods. If we add the function as one of # our codemods, this test would change. @@ -122,9 +126,9 @@ def test_requests_multifunctions(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + add_dependency.assert_called_once_with(Security) - def test_custom_get(self, tmpdir): + def test_custom_get(self, _, tmpdir): input_code = """ from app_funcs import get @@ -134,7 +138,7 @@ def test_custom_get(self, tmpdir): expected = input_code self.run_and_assert(tmpdir, input_code, expected) - def test_ambiguous_get(self, tmpdir): + def test_ambiguous_get(self, _, tmpdir): input_code = """ from requests import get @@ -147,7 +151,7 @@ def get(url): expected = input_code self.run_and_assert(tmpdir, input_code, expected) - def test_from_requests_with_alias(self, tmpdir): + def test_from_requests_with_alias(self, add_dependency, tmpdir): input_code = """ from requests import get as got @@ -163,9 +167,9 @@ def test_from_requests_with_alias(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + add_dependency.assert_called_once_with(Security) - def test_requests_with_alias(self, tmpdir): + def test_requests_with_alias(self, add_dependency, tmpdir): input_code = """ import requests as req @@ -181,9 +185,9 @@ def test_requests_with_alias(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + add_dependency.assert_called_once_with(Security) - def test_ignore_hardcoded(self, tmpdir): + def test_ignore_hardcoded(self, _, tmpdir): expected = input_code = """ import requests @@ -192,7 +196,7 @@ def test_ignore_hardcoded(self, tmpdir): self.run_and_assert(tmpdir, input_code, expected) - def test_ignore_hardcoded_from_global_variable(self, tmpdir): + def test_ignore_hardcoded_from_global_variable(self, _, tmpdir): expected = input_code = """ import requests @@ -202,7 +206,7 @@ def test_ignore_hardcoded_from_global_variable(self, tmpdir): self.run_and_assert(tmpdir, input_code, expected) - def test_ignore_hardcoded_from_local_variable(self, tmpdir): + def test_ignore_hardcoded_from_local_variable(self, _, tmpdir): expected = input_code = """ import requests @@ -213,7 +217,7 @@ def foo(): self.run_and_assert(tmpdir, input_code, expected) - def test_ignore_hardcoded_from_local_variable_transitive(self, tmpdir): + def test_ignore_hardcoded_from_local_variable_transitive(self, _, tmpdir): expected = input_code = """ import requests @@ -225,7 +229,9 @@ def foo(): self.run_and_assert(tmpdir, input_code, expected) - def test_ignore_hardcoded_from_local_variable_transitive_reassigned(self, tmpdir): + def test_ignore_hardcoded_from_local_variable_transitive_reassigned( + self, _, tmpdir + ): input_code = """ import requests diff --git a/tests/codemods/test_use_defused_xml.py b/tests/codemods/test_use_defused_xml.py index 1df986a8d..94dfc1999 100644 --- a/tests/codemods/test_use_defused_xml.py +++ b/tests/codemods/test_use_defused_xml.py @@ -1,4 +1,5 @@ import pytest +import mock from codemodder.dependency import DefusedXML from core_codemods.use_defused_xml import ( @@ -10,12 +11,13 @@ from tests.codemods.base_codemod_test import BaseCodemodTest +@mock.patch("codemodder.codemods.api.FileContext.add_dependency") class TestUseDefusedXml(BaseCodemodTest): codemod = UseDefusedXml @pytest.mark.parametrize("method", ETREE_METHODS) @pytest.mark.parametrize("module", ["ElementTree", "cElementTree"]) - def test_etree_simple_call(self, tmpdir, module, method): + def test_etree_simple_call(self, add_dependency, tmpdir, module, method): original_code = f""" from xml.etree.{module} import {method}, ElementPath @@ -30,10 +32,10 @@ def test_etree_simple_call(self, tmpdir, module, method): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) @pytest.mark.parametrize("method", ETREE_METHODS) - def test_etree_module_alias(self, tmpdir, method): + def test_etree_module_alias(self, add_dependency, tmpdir, method): original_code = f""" import xml.etree.ElementTree as alias import xml.etree.cElementTree as calias @@ -50,11 +52,11 @@ def test_etree_module_alias(self, tmpdir, method): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) @pytest.mark.parametrize("method", ETREE_METHODS) @pytest.mark.parametrize("module", ["ElementTree", "cElementTree"]) - def test_etree_attribute_call(self, tmpdir, module, method): + def test_etree_attribute_call(self, add_dependency, tmpdir, module, method): original_code = f""" from xml.etree import {module} @@ -68,9 +70,9 @@ def test_etree_attribute_call(self, tmpdir, module, method): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) - def test_etree_elementtree_with_alias(self, tmpdir): + def test_etree_elementtree_with_alias(self, add_dependency, tmpdir): original_code = """ from xml.etree import ElementTree as ET @@ -84,9 +86,9 @@ def test_etree_elementtree_with_alias(self, tmpdir): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) - def test_etree_parse_with_alias(self, tmpdir): + def test_etree_parse_with_alias(self, add_dependency, tmpdir): original_code = """ from xml.etree.ElementTree import parse as parse_xml @@ -100,10 +102,10 @@ def test_etree_parse_with_alias(self, tmpdir): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) @pytest.mark.parametrize("method", SAX_METHODS) - def test_sax_simple_call(self, tmpdir, method): + def test_sax_simple_call(self, add_dependency, tmpdir, method): original_code = f""" from xml.sax import {method} @@ -117,10 +119,10 @@ def test_sax_simple_call(self, tmpdir, method): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) @pytest.mark.parametrize("method", SAX_METHODS) - def test_sax_attribute_call(self, tmpdir, method): + def test_sax_attribute_call(self, add_dependency, tmpdir, method): original_code = f""" from xml import sax @@ -134,11 +136,11 @@ def test_sax_attribute_call(self, tmpdir, method): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) @pytest.mark.parametrize("method", DOM_METHODS) @pytest.mark.parametrize("module", ["minidom", "pulldom"]) - def test_dom_simple_call(self, tmpdir, module, method): + def test_dom_simple_call(self, add_dependency, tmpdir, module, method): original_code = f""" from xml.dom.{module} import {method} @@ -152,4 +154,4 @@ def test_dom_simple_call(self, tmpdir, module, method): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) diff --git a/tests/codemods/test_use_set_literal.py b/tests/codemods/test_use_set_literal.py index 29ee31aa1..29b398eda 100644 --- a/tests/codemods/test_use_set_literal.py +++ b/tests/codemods/test_use_set_literal.py @@ -13,7 +13,6 @@ def test_simple(self, tmpdir): x = {1, 2, 3} """ self.run_and_assert(tmpdir, original_code, expected_code) - assert self.file_context and len(self.file_context.codemod_changes) == 1 def test_empty_list(self, tmpdir): original_code = """ @@ -23,14 +22,12 @@ def test_empty_list(self, tmpdir): x = set() """ self.run_and_assert(tmpdir, original_code, expected_code) - assert self.file_context and len(self.file_context.codemod_changes) == 1 def test_already_empty(self, tmpdir): original_code = """ x = set() """ self.run_and_assert(tmpdir, original_code, original_code) - assert self.file_context and len(self.file_context.codemod_changes) == 0 def test_not_builtin(self, tmpdir): original_code = """ @@ -38,11 +35,9 @@ def test_not_builtin(self, tmpdir): x = set([1, 2, 3]) """ self.run_and_assert(tmpdir, original_code, original_code) - assert self.file_context and len(self.file_context.codemod_changes) == 0 def test_not_list_literal(self, tmpdir): original_code = """ x = set(some_previously_defined_list) """ self.run_and_assert(tmpdir, original_code, original_code) - assert self.file_context and len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py index 7a1d56f0e..39392b119 100644 --- a/tests/codemods/test_with_threading_lock.py +++ b/tests/codemods/test_with_threading_lock.py @@ -18,7 +18,7 @@ class TestWithThreadingLock(BaseSemgrepCodemodTest): codemod = WithThreadingLock def test_rule_ids(self): - assert self.codemod.name() == "bad-lock-with-statement" + assert self.codemod.name == "bad-lock-with-statement" @each_class def test_import(self, tmpdir, klass): diff --git a/tests/conftest.py b/tests/conftest.py index 2ef1826b5..37d7cd1be 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,57 +1,39 @@ import pytest -import mock -@pytest.fixture(autouse=True, scope="module") -def disable_write_report(): +@pytest.fixture(autouse=True) +def disable_write_report(mocker): """ Unit tests should not write analysis report or update any source files. """ - patch_write_report = mock.patch( - "codemodder.report.codetf_reporter.CodeTF.write_report" - ) - - patch_write_report.start() - yield - patch_write_report.stop() + mocker.patch("codemodder.report.codetf_reporter.CodeTF.write_report") -@pytest.fixture(autouse=True, scope="module") -def disable_update_code(): +@pytest.fixture(autouse=True) +def disable_update_code(mocker): """ Unit tests should not write analysis report or update any source files. """ - patch_update_code = mock.patch("codemodder.codemodder.update_code") - patch_update_code.start() - yield - patch_update_code.stop() + mocker.patch("codemodder.codemods.libcst_transformer.update_code") -@pytest.fixture(autouse=True, scope="module") -def disable_semgrep_run(): +@pytest.fixture(autouse=True) +def disable_semgrep_run(mocker): """ Semgrep run is slow so we mock them or pass hardcoded results when possible. """ - semgrep_run = mock.patch("codemodder.codemods.base_codemod.semgrep_run") + mocker.patch("codemodder.codemods.semgrep.semgrep_run") - semgrep_run.start() - yield - semgrep_run.stop() - -@pytest.fixture(autouse=True, scope="module") -def disable_write_dependencies(): +@pytest.fixture(autouse=True) +def disable_write_dependencies(mocker): """ Unit tests should not write any dependency files """ - dm_write = mock.patch( + mocker.patch( "codemodder.dependency_management.dependency_manager.DependencyManager.write" ) - dm_write.start() - yield - dm_write.stop() - @pytest.fixture(scope="module") def pkg_with_reqs_txt(tmp_path_factory): diff --git a/tests/dependency_management/test_setup_py_writer.py b/tests/dependency_management/test_setup_py_writer.py index 0fa919f88..affc59550 100644 --- a/tests/dependency_management/test_setup_py_writer.py +++ b/tests/dependency_management/test_setup_py_writer.py @@ -85,7 +85,7 @@ def test_update_setuppy_comma_single_element_inline(tmpdir): store = PackageStore( type=FileType.SETUP_PY, - file=str(dependency_file), + file=dependency_file, dependencies=set(), py_versions=[">=3.6"], ) diff --git a/tests/test_codemod_docs.py b/tests/test_codemod_docs.py index 2329a620a..a8cbdae70 100644 --- a/tests/test_codemod_docs.py +++ b/tests/test_codemod_docs.py @@ -1,5 +1,6 @@ import pytest +from codemodder.codemods.api import BaseCodemod from codemodder.registry import load_registered_codemods from codemodder.scripts.generate_docs import METADATA @@ -10,10 +11,11 @@ def pytest_generate_tests(metafunc): metafunc.parametrize("codemod", registry.codemods) -def test_load_codemod_docs_info(codemod): - if codemod.name in ["order-imports"]: - pytest.xfail(reason="order-imports has no description") - assert codemod._get_description() is not None # pylint: disable=protected-access +def test_load_codemod_docs_info(codemod: BaseCodemod): + if codemod.name in ["order-imports", "refactor-new-api"]: + pytest.xfail(reason=f"{codemod.name} has no description") + + assert codemod.description is not None # pylint: disable=protected-access assert codemod.review_guidance in ( "Merge After Review", "Merge After Cursory Review", diff --git a/tests/test_codemodder.py b/tests/test_codemodder.py index ad828c521..ff5e27fe0 100644 --- a/tests/test_codemodder.py +++ b/tests/test_codemodder.py @@ -61,14 +61,13 @@ def test_cst_parsing_fails(self, mock_reporting, mock_parse): requests_report = results_by_codemod[0] assert requests_report["changeset"] == [] - assert len(requests_report["failedFiles"]) == 2 + assert len(requests_report["failedFiles"]) == 1 assert sorted(requests_report["failedFiles"]) == [ - "tests/samples/make_request.py", "tests/samples/unverified_request.py", ] - @mock.patch("codemodder.codemodder.update_code") - @mock.patch("codemodder.codemods.base_codemod.semgrep_run", side_effect=semgrep_run) + @mock.patch("codemodder.codemods.libcst_transformer.update_code") + @mock.patch("codemodder.codemods.semgrep.semgrep_run", side_effect=semgrep_run) def test_dry_run(self, _, mock_update_code, tmpdir): codetf = tmpdir / "result.codetf" args = [ @@ -111,7 +110,7 @@ def test_reporting(self, mock_reporting, dry_run): assert len(results_by_codemod) == 3 - @mock.patch("codemodder.codemods.base_codemod.semgrep_run") + @mock.patch("codemodder.codemods.semgrep.semgrep_run") def test_no_codemods_to_run(self, mock_semgrep_run, tmpdir): codetf = tmpdir / "result.codetf" assert not codetf.exists() diff --git a/tests/test_context.py b/tests/test_context.py index 1c10707f1..b01b45b42 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -10,7 +10,7 @@ def test_successful_dependency_description(self, mocker): repo_manager = PythonRepoManager(mocker.Mock()) codemod = registry.match_codemods(codemod_include=["url-sandbox"])[0] - context = Context(mocker.Mock(), True, False, registry, repo_manager) + context = Context(mocker.Mock(), True, False, registry, repo_manager, [], []) context.add_dependencies(codemod.id, {Security}) pkg_store_name = "pyproject.toml" @@ -39,7 +39,7 @@ def test_failed_dependency_description(self, mocker): repo_manager = PythonRepoManager(mocker.Mock()) codemod = registry.match_codemods(codemod_include=["url-sandbox"])[0] - context = Context(mocker.Mock(), True, False, registry, repo_manager) + context = Context(mocker.Mock(), True, False, registry, repo_manager, [], []) context.add_dependencies(codemod.id, {Security}) mocker.patch(