diff --git a/.coveragerc b/.coveragerc index d3e029f2..61b9c624 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,6 +3,7 @@ source = codemodder omit = */codemodder/scripts/* */codemodder/_version.py + */core_codemods/refactor/* [paths] codemodder = diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 303bf312..966b70ec 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @drdavella @andrecsilva +* @drdavella @andrecsilva @clavedeluna diff --git a/integration_tests/base_test.py b/integration_tests/base_test.py index 07ada506..2bfd5356 100644 --- a/integration_tests/base_test.py +++ b/integration_tests/base_test.py @@ -62,8 +62,14 @@ def setup_class(cls): def setup_method(self): try: - self.codemod_wrapper = self.codemod_registry.match_codemods( - codemod_include=[self.codemod.name()] + name = ( + self.codemod().name + if isinstance(self.codemod, type) + else self.codemod.name + ) + # This is how we ensure that the codemod is actually in the registry + self.codemod_instance = self.codemod_registry.match_codemods( + codemod_include=[name] )[0] except IndexError as exc: raise IndexError( @@ -77,7 +83,7 @@ def _assert_run_fields(self, run, output_path): assert run["elapsed"] != "" assert ( run["commandLine"] - == f"codemodder {SAMPLES_DIR} --output {output_path} --codemod-include={self.codemod_wrapper.name} --path-include={self.code_path}" + == f"codemodder {SAMPLES_DIR} --output {output_path} --codemod-include={self.codemod_instance.name} --path-include={self.code_path}" ) assert run["directory"] == os.path.abspath(SAMPLES_DIR) assert run["sarifs"] == [] @@ -85,8 +91,10 @@ def _assert_run_fields(self, run, output_path): def _assert_results_fields(self, results, output_path): assert len(results) == 1 result = results[0] - assert result["codemod"] == self.codemod_wrapper.id - assert result["references"] == self.codemod_wrapper.references + assert result["codemod"] == self.codemod_instance.id + assert result["references"] == [ + ref.to_json() for ref in self.codemod_instance.references + ] # TODO: once we add description for each url. for reference in result["references"]: @@ -147,7 +155,7 @@ def test_file_rewritten(self): SAMPLES_DIR, "--output", self.output_path, - f"--codemod-include={self.codemod_wrapper.name}", + f"--codemod-include={self.codemod_instance.name}", f"--path-include={self.code_path}", ] diff --git a/integration_tests/test_add_requests_timeout.py b/integration_tests/test_add_requests_timeout.py index df7407c4..556dbe4d 100644 --- a/integration_tests/test_add_requests_timeout.py +++ b/integration_tests/test_add_requests_timeout.py @@ -1,4 +1,7 @@ -from core_codemods.add_requests_timeouts import AddRequestsTimeouts +from core_codemods.add_requests_timeouts import ( + AddRequestsTimeouts, + TransformAddRequestsTimeouts, +) from integration_tests.base_test import ( BaseIntegrationTest, original_and_expected_from_code_path, @@ -33,4 +36,4 @@ class TestAddRequestsTimeouts(BaseIntegrationTest): num_changes = 2 expected_line_change = "3" - change_description = AddRequestsTimeouts.CHANGE_DESCRIPTION + change_description = TransformAddRequestsTimeouts.change_description diff --git a/integration_tests/test_combine_startswith_endswith.py b/integration_tests/test_combine_startswith_endswith.py index 9ad639f6..6426e1fa 100644 --- a/integration_tests/test_combine_startswith_endswith.py +++ b/integration_tests/test_combine_startswith_endswith.py @@ -13,4 +13,4 @@ class TestCombineStartswithEndswith(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,3 +1,3 @@\n x = \'foo\'\n-if x.startswith("foo") or x.startswith("bar"):\n+if x.startswith(("foo", "bar")):\n print("Yes")\n' expected_line_change = "2" - change_description = CombineStartswithEndswith.CHANGE_DESCRIPTION + change_description = CombineStartswithEndswith.change_description diff --git a/integration_tests/test_django_debug_flag_on.py b/integration_tests/test_django_debug_flag_on.py index bb4a2c1c..f3abf56c 100644 --- a/integration_tests/test_django_debug_flag_on.py +++ b/integration_tests/test_django_debug_flag_on.py @@ -14,4 +14,4 @@ class TestDjangoDebugFlagFlip(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -23,7 +23,7 @@\n SECRET_KEY = "django-insecure-t*rrda&qd4^#q+50^%q^rrsp-t$##&u5_#=9)&@ei^ppl6$*c*"\n \n # SECURITY WARNING: don\'t run with debug turned on in production!\n-DEBUG = True\n+DEBUG = False\n \n ALLOWED_HOSTS = []\n \n' expected_line_change = "26" - change_description = DjangoDebugFlagOn.CHANGE_DESCRIPTION + change_description = DjangoDebugFlagOn.change_description diff --git a/integration_tests/test_django_json_response_type.py b/integration_tests/test_django_json_response_type.py index 6b2c28df..59f751d7 100644 --- a/integration_tests/test_django_json_response_type.py +++ b/integration_tests/test_django_json_response_type.py @@ -32,5 +32,5 @@ class TestDjangoJsonResponseType(BaseIntegrationTest): # fmt: on expected_line_change = "6" - change_description = DjangoJsonResponseType.CHANGE_DESCRIPTION + change_description = DjangoJsonResponseType.change_description num_changed_files = 1 diff --git a/integration_tests/test_django_receiver_on_top.py b/integration_tests/test_django_receiver_on_top.py index 4fcf5652..833f3c8d 100644 --- a/integration_tests/test_django_receiver_on_top.py +++ b/integration_tests/test_django_receiver_on_top.py @@ -33,5 +33,5 @@ class TestDjangoReceiverOnTop(BaseIntegrationTest): # fmt: on expected_line_change = "7" - change_description = DjangoReceiverOnTop.CHANGE_DESCRIPTION + change_description = DjangoReceiverOnTop.change_description num_changed_files = 1 diff --git a/integration_tests/test_django_session_cookie_secure_off.py b/integration_tests/test_django_session_cookie_secure_off.py index ca193390..bf726e48 100644 --- a/integration_tests/test_django_session_cookie_secure_off.py +++ b/integration_tests/test_django_session_cookie_secure_off.py @@ -16,4 +16,4 @@ class TestDjangoSessionCookieSecureOff(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -121,3 +121,4 @@\n # https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field\n \n DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"\n+SESSION_COOKIE_SECURE = True\n' expected_line_change = "124" - change_description = DjangoSessionCookieSecureOff.CHANGE_DESCRIPTION + change_description = DjangoSessionCookieSecureOff.change_description diff --git a/integration_tests/test_exception_without_raise.py b/integration_tests/test_exception_without_raise.py index 565bdf7b..0b5f9b0b 100644 --- a/integration_tests/test_exception_without_raise.py +++ b/integration_tests/test_exception_without_raise.py @@ -29,5 +29,5 @@ class TestExceptionWithoutRaise(BaseIntegrationTest): # fmt: on expected_line_change = "2" - change_description = ExceptionWithoutRaise.CHANGE_DESCRIPTION + change_description = ExceptionWithoutRaise.change_description num_changed_files = 1 diff --git a/integration_tests/test_file_resource_leak.py b/integration_tests/test_file_resource_leak.py index bbf87785..6c518ba3 100644 --- a/integration_tests/test_file_resource_leak.py +++ b/integration_tests/test_file_resource_leak.py @@ -33,5 +33,5 @@ class TestFileResourceLeak(BaseIntegrationTest): # fmt: on expected_line_change = "3" - change_description = FileResourceLeak.CHANGE_DESCRIPTION + change_description = FileResourceLeak.change_description num_changed_files = 1 diff --git a/integration_tests/test_fix_deprecated_logging_warn.py b/integration_tests/test_fix_deprecated_logging_warn.py index caa08ff7..89dd65fb 100644 --- a/integration_tests/test_fix_deprecated_logging_warn.py +++ b/integration_tests/test_fix_deprecated_logging_warn.py @@ -13,4 +13,4 @@ class TestFixDeprecatedLoggingWarn(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n import logging\n \n log = logging.getLogger("my logger")\n-log.warn("hello")\n+log.warning("hello")\n' expected_line_change = "4" - change_description = FixDeprecatedLoggingWarn.CHANGE_DESCRIPTION + change_description = FixDeprecatedLoggingWarn.change_description diff --git a/integration_tests/test_fix_mutable_params.py b/integration_tests/test_fix_mutable_params.py index 110e9ac6..b84470b0 100644 --- a/integration_tests/test_fix_mutable_params.py +++ b/integration_tests/test_fix_mutable_params.py @@ -30,4 +30,4 @@ def baz(x=None, y=None): expected_diff = '--- \n+++ \n@@ -1,4 +1,5 @@\n-def foo(x, y=[]):\n+def foo(x, y=None):\n+ y = [] if y is None else y\n y.append(x)\n print(y)\n \n@@ -7,6 +8,8 @@\n print(x)\n \n \n-def baz(x={"foo": 42}, y=set()):\n+def baz(x=None, y=None):\n+ x = {"foo": 42} if x is None else x\n+ y = set() if y is None else y\n print(x)\n print(y)\n' expected_line_change = 1 num_changes = 2 - change_description = FixMutableParams.CHANGE_DESCRIPTION + change_description = FixMutableParams.change_description diff --git a/integration_tests/test_flask_enable_csrf_protection.py b/integration_tests/test_flask_enable_csrf_protection.py index 3c6500e6..5e26f5da 100644 --- a/integration_tests/test_flask_enable_csrf_protection.py +++ b/integration_tests/test_flask_enable_csrf_protection.py @@ -32,5 +32,5 @@ class TestFlaskEnableCSRFProtection(BaseIntegrationTest): # fmt: on expected_line_change = "3" - change_description = FlaskEnableCSRFProtection.CHANGE_DESCRIPTION + change_description = FlaskEnableCSRFProtection.change_description num_changed_files = 2 diff --git a/integration_tests/test_flask_json_response_type.py b/integration_tests/test_flask_json_response_type.py index d7511a2e..f283e742 100644 --- a/integration_tests/test_flask_json_response_type.py +++ b/integration_tests/test_flask_json_response_type.py @@ -32,5 +32,5 @@ class TestFlaskJsonResponseType(BaseIntegrationTest): # fmt: on expected_line_change = "9" - change_description = FlaskJsonResponseType.CHANGE_DESCRIPTION + change_description = FlaskJsonResponseType.change_description num_changed_files = 1 diff --git a/integration_tests/test_harden_pyyaml.py b/integration_tests/test_harden_pyyaml.py index 1a7caca4..407ffa2f 100644 --- a/integration_tests/test_harden_pyyaml.py +++ b/integration_tests/test_harden_pyyaml.py @@ -15,6 +15,6 @@ class TestHardenPyyaml(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n import yaml\n \n data = b"!!python/object/apply:subprocess.Popen \\\\n- ls"\n-deserialized_data = yaml.load(data, Loader=yaml.Loader)\n+deserialized_data = yaml.load(data, Loader=yaml.SafeLoader)\n' expected_line_change = "4" - change_description = HardenPyyaml.CHANGE_DESCRIPTION + change_description = HardenPyyaml.change_description # expected exception because the yaml.SafeLoader protects against unsafe code allowed_exceptions = (yaml.constructor.ConstructorError,) diff --git a/integration_tests/test_harden_ruamel.py b/integration_tests/test_harden_ruamel.py index 4a8b06bc..dd5d6244 100644 --- a/integration_tests/test_harden_ruamel.py +++ b/integration_tests/test_harden_ruamel.py @@ -18,4 +18,4 @@ class TestHardenRuamel(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n from ruamel.yaml import YAML\n \n-serializer = YAML(typ="unsafe")\n-serializer = YAML(typ="base")\n+serializer = YAML(typ="safe")\n+serializer = YAML(typ="safe")\n' expected_line_change = "3" num_changes = 2 - change_description = HardenRuamel.CHANGE_DESCRIPTION + change_description = HardenRuamel.change_description diff --git a/integration_tests/test_https_connection.py b/integration_tests/test_https_connection.py index 31f56013..79ac59a4 100644 --- a/integration_tests/test_https_connection.py +++ b/integration_tests/test_https_connection.py @@ -22,4 +22,4 @@ class TestHTTPSConnection(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,8 +1,7 @@\n import urllib3\n import urllib3.connectionpool as pool\n-from urllib3 import HTTPConnectionPool as something\n \n-urllib3.HTTPConnectionPool("localhost", "80")\n-urllib3.connectionpool.HTTPConnectionPool("localhost", "80")\n-something("localhost", "80")\n-pool.HTTPConnectionPool("localhost", "80")\n+urllib3.HTTPSConnectionPool("localhost", "80")\n+urllib3.connectionpool.HTTPSConnectionPool("localhost", "80")\n+urllib3.HTTPSConnectionPool("localhost", "80")\n+pool.HTTPSConnectionPool("localhost", "80")\n' expected_line_change = "5" num_changes = 4 - change_description = HTTPSConnection.CHANGE_DESCRIPTION + change_description = HTTPSConnection.change_description diff --git a/integration_tests/test_jinja2_autoescape.py b/integration_tests/test_jinja2_autoescape.py index 7dbf0fa0..83a0b8b9 100644 --- a/integration_tests/test_jinja2_autoescape.py +++ b/integration_tests/test_jinja2_autoescape.py @@ -18,4 +18,4 @@ class TestEnableJinja2Autoescape(BaseIntegrationTest): expected_diff = "--- \n+++ \n@@ -1,4 +1,4 @@\n from jinja2 import Environment\n \n-env = Environment()\n-env = Environment(autoescape=False)\n+env = Environment(autoescape=True)\n+env = Environment(autoescape=True)\n" expected_line_change = "3" num_changes = 2 - change_description = EnableJinja2Autoescape.CHANGE_DESCRIPTION + change_description = EnableJinja2Autoescape.change_description diff --git a/integration_tests/test_jwt_decode_verify.py b/integration_tests/test_jwt_decode_verify.py index c6ae99aa..6c1360fd 100644 --- a/integration_tests/test_jwt_decode_verify.py +++ b/integration_tests/test_jwt_decode_verify.py @@ -24,4 +24,4 @@ class TestJwtDecodeVerify(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -8,7 +8,7 @@\n \n encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm="HS256")\n \n-decoded_payload = jwt.decode(encoded_jwt, SECRET_KEY, algorithms=["HS256"], verify=False)\n-decoded_payload = jwt.decode(encoded_jwt, SECRET_KEY, algorithms=["HS256"], options={"verify_signature": False})\n+decoded_payload = jwt.decode(encoded_jwt, SECRET_KEY, algorithms=["HS256"], verify=True)\n+decoded_payload = jwt.decode(encoded_jwt, SECRET_KEY, algorithms=["HS256"], options={"verify_signature": True})\n \n var = "something"\n' expected_line_change = "11" num_changes = 2 - change_description = JwtDecodeVerify.CHANGE_DESCRIPTION + change_description = JwtDecodeVerify.change_description diff --git a/integration_tests/test_limit_readline.py b/integration_tests/test_limit_readline.py index b228b24c..a91517f3 100644 --- a/integration_tests/test_limit_readline.py +++ b/integration_tests/test_limit_readline.py @@ -13,6 +13,6 @@ class TestLimitReadline(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,2 +1,2 @@\n file = open("some_file.txt")\n-file.readline()\n+file.readline(5_000_000)\n' expected_line_change = "2" - change_description = LimitReadline.CHANGE_DESCRIPTION + change_description = LimitReadline.change_description # expected because output code points to fake file allowed_exceptions = (FileNotFoundError,) diff --git a/integration_tests/test_literal_or_new_object_identity.py b/integration_tests/test_literal_or_new_object_identity.py index 7e5bccbc..03f37af3 100644 --- a/integration_tests/test_literal_or_new_object_identity.py +++ b/integration_tests/test_literal_or_new_object_identity.py @@ -28,5 +28,5 @@ class TestLiteralOrNewObjectIdentity(BaseIntegrationTest): # fmt: on expected_line_change = "2" - change_description = LiteralOrNewObjectIdentity.CHANGE_DESCRIPTION + change_description = LiteralOrNewObjectIdentity.change_description num_changed_files = 1 diff --git a/integration_tests/test_lxml_safe_parser_defaults.py b/integration_tests/test_lxml_safe_parser_defaults.py index 25292b9c..5c1a5cba 100644 --- a/integration_tests/test_lxml_safe_parser_defaults.py +++ b/integration_tests/test_lxml_safe_parser_defaults.py @@ -13,4 +13,4 @@ class TestLxmlSafeParserDefaults(BaseIntegrationTest): ) expected_diff = "--- \n+++ \n@@ -1,2 +1,2 @@\n import lxml.etree\n-parser = lxml.etree.XMLParser()\n+parser = lxml.etree.XMLParser(resolve_entities=False)\n" expected_line_change = "2" - change_description = LxmlSafeParserDefaults.CHANGE_DESCRIPTION + change_description = LxmlSafeParserDefaults.change_description diff --git a/integration_tests/test_lxml_safe_parsing.py b/integration_tests/test_lxml_safe_parsing.py index 3a8d65ea..b121dce4 100644 --- a/integration_tests/test_lxml_safe_parsing.py +++ b/integration_tests/test_lxml_safe_parsing.py @@ -24,5 +24,5 @@ class TestLxmlSafeParsing(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,3 +1,3 @@\n import lxml.etree\n-lxml.etree.parse("path_to_file")\n-lxml.etree.fromstring("xml_str")\n+lxml.etree.parse("path_to_file", parser=lxml.etree.XMLParser(resolve_entities=False))\n+lxml.etree.fromstring("xml_str", parser=lxml.etree.XMLParser(resolve_entities=False))\n' expected_line_change = "2" num_changes = 2 - change_description = LxmlSafeParsing.CHANGE_DESCRIPTION + change_description = LxmlSafeParsing.change_description allowed_exceptions = (OSError,) diff --git a/integration_tests/test_numpy_nan_equality.py b/integration_tests/test_numpy_nan_equality.py index f5b7626a..85ebd36a 100644 --- a/integration_tests/test_numpy_nan_equality.py +++ b/integration_tests/test_numpy_nan_equality.py @@ -30,5 +30,5 @@ class TestNumpyNanEquality(BaseIntegrationTest): # fmt: on expected_line_change = "4" - change_description = NumpyNanEquality.CHANGE_DESCRIPTION + change_description = NumpyNanEquality.change_description num_changed_files = 1 diff --git a/integration_tests/test_order_imports.py b/integration_tests/test_order_imports.py index 3d174c87..0b515178 100644 --- a/integration_tests/test_order_imports.py +++ b/integration_tests/test_order_imports.py @@ -37,4 +37,4 @@ class TestOrderImports(BaseIntegrationTest): expected_diff = "--- \n+++ \n@@ -1,20 +1,14 @@\n #!/bin/env python\n-from abc import ABCMeta\n-\n+# comment builtins4\n+# comment builtins5\n+# comment builtins3\n # comment builtins1\n # comment builtins2\n import builtins\n-\n+import collections\n+import datetime\n # comment a\n-from abc import ABC\n-\n-# comment builtins3\n-import builtins, datetime\n-\n-# comment builtins4\n-# comment builtins5\n-import builtins\n-import collections\n+from abc import ABC, ABCMeta\n \n ABC\n ABCMeta\n" expected_line_change = "2" - change_description = OrderImports.CHANGE_DESCRIPTION + change_description = OrderImports.change_description diff --git a/integration_tests/test_process_sandbox.py b/integration_tests/test_process_sandbox.py index a54fb926..ba6aa5e3 100644 --- a/integration_tests/test_process_sandbox.py +++ b/integration_tests/test_process_sandbox.py @@ -22,7 +22,7 @@ class TestProcessSandbox(BaseIntegrationTest): expected_line_change = "3" num_changes = 4 num_changed_files = 2 - change_description = ProcessSandbox.CHANGE_DESCRIPTION + change_description = ProcessSandbox.change_description requirements_path = "tests/samples/requirements.txt" original_requirements = "# file used to test dependency management\nrequests==2.31.0\nblack==23.7.*\nmypy~=1.4\npylint>1\n" diff --git a/integration_tests/test_remove_debug_breakpoint.py b/integration_tests/test_remove_debug_breakpoint.py index 2bfe0120..14c6b294 100644 --- a/integration_tests/test_remove_debug_breakpoint.py +++ b/integration_tests/test_remove_debug_breakpoint.py @@ -17,4 +17,4 @@ class TestRemoveDebugBreakpoint(BaseIntegrationTest): '--- \n+++ \n@@ -1,3 +1,2 @@\n print("hello")\n-breakpoint()\n print("world")\n' ) expected_line_change = "2" - change_description = RemoveDebugBreakpoint.CHANGE_DESCRIPTION + change_description = RemoveDebugBreakpoint.change_description diff --git a/integration_tests/test_remove_future_imports.py b/integration_tests/test_remove_future_imports.py index 8e779074..9f7bae29 100644 --- a/integration_tests/test_remove_future_imports.py +++ b/integration_tests/test_remove_future_imports.py @@ -33,4 +33,4 @@ class TestRemoveFutureImports(BaseIntegrationTest): num_changes = 2 expected_line_change = "1" - change_description = RemoveFutureImports.CHANGE_DESCRIPTION + change_description = RemoveFutureImports.change_description diff --git a/integration_tests/test_remove_module_global.py b/integration_tests/test_remove_module_global.py index 5daf5fca..355c0b5e 100644 --- a/integration_tests/test_remove_module_global.py +++ b/integration_tests/test_remove_module_global.py @@ -16,4 +16,4 @@ class TestRemoveModuleGlobal(BaseIntegrationTest): """.lstrip() expected_diff = '--- \n+++ \n@@ -1,4 +1,3 @@\n price = 25\n print("hello")\n-global price\n price = 30\n' expected_line_change = "3" - change_description = RemoveModuleGlobal.CHANGE_DESCRIPTION + change_description = RemoveModuleGlobal.change_description diff --git a/integration_tests/test_remove_unused_imports.py b/integration_tests/test_remove_unused_imports.py index 366756ec..1c3d3860 100644 --- a/integration_tests/test_remove_unused_imports.py +++ b/integration_tests/test_remove_unused_imports.py @@ -17,4 +17,4 @@ class TestRemoveUnusedImports(BaseIntegrationTest): expected_diff = "--- \n+++ \n@@ -1,5 +1,5 @@\n import abc\n-from builtins import complex, dict\n+from builtins import complex\n \n abc\n complex\n" expected_line_change = 2 - change_description = RemoveUnusedImports.CHANGE_DESCRIPTION + change_description = RemoveUnusedImports.change_description diff --git a/integration_tests/test_replace_flask_send_file.py b/integration_tests/test_replace_flask_send_file.py index ad105e09..acea4515 100644 --- a/integration_tests/test_replace_flask_send_file.py +++ b/integration_tests/test_replace_flask_send_file.py @@ -49,5 +49,5 @@ class TestReplaceFlaskSendFile(BaseIntegrationTest): # fmt: on expected_line_change = "7" - change_description = ReplaceFlaskSendFile.CHANGE_DESCRIPTION + change_description = ReplaceFlaskSendFile.change_description num_changed_files = 1 diff --git a/integration_tests/test_request_verify.py b/integration_tests/test_request_verify.py index 5478abbb..dd2cf5ec 100644 --- a/integration_tests/test_request_verify.py +++ b/integration_tests/test_request_verify.py @@ -22,6 +22,6 @@ class TestRequestsVerify(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,5 +1,5 @@\n import requests\n \n-requests.get("https://www.google.com", verify=False)\n-requests.post("https://some-api/", json={"id": 1234, "price": 18}, verify=False)\n+requests.get("https://www.google.com", verify=True)\n+requests.post("https://some-api/", json={"id": 1234, "price": 18}, verify=True)\n var = "hello"\n' expected_line_change = "3" num_changes = 2 - change_description = RequestsVerify.CHANGE_DESCRIPTION + change_description = RequestsVerify.change_description # expected because when executing the output code it will make a request which fails, which is OK. allowed_exceptions = (exceptions.ConnectionError,) diff --git a/integration_tests/test_secure_flask_cookie.py b/integration_tests/test_secure_flask_cookie.py index e8394b67..f1a8cdc9 100644 --- a/integration_tests/test_secure_flask_cookie.py +++ b/integration_tests/test_secure_flask_cookie.py @@ -19,4 +19,4 @@ class TestSecureFlaskCookie(BaseIntegrationTest): ) expected_diff = "--- \n+++ \n@@ -5,5 +5,5 @@\n @app.route('/')\n def index():\n resp = make_response('Custom Cookie Set')\n- resp.set_cookie('custom_cookie', 'value')\n+ resp.set_cookie('custom_cookie', 'value', secure=True, httponly=True, samesite='Lax')\n return resp\n" expected_line_change = "8" - change_description = SecureFlaskCookie.CHANGE_DESCRIPTION + change_description = SecureFlaskCookie.change_description diff --git a/integration_tests/test_secure_flask_session_config.py b/integration_tests/test_secure_flask_session_config.py index f138ba18..b07ccd1d 100644 --- a/integration_tests/test_secure_flask_session_config.py +++ b/integration_tests/test_secure_flask_session_config.py @@ -13,4 +13,4 @@ class TestSecureFlaskSessionConfig(BaseIntegrationTest): ) expected_diff = "--- \n+++ \n@@ -1,6 +1,6 @@\n from flask import Flask\n app = Flask(__name__)\n-app.config['SESSION_COOKIE_HTTPONLY'] = False\n+app.config['SESSION_COOKIE_HTTPONLY'] = True\n @app.route('/')\n def hello_world():\n return 'Hello World!'\n" expected_line_change = "3" - change_description = SecureFlaskSessionConfig.CHANGE_DESCRIPTION + change_description = SecureFlaskSessionConfig.change_description diff --git a/integration_tests/test_secure_random.py b/integration_tests/test_secure_random.py index 458cf785..97863fa9 100644 --- a/integration_tests/test_secure_random.py +++ b/integration_tests/test_secure_random.py @@ -19,4 +19,4 @@ class TestSecureRandom(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n-import random\n+import secrets\n \n-random.random()\n+secrets.SystemRandom().random()\n var = "hello"\n' expected_line_change = "3" - change_description = SecureRandom.CHANGE_DESCRIPTION + change_description = SecureRandom.change_description diff --git a/integration_tests/test_sql_parameterization.py b/integration_tests/test_sql_parameterization.py index f27063fa..99242d33 100644 --- a/integration_tests/test_sql_parameterization.py +++ b/integration_tests/test_sql_parameterization.py @@ -37,5 +37,5 @@ class TestSQLQueryParameterization(BaseIntegrationTest): # fmt: on expected_line_change = "12" - change_description = SQLQueryParameterization.CHANGE_DESCRIPTION + change_description = SQLQueryParameterization.change_description num_changed_files = 1 diff --git a/integration_tests/test_subprocess_shell_false.py b/integration_tests/test_subprocess_shell_false.py index dd520269..078b8898 100644 --- a/integration_tests/test_subprocess_shell_false.py +++ b/integration_tests/test_subprocess_shell_false.py @@ -13,6 +13,6 @@ class TestSubprocessShellFalse(BaseIntegrationTest): ) expected_diff = "--- \n+++ \n@@ -1,2 +1,2 @@\n import subprocess\n-subprocess.run(\"echo 'hi'\", shell=True)\n+subprocess.run(\"echo 'hi'\", shell=False)\n" expected_line_change = "2" - change_description = SubprocessShellFalse.CHANGE_DESCRIPTION + change_description = SubprocessShellFalse.change_description # expected because output code points to fake file allowed_exceptions = (FileNotFoundError,) diff --git a/integration_tests/test_tempfile_mktemp.py b/integration_tests/test_tempfile_mktemp.py index f8e87e70..7acca7ef 100644 --- a/integration_tests/test_tempfile_mktemp.py +++ b/integration_tests/test_tempfile_mktemp.py @@ -13,4 +13,4 @@ class TestTempfileMktemp(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,4 +1,4 @@\n import tempfile\n \n-tempfile.mktemp()\n+tempfile.mkstemp()\n var = "hello"\n' expected_line_change = "3" - change_description = TempfileMktemp.CHANGE_DESCRIPTION + change_description = TempfileMktemp.change_description diff --git a/integration_tests/test_unnecessary_f_str.py b/integration_tests/test_unnecessary_f_str.py index 009cb268..a29bfb29 100644 --- a/integration_tests/test_unnecessary_f_str.py +++ b/integration_tests/test_unnecessary_f_str.py @@ -13,4 +13,4 @@ class TestFStr(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,2 +1,2 @@\n-bad = f"hello"\n+bad = "hello"\n good = f"{2+3}"\n' expected_line_change = "1" - change_description = RemoveUnnecessaryFStr.CHANGE_DESCRIPTION + change_description = RemoveUnnecessaryFStr.change_description diff --git a/integration_tests/test_upgrade_sslcontext_minimum_version.py b/integration_tests/test_upgrade_sslcontext_minimum_version.py index 2b5b7a11..4da3f45e 100644 --- a/integration_tests/test_upgrade_sslcontext_minimum_version.py +++ b/integration_tests/test_upgrade_sslcontext_minimum_version.py @@ -20,4 +20,4 @@ class TestUpgradeSSLContextMininumVersion(BaseIntegrationTest): expected_diff = '--- \n+++ \n@@ -1,8 +1,9 @@\n from ssl import PROTOCOL_TLS_CLIENT, SSLContext, TLSVersion\n+import ssl\n \n my_ctx = SSLContext(protocol=PROTOCOL_TLS_CLIENT)\n \n print("FOO")\n \n my_ctx.maximum_version = TLSVersion.MAXIMUM_SUPPORTED\n-my_ctx.minimum_version = TLSVersion.TLSv1_1\n+my_ctx.minimum_version = ssl.TLSVersion.TLSv1_2\n' expected_line_change = "8" - change_description = UpgradeSSLContextMinimumVersion.CHANGE_DESCRIPTION + change_description = UpgradeSSLContextMinimumVersion.change_description diff --git a/integration_tests/test_upgrade_sslcontext_tls.py b/integration_tests/test_upgrade_sslcontext_tls.py index ea788a62..34d9b3ea 100644 --- a/integration_tests/test_upgrade_sslcontext_tls.py +++ b/integration_tests/test_upgrade_sslcontext_tls.py @@ -9,4 +9,4 @@ class TestUpgradeWeakTLS(BaseIntegrationTest): expected_new_code = "import ssl\n\nssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)\n" expected_diff = "--- \n+++ \n@@ -1,3 +1,3 @@\n import ssl\n \n-ssl.SSLContext(ssl.PROTOCOL_SSLv2)\n+ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)\n" expected_line_change = "3" - change_description = UpgradeSSLContextTLS.CHANGE_DESCRIPTION + change_description = UpgradeSSLContextTLS.change_description diff --git a/integration_tests/test_url_sandbox.py b/integration_tests/test_url_sandbox.py index 1fd9688a..4fba2507 100644 --- a/integration_tests/test_url_sandbox.py +++ b/integration_tests/test_url_sandbox.py @@ -31,7 +31,7 @@ class TestUrlSandbox(BaseIntegrationTest): """ expected_line_change = "5" - change_description = UrlSandbox.CHANGE_DESCRIPTION + change_description = UrlSandbox.change_description num_changed_files = 2 requirements_path = "tests/samples/requirements.txt" diff --git a/integration_tests/test_use_defusedxml.py b/integration_tests/test_use_defusedxml.py index f9734208..681371ee 100644 --- a/integration_tests/test_use_defusedxml.py +++ b/integration_tests/test_use_defusedxml.py @@ -35,7 +35,7 @@ class TestUseDefusedXml(BaseIntegrationTest): """ expected_line_change = "5" - change_description = UseDefusedXml.CHANGE_DESCRIPTION + change_description = UseDefusedXml.change_description requirements_path = "tests/samples/requirements.txt" original_requirements = "# file used to test dependency management\nrequests==2.31.0\nblack==23.7.*\nmypy~=1.4\npylint>1\n" diff --git a/integration_tests/test_use_generator.py b/integration_tests/test_use_generator.py index a7d6b38a..fe432617 100644 --- a/integration_tests/test_use_generator.py +++ b/integration_tests/test_use_generator.py @@ -27,4 +27,4 @@ class TestUseGenerator(BaseIntegrationTest): """ expected_line_change = "6" - change_description = UseGenerator.CHANGE_DESCRIPTION + change_description = UseGenerator.change_description diff --git a/integration_tests/test_use_set_literal.py b/integration_tests/test_use_set_literal.py index 031ed035..07e691d2 100644 --- a/integration_tests/test_use_set_literal.py +++ b/integration_tests/test_use_set_literal.py @@ -26,4 +26,4 @@ class TestUseSetLiteral(BaseIntegrationTest): expected_line_change = "1" num_changes = 2 - change_description = UseSetLiteral.CHANGE_DESCRIPTION + change_description = UseSetLiteral.change_description diff --git a/integration_tests/test_use_walrus_if.py b/integration_tests/test_use_walrus_if.py index 5b091071..6d3abedd 100644 --- a/integration_tests/test_use_walrus_if.py +++ b/integration_tests/test_use_walrus_if.py @@ -29,4 +29,4 @@ def whatever(): num_changes = 3 expected_line_change = 1 - change_description = UseWalrusIf.CHANGE_DESCRIPTION + change_description = UseWalrusIf.change_description diff --git a/integration_tests/test_with_threading_lock.py b/integration_tests/test_with_threading_lock.py index 72add390..d696f779 100644 --- a/integration_tests/test_with_threading_lock.py +++ b/integration_tests/test_with_threading_lock.py @@ -18,4 +18,4 @@ class TestWithThreadingLock(BaseIntegrationTest): ) expected_diff = '--- \n+++ \n@@ -1,3 +1,4 @@\n import threading\n-with threading.Lock():\n+lock = threading.Lock()\n+with lock:\n print("Hello")\n' expected_line_change = "2" - change_description = WithThreadingLock.CHANGE_DESCRIPTION + change_description = WithThreadingLock.change_description diff --git a/pylintrc b/pylintrc index 74e31145..e5b7d0a4 100644 --- a/pylintrc +++ b/pylintrc @@ -7,7 +7,8 @@ jobs=0 # Add files or directories matching the regex patterns to the ignore-list. The # regex matches against paths and can be in Posix or Windows format. ignore-paths= - tests/samples, + tests/samples/, + .*/core_codemods/refactor/.*py$ # Specify a score threshold under which the program will exit with error. fail-under=10.0 diff --git a/src/codemodder/codemodder.py b/src/codemodder/codemodder.py index 0983e04c..cf9215fc 100644 --- a/src/codemodder/codemodder.py +++ b/src/codemodder/codemodder.py @@ -1,24 +1,20 @@ -from concurrent.futures import ThreadPoolExecutor import datetime import itertools import logging import os import sys +from typing import Sequence from pathlib import Path -import libcst as cst -from libcst.codemod import CodemodContext from codemodder.dependency import Dependency -from codemodder.file_context import FileContext from codemodder import registry, __version__ from codemodder.logging import configure_logger, logger, log_section, log_list from codemodder.cli import parse_args -from codemodder.change import ChangeSet -from codemodder.code_directory import file_line_patterns, match_files +from codemodder.code_directory import match_files +from codemodder.codemods.semgrep import SemgrepRuleDetector from codemodder.context import CodemodExecutionContext -from codemodder.diff import create_diff_from_tree -from codemodder.executor import CodemodExecutorWrapper +from codemodder.codemods.api import BaseCodemod from codemodder.project_analysis.file_parsers.package_store import PackageStore from codemodder.project_analysis.python_repo_manager import PythonRepoManager from codemodder.report.codetf_reporter import report_default @@ -36,128 +32,24 @@ def update_code(file_path, new_code): def find_semgrep_results( context: CodemodExecutionContext, - codemods: list[CodemodExecutorWrapper], + codemods: Sequence[BaseCodemod], + files_to_analyze: list[Path] | None = None, ) -> ResultSet: """Run semgrep once with all configuration files from all codemods and return a set of applicable rule IDs""" yaml_files = list( itertools.chain.from_iterable( - [codemod.yaml_files for codemod in codemods if codemod.yaml_files] + [ + codemod.detector.get_yaml_files(codemod.name) + for codemod in codemods + if codemod.detector + and isinstance(codemod.detector, SemgrepRuleDetector) + ] ) ) if not yaml_files: return ResultSet() - return run_semgrep(context, yaml_files) - - -def apply_codemod_to_file( - base_directory: Path, - file_context, - codemod_kls: CodemodExecutorWrapper, - source_tree, - dry_run: bool = False, -): - wrapper = cst.MetadataWrapper(source_tree) - codemod = codemod_kls(CodemodContext(wrapper=wrapper), file_context) - if not codemod.should_transform: - return False - - with file_context.timer.measure("transform"): - output_tree = codemod.transform_module(source_tree) - - # TODO: we can probably just use the presence of recorded changes instead of - # comparing the trees to gain some efficiency - if output_tree.deep_equals(source_tree): - return False - - diff = create_diff_from_tree(source_tree, output_tree) - change_set = ChangeSet( - str(file_context.file_path.relative_to(base_directory)), - diff, - changes=file_context.codemod_changes, - ) - file_context.add_result(change_set) - - if not dry_run: - with file_context.timer.measure("write"): - update_code(file_context.file_path, output_tree.code) - - return True - - -# pylint: disable-next=too-many-arguments -def process_file( - idx: int, - file_path: Path, - base_directory: Path, - codemod, - results: ResultSet, - cli_args, -) -> FileContext: - logger.debug("scanning file %s", file_path) - if idx and idx % 100 == 0: - logger.info("scanned %s files...", idx) # pragma: no cover - - line_exclude = file_line_patterns(file_path, cli_args.path_exclude) - line_include = file_line_patterns(file_path, cli_args.path_include) - findings_for_rule = results.results_for_rule_and_file( - codemod.name, # TODO: should be full ID - file_path, - ) - - file_context = FileContext( - base_directory, - file_path, - line_exclude, - line_include, - findings_for_rule, - ) - - try: - with file_context.timer.measure("parse"): - with open(file_path, "r", encoding="utf-8") as f: - source_tree = cst.parse_module(f.read()) - except Exception: - file_context.add_failure(file_path) - logger.exception("error parsing file %s", file_path) - return file_context - - apply_codemod_to_file( - base_directory, - file_context, - codemod, - source_tree, - cli_args.dry_run, - ) - - return file_context - - -def analyze_files( - execution_context: CodemodExecutionContext, - files_to_analyze, - codemod, - results: ResultSet, - cli_args, -): - with ThreadPoolExecutor(max_workers=cli_args.max_workers) as executor: - logger.debug( - "using executor with %s threads", - cli_args.max_workers, - ) - analysis_results = executor.map( - lambda args: process_file( - args[0], - args[1], - execution_context.directory, - codemod, - results, - cli_args, - ), - enumerate(files_to_analyze), - ) - executor.shutdown(wait=True) - execution_context.process_results(codemod.id, analysis_results) + return run_semgrep(context, yaml_files, files_to_analyze) def log_report(context, argv, elapsed_ms, files_to_analyze): @@ -185,10 +77,9 @@ def log_report(context, argv, elapsed_ms, files_to_analyze): def apply_codemods( context: CodemodExecutionContext, - codemods_to_run: list[CodemodExecutorWrapper], + codemods_to_run: Sequence[BaseCodemod], semgrep_results: ResultSet, files_to_analyze: list[Path], - argv, ): log_section("scanning") @@ -207,25 +98,20 @@ def apply_codemods( # NOTE: this may be used as a progress indicator by upstream tools logger.info("running codemod %s", codemod.id) - # Unfortunately the IDs from semgrep are not fully specified - # TODO: eventually we need to be able to use fully specified IDs here - if codemod.is_semgrep and codemod.name not in semgrep_finding_ids: - logger.debug( - "no results from semgrep for %s, skipping analysis", - codemod.id, - ) - continue + if isinstance(codemod.detector, SemgrepRuleDetector): + # Unfortunately the IDs from semgrep are not fully specified + # TODO: eventually we need to be able to use fully specified IDs here + if codemod.name not in semgrep_finding_ids: + logger.debug( + "no results from semgrep for %s, skipping analysis", + codemod.id, + ) + continue + + files_to_analyze = semgrep_results.files_for_rule(codemod.name) - semgrep_files = semgrep_results.files_for_rule(codemod.name) # Non-semgrep codemods ignore the semgrep results - results = codemod.apply(context, semgrep_files) - analyze_files( - context, - files_to_analyze, - codemod, - results, - argv, - ) + codemod.apply(context, files_to_analyze) record_dependency_update(context.process_dependencies(codemod.id)) context.log_changes(codemod.id) @@ -276,6 +162,9 @@ def run(original_args) -> int: argv.verbose, codemod_registry, repo_manager, + argv.path_include, + argv.path_exclude, + argv.max_workers, ) repo_manager.parse_project() @@ -297,14 +186,17 @@ def run(original_args) -> int: full_names = [str(path) for path in files_to_analyze] log_list(logging.DEBUG, "matched files", full_names) - semgrep_results: ResultSet = find_semgrep_results(context, codemods_to_run) + semgrep_results: ResultSet = find_semgrep_results( + context, + codemods_to_run, + files_to_analyze, + ) apply_codemods( context, codemods_to_run, semgrep_results, files_to_analyze, - argv, ) results = context.compile_results(codemods_to_run) diff --git a/src/codemodder/codemods/api.py b/src/codemodder/codemods/api.py new file mode 100644 index 00000000..4c9c3c0b --- /dev/null +++ b/src/codemodder/codemods/api.py @@ -0,0 +1,52 @@ +from abc import ABCMeta +from typing import Callable + +import libcst as cst + +from codemodder.codemods.base_codemod import ( # pylint: disable=unused-import + BaseCodemod, + Metadata, + Reference, + ReviewGuidance, +) +from codemodder.codemods.libcst_transformer import ( + LibcstResultTransformer, + LibcstTransformerPipeline, +) +from codemodder.file_context import FileContext # pylint: disable=unused-import +from codemodder.codemods.semgrep import SemgrepRuleDetector + + +class SimpleCodemod(LibcstResultTransformer, metaclass=ABCMeta): + """ + Base class for codemods with a single detector and transformer + + Child classes must implement the following attributes: + - metadata: Metadata + - codemod_base: type[BaseCodemod] + """ + + metadata: Metadata + detector_pattern: str + on_result_found: Callable[[cst.CSTNode, cst.CSTNode], cst.CSTNode] + + codemod_base: type[BaseCodemod] + + def __init__(self, *args, **kwargs): + """Obfuscates the type of the constructor to make the type checker happy""" + super().__init__(*args, **kwargs) + + def __new__(cls, *args, **kwargs): + del args + + if kwargs.get("_transformer", False): + return super().__new__(cls) + + return cls.codemod_base( + metadata=cls.metadata, + detector=SemgrepRuleDetector(cls.detector_pattern) + if getattr(cls, "detector_pattern", None) + else None, + # This allows the transformer to inherit all the methods of the class itself + transformer=LibcstTransformerPipeline(cls), + ) diff --git a/src/codemodder/codemods/api/__init__.py b/src/codemodder/codemods/api/__init__.py deleted file mode 100644 index 0a05313b..00000000 --- a/src/codemodder/codemods/api/__init__.py +++ /dev/null @@ -1,142 +0,0 @@ -import io -import os -import tempfile - -import libcst as cst -from libcst.codemod import ( - CodemodContext, -) -import yaml - -from codemodder.codemods.base_codemod import ( - CodemodMetadata, - BaseCodemod as _BaseCodemod, - SemgrepCodemod as _SemgrepCodemod, - # Make this available via the simplified API - ReviewGuidance, # pylint: disable=unused-import -) - -from codemodder.codemods.base_visitor import BaseTransformer -from codemodder.change import Change -from codemodder.file_context import FileContext -from .helpers import Helpers - - -def _populate_yaml(rule: str, metadata: CodemodMetadata) -> str: - config = yaml.safe_load(io.StringIO(rule)) - # TODO: handle more than rule per config? - assert len(config["rules"]) == 1 - config["rules"][0].setdefault("id", metadata.NAME) - config["rules"][0].setdefault("message", "Semgrep found a match") - config["rules"][0].setdefault("severity", "WARNING") - config["rules"][0].setdefault("languages", ["python"]) - return yaml.safe_dump(config) - - -def _create_temp_yaml_file(orig_cls, metadata: CodemodMetadata): - fd, path = tempfile.mkstemp() - with os.fdopen(fd, "w") as ff: - ff.write(_populate_yaml(orig_cls.rule(), metadata)) - - return [path] - - -class _CodemodSubclassWithMetadata: - def __init_subclass__(cls): - # This is a pretty yucky workaround. - # But it is necessary to get around the fact that these fields are - # checked by __init_subclass__ of the other parents of SemgrepCodemod - # first. - if cls.__name__ not in ("BaseCodemod", "SemgrepCodemod"): - # TODO: if we intend to continue to check class-level attributes - # using this mechanism, we should add checks (or defaults) for - # NAME, DESCRIPTION, and REVIEW_GUIDANCE here. - missing_fields = [] - for field in ["SUMMARY", "DESCRIPTION", "REVIEW_GUIDANCE"]: - try: - assert ( - hasattr(cls, field) - and getattr(cls, field) is not NotImplemented - ) - except AssertionError: - missing_fields.append(field) - - if missing_fields: - raise AssertionError( - f"{cls.__name__} is missing the following fields: {missing_fields}" - ) - - cls.METADATA = CodemodMetadata( - cls.DESCRIPTION, # pylint: disable=no-member - cls.NAME, # pylint: disable=no-member - cls.REVIEW_GUIDANCE, # pylint: disable=no-member - cls.REFERENCES, # pylint: disable=no-member - ) - - # This is a little bit hacky, but it also feels like the right solution? - cls.CHANGE_DESCRIPTION = cls.DESCRIPTION # pylint: disable=no-member - - return cls - - -class BaseCodemod( - _CodemodSubclassWithMetadata, - _BaseCodemod, - BaseTransformer, - Helpers, -): - def __init__(self, codemod_context: CodemodContext, file_context: FileContext): - _BaseCodemod.__init__(self, file_context) - BaseTransformer.__init__(self, codemod_context, []) - - def report_change(self, original_node): - line_number = self.lineno_for_node(original_node) - self.file_context.codemod_changes.append( - Change(line_number, self.CHANGE_DESCRIPTION) - ) - - -# NOTE: this shadows base_codemod.SemgrepCodemod but I can't think of a better name right now -# At least it is namespaced but we might want to deconflict these things in the long term -class SemgrepCodemod( - BaseCodemod, - _CodemodSubclassWithMetadata, - _SemgrepCodemod, - BaseTransformer, -): - def __init_subclass__(cls): - super().__init_subclass__() - cls.YAML_FILES = _create_temp_yaml_file(cls, cls.METADATA) - - def __init__(self, codemod_context: CodemodContext, file_context: FileContext): - BaseCodemod.__init__(self, codemod_context, file_context) - _SemgrepCodemod.__init__(self, file_context) - BaseTransformer.__init__(self, codemod_context, file_context.findings) - - def _new_or_updated_node(self, original_node, updated_node): - if self.node_is_selected(original_node): - self.report_change(original_node) - if (attr := getattr(self, "on_result_found", None)) is not None: - # pylint: disable=not-callable - new_node = attr(original_node, updated_node) - return new_node - return updated_node - - # TODO: there needs to be a way to generalize this so that it applies - # more broadly than to just a specific kind of node. There's probably a - # decent way to do this with metaprogramming. We could either apply it - # broadly to every known method (which would probably have a big - # performance impact). Or we could allow users to register the handler - # for a specific node or nodes by means of a decorator or something - # similar when they define their `on_result_found` method. - # Right now this is just to demonstrate a particular use case. - def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): - return self._new_or_updated_node(original_node, updated_node) - - def leave_Assign(self, original_node, updated_node): - return self._new_or_updated_node(original_node, updated_node) - - def leave_ClassDef( - self, original_node: cst.ClassDef, updated_node: cst.ClassDef - ) -> cst.ClassDef: - return self._new_or_updated_node(original_node, updated_node) diff --git a/src/codemodder/codemods/api/helpers.py b/src/codemodder/codemods/api/helpers.py deleted file mode 100644 index 93da9506..00000000 --- a/src/codemodder/codemods/api/helpers.py +++ /dev/null @@ -1,130 +0,0 @@ -from collections import namedtuple -import libcst as cst -from libcst import matchers -from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor -from codemodder.codemods.utils import get_call_name - -NewArg = namedtuple("NewArg", ["name", "value", "add_if_missing"]) - - -class Helpers: - def remove_unused_import(self, original_node): - # pylint: disable=no-member - RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) - - def add_needed_import(self, module, obj=None): - # TODO: do we need to check if this import already exists? - AddImportsVisitor.add_needed_import( - self.context, module, obj # pylint: disable=no-member - ) - - def update_call_target( - self, original_node, new_target, new_func=None, replacement_args=None - ): - # TODO: is an assertion the best way to handle this? - # Or should we just return the original node if it's not a Call? - assert isinstance(original_node, cst.Call) - - attr = ( - cst.parse_expression(new_func) - if new_func - else cst.Name(value=get_call_name(original_node)) - ) - return cst.Call( - func=cst.Attribute( - value=cst.parse_expression(new_target), - attr=attr, - ), - args=replacement_args if replacement_args else original_node.args, - ) - - def update_arg_target(self, updated_node, new_args: list): - return updated_node.with_changes( - args=[new if isinstance(new, cst.Arg) else cst.Arg(new) for new in new_args] - ) - - def update_assign_rhs(self, updated_node: cst.Assign, rhs: str): - value = cst.parse_expression(rhs) - return updated_node.with_changes(value=value) - - def parse_expression(self, expression: str): - return cst.parse_expression(expression) - - def replace_args(self, original_node, args_info): - """ - Iterate over the args in original_node and replace each arg - with any matching arg in `args_info`. - - :param original_node: libcst node with args attribute. - :param list args_info: List of NewArg - """ - assert hasattr(original_node, "args") - assert all( - isinstance(arg, NewArg) for arg in args_info - ), "`args_info` must contain `NewArg` types." - new_args = [] - - for arg in original_node.args: - arg_name, replacement_val, idx = _match_with_existing_arg(arg, args_info) - if arg_name is not None: - new = self.make_new_arg(replacement_val, arg_name, arg) - del args_info[idx] - else: - new = arg - new_args.append(new) - - for arg_name, replacement_val, add_if_missing in args_info: - if add_if_missing: - new = self.make_new_arg(replacement_val, arg_name) - new_args.append(new) - - return new_args - - def make_new_arg(self, value, name=None, existing_arg=None): - if name is None: - # Make a positional argument - return cst.Arg( - value=cst.parse_expression(value), - ) - - # make a keyword argument - equal = ( - existing_arg.equal - if existing_arg - else cst.AssignEqual( - whitespace_before=cst.SimpleWhitespace(""), - whitespace_after=cst.SimpleWhitespace(""), - ) - ) - return cst.Arg( - keyword=cst.Name(value=name), - value=cst.parse_expression(value), - equal=equal, - ) - - def add_arg_to_call(self, node: cst.Call, name: str, value): - """ - Add a new arg to the end of the args list. - """ - new_args = list(node.args) + [ - cst.Arg( - keyword=cst.Name(value=name), - value=cst.parse_expression(str(value)), - equal=cst.AssignEqual( - whitespace_before=cst.SimpleWhitespace(""), - whitespace_after=cst.SimpleWhitespace(""), - ), - ) - ] - return node.with_changes(args=new_args) - - -def _match_with_existing_arg(arg, args_info): - """ - Given an `arg` and a list of arg info, determine if - any of the names in arg_info match the arg. - """ - for idx, (arg_name, replacement_val, _) in enumerate(args_info): - if matchers.matches(arg.keyword, matchers.Name(arg_name)): - return arg_name, replacement_val, idx - return None, None, None diff --git a/src/codemodder/codemods/base_codemod.py b/src/codemodder/codemods/base_codemod.py index 6b1b4448..2a0e5e0c 100644 --- a/src/codemodder/codemods/base_codemod.py +++ b/src/codemodder/codemods/base_codemod.py @@ -1,14 +1,19 @@ +from abc import ABCMeta, abstractmethod +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum -from typing import List, ClassVar - -from libcst._position import CodeRange - -from codemodder.change import Change -from codemodder.dependency import Dependency +from functools import cached_property +import importlib.resources +from importlib.abc import Traversable +from pathlib import Path + +from codemodder.codemods.base_detector import BaseDetector +from codemodder.codemods.base_transformer import BaseTransformerPipeline +from codemodder.context import CodemodExecutionContext +from codemodder.code_directory import file_line_patterns from codemodder.file_context import FileContext +from codemodder.logging import logger from codemodder.result import ResultSet -from codemodder.semgrep import run as semgrep_run class ReviewGuidance(Enum): @@ -17,109 +22,164 @@ class ReviewGuidance(Enum): MERGE_WITHOUT_REVIEW = 3 -@dataclass(frozen=True) -class CodemodMetadata: - DESCRIPTION: str # TODO: this field should be optional - NAME: str - REVIEW_GUIDANCE: ReviewGuidance - REFERENCES: list = field(default_factory=list) - - # TODO: remove post_init update_references once we add description for each url. - def __post_init__(self): - object.__setattr__(self, "REFERENCES", self.update_references(self.REFERENCES)) - - @staticmethod - def update_references(references): - updated_references = [] - for reference in references: - updated_reference = dict( - reference - ) # Create a copy to avoid modifying the original dict - updated_reference["description"] = updated_reference["url"] - updated_references.append(updated_reference) - return updated_references - - -class BaseCodemod: - # Implementation borrowed from https://stackoverflow.com/a/45250114 - METADATA: ClassVar[CodemodMetadata] = NotImplemented - SUMMARY: ClassVar[str] = NotImplemented - is_semgrep: bool = False - adds_dependency: bool = False - file_context: FileContext - - def __init__(self, file_context: FileContext): - self.file_context = file_context - - @classmethod - def apply_rule(cls, context, *args, **kwargs) -> ResultSet: - """ - Apply rule associated with this codemod and gather results +@dataclass +class Reference: + url: str + description: str = "" - Does nothing by default. Subclasses may override for custom rule logic. - """ - del context, args, kwargs - return ResultSet() + def to_json(self): + return { + "url": self.url, + "description": self.description or self.url, + } - @classmethod - def name(cls): - # pylint: disable=no-member - return cls.METADATA.NAME - @property - def should_transform(self): - return True +@dataclass +class Metadata: + name: str + summary: str + review_guidance: ReviewGuidance + references: list[Reference] = field(default_factory=list) + has_description: bool = True + + +class BaseCodemod(metaclass=ABCMeta): + """ + Base class for all codemods + + Conceptually a codemod is composed of the following attributes: + * Metadata: contains information about the codemod including its name, summary, and review guidance + * Detector (optional): the source of results indicating which code locations the codemod should be applied + * Transformer: a transformer pipeline that will be applied to each applicable file and perform the actual modifications + + A detector may parse result files generated by other tools or it may + perform its own analysis at runtime, potentially by calling another tool + (e.g. Semgrep). - def node_position(self, node): - # pylint: disable=no-member - # See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112 - return self.get_metadata(self.METADATA_DEPENDENCIES[0], node) + Some codemods may not require a detector if the transformation pipeline + itself is capable of determining locations to modify. - def add_change(self, node, description: str, start: bool = True): - position = self.node_position(node) - self.add_change_from_position(position, description, start) + Codemods that apply the same transformation but use different detectors + should be implemented as distinct codemod classes. + """ - def add_change_from_position( - self, position: CodeRange, description: str, start: bool = True + _metadata: Metadata + detector: BaseDetector | None + transformer: BaseTransformerPipeline + + def __init__( + self, + *, + metadata: Metadata, + detector: BaseDetector | None = None, + transformer: BaseTransformerPipeline, ): - lineno = position.start.line if start else position.end.line - self.file_context.codemod_changes.append( - Change( - lineNumber=lineno, - description=description, - ) - ) + # Metadata should only be accessed via properties + self._metadata = metadata + self.detector = detector + self.transformer = transformer + + @property + @abstractmethod + def origin(self) -> str: + ... + + @property + @abstractmethod + def docs_module_path(self) -> str: + ... - def lineno_for_node(self, node): - return self.node_position(node).start.line + @property + def name(self) -> str: + return self._metadata.name @property - def line_exclude(self): - return self.file_context.line_exclude + def id(self) -> str: + return f"{self.origin}:python/{self.name}" @property - def line_include(self): - return self.file_context.line_include + def summary(self): + return self._metadata.summary - def add_dependency(self, dependency: Dependency): - self.file_context.add_dependency(dependency) + @cached_property + def docs_module(self) -> Traversable: + return importlib.resources.files(self.docs_module_path) + @cached_property + def description(self) -> str: + if not self._metadata.has_description: + return "" -class SemgrepCodemod(BaseCodemod): - YAML_FILES: ClassVar[List[str]] = NotImplemented - is_semgrep = True + doc_path = self.docs_module / f"{self.origin}_python_{self.name}.md" + return doc_path.read_text() - @classmethod - def apply_rule(cls, context, *args, **kwargs) -> ResultSet: + @property + def review_guidance(self): + return self._metadata.review_guidance.name.replace("_", " ").title() + + @property + def references(self) -> list[Reference]: + return self._metadata.references + + def describe(self): + return { + "codemod": self.id, + "summary": self.summary, + "description": self.description, + "references": [ref.to_json() for ref in self.references], + } + + def apply( + self, + context: CodemodExecutionContext, + files_to_analyze: list[Path], + ) -> None: """ - Apply semgrep to gather rule results + Apply the codemod to the given list of files + + This method is responsible for orchestrating the application of the codemod to a given list of files. + + It will first apply the detector (if any) to the files to determine which files should be modified. + + It then applies the transformer pipeline to each file applicable file, potentially generating a change set. + + All results are then processed and reported to the context. + + Per-file processing can be parallelized based on the `max_workers` setting. + + :param context: The codemod execution context + :param files_to_analyze: The list of files to analyze """ - yaml_files = kwargs.get("yaml_files") or args[0] - files_to_analyze = kwargs.get("files_to_analyze") or args[1] - with context.timer.measure("semgrep"): - return semgrep_run(context, yaml_files, files_to_analyze) + results = ( + # It seems like semgrep doesn't like our fully-specified id format + self.detector.apply(self.name, context, files_to_analyze) + if self.detector + else ResultSet() + ) - @property - def should_transform(self): - """Semgrep codemods should attempt transform only if there are semgrep results""" - return bool(self.file_context.findings) + def process_file(filename: Path): + line_exclude = file_line_patterns(filename, context.path_exclude) + line_include = file_line_patterns(filename, context.path_include) + findings_for_rule = results.results_for_rule_and_file(self.name, filename) + + file_context = FileContext( + context.directory, + filename, + line_exclude, + line_include, + findings_for_rule, + ) + + if change_set := self.transformer.apply( + context, file_context, findings_for_rule + ): + file_context.add_result(change_set) + + return file_context + + with ThreadPoolExecutor() as executor: + logger.debug("using executor with %s workers", context.max_workers) + contexts = executor.map(process_file, files_to_analyze) + executor.shutdown(wait=True) + + context.process_results(self.id, contexts) diff --git a/src/codemodder/codemods/base_detector.py b/src/codemodder/codemods/base_detector.py new file mode 100644 index 00000000..c1bae4d6 --- /dev/null +++ b/src/codemodder/codemods/base_detector.py @@ -0,0 +1,16 @@ +from abc import ABCMeta, abstractmethod +from pathlib import Path + +from codemodder.context import CodemodExecutionContext +from codemodder.result import ResultSet + + +class BaseDetector(metaclass=ABCMeta): + @abstractmethod + def apply( + self, + codemod_id: str, + context: CodemodExecutionContext, + files_to_analyze: list[Path], + ) -> ResultSet: + ... diff --git a/src/codemodder/codemods/base_transformer.py b/src/codemodder/codemods/base_transformer.py new file mode 100644 index 00000000..434d671b --- /dev/null +++ b/src/codemodder/codemods/base_transformer.py @@ -0,0 +1,45 @@ +from abc import ABCMeta, abstractmethod + +from codemodder.change import ChangeSet +from codemodder.context import CodemodExecutionContext +from codemodder.codemods.base_visitor import BaseTransformer +from codemodder.file_context import FileContext +from codemodder.result import Result + + +class BaseTransformerPipeline(metaclass=ABCMeta): + """ + Base class for a pipeline of transformers + + A pipeline is a list of one or more transformers that are applied in sequence. + + The transformers in a given pipeline can either be homogeneous or heterogeneous in terms of inputs and output formats accepted by each transformer. For a heterogeneous pipeline it may be necessary to implement adapter classes to convert between formats. + + Each transformer pipeline is responsible for writing results to the output files if `dry_run` is `False`. + + **NOTE**: In general, pipelines that rely on detectors will need to account for the fact that the detected results become "stale" after the application of the first transformer in the pipeline. This is not an issue for transformers that do their own detection or which are capable of adjusting the location of results + """ + + transformers: list[type[BaseTransformer]] + + def __init__(self, *transformers: type[BaseTransformer]): + self.transformers = list(transformers) + + @abstractmethod + def apply( + self, + context: CodemodExecutionContext, + file_context: FileContext, + results: list[Result], + ) -> ChangeSet | None: + """ + Apply the pipeline to the given file context + + :param context: The codemod execution context + :param file_context: The file context representing the file to transform + :param results: The (optional) results of the detector phase + + :return: The `ChangeSet` to apply to the file, or `None` if no changes are applied + + This method is responsible for writing the results to the output files if `dry_run` is False. + """ diff --git a/src/codemodder/codemods/base_visitor.py b/src/codemodder/codemods/base_visitor.py index 65d30d1d..e50c9178 100644 --- a/src/codemodder/codemods/base_visitor.py +++ b/src/codemodder/codemods/base_visitor.py @@ -5,6 +5,7 @@ from codemodder.result import Result +# TODO: this should just be part of BaseTransformer and BaseVisitor? class UtilsMixin: results: list[Result] diff --git a/src/codemodder/codemods/libcst_transformer.py b/src/codemodder/codemods/libcst_transformer.py new file mode 100644 index 00000000..f7d459ea --- /dev/null +++ b/src/codemodder/codemods/libcst_transformer.py @@ -0,0 +1,308 @@ +from collections import namedtuple + +import libcst as cst +from libcst import matchers +from libcst.codemod import CodemodContext +from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor +from libcst._position import CodeRange + +from codemodder.codemods.base_visitor import BaseTransformer +from codemodder.codemods.base_transformer import BaseTransformerPipeline +from codemodder.codemods.utils import get_call_name +from codemodder.context import CodemodExecutionContext +from codemodder.diff import create_diff_from_tree +from codemodder.change import ChangeSet, Change +from codemodder.dependency import Dependency +from codemodder.logging import logger +from codemodder.file_context import FileContext +from codemodder.result import Result + + +NewArg = namedtuple("NewArg", ["name", "value", "add_if_missing"]) + + +def update_code(file_path, new_code): + """ + Write the `new_code` to the `file_path` + """ + with open(file_path, "w", encoding="utf-8") as f: + f.write(new_code) + + +class LibcstResultTransformer( + BaseTransformer +): # pylint: disable=too-many-public-methods + """ + Transformer class that performs libcst-based transformations on a given file + + :param context: libcst CodemodContext + :param results: list of `Result` generated by the detector phase (may be empty) + :param file_context: `FileContext` for the file to be transformed + """ + + change_description: str = "" + + def __init__( + self, + context: CodemodContext, + results: list[Result], + file_context: FileContext, + _transformer: bool = False, + ): + del _transformer + + super().__init__(context, results) + self.file_context = file_context + + @classmethod + def transform( + cls, module: cst.Module, results: list[Result], file_context: FileContext + ) -> cst.Module: + wrapper = cst.MetadataWrapper(module) + codemod = cls( + CodemodContext(wrapper=wrapper), + results, + file_context, + _transformer=True, + ) + + return codemod.transform_module(module) + + def _new_or_updated_node(self, original_node, updated_node): + if self.node_is_selected(original_node): + if (attr := getattr(self, "on_result_found", None)) is not None: + # pylint: disable=not-callable + new_node = attr(original_node, updated_node) + self.report_change(original_node) + return new_node + return updated_node + + # TODO: there needs to be a way to generalize this so that it applies + # more broadly than to just a specific kind of node. There's probably a + # decent way to do this with metaprogramming. We could either apply it + # broadly to every known method (which would probably have a big + # performance impact). Or we could allow users to register the handler + # for a specific node or nodes by means of a decorator or something + # similar when they define their `on_result_found` method. + # Right now this is just to demonstrate a particular use case. + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): + return self._new_or_updated_node(original_node, updated_node) + + def leave_Assign(self, original_node, updated_node): + return self._new_or_updated_node(original_node, updated_node) + + def leave_ClassDef( + self, original_node: cst.ClassDef, updated_node: cst.ClassDef + ) -> cst.ClassDef: + return self._new_or_updated_node(original_node, updated_node) + + def node_position(self, node): + # pylint: disable=no-member + # See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112 + return self.get_metadata(self.METADATA_DEPENDENCIES[0], node) + + def add_change(self, node, description: str, start: bool = True): + position = self.node_position(node) + self.add_change_from_position(position, description, start) + + def add_change_from_position( + self, position: CodeRange, description: str, start: bool = True + ): + lineno = position.start.line if start else position.end.line + self.file_context.codemod_changes.append( + Change( + lineNumber=lineno, + description=description, + ) + ) + + def lineno_for_node(self, node): + return self.node_position(node).start.line + + @property + def line_exclude(self): + return self.file_context.line_exclude + + @property + def line_include(self): + return self.file_context.line_include + + def add_dependency(self, dependency: Dependency): + self.file_context.add_dependency(dependency) + + def report_change(self, original_node): + line_number = self.lineno_for_node(original_node) + self.file_context.codemod_changes.append( + Change(line_number, self.change_description) + ) + + def remove_unused_import(self, original_node): + # pylint: disable=no-member + RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) + + def add_needed_import(self, module, obj=None): + # TODO: do we need to check if this import already exists? + AddImportsVisitor.add_needed_import( + self.context, module, obj # pylint: disable=no-member + ) + + def update_call_target( + self, + original_node, + new_target, + new_func: str | None = None, + replacement_args=None, + ): + # TODO: is an assertion the best way to handle this? + # Or should we just return the original node if it's not a Call? + assert isinstance(original_node, cst.Call) + + func_name = new_func if new_func else get_call_name(original_node) + return cst.Call( + func=cst.Attribute( + value=cst.parse_expression(new_target), + attr=cst.Name(value=func_name), + ), + args=replacement_args if replacement_args else original_node.args, + ) + + def update_arg_target(self, updated_node, new_args: list): + return updated_node.with_changes( + args=[new if isinstance(new, cst.Arg) else cst.Arg(new) for new in new_args] + ) + + def update_assign_rhs(self, updated_node: cst.Assign, rhs: str): + value = cst.parse_expression(rhs) + return updated_node.with_changes(value=value) + + def parse_expression(self, expression: str): + return cst.parse_expression(expression) + + def replace_args(self, original_node, args_info): + """ + Iterate over the args in original_node and replace each arg + with any matching arg in `args_info`. + + :param original_node: libcst node with args attribute. + :param list args_info: List of NewArg + """ + assert hasattr(original_node, "args") + assert all( + isinstance(arg, NewArg) for arg in args_info + ), "`args_info` must contain `NewArg` types." + new_args = [] + + for arg in original_node.args: + arg_name, replacement_val, idx = _match_with_existing_arg(arg, args_info) + if arg_name is not None: + new = self.make_new_arg(replacement_val, arg_name, arg) + del args_info[idx] + else: + new = arg + new_args.append(new) + + for arg_name, replacement_val, add_if_missing in args_info: + if add_if_missing: + new = self.make_new_arg(replacement_val, arg_name) + new_args.append(new) + + return new_args + + def make_new_arg(self, value, name=None, existing_arg=None): + if name is None: + # Make a positional argument + return cst.Arg( + value=cst.parse_expression(value), + ) + + # make a keyword argument + equal = ( + existing_arg.equal + if existing_arg + else cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ) + ) + return cst.Arg( + keyword=cst.Name(value=name), + value=cst.parse_expression(value), + equal=equal, + ) + + def add_arg_to_call(self, node: cst.Call, name: str, value): + """ + Add a new arg to the end of the args list. + """ + new_args = list(node.args) + [ + cst.Arg( + keyword=cst.Name(value=name), + value=cst.parse_expression(str(value)), + equal=cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ), + ) + ] + return node.with_changes(args=new_args) + + +class LibcstTransformerPipeline(BaseTransformerPipeline): + """ + Transformer pipeline class that applies one or more `LibcstResultTransformer` to a given file + + This pipeline expects that all transformers accept a libcst `Module` as input and return a libcst `Module` as output. + """ + + transformers: list[type[LibcstResultTransformer]] + + def apply( + self, + context: CodemodExecutionContext, + file_context: FileContext, + results: list[Result], + ) -> ChangeSet | None: + file_path = file_context.file_path + + try: + with file_context.timer.measure("parse"): + with open(file_path, "r", encoding="utf-8") as f: + source_tree = cst.parse_module(f.read()) + except Exception: + file_context.add_failure(file_path) + logger.exception("error parsing file %s", file_path) + return None + + tree = source_tree + with file_context.timer.measure("transform"): + for transformer in self.transformers: + tree = transformer.transform(tree, results, file_context) + + if not file_context.codemod_changes: + return None + + diff = create_diff_from_tree(source_tree, tree) + if not diff: + return None + + change_set = ChangeSet( + str(file_context.file_path.relative_to(context.directory)), + diff, + changes=file_context.codemod_changes, + ) + + if not context.dry_run: + with file_context.timer.measure("write"): + update_code(file_context.file_path, tree.code) + + return change_set + + +def _match_with_existing_arg(arg, args_info): + """ + Given an `arg` and a list of arg info, determine if any of the names in arg_info match the arg. + """ + for idx, (arg_name, replacement_val, _) in enumerate(args_info): + if matchers.matches(arg.keyword, matchers.Name(arg_name)): + return arg_name, replacement_val, idx + return None, None, None diff --git a/src/codemodder/codemods/semgrep.py b/src/codemodder/codemods/semgrep.py new file mode 100644 index 00000000..c3e07352 --- /dev/null +++ b/src/codemodder/codemods/semgrep.py @@ -0,0 +1,49 @@ +import io +import os +from pathlib import Path +import tempfile + +import yaml + +from codemodder.codemods.base_detector import BaseDetector +from codemodder.context import CodemodExecutionContext +from codemodder.result import ResultSet +from codemodder.semgrep import run as semgrep_run + + +def _populate_yaml(rule: str, codemod_id: str) -> str: + rule_yaml = yaml.safe_load(io.StringIO(rule)) + config = {"rules": rule_yaml} if "rules" not in rule_yaml else rule_yaml + config["rules"][0].setdefault("id", codemod_id) + config["rules"][0].setdefault("message", "Semgrep found a match") + config["rules"][0].setdefault("severity", "WARNING") + config["rules"][0].setdefault("languages", ["python"]) + return yaml.safe_dump(config) + + +def _create_temp_yaml_file(rule: str, codemod_id: str): + fd, path = tempfile.mkstemp() + with os.fdopen(fd, "w") as ff: + ff.write(_populate_yaml(rule, codemod_id)) + + return [Path(path)] + + +class SemgrepRuleDetector(BaseDetector): + rule: str + + def __init__(self, rule: str): + self.rule = rule + + def get_yaml_files(self, codemod_id: str) -> list[Path]: + return _create_temp_yaml_file(self.rule, codemod_id) + + def apply( + self, + codemod_id: str, + context: CodemodExecutionContext, + files_to_analyze: list[Path], + ) -> ResultSet: + yaml_files = self.get_yaml_files(codemod_id) + with context.timer.measure("semgrep"): + return semgrep_run(context, yaml_files, files_to_analyze) diff --git a/src/codemodder/context.py b/src/codemodder/context.py index 6f8b5e26..84811959 100644 --- a/src/codemodder/context.py +++ b/src/codemodder/context.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import logging from pathlib import Path import itertools from textwrap import indent -from typing import List, Iterator +from typing import TYPE_CHECKING, List, Iterator from codemodder.change import ChangeSet from codemodder.dependency import ( @@ -10,7 +12,6 @@ build_dependency_notification, build_failed_dependency_notification, ) -from codemodder.executor import CodemodExecutorWrapper from codemodder.file_context import FileContext from codemodder.logging import logger, log_list from codemodder.project_analysis.file_parsers.package_store import PackageStore @@ -18,6 +19,9 @@ from codemodder.project_analysis.python_repo_manager import PythonRepoManager from codemodder.utils.timer import Timer +if TYPE_CHECKING: + from codemodder.codemods.base_codemod import BaseCodemod + class CodemodExecutionContext: # pylint: disable=too-many-instance-attributes _results_by_codemod: dict[str, list[ChangeSet]] = {} @@ -30,6 +34,9 @@ class CodemodExecutionContext: # pylint: disable=too-many-instance-attributes registry: CodemodRegistry repo_manager: PythonRepoManager timer: Timer + path_include: list[str] + path_exclude: list[str] + max_workers: int = 1 def __init__( self, @@ -38,6 +45,9 @@ def __init__( verbose: bool, registry: CodemodRegistry, repo_manager: PythonRepoManager, + path_include: list[str], + path_exclude: list[str], + max_workers: int = 1, ): # pylint: disable=too-many-arguments self.directory = directory self.dry_run = dry_run @@ -48,6 +58,9 @@ def __init__( self.registry = registry self.repo_manager = repo_manager self.timer = Timer() + self.path_include = path_include + self.path_exclude = path_exclude + self.max_workers = max_workers def add_results(self, codemod_name: str, change_sets: List[ChangeSet]): self._results_by_codemod.setdefault(codemod_name, []).extend(change_sets) @@ -116,7 +129,7 @@ def process_dependencies( return record - def add_description(self, codemod: CodemodExecutorWrapper): + def add_description(self, codemod: BaseCodemod): description = codemod.description if dependencies := list(self.dependencies.get(codemod.id, [])): if pkg_store := self._dependency_update_by_codemod.get(codemod.id): @@ -135,14 +148,14 @@ def process_results(self, codemod_id: str, results: Iterator[FileContext]): self.add_dependencies(codemod_id, file_context.dependencies) self.timer.aggregate(file_context.timer) - def compile_results(self, codemods: list[CodemodExecutorWrapper]): + def compile_results(self, codemods: list[BaseCodemod]): results = [] for codemod in codemods: data = { "codemod": codemod.id, "summary": codemod.summary, "description": self.add_description(codemod), - "references": codemod.references, + "references": [ref.to_json() for ref in codemod.references], "properties": {}, "failedFiles": [str(file) for file in self.get_failures(codemod.id)], "changeset": [ diff --git a/src/codemodder/dependency_management/setup_py_writer.py b/src/codemodder/dependency_management/setup_py_writer.py index 50eab85a..fcba5288 100644 --- a/src/codemodder/dependency_management/setup_py_writer.py +++ b/src/codemodder/dependency_management/setup_py_writer.py @@ -2,8 +2,7 @@ from libcst.codemod import CodemodContext from libcst import matchers from typing import Optional -from codemodder.codemods.api import BaseCodemod -from codemodder.codemods.base_codemod import ReviewGuidance +from codemodder.codemods.api import SimpleCodemod, Metadata, ReviewGuidance from codemodder.codemods.utils import is_setup_py_file from codemodder.codemods.utils_mixin import NameResolutionMixin from codemodder.file_context import FileContext @@ -30,6 +29,7 @@ def add_to_file( CodemodContext(wrapper=wrapper), file_context, dependencies=[dep.requirement for dep in dependencies], + _transformer=True, ) output_tree = codemod.transform_module(input_tree) @@ -56,20 +56,21 @@ def _parse_file(self): return cst.parse_module(f.read()) -class SetupPyAddDependencies(BaseCodemod, NameResolutionMixin): - NAME = "setup-py-add-dependencies" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Add Dependencies to `setup.py` `install_requires`" - DESCRIPTION = SUMMARY - REFERENCES: list = [] +class SetupPyAddDependencies(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="setup-py-add-dependencies", + summary="Add Dependencies to `setup.py` `install_requires`", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + ) def __init__( self, codemod_context: CodemodContext, file_context: FileContext, dependencies: list[Requirement], + **kwargs, ): - BaseCodemod.__init__(self, codemod_context, file_context) + SimpleCodemod.__init__(self, codemod_context, [], file_context, **kwargs) NameResolutionMixin.__init__(self) self.filename = self.file_context.file_path self.dependencies = dependencies diff --git a/src/codemodder/executor.py b/src/codemodder/executor.py deleted file mode 100644 index 7e46ed3f..00000000 --- a/src/codemodder/executor.py +++ /dev/null @@ -1,109 +0,0 @@ -from importlib.abc import Traversable -from pathlib import Path - -from wrapt import CallableObjectProxy - - -class CodemodExecutorWrapper(CallableObjectProxy): - """A wrapper around a codemod that provides additional metadata.""" - - origin: str - docs_module: Traversable - semgrep_config_module: Traversable - - def __init__( - self, - codemod, - origin: str, - docs_module: Traversable, - semgrep_config_module: Traversable, - ): - super().__init__(codemod) - self.origin = origin - self.docs_module = docs_module - self.semgrep_config_module = semgrep_config_module - - def apply(self, context, files: list[Path]): - """ - Wraps the codemod's apply method to inject additional arguments. - - Not all codemods will need these arguments. - """ - return self.apply_rule( - context, - yaml_files=self.yaml_files, - files_to_analyze=files, - ) - - @property - def name(self): - return self.__wrapped__.name() - - @property - def id(self): - return f"{self.origin}:python/{self.name}" - - @property - def is_semgrep(self): - return self.__wrapped__.is_semgrep - - @property - def summary(self): - return self.SUMMARY - - def _get_description(self): - doc_path = self.docs_module / f"{self.origin}_python_{self.name}.md" - return doc_path.read_text() - - @property - def description(self): - try: - return self._get_description() - except FileNotFoundError: - # TODO: temporary workaround - return self.METADATA.DESCRIPTION - - @property - def review_guidance(self): - return self.METADATA.REVIEW_GUIDANCE.name.replace("_", " ").title() - - @property - def references(self): - return self.METADATA.REFERENCES - - @property - def yaml_files(self): - return [ - self.semgrep_config_module / yaml_file - for yaml_file in getattr(self, "YAML_FILES", []) - ] - - def describe(self): - return { - "codemod": self.id, - "summary": self.summary, - "description": self.description, - "references": self.references, - } - - def __repr__(self): - return "<{} at 0x{:x} for {}.{}>".format( - type(self).__name__, - id(self), - self.__wrapped__.__module__, - self.__wrapped__.__name__, - ) - - # The following methods are all abstract in the ObjectProxy class, so - # we just implement them as simple pass-throughs to the wrapped object. - def __copy__(self): - return self.__wrapped__.__copy__() - - def __deepcopy__(self, memo): - return self.__wrapped__.__deepcopy__(memo) - - def __reduce__(self): - return self.__wrapped__.__reduce__() - - def __reduce_ex__(self, protocol): - return self.__wrapped__.__reduce_ex__(protocol) diff --git a/src/codemodder/registry.py b/src/codemodder/registry.py index b58a70e9..293efd9f 100644 --- a/src/codemodder/registry.py +++ b/src/codemodder/registry.py @@ -1,11 +1,14 @@ -from dataclasses import dataclass, asdict -from importlib.resources import files +from __future__ import annotations + +from dataclasses import dataclass from importlib.metadata import entry_points -from typing import Optional +from typing import Optional, TYPE_CHECKING -from codemodder.executor import CodemodExecutorWrapper from codemodder.logging import logger +if TYPE_CHECKING: + from codemodder.codemods.base_codemod import BaseCodemod + # These are generally not intended to be applied directly so they are excluded by default. DEFAULT_EXCLUDED_CODEMODS = [ @@ -25,8 +28,8 @@ class CodemodCollection: class CodemodRegistry: - _codemods_by_name: dict[str, CodemodExecutorWrapper] - _codemods_by_id: dict[str, CodemodExecutorWrapper] + _codemods_by_name: dict[str, BaseCodemod] + _codemods_by_id: dict[str, BaseCodemod] def __init__(self): self._codemods_by_name = {} @@ -45,45 +48,16 @@ def codemods(self): return list(self._codemods_by_name.values()) def add_codemod_collection(self, collection: CodemodCollection): - docs_module = files(collection.docs_module) - semgrep_module = files(collection.semgrep_config_module) for codemod in collection.codemods: - self._validate_codemod(codemod) - wrapper = CodemodExecutorWrapper( - codemod, - collection.origin, - docs_module, - semgrep_module, - ) + wrapper = codemod() if isinstance(codemod, type) else codemod self._codemods_by_name[wrapper.name] = wrapper self._codemods_by_id[wrapper.id] = wrapper - def _validate_codemod(self, codemod): - for name in ["SUMMARY", "METADATA"]: - if not (attr := getattr(codemod, name)) or attr is NotImplemented: - raise ValueError( - f'Missing required attribute "{name}" on codemod {codemod}' - ) - - for k, v in asdict(codemod.METADATA).items(): - if v is NotImplemented: - raise NotImplementedError(f"METADATA.{k} not defined for {codemod}") - if k != "REFERENCES" and not v: - raise NotImplementedError( - f"METADATA.{k} should not be None or empty for {codemod}" - ) - - # TODO: eventually we will represent IS_SEMGREP using the class hierarchy - if codemod.is_semgrep and not codemod.YAML_FILES: - raise ValueError( - f"Missing required attribute YAML_FILES on semgrep codemod {codemod}" - ) - def match_codemods( self, codemod_include: Optional[list] = None, codemod_exclude: Optional[list] = None, - ) -> list[CodemodExecutorWrapper]: + ) -> list[BaseCodemod]: codemod_include = codemod_include or [] codemod_exclude = codemod_exclude or DEFAULT_EXCLUDED_CODEMODS diff --git a/src/core_codemods/add_requests_timeouts.py b/src/core_codemods/add_requests_timeouts.py index 1f48d6d3..cadbf349 100644 --- a/src/core_codemods/add_requests_timeouts.py +++ b/src/core_codemods/add_requests_timeouts.py @@ -1,45 +1,60 @@ -from codemodder.codemods.api import SemgrepCodemod, ReviewGuidance +from core_codemods.api import ( + CoreCodemod, + Metadata, + Reference, + ReviewGuidance, +) +from codemodder.codemods.libcst_transformer import ( + LibcstTransformerPipeline, + LibcstResultTransformer, +) +from codemodder.codemods.semgrep import SemgrepRuleDetector -class AddRequestsTimeouts(SemgrepCodemod): - NAME = "add-requests-timeouts" - SUMMARY = "Add timeout to `requests` calls" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - DESCRIPTION = "Add timeout to `requests` call" - REFERENCES = [ - { - "url": "https://docs.python-requests.org/en/master/user/quickstart/#timeouts", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ - rules: - - patterns: - - pattern-inside: | - import requests - ... - - pattern: $CALL(...) - - pattern-not: $CALL(..., timeout=$TIMEOUT, ...) - - metavariable-pattern: - metavariable: $CALL - patterns: - - pattern-either: - - pattern: requests.get - - pattern: requests.post - - pattern: requests.put - - pattern: requests.delete - - pattern: requests.head - - pattern: requests.options - - pattern: requests.patch - - pattern: requests.request - """ - +class TransformAddRequestsTimeouts(LibcstResultTransformer): # Sets an arbitrary default timeout for all requests DEFAULT_TIMEOUT = 60 + change_description = "Add timeout to `requests` call" + def on_result_found(self, original_node, updated_node): del original_node return self.add_arg_to_call(updated_node, "timeout", self.DEFAULT_TIMEOUT) + + +# This codemod uses the lower level codemod and transformer APIs for the sake of example. +AddRequestsTimeouts = CoreCodemod( + metadata=Metadata( + name="add-requests-timeouts", + summary="Add timeout to `requests` calls", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://docs.python-requests.org/en/master/user/quickstart/#timeouts" + ), + ], + ), + detector=SemgrepRuleDetector( + """ + - patterns: + - pattern-inside: | + import requests + ... + - pattern: $CALL(...) + - pattern-not: $CALL(..., timeout=$TIMEOUT, ...) + - metavariable-pattern: + metavariable: $CALL + patterns: + - pattern-either: + - pattern: requests.get + - pattern: requests.post + - pattern: requests.put + - pattern: requests.delete + - pattern: requests.head + - pattern: requests.options + - pattern: requests.patch + - pattern: requests.request + """ + ), + transformer=LibcstTransformerPipeline(TransformAddRequestsTimeouts), +) diff --git a/src/core_codemods/api/__init__.py b/src/core_codemods/api/__init__.py new file mode 100644 index 00000000..904e6b82 --- /dev/null +++ b/src/core_codemods/api/__init__.py @@ -0,0 +1,6 @@ +from codemodder.codemods.api import ( + Metadata, + Reference, + ReviewGuidance, +) +from .core_codemod import SimpleCodemod, CoreCodemod diff --git a/src/core_codemods/api/core_codemod.py b/src/core_codemods/api/core_codemod.py new file mode 100644 index 00000000..deefb4e4 --- /dev/null +++ b/src/core_codemods/api/core_codemod.py @@ -0,0 +1,23 @@ +from codemodder.codemods.api import BaseCodemod, SimpleCodemod as _SimpleCodemod + + +class CoreCodemod(BaseCodemod): + """ + Base class for all core codemods provided by this package. + """ + + @property + def origin(self): + return "pixee" + + @property + def docs_module_path(self): + return "core_codemods.docs" + + +class SimpleCodemod(_SimpleCodemod): + """ + Base class for all core codemods with a single detector and transformer. + """ + + codemod_base = CoreCodemod diff --git a/src/core_codemods/combine_startswith_endswith.py b/src/core_codemods/combine_startswith_endswith.py index 2dc42d51..3f7a0403 100644 --- a/src/core_codemods/combine_startswith_endswith.py +++ b/src/core_codemods/combine_startswith_endswith.py @@ -1,14 +1,17 @@ import libcst as cst from libcst import matchers as m -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod -class CombineStartswithEndswith(BaseCodemod, NameResolutionMixin): - NAME = "combine-startswith-endswith" - SUMMARY = "Simplify Boolean Expressions Using `startswith` and `endswith`" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Use tuple of matches instead of boolean expression" +class CombineStartswithEndswith(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="combine-startswith-endswith", + summary="Simplify Boolean Expressions Using `startswith` and `endswith`", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[], + ) + change_description = "Use tuple of matches instead of boolean expression" REFERENCES: list = [] def leave_BooleanOperation( diff --git a/src/core_codemods/django_debug_flag_on.py b/src/core_codemods/django_debug_flag_on.py index 38b23c40..3a1a2643 100644 --- a/src/core_codemods/django_debug_flag_on.py +++ b/src/core_codemods/django_debug_flag_on.py @@ -1,28 +1,29 @@ import libcst as cst -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils import is_django_settings_file +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class DjangoDebugFlagOn(SemgrepCodemod): - NAME = "django-debug-flag-on" - DESCRIPTION = "Flip `Django` debug flag to off." - SUMMARY = "Disable Django Debug Mode" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - REFERENCES = [ - { - "url": "https://owasp.org/www-project-top-ten/2017/A3_2017-Sensitive_Data_Exposure", - "description": "", - }, - { - "url": "https://docs.djangoproject.com/en/4.2/ref/settings/#std-setting-DEBUG", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ +class DjangoDebugFlagOn(SimpleCodemod): + metadata = Metadata( + name="django-debug-flag-on", + summary="Disable Django Debug Mode", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://owasp.org/www-project-top-ten/2017/A3_2017-Sensitive_Data_Exposure" + ), + Reference( + url="https://docs.djangoproject.com/en/4.2/ref/settings/#std-setting-DEBUG" + ), + ], + ) + change_description = "Flip `Django` debug flag to off." + detector_pattern = """ rules: - id: django-debug-flag-on pattern: DEBUG = True diff --git a/src/core_codemods/django_json_response_type.py b/src/core_codemods/django_json_response_type.py index bc1a963b..a3337ed0 100644 --- a/src/core_codemods/django_json_response_type.py +++ b/src/core_codemods/django_json_response_type.py @@ -1,28 +1,28 @@ import libcst as cst +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod - -class DjangoJsonResponseType(SemgrepCodemod): - NAME = "django-json-response-type" - SUMMARY = "Set content type to `application/json` for `django.http.HttpResponse` with JSON data" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Sets `content_type` to `application/json`." - REFERENCES = [ - { - "url": "https://docs.djangoproject.com/en/4.0/ref/request-response/#django.http.HttpResponse.__init__", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/Cross_Site_Scripting_Prevention_Cheat_Sheet.html#output-encoding-for-javascript-contexts", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ +class DjangoJsonResponseType(SimpleCodemod): + metadata = Metadata( + name="django-json-response-type", + summary="Set content type to `application/json` for `django.http.HttpResponse` with JSON data", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.djangoproject.com/en/4.0/ref/request-response/#django.http.HttpResponse.__init__" + ), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/Cross_Site_Scripting_Prevention_Cheat_Sheet.html#output-encoding-for-javascript-contexts" + ), + ], + ) + change_description = "Sets `content_type` to `application/json`." + detector_pattern = """ rules: - id: django-json-response-type mode: taint diff --git a/src/core_codemods/django_receiver_on_top.py b/src/core_codemods/django_receiver_on_top.py index 5d65bf32..c153fb58 100644 --- a/src/core_codemods/django_receiver_on_top.py +++ b/src/core_codemods/django_receiver_on_top.py @@ -1,22 +1,24 @@ from typing import Union import libcst as cst -from codemodder.codemods.api import BaseCodemod -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class DjangoReceiverOnTop(BaseCodemod, NameResolutionMixin): - NAME = "django-receiver-on-top" - SUMMARY = "Ensure Django @receiver is the first decorator" - DESCRIPTION = SUMMARY - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - REFERENCES = [ - { - "url": "https://docs.djangoproject.com/en/4.1/topics/signals/", - "description": "", - }, - ] - CHANGE_DESCRIPTION = "Moved @receiver to the top." +class DjangoReceiverOnTop(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="django-receiver-on-top", + summary="Ensure Django @receiver is the first decorator", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference(url="https://docs.djangoproject.com/en/4.1/topics/signals/"), + ], + ) + change_description = "Moved @receiver to the top." def leave_FunctionDef( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef diff --git a/src/core_codemods/django_session_cookie_secure_off.py b/src/core_codemods/django_session_cookie_secure_off.py index bf1bf4b7..2244b580 100644 --- a/src/core_codemods/django_session_cookie_secure_off.py +++ b/src/core_codemods/django_session_cookie_secure_off.py @@ -1,28 +1,29 @@ import libcst as cst -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils import is_django_settings_file, is_assigned_to_True +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class DjangoSessionCookieSecureOff(SemgrepCodemod): - NAME = "django-session-cookie-secure-off" - DESCRIPTION = "Sets Django's `SESSION_COOKIE_SECURE` flag if off or missing." - SUMMARY = "Secure Setting for Django `SESSION_COOKIE_SECURE` flag" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - REFERENCES = [ - { - "url": "https://owasp.org/www-community/controls/SecureCookieAttribute", - "description": "", - }, - { - "url": "https://docs.djangoproject.com/en/4.2/ref/settings/#session-cookie-secure", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ +class DjangoSessionCookieSecureOff(SimpleCodemod): + metadata = Metadata( + name="django-session-cookie-secure-off", + summary="Secure Setting for Django `SESSION_COOKIE_SECURE` flag", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://owasp.org/www-community/controls/SecureCookieAttribute" + ), + Reference( + url="https://docs.djangoproject.com/en/4.2/ref/settings/#session-cookie-secure" + ), + ], + ) + change_description = "Sets Django's `SESSION_COOKIE_SECURE` flag if off or missing." + detector_pattern = """ rules: - id: django-session-cookie-secure-off # This pattern creates one finding with no text for settings.py file. @@ -32,8 +33,8 @@ def rule(cls): - settings.py """ - def __init__(self, *args): - super().__init__(*args) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.is_django_settings_file = is_django_settings_file( self.file_context.file_path ) @@ -60,7 +61,7 @@ def leave_Module( # something else and we changed it in `leave_Assign`. return updated_node - self.add_change(original_node, self.CHANGE_DESCRIPTION, start=False) + self.add_change(original_node, self.change_description, start=False) final_line = cst.parse_statement("SESSION_COOKIE_SECURE = True") new_body = updated_node.body + (final_line,) return updated_node.with_changes(body=new_body) @@ -80,7 +81,7 @@ def leave_Assign( return updated_node # SESSION_COOKIE_SECURE = anything other than True - self.add_change(original_node, self.CHANGE_DESCRIPTION) + self.add_change(original_node, self.change_description) return updated_node.with_changes(value=cst.Name("True")) return updated_node diff --git a/src/core_codemods/enable_jinja2_autoescape.py b/src/core_codemods/enable_jinja2_autoescape.py index 470aac6e..d4426ade 100644 --- a/src/core_codemods/enable_jinja2_autoescape.py +++ b/src/core_codemods/enable_jinja2_autoescape.py @@ -1,24 +1,28 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class EnableJinja2Autoescape(SemgrepCodemod): - NAME = "enable-jinja2-autoescape" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Enable Jinja2 Autoescape" - DESCRIPTION = "Sets the `autoescape` parameter in jinja2.Environment to `True`." - REFERENCES = [ - {"url": "https://owasp.org/www-community/attacks/xss/", "description": ""}, - { - "url": "https://jinja.palletsprojects.com/en/3.1.x/api/#autoescaping", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ +class EnableJinja2Autoescape(SimpleCodemod): + metadata = Metadata( + name="enable-jinja2-autoescape", + summary="Enable Jinja2 Autoescape", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference(url="https://owasp.org/www-community/attacks/xss/"), + Reference( + url="https://jinja.palletsprojects.com/en/3.1.x/api/#autoescaping" + ), + ], + ) + change_description = ( + "Sets the `autoescape` parameter in jinja2.Environment to `True`." + ) + detector_pattern = """ rules: - pattern-either: - patterns: diff --git a/src/core_codemods/exception_without_raise.py b/src/core_codemods/exception_without_raise.py index c293d0d6..a7d03d72 100644 --- a/src/core_codemods/exception_without_raise.py +++ b/src/core_codemods/exception_without_raise.py @@ -1,24 +1,28 @@ from typing import Union import libcst as cst -from codemodder.codemods.api import BaseCodemod -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin from codemodder.utils.utils import full_qualified_name_from_class, list_subclasses +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class ExceptionWithoutRaise(BaseCodemod, NameResolutionMixin): - NAME = "exception-without-raise" - SUMMARY = "Ensure bare exception statements are raised" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = SUMMARY - REFERENCES = [ - { - "url": "https://docs.python.org/3/tutorial/errors.html#raising-exceptions", - "description": "", - }, - ] - CHANGE_DESCRIPTION = "Raised bare exception statement" +class ExceptionWithoutRaise(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="exception-without-raise", + summary="Ensure bare exception statements are raised", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/tutorial/errors.html#raising-exceptions" + ), + ], + ) + change_description = "Raised bare exception statement" def leave_SimpleStatementLine( self, diff --git a/src/core_codemods/file_resource_leak.py b/src/core_codemods/file_resource_leak.py index 24992190..ee4bb4da 100644 --- a/src/core_codemods/file_resource_leak.py +++ b/src/core_codemods/file_resource_leak.py @@ -13,32 +13,29 @@ ScopeProvider, ) from codemodder.change import Change -from codemodder.codemods.base_codemod import ( - ReviewGuidance, -) -from codemodder.codemods.api import BaseCodemod from codemodder.codemods.utils import MetadataPreservingTransformer from codemodder.codemods.utils_mixin import AncestorPatternsMixin, NameResolutionMixin from codemodder.file_context import FileContext from functools import partial +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class FileResourceLeak(BaseCodemod): - NAME = "fix-file-resource-leak" - SUMMARY = "Automatically Close Resources" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = SUMMARY - REFERENCES = [ - { - "url": "https://cwe.mitre.org/data/definitions/772.html", - "description": "", - }, - { - "url": "https://cwe.mitre.org/data/definitions/404.html", - "description": "", - }, - ] - CHANGE_DESCRIPTION = "Wrapped opened resource in a with statement." +class FileResourceLeak(SimpleCodemod): + metadata = Metadata( + name="fix-file-resource-leak", + summary="Automatically Close Resources", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference(url="https://cwe.mitre.org/data/definitions/772.html"), + Reference(url="https://cwe.mitre.org/data/definitions/404.html"), + ], + ) + change_description = "Wrapped opened resource in a with statement." METADATA_DEPENDENCIES = ( PositionProvider, @@ -51,11 +48,14 @@ def __init__( context: CodemodContext, file_context: FileContext, *codemod_args, + **codemod_kwargs, ) -> None: self.changed_nodes: dict[ cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel ] = {} - BaseCodemod.__init__(self, context, file_context, *codemod_args) + SimpleCodemod.__init__( + self, context, file_context, *codemod_args, **codemod_kwargs + ) def transform_module_impl(self, tree: cst.Module) -> cst.Module: fr = FindResources(self.context) @@ -200,7 +200,7 @@ def leave_Module(self, original_node: cst.Module, updated_node) -> cst.Module: if all(name_condition): line_number = self.get_metadata(PositionProvider, resource).start.line self.changes.append( - Change(line_number, FileResourceLeak.CHANGE_DESCRIPTION) + Change(line_number, FileResourceLeak.change_description) ) last_index = self._find_last_index_with_access( named_targets, block, index diff --git a/src/core_codemods/fix_deprecated_abstractproperty.py b/src/core_codemods/fix_deprecated_abstractproperty.py index c4c0dd38..29805c58 100644 --- a/src/core_codemods/fix_deprecated_abstractproperty.py +++ b/src/core_codemods/fix_deprecated_abstractproperty.py @@ -1,20 +1,27 @@ import libcst as cst - -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class FixDeprecatedAbstractproperty(BaseCodemod, NameResolutionMixin): - NAME = "fix-deprecated-abstractproperty" - SUMMARY = "Replace deprecated abstractproperty" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Replace deprecated abstractproperty with property and abstractmethod" - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/abc.html#abc.abstractproperty", - "description": "", - }, - ] +class FixDeprecatedAbstractproperty(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="fix-deprecated-abstractproperty", + summary="Replace deprecated abstractproperty", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/abc.html#abc.abstractproperty" + ), + ], + ) + change_description = ( + "Replace deprecated abstractproperty with property and abstractmethod" + ) def leave_Decorator( self, original_node: cst.Decorator, updated_node: cst.Decorator diff --git a/src/core_codemods/fix_deprecated_logging_warn.py b/src/core_codemods/fix_deprecated_logging_warn.py index 6025d247..291d6fcc 100644 --- a/src/core_codemods/fix_deprecated_logging_warn.py +++ b/src/core_codemods/fix_deprecated_logging_warn.py @@ -1,24 +1,27 @@ import libcst as cst -from codemodder.codemods.api import SemgrepCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class FixDeprecatedLoggingWarn(SemgrepCodemod, NameResolutionMixin): - NAME = "fix-deprecated-logging-warn" - SUMMARY = "Replace Deprecated `logging.warn`" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Replace deprecated `logging.warn` with `logging.warning`" - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/logging.html#logging.Logger.warning", - "description": "", - }, - ] +class FixDeprecatedLoggingWarn(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="fix-deprecated-logging-warn", + summary="Replace Deprecated `logging.warn`", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/logging.html#logging.Logger.warning" + ), + ], + ) + change_description = "Replace deprecated `logging.warn` with `logging.warning`" _module_name = "logging" - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - pattern-either: - patterns: diff --git a/src/core_codemods/fix_mutable_params.py b/src/core_codemods/fix_mutable_params.py index ae991131..f7051b87 100644 --- a/src/core_codemods/fix_mutable_params.py +++ b/src/core_codemods/fix_mutable_params.py @@ -1,15 +1,16 @@ import libcst as cst from libcst import matchers as m +from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import BaseCodemod - -class FixMutableParams(BaseCodemod): - NAME = "fix-mutable-params" - SUMMARY = "Replace Mutable Default Parameters" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Replace mutable parameter with `None`." +class FixMutableParams(SimpleCodemod): + metadata = Metadata( + name="fix-mutable-params", + summary="Replace Mutable Default Parameters", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[], + ) + change_description = "Replace mutable parameter with `None`." REFERENCES: list = [] _BUILTIN_TO_LITERAL = { "list": cst.List(elements=[]), @@ -168,7 +169,7 @@ def leave_FunctionDef( ) if new_var_decls: # If we're adding statements to the body, we know a change took place - self.add_change(original_node, self.CHANGE_DESCRIPTION) + self.add_change(original_node, self.change_description) if add_annotation: self.add_needed_import("typing", "Optional") diff --git a/src/core_codemods/flask_enable_csrf_protection.py b/src/core_codemods/flask_enable_csrf_protection.py index 1f6148cb..5638daf6 100644 --- a/src/core_codemods/flask_enable_csrf_protection.py +++ b/src/core_codemods/flask_enable_csrf_protection.py @@ -1,22 +1,28 @@ from typing import Optional, Union + import libcst as cst -from codemodder.codemods.api import BaseCodemod -from codemodder.codemods.base_codemod import ReviewGuidance + +from core_codemods.api import SimpleCodemod, Metadata, Reference, ReviewGuidance from codemodder.codemods.utils_mixin import AncestorPatternsMixin, NameResolutionMixin from codemodder.dependency import FlaskWTF class FlaskEnableCSRFProtection( - BaseCodemod, NameResolutionMixin, AncestorPatternsMixin + SimpleCodemod, + NameResolutionMixin, + AncestorPatternsMixin, ): - NAME = "flask-enable-csrf-protection" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_REVIEW - DESCRIPTION = "Uses CSRFProtect module to harden the app." - SUMMARY = "Enable CSRF protection globally for a Flask app." - REFERENCES = [ - {"url": "https://owasp.org/www-community/attacks/csrf", "description": ""}, - {"url": "https://flask-wtf.readthedocs.io/en/1.2.x/csrf/", "description": ""}, - ] + metadata = Metadata( + name="flask-enable-csrf-protection", + summary="Enable CSRF protection globally for a Flask app.", + review_guidance=ReviewGuidance.MERGE_AFTER_REVIEW, + references=[ + Reference(url="https://owasp.org/www-community/attacks/csrf"), + Reference(url="https://flask-wtf.readthedocs.io/en/1.2.x/csrf/"), + ], + ) + + change_description = "Add CSRFProtect module to harden the app" def leave_SimpleStatementSuite( self, diff --git a/src/core_codemods/flask_json_response_type.py b/src/core_codemods/flask_json_response_type.py index 959e5674..273ab7c1 100644 --- a/src/core_codemods/flask_json_response_type.py +++ b/src/core_codemods/flask_json_response_type.py @@ -1,28 +1,31 @@ from typing import Optional, Tuple import libcst as cst from libcst.codemod import CodemodContext, ContextAwareVisitor -from codemodder.codemods.api import BaseCodemod - -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin - - -class FlaskJsonResponseType(BaseCodemod, NameAndAncestorResolutionMixin): - NAME = "flask-json-response-type" - SUMMARY = "Set content type to `application/json` for `flask.make_response` with JSON data" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Sets `mimetype` to `application/json`." - REFERENCES = [ - { - "url": "https://flask.palletsprojects.com/en/2.3.x/patterns/javascript/#return-json-from-views", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/Cross_Site_Scripting_Prevention_Cheat_Sheet.html#output-encoding-for-javascript-contexts", - "description": "", - }, - ] +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) + + +class FlaskJsonResponseType(SimpleCodemod, NameAndAncestorResolutionMixin): + metadata = Metadata( + name="flask-json-response-type", + summary="Set content type to `application/json` for `flask.make_response` with JSON data", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://flask.palletsprojects.com/en/2.3.x/patterns/javascript/#return-json-from-views" + ), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/Cross_Site_Scripting_Prevention_Cheat_Sheet.html#output-encoding-for-javascript-contexts" + ), + ], + ) + change_description = "Sets `mimetype` to `application/json`." content_type_key = "Content-Type" json_content_type = "application/json" diff --git a/src/core_codemods/harden_pyyaml.py b/src/core_codemods/harden_pyyaml.py index 1387ea78..2dac634a 100644 --- a/src/core_codemods/harden_pyyaml.py +++ b/src/core_codemods/harden_pyyaml.py @@ -1,70 +1,71 @@ from typing import Union import libcst as cst -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class HardenPyyaml(SemgrepCodemod, NameResolutionMixin): - NAME = "harden-pyyaml" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Replace unsafe `pyyaml` loader with `SafeLoader`" - DESCRIPTION = "Replace unsafe `pyyaml` loader with `SafeLoader` in calls to `yaml.load` or custom loader classes." - REFERENCES = [ - { - "url": "https://owasp.org/www-community/vulnerabilities/Deserialization_of_untrusted_data", - "description": "", - } - ] +class HardenPyyaml(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="harden-pyyaml", + summary="Replace unsafe `pyyaml` loader with `SafeLoader`", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://owasp.org/www-community/vulnerabilities/Deserialization_of_untrusted_data" + ), + ], + ) + change_description = "Replace unsafe `pyyaml` loader with `SafeLoader` in calls to `yaml.load` or custom loader classes." _module_name = "yaml" - - @classmethod - def rule(cls): - return """ - rules: - - pattern-either: - - patterns: - - pattern: yaml.load(...) - - pattern-inside: | - import yaml - ... - yaml.load(...,$ARG) - - metavariable-pattern: - metavariable: $ARG - patterns: - - pattern-either: - - pattern: yaml.Loader - - pattern: yaml.BaseLoader - - pattern: yaml.FullLoader - - pattern: yaml.UnsafeLoader - - patterns: - - pattern: yaml.load(...) - - pattern-inside: | - import yaml - ... - yaml.load(...,Loader=$ARG) - - metavariable-pattern: - metavariable: $ARG - patterns: - - pattern-either: - - pattern: yaml.Loader - - pattern: yaml.BaseLoader - - pattern: yaml.FullLoader - - pattern: yaml.UnsafeLoader - - patterns: - - pattern: | - class $X(...,$LOADER, ...): - ... - - metavariable-pattern: - metavariable: $LOADER - patterns: - - pattern-either: - - pattern: yaml.Loader - - pattern: yaml.BaseLoader - - pattern: yaml.FullLoader - - pattern: yaml.UnsafeLoader - + detector_pattern = """ + rules: + - pattern-either: + - patterns: + - pattern: yaml.load(...) + - pattern-inside: | + import yaml + ... + yaml.load(...,$ARG) + - metavariable-pattern: + metavariable: $ARG + patterns: + - pattern-either: + - pattern: yaml.Loader + - pattern: yaml.BaseLoader + - pattern: yaml.FullLoader + - pattern: yaml.UnsafeLoader + - patterns: + - pattern: yaml.load(...) + - pattern-inside: | + import yaml + ... + yaml.load(...,Loader=$ARG) + - metavariable-pattern: + metavariable: $ARG + patterns: + - pattern-either: + - pattern: yaml.Loader + - pattern: yaml.BaseLoader + - pattern: yaml.FullLoader + - pattern: yaml.UnsafeLoader + - patterns: + - pattern: | + class $X(...,$LOADER, ...): + ... + - metavariable-pattern: + metavariable: $LOADER + patterns: + - pattern-either: + - pattern: yaml.Loader + - pattern: yaml.BaseLoader + - pattern: yaml.FullLoader + - pattern: yaml.UnsafeLoader """ def on_result_found( diff --git a/src/core_codemods/harden_ruamel.py b/src/core_codemods/harden_ruamel.py index 8eb0a12b..44fd7342 100644 --- a/src/core_codemods/harden_ruamel.py +++ b/src/core_codemods/harden_ruamel.py @@ -1,23 +1,27 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class HardenRuamel(SemgrepCodemod): - NAME = "harden-ruamel" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Use `typ='safe'` in ruamel.yaml() Calls" - DESCRIPTION = "Ensures all unsafe calls to ruamel.yaml.YAML use `typ='safe'`." - REFERENCES = [ - { - "url": "https://owasp.org/www-community/vulnerabilities/Deserialization_of_untrusted_data", - "description": "", - } - ] - - @classmethod - def rule(cls): - return """ +class HardenRuamel(SimpleCodemod): + metadata = Metadata( + name="harden-ruamel", + summary="Use `typ='safe'` in ruamel.yaml() Calls", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://owasp.org/www-community/vulnerabilities/Deserialization_of_untrusted_data" + ), + ], + ) + change_description = ( + "Ensures all unsafe calls to ruamel.yaml.YAML use `typ='safe'`." + ) + detector_pattern = """ rules: - pattern-either: - patterns: diff --git a/src/core_codemods/https_connection.py b/src/core_codemods/https_connection.py index 1a011a16..a1b2eeb1 100644 --- a/src/core_codemods/https_connection.py +++ b/src/core_codemods/https_connection.py @@ -3,10 +3,13 @@ import libcst as cst from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor from libcst.metadata import PositionProvider - -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import BaseCodemod from codemodder.codemods.imported_call_modifier import ImportedCallModifier +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) class HTTPSConnectionModifier(ImportedCallModifier[Set[str]]): @@ -48,20 +51,20 @@ def count_positional_args(self, arglist: Sequence[cst.Arg]) -> int: return len(arglist) -class HTTPSConnection(BaseCodemod): - SUMMARY = "Enforce HTTPS Connection for `urllib3`" - NAME = "https-connection" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - REFERENCES = [ - { - "url": "https://owasp.org/www-community/vulnerabilities/Insecure_Transport", - "description": "", - }, - { - "url": "https://urllib3.readthedocs.io/en/stable/reference/urllib3.connectionpool.html#urllib3.HTTPConnectionPool", - "description": "", - }, - ] +class HTTPSConnection(SimpleCodemod): + metadata = Metadata( + name="https-connection", + summary="Enforce HTTPS Connection for `urllib3`", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://owasp.org/www-community/vulnerabilities/Insecure_Transport" + ), + Reference( + url="https://urllib3.readthedocs.io/en/stable/reference/urllib3.connectionpool.html#urllib3.HTTPConnectionPool" + ), + ], + ) METADATA_DEPENDENCIES = (PositionProvider,) @@ -75,7 +78,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.context, self.file_context, self.matching_functions, - self.CHANGE_DESCRIPTION, + self.change_description, ) result_tree = visitor.transform_module(tree) self.file_context.codemod_changes.extend(visitor.changes_in_file) diff --git a/src/core_codemods/jwt_decode_verify.py b/src/core_codemods/jwt_decode_verify.py index 343ac94a..604a5c82 100644 --- a/src/core_codemods/jwt_decode_verify.py +++ b/src/core_codemods/jwt_decode_verify.py @@ -1,26 +1,28 @@ import libcst as cst from libcst import matchers -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class JwtDecodeVerify(SemgrepCodemod): - NAME = "jwt-decode-verify" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Verify JWT Decode" - DESCRIPTION = "Enable all verifications in `jwt.decode` call." - REFERENCES = [ - {"url": "https://pyjwt.readthedocs.io/en/stable/api.html", "description": ""}, - { - "url": "https://owasp.org/www-project-web-security-testing-guide/latest/4-Web_Application_Security_Testing/06-Session_Management_Testing/10-Testing_JSON_Web_Tokens", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return r""" +class JwtDecodeVerify(SimpleCodemod): + metadata = Metadata( + name="jwt-decode-verify", + summary="Verify JWT Decode", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference(url="https://pyjwt.readthedocs.io/en/stable/api.html"), + Reference( + url="https://owasp.org/www-project-web-security-testing-guide/latest/4-Web_Application_Security_Testing/06-Session_Management_Testing/10-Testing_JSON_Web_Tokens" + ), + ], + ) + change_description = "Enable all verifications in `jwt.decode` call." + detector_pattern = r""" rules: - pattern-either: - patterns: diff --git a/src/core_codemods/limit_readline.py b/src/core_codemods/limit_readline.py index 694780b0..81f52734 100644 --- a/src/core_codemods/limit_readline.py +++ b/src/core_codemods/limit_readline.py @@ -1,23 +1,26 @@ import libcst as cst -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) default_limit = "5_000_000" -class LimitReadline(SemgrepCodemod): - NAME = "limit-readline" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - SUMMARY = "Limit readline()" - DESCRIPTION = "Adds a size limit argument to readline() calls." - REFERENCES = [ - {"url": "https://cwe.mitre.org/data/definitions/400.html", "description": ""} - ] - - @classmethod - def rule(cls): - return """ +class LimitReadline(SimpleCodemod): + metadata = Metadata( + name="limit-readline", + summary="Limit readline()", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference(url="https://cwe.mitre.org/data/definitions/400.html"), + ], + ) + change_description = "Adds a size limit argument to readline() calls." + detector_pattern = """ rules: - id: limit-readline mode: taint diff --git a/src/core_codemods/literal_or_new_object_identity.py b/src/core_codemods/literal_or_new_object_identity.py index e6d9f73b..32f907aa 100644 --- a/src/core_codemods/literal_or_new_object_identity.py +++ b/src/core_codemods/literal_or_new_object_identity.py @@ -1,22 +1,26 @@ import libcst as cst -from codemodder.codemods.api import BaseCodemod -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class LiteralOrNewObjectIdentity(BaseCodemod, NameAndAncestorResolutionMixin): - NAME = "literal-or-new-object-identity" - SUMMARY = "Replaces is operator with == for literal or new object comparisons" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = SUMMARY - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/stdtypes.html#comparisons", - "description": "", - }, - ] - CHANGE_DESCRIPTION = "Replaces is operator with ==" +class LiteralOrNewObjectIdentity(SimpleCodemod, NameAndAncestorResolutionMixin): + metadata = Metadata( + name="literal-or-new-object-identity", + summary="Replaces is operator with == for literal or new object comparisons", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/stdtypes.html#comparisons" + ), + ], + ) + change_description = "Replaces is operator with ==" def _is_object_creation_or_literal(self, node: cst.BaseExpression): match node: diff --git a/src/core_codemods/lxml_safe_parser_defaults.py b/src/core_codemods/lxml_safe_parser_defaults.py index 20b47b84..d9445040 100644 --- a/src/core_codemods/lxml_safe_parser_defaults.py +++ b/src/core_codemods/lxml_safe_parser_defaults.py @@ -1,31 +1,31 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class LxmlSafeParserDefaults(SemgrepCodemod): - NAME = "safe-lxml-parser-defaults" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Use Safe Defaults for `lxml` Parsers" - DESCRIPTION = "Replace `lxml` parser parameters with safe defaults." - REFERENCES = [ - { - "url": "https://lxml.de/apidoc/lxml.etree.html#lxml.etree.XMLParser", - "description": "", - }, - { - "url": "https://owasp.org/www-community/vulnerabilities/XML_External_Entity_(XXE)_Processing", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ +class LxmlSafeParserDefaults(SimpleCodemod): + metadata = Metadata( + name="safe-lxml-parser-defaults", + summary="Use Safe Defaults for `lxml` Parsers", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://lxml.de/apidoc/lxml.etree.html#lxml.etree.XMLParser" + ), + Reference( + url="https://owasp.org/www-community/vulnerabilities/XML_External_Entity_(XXE)_Processing" + ), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html" + ), + ], + ) + change_description = "Replace `lxml` parser parameters with safe defaults." + detector_pattern = """ rules: - patterns: - pattern: lxml.etree.$CLASS(...) diff --git a/src/core_codemods/lxml_safe_parsing.py b/src/core_codemods/lxml_safe_parsing.py index a7a98101..aa5c9b26 100644 --- a/src/core_codemods/lxml_safe_parsing.py +++ b/src/core_codemods/lxml_safe_parsing.py @@ -1,33 +1,33 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class LxmlSafeParsing(SemgrepCodemod): - NAME = "safe-lxml-parsing" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Use Safe Parsers in `lxml` Parsing Functions" - DESCRIPTION = ( +class LxmlSafeParsing(SimpleCodemod): + metadata = Metadata( + name="safe-lxml-parsing", + summary="Use Safe Parsers in `lxml` Parsing Functions", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://lxml.de/apidoc/lxml.etree.html#lxml.etree.XMLParser" + ), + Reference( + url="https://owasp.org/www-community/vulnerabilities/XML_External_Entity_(XXE)_Processing" + ), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html" + ), + ], + ) + change_description = ( "Call `lxml.etree.parse` and `lxml.etree.fromstring` with a safe parser." ) - REFERENCES = [ - { - "url": "https://lxml.de/apidoc/lxml.etree.html#lxml.etree.XMLParser", - "description": "", - }, - { - "url": "https://owasp.org/www-community/vulnerabilities/XML_External_Entity_(XXE)_Processing", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - pattern-either: - patterns: diff --git a/src/core_codemods/numpy_nan_equality.py b/src/core_codemods/numpy_nan_equality.py index 912439ec..e63016c6 100644 --- a/src/core_codemods/numpy_nan_equality.py +++ b/src/core_codemods/numpy_nan_equality.py @@ -1,23 +1,27 @@ import libcst as cst from libcst import UnaryOperation -from codemodder.codemods.api import BaseCodemod -from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class NumpyNanEquality(BaseCodemod, NameResolutionMixin): - NAME = "numpy-nan-equality" - SUMMARY = "Replace == comparison with numpy.isnan()" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = SUMMARY - REFERENCES = [ - { - "url": "https://numpy.org/doc/stable/reference/constants.html#numpy.nan", - "description": "", - }, - ] - CHANGE_DESCRIPTION = "Replaces == check with numpy.isnan()." +class NumpyNanEquality(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="numpy-nan-equality", + summary="Replace == comparison with numpy.isnan()", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://numpy.org/doc/stable/reference/constants.html#numpy.nan" + ), + ], + ) + change_description = "Replaces == check with numpy.isnan()." np_nan = "numpy.nan" diff --git a/src/core_codemods/order_imports.py b/src/core_codemods/order_imports.py index 17ecd232..f6c49718 100644 --- a/src/core_codemods/order_imports.py +++ b/src/core_codemods/order_imports.py @@ -1,34 +1,25 @@ +import libcst as cst from libcst.metadata import PositionProvider -from codemodder.codemods.base_codemod import ( - BaseCodemod, - CodemodMetadata, - ReviewGuidance, -) + +from core_codemods.api import SimpleCodemod, Metadata, ReviewGuidance from codemodder.change import Change from codemodder.codemods.transformations.clean_imports import ( GatherTopLevelImportBlocks, OrderImportsBlocksTransform, ) -import libcst as cst -from libcst.codemod import Codemod, CodemodContext -class OrderImports(BaseCodemod, Codemod): - METADATA = CodemodMetadata( - DESCRIPTION=("Formats and orders imports by categories."), - NAME="order-imports", - REVIEW_GUIDANCE=ReviewGuidance.MERGE_WITHOUT_REVIEW, - REFERENCES=[], +class OrderImports(SimpleCodemod): + metadata = Metadata( + name="order-imports", + summary="Order Imports", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + has_description=False, ) - SUMMARY = "Order Imports" - CHANGE_DESCRIPTION = "Ordered and formatted import block below this line" + change_description = "Ordered and formatted import block below this line" METADATA_DEPENDENCIES = (PositionProvider,) - def __init__(self, codemod_context: CodemodContext, *codemod_args): - Codemod.__init__(self, codemod_context) - BaseCodemod.__init__(self, *codemod_args) - def transform_module_impl(self, tree: cst.Module) -> cst.Module: top_imports_visitor = GatherTopLevelImportBlocks() tree.visit(top_imports_visitor) @@ -54,7 +45,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: top_imports_visitor.top_imports_blocks[i][0] ).start.line self.file_context.codemod_changes.append( - Change(line_number, self.CHANGE_DESCRIPTION) + Change(line_number, self.change_description) ) return result_tree return tree diff --git a/src/core_codemods/process_creation_sandbox.py b/src/core_codemods/process_creation_sandbox.py index e3a9339a..f9b6068a 100644 --- a/src/core_codemods/process_creation_sandbox.py +++ b/src/core_codemods/process_creation_sandbox.py @@ -1,32 +1,33 @@ import libcst as cst -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod from codemodder.dependency import Security +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class ProcessSandbox(SemgrepCodemod): - NAME = "sandbox-process-creation" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - SUMMARY = "Sandbox Process Creation" - DESCRIPTION = ( +class ProcessSandbox(SimpleCodemod): + metadata = Metadata( + name="sandbox-process-creation", + summary="Sandbox Process Creation", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://github.com/pixee/python-security/blob/main/src/security/safe_command/api.py" + ), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/OS_Command_Injection_Defense_Cheat_Sheet.html" + ), + ], + ) + change_description = ( "Replaces subprocess.{func} with more secure safe_command library functions." ) - REFERENCES = [ - { - "url": "https://github.com/pixee/python-security/blob/main/src/security/safe_command/api.py", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/OS_Command_Injection_Defense_Cheat_Sheet.html", - "description": "", - }, - ] adds_dependency = True - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - pattern-either: - patterns: diff --git a/src/core_codemods/semgrep/__init__.py b/src/core_codemods/refactor/__init__.py similarity index 100% rename from src/core_codemods/semgrep/__init__.py rename to src/core_codemods/refactor/__init__.py diff --git a/src/core_codemods/refactor/refactor_new_api.py b/src/core_codemods/refactor/refactor_new_api.py new file mode 100644 index 00000000..02453bbf --- /dev/null +++ b/src/core_codemods/refactor/refactor_new_api.py @@ -0,0 +1,297 @@ +import libcst as cst + +from core_codemods.api import SimpleCodemod, Metadata, ReviewGuidance +from codemodder.codemods.utils_mixin import NameResolutionMixin + + +class RefactorNewApi(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="refactor-new-api", + summary="Refactor to use thew new simplified API", + review_guidance=ReviewGuidance.MERGE_AFTER_REVIEW, + has_description=False, + ) + + new_api_module = "codemodder.codemods.new_api" + new_api_class = "SimpleCodemod" + + no_whitespace_assign = cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ) + + def _build_metadata(self, metadata: dict) -> tuple[cst.SimpleStatementLine, bool]: + refs = ( + build_references(metadata["references"]) + if "references" in metadata + else cst.List(elements=[]) + ) + + return ( + cst.SimpleStatementLine( + body=[ + cst.Assign( + targets=[cst.AssignTarget(target=cst.Name(value="metadata"))], + value=cst.Call( + func=cst.Name(value="Metadata"), + args=[ + make_metadata_arg(metadata, "name"), + make_metadata_arg(metadata, "summary"), + make_metadata_arg(metadata, "review_guidance"), + make_metadata_arg( + metadata, "references", value=refs, last=True + ), + ], + whitespace_before_args=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(" "), + ), + ), + ) + ] + ), + bool(refs.elements), + ) + + def find_metadata( + self, assign: cst.Assign + ) -> tuple[str, cst.BaseExpression] | None: + match assign: + case cst.Assign( + targets=[cst.AssignTarget(target=cst.Name(value=name))], + value=value, + ): + match name: + case "NAME" as name: + return name, value + case "SUMMARY" as name: + return name, value + case "DESCRIPTION" as name: + return name, value + case "REVIEW_GUIDANCE" as name: + return name, value + case "REFERENCES" as name: + return name, value + case "CHANGE_DESCRIPTION" | "change_description" as name: + return name, value + + return None + + def create_rule(self, body: cst.BaseSuite) -> cst.SimpleStatementLine: + match body: + case cst.IndentedBlock( + body=[ + cst.SimpleStatementLine( + body=[cst.Return(value=cst.SimpleString() as value)] + ) + ] + ): + return cst.SimpleStatementLine( + body=[ + cst.Assign( + targets=[ + cst.AssignTarget( + target=cst.Name(value="detector_pattern") + ) + ], + value=value, + ) + ] + ) + + raise ValueError("Could not find detector pattern for codemod") + + def create_change_description( + self, metadata: dict + ) -> cst.SimpleStatementLine | None: + if ( + "description" in metadata + and "change_description" not in metadata + and "CHANGE_DESCRIPTION" not in metadata + ): + match metadata["description"]: + case cst.SimpleString() as description: + return cst.SimpleStatementLine( + body=[ + cst.Assign( + targets=[ + cst.AssignTarget( + target=cst.Name(value="change_description") + ) + ], + value=description, + ) + ] + ) + + return None + + def leave_Assert(self, original: cst.Assert, updated: cst.Assert): + match original: + case cst.Assert( + test=cst.Comparison( + left=cst.Call( + func=cst.Name(value="len"), + args=[ + cst.Arg( + value=cst.Attribute( + value=cst.Attribute( + value=cst.Name(value="self"), + attr=cst.Name(value="file_context"), + ), + attr=cst.Name(value="codemod_changes"), + ) + ) + ], + ) + ) + ): + return cst.RemoveFromParent() + return updated + + def leave_Name(self, original: cst.Name, updated: cst.Name) -> cst.Name: + if original.value == "CHANGE_DESCRIPTION": + return updated.with_changes(value="change_description") + + return updated + + def leave_ImportFrom(self, original: cst.ImportFrom, updated: cst.ImportFrom): + match original: + case cst.ImportFrom( + module=cst.Attribute( + value=cst.Attribute( + value=cst.Name(value="codemodder"), + attr=cst.Name(value="codemods"), + ), + attr=cst.Name(value="api"), + ) + ): + return cst.RemoveFromParent() + + return updated + + def leave_ClassDef(self, original: cst.ClassDef, new: cst.ClassDef) -> cst.ClassDef: + new_bases: list[cst.Arg] = [ + base.with_changes(value=cst.Name(self.new_api_class)) + if self.find_base_name(base.value) + in ( + "codemodder.codemods.api.BaseCodemod", + "codemodder.codemods.api.SemgrepCodemod", + ) + else base + for base in original.bases + ] + + if all(base.value.value != self.new_api_class for base in new_bases): + return new + + self.add_needed_import(self.new_api_module, obj=self.new_api_class) + + metadata = {} + new_body = [] + for stmt in new.body.body: + match stmt: + case cst.SimpleStatementLine(body=(cst.Assign() as assign,)): + if result := self.find_metadata(assign): + key, value = result + if key == "change_description": + new_body.append(stmt) + continue + + metadata[key.lower()] = value + continue + case cst.FunctionDef( + name=cst.Name(value="rule"), + body=body, + ): + new_body.append(self.create_rule(body)) + continue + + new_body.append(stmt) + + if not metadata or "name" not in metadata: + return new + + new_metadata, has_references = self._build_metadata(metadata) + new_body.insert(0, new_metadata) + + if change_description := self.create_change_description(metadata): + new_body.insert(1, change_description) + + self.add_needed_import(self.new_api_module, obj="SimpleCodemod") + self.add_needed_import(self.new_api_module, obj="Metadata") + if has_references: + self.add_needed_import(self.new_api_module, obj="Reference") + self.add_needed_import(self.new_api_module, obj="ReviewGuidance") + + self.remove_unused_import(original) + + new_body = new.body.with_changes(body=new_body) + return new.with_changes(bases=new_bases, body=new_body) + + +def make_metadata_arg( + metadata: dict, + metadata_key: str, + value: cst.BaseExpression | None = None, + last: bool = False, +) -> cst.Arg: + return cst.Arg( + value=value or metadata[metadata_key], + keyword=cst.Name(value=metadata_key), + equal=RefactorNewApi.no_whitespace_assign, + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace("" if last else " "), + ), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace("")), + ) + + +def build_references(old_refs: cst.List) -> cst.List: + new_refs: list[cst.Call] = [] + for ref in old_refs.elements: + match ref: + case cst.Element(value=cst.Dict(elements=elements)): + args = [ + cst.Arg( + keyword=cst.Name(value=elm.key.raw_value), + value=elm.value, + equal=RefactorNewApi.no_whitespace_assign, + ) + for elm in elements + if elm.value.raw_value + ] + new_refs.append( + cst.Call( + func=cst.Name(value="Reference"), + args=args, + ) + ) + + return cst.List( + elements=[ + cst.Element( + value=ref, + comma=cst.Comma( + whitespace_after=cst.ParenthesizedWhitespace( + indent=True, + first_line=cst.TrailingWhitespace(newline=cst.Newline()), + ) + ), + ) + for ref in new_refs + ], + lbracket=cst.LeftSquareBracket( + whitespace_after=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(" " * 8), + ), + ), + rbracket=cst.RightSquareBracket( + whitespace_before=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(" " * 4), + ), + ), + ) diff --git a/src/core_codemods/remove_debug_breakpoint.py b/src/core_codemods/remove_debug_breakpoint.py index 27e9eeae..e31ccab6 100644 --- a/src/core_codemods/remove_debug_breakpoint.py +++ b/src/core_codemods/remove_debug_breakpoint.py @@ -1,14 +1,17 @@ import libcst as cst from typing import Union -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin, AncestorPatternsMixin +from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod -class RemoveDebugBreakpoint(BaseCodemod, NameResolutionMixin, AncestorPatternsMixin): - NAME = "remove-debug-breakpoint" - SUMMARY = "Remove Calls to `builtin` `breakpoint` and `pdb.set_trace" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Remove breakpoint call" +class RemoveDebugBreakpoint(SimpleCodemod, NameResolutionMixin, AncestorPatternsMixin): + metadata = Metadata( + name="remove-debug-breakpoint", + summary="Remove Calls to `builtin` `breakpoint` and `pdb.set_trace", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[], + ) + change_description = "Remove breakpoint call" REFERENCES: list = [] def leave_Expr( diff --git a/src/core_codemods/remove_future_imports.py b/src/core_codemods/remove_future_imports.py index 1823bcf1..d126c30e 100644 --- a/src/core_codemods/remove_future_imports.py +++ b/src/core_codemods/remove_future_imports.py @@ -1,6 +1,10 @@ import libcst as cst - -from codemodder.codemods.api import BaseCodemod, ReviewGuidance +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) DEPRECATED_NAMES = [ @@ -18,17 +22,16 @@ ] -class RemoveFutureImports(BaseCodemod): - NAME = "remove-future-imports" - SUMMARY = "Remove deprecated `__future__` imports" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Remove deprecated `__future__` imports" - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/__future__.html", - "description": "", - }, - ] +class RemoveFutureImports(SimpleCodemod): + metadata = Metadata( + name="remove-future-imports", + summary="Remove deprecated `__future__` imports", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference(url="https://docs.python.org/3/library/__future__.html"), + ], + ) + change_description = "Remove deprecated `__future__` imports" def leave_ImportFrom( self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom @@ -41,7 +44,7 @@ def leave_ImportFrom( cst.ImportAlias(name=cst.Name(value=name)) for name in CURRENT_NAMES ] - self.add_change(original_node, self.CHANGE_DESCRIPTION) + self.add_change(original_node, self.change_description) return original_node.with_changes(names=names) updated_names: list[cst.ImportAlias] = [ @@ -49,7 +52,7 @@ def leave_ImportFrom( for name in original_node.names if name.name.value not in DEPRECATED_NAMES ] - self.add_change(original_node, self.CHANGE_DESCRIPTION) + self.add_change(original_node, self.change_description) return ( updated_node.with_changes(names=updated_names) if updated_names diff --git a/src/core_codemods/remove_module_global.py b/src/core_codemods/remove_module_global.py index b0f1d28b..0e923780 100644 --- a/src/core_codemods/remove_module_global.py +++ b/src/core_codemods/remove_module_global.py @@ -1,15 +1,18 @@ import libcst as cst from libcst.metadata import GlobalScope, ScopeProvider from typing import Union -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod -class RemoveModuleGlobal(BaseCodemod, NameResolutionMixin): - NAME = "remove-module-global" - SUMMARY = "Remove `global` Usage at Module Level" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Remove `global` usage at module level." +class RemoveModuleGlobal(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="remove-module-global", + summary="Remove `global` Usage at Module Level", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[], + ) + change_description = "Remove `global` usage at module level." REFERENCES: list = [] def leave_Global( diff --git a/src/core_codemods/remove_unnecessary_f_str.py b/src/core_codemods/remove_unnecessary_f_str.py index d9998a7b..1d3003b9 100644 --- a/src/core_codemods/remove_unnecessary_f_str.py +++ b/src/core_codemods/remove_unnecessary_f_str.py @@ -4,29 +4,35 @@ ) from libcst.codemod.commands.unnecessary_format_string import UnnecessaryFormatString import libcst.matchers as m -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import BaseCodemod + +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class RemoveUnnecessaryFStr(BaseCodemod, UnnecessaryFormatString): - NAME = "remove-unnecessary-f-str" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Remove Unnecessary F-strings" - DESCRIPTION = UnnecessaryFormatString.DESCRIPTION - REFERENCES = [ - { - "url": "https://pylint.readthedocs.io/en/latest/user_guide/messages/warning/f-string-without-interpolation.html", - "description": "", - }, - { - "url": "https://github.com/Instagram/LibCST/blob/main/libcst/codemod/commands/unnecessary_format_string.py", - "description": "", - }, - ] +class RemoveUnnecessaryFStr(SimpleCodemod, UnnecessaryFormatString): + metadata = Metadata( + name="remove-unnecessary-f-str", + summary="Remove Unnecessary F-strings", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://pylint.readthedocs.io/en/latest/user_guide/messages/warning/f-string-without-interpolation.html" + ), + Reference( + url="https://github.com/Instagram/LibCST/blob/main/libcst/codemod/commands/unnecessary_format_string.py" + ), + ], + ) - def __init__(self, codemod_context: CodemodContext, *codemod_args): + def __init__( + self, codemod_context: CodemodContext, *codemod_args, **codemod_kwargs + ): UnnecessaryFormatString.__init__(self, codemod_context) - BaseCodemod.__init__(self, codemod_context, *codemod_args) + SimpleCodemod.__init__(self, codemod_context, *codemod_args, **codemod_kwargs) @m.leave(m.FormattedString(parts=(m.FormattedStringText(),))) def _check_formatted_string( diff --git a/src/core_codemods/remove_unused_imports.py b/src/core_codemods/remove_unused_imports.py index 1b54b1a1..f83b2842 100644 --- a/src/core_codemods/remove_unused_imports.py +++ b/src/core_codemods/remove_unused_imports.py @@ -1,3 +1,6 @@ +import re + +import libcst as cst from libcst import CSTVisitor, ensure_type, matchers from libcst.codemod.visitors import GatherUnusedImportsVisitor from libcst.metadata import ( @@ -6,32 +9,25 @@ ScopeProvider, ParentNodeProvider, ) -from codemodder.codemods.base_codemod import ( - BaseCodemod, - CodemodMetadata, - ReviewGuidance, -) + +from pylint.utils.pragma_parser import parse_pragma + +from core_codemods.api import SimpleCodemod, Metadata, ReviewGuidance from codemodder.change import Change from codemodder.codemods.transformations.remove_unused_imports import ( RemoveUnusedImportsTransformer, ) -import libcst as cst -from libcst.codemod import Codemod, CodemodContext -import re -from pylint.utils.pragma_parser import parse_pragma NOQA_PATTERN = re.compile(r"^#\s*noqa", re.IGNORECASE) -class RemoveUnusedImports(BaseCodemod, Codemod): - METADATA = CodemodMetadata( - DESCRIPTION=("Remove unused imports from a module."), - NAME="unused-imports", - REVIEW_GUIDANCE=ReviewGuidance.MERGE_WITHOUT_REVIEW, - REFERENCES=[], +class RemoveUnusedImports(SimpleCodemod): + metadata = Metadata( + name="unused-imports", + summary="Remove Unused Imports", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, ) - SUMMARY = "Remove Unused Imports" - CHANGE_DESCRIPTION = "Unused import." + change_description = "Unused import." METADATA_DEPENDENCIES = ( PositionProvider, @@ -40,10 +36,6 @@ class RemoveUnusedImports(BaseCodemod, Codemod): ParentNodeProvider, ) - def __init__(self, codemod_context: CodemodContext, *codemod_args): - Codemod.__init__(self, codemod_context) - BaseCodemod.__init__(self, *codemod_args) - def transform_module_impl(self, tree: cst.Module) -> cst.Module: # Do nothing in __init__.py files if self.file_context.file_path.name == "__init__.py": @@ -57,7 +49,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: if self.filter_by_path_includes_or_excludes(pos): if not self._is_disabled_by_linter(importt): self.file_context.codemod_changes.append( - Change(pos.start.line, self.CHANGE_DESCRIPTION) + Change(pos.start.line, self.change_description) ) filtered_unused_imports.add((import_alias, importt)) return tree.visit(RemoveUnusedImportsTransformer(filtered_unused_imports)) diff --git a/src/core_codemods/replace_flask_send_file.py b/src/core_codemods/replace_flask_send_file.py index 1318901f..8737347a 100644 --- a/src/core_codemods/replace_flask_send_file.py +++ b/src/core_codemods/replace_flask_send_file.py @@ -1,27 +1,25 @@ -import libcst as cst from typing import Optional -from codemodder.codemods.api import BaseCodemod -from codemodder.codemods.base_codemod import ReviewGuidance + +import libcst as cst + +from core_codemods.api import SimpleCodemod, Metadata, Reference, ReviewGuidance from codemodder.codemods.utils import BaseType, infer_expression_type from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin from codemodder.utils.utils import positional_to_keyword -class ReplaceFlaskSendFile(BaseCodemod, NameAndAncestorResolutionMixin): - NAME = "replace-flask-send-file" - SUMMARY = "Replace unsafe usage of `flask.send_file`" - DESCRIPTION = SUMMARY - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - REFERENCES = [ - { - "url": "https://flask.palletsprojects.com/en/3.0.x/api/#flask.send_from_directory", - "description": "", - }, - { - "url": "https://owasp.org/www-community/attacks/Path_Traversal", - "description": "", - }, - ] +class ReplaceFlaskSendFile(SimpleCodemod, NameAndAncestorResolutionMixin): + metadata = Metadata( + name="replace-flask-send-file", + summary="Replace unsafe usage of `flask.send_file`", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://flask.palletsprojects.com/en/3.0.x/api/#flask.send_from_directory" + ), + Reference(url="https://owasp.org/www-community/attacks/Path_Traversal"), + ], + ) pos_to_key_map: list[str | None] = [ "mimetype", diff --git a/src/core_codemods/requests_verify.py b/src/core_codemods/requests_verify.py index 08671265..47380efd 100644 --- a/src/core_codemods/requests_verify.py +++ b/src/core_codemods/requests_verify.py @@ -1,26 +1,28 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class RequestsVerify(SemgrepCodemod): - NAME = "requests-verify" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - SUMMARY = "Verify SSL Certificates for Requests." - DESCRIPTION = ( +class RequestsVerify(SimpleCodemod): + metadata = Metadata( + name="requests-verify", + summary="Verify SSL Certificates for Requests.", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference(url="https://requests.readthedocs.io/en/latest/api/"), + Reference( + url="https://owasp.org/www-community/attacks/Manipulator-in-the-middle_attack" + ), + ], + ) + change_description = ( "Makes any calls to requests.{func} with `verify=False` to `verify=True`." ) - REFERENCES = [ - {"url": "https://requests.readthedocs.io/en/latest/api/", "description": ""}, - { - "url": "https://owasp.org/www-community/attacks/Manipulator-in-the-middle_attack", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - patterns: - pattern: requests.$F(..., verify=False, ...) diff --git a/src/core_codemods/secure_flask_cookie.py b/src/core_codemods/secure_flask_cookie.py index 82fce6de..c935c11e 100644 --- a/src/core_codemods/secure_flask_cookie.py +++ b/src/core_codemods/secure_flask_cookie.py @@ -1,28 +1,29 @@ from libcst import matchers -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class SecureFlaskCookie(SemgrepCodemod): - NAME = "secure-flask-cookie" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - SUMMARY = "Use Safe Parameters in `flask` Response `set_cookie` Call" - DESCRIPTION = "Flask response `set_cookie` call should be called with `secure=True`, `httponly=True`, and `samesite='Lax'`." - REFERENCES = [ - { - "url": "https://flask.palletsprojects.com/en/3.0.x/api/#flask.Response.set_cookie", - "description": "", - }, - { - "url": "https://owasp.org/www-community/controls/SecureCookieAttribute", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ +class SecureFlaskCookie(SimpleCodemod): + metadata = Metadata( + name="secure-flask-cookie", + summary="Use Safe Parameters in `flask` Response `set_cookie` Call", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://flask.palletsprojects.com/en/3.0.x/api/#flask.Response.set_cookie" + ), + Reference( + url="https://owasp.org/www-community/controls/SecureCookieAttribute" + ), + ], + ) + change_description = "Flask response `set_cookie` call should be called with `secure=True`, `httponly=True`, and `samesite='Lax'`." + detector_pattern = """ rules: - id: secure-flask-cookie mode: taint diff --git a/src/core_codemods/secure_flask_session_config.py b/src/core_codemods/secure_flask_session_config.py index 3efba51b..95c2ae5e 100644 --- a/src/core_codemods/secure_flask_session_config.py +++ b/src/core_codemods/secure_flask_session_config.py @@ -3,30 +3,34 @@ from libcst.metadata import ParentNodeProvider, PositionProvider from libcst import matchers -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import BaseCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin from codemodder.utils.utils import extract_targets_of_assignment, true_value from codemodder.codemods.base_visitor import BaseTransformer from codemodder.change import Change from codemodder.file_context import FileContext - - -class SecureFlaskSessionConfig(BaseCodemod, Codemod): - NAME = "secure-flask-session-configuration" - SUMMARY = "Flip Insecure `Flask` Session Configurations" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_REVIEW - DESCRIPTION = "Flip Flask session configuration if defined as insecure." - REFERENCES = [ - { - "url": "https://owasp.org/www-community/controls/SecureCookieAttribute", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html", - "description": "", - }, - ] +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) + + +class SecureFlaskSessionConfig(SimpleCodemod, Codemod): + metadata = Metadata( + name="secure-flask-session-configuration", + summary="Flip Insecure `Flask` Session Configurations", + review_guidance=ReviewGuidance.MERGE_AFTER_REVIEW, + references=[ + Reference( + url="https://owasp.org/www-community/controls/SecureCookieAttribute" + ), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html" + ), + ], + ) + change_description = "Flip Flask session configuration if defined as insecure." def transform_module_impl(self, tree: cst.Module) -> cst.Module: flask_codemod = FixFlaskConfig(self.context, self.file_context) @@ -191,5 +195,5 @@ def _is_config_subscript(self, original_node: cst.Assign): def report_change(self, original_node): line_number = self.lineno_for_node(original_node) self.file_context.codemod_changes.append( - Change(line_number, SecureFlaskSessionConfig.CHANGE_DESCRIPTION) + Change(line_number, SecureFlaskSessionConfig.change_description) ) diff --git a/src/core_codemods/secure_random.py b/src/core_codemods/secure_random.py index b1b6944c..2761ae00 100644 --- a/src/core_codemods/secure_random.py +++ b/src/core_codemods/secure_random.py @@ -1,30 +1,37 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod +from core_codemods.api import ( + SimpleCodemod, + Metadata, + Reference, + ReviewGuidance, +) -class SecureRandom(SemgrepCodemod): - NAME = "secure-random" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - SUMMARY = "Secure Source of Randomness" - DESCRIPTION = "Replaces random.{func} with more secure secrets library functions." - REFERENCES = [ - { - "url": "https://owasp.org/www-community/vulnerabilities/Insecure_Randomness", - "description": "", - }, - {"url": "https://docs.python.org/3/library/random.html", "description": ""}, - ] +class SecureRandom(SimpleCodemod): + metadata = Metadata( + name="secure-random", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + summary="Secure Source of Randomness", + references=[ + Reference( + url="https://owasp.org/www-community/vulnerabilities/Insecure_Randomness", + ), + Reference( + url="https://docs.python.org/3/library/random.html", + ), + ], + ) - @classmethod - def rule(cls): - return """ - rules: - - patterns: - - pattern: random.$F(...) - - pattern-inside: | - import random - ... - """ + detector_pattern = """ + - patterns: + - pattern: random.$F(...) + - pattern-inside: | + import random + ... + """ + + change_description = ( + "Replace random.{func} with more secure secrets library functions." + ) def on_result_found(self, original_node, updated_node): self.remove_unused_import(original_node) diff --git a/src/core_codemods/semgrep/sandbox_url_creation.yaml b/src/core_codemods/semgrep/sandbox_url_creation.yaml deleted file mode 100644 index aa793bc2..00000000 --- a/src/core_codemods/semgrep/sandbox_url_creation.yaml +++ /dev/null @@ -1,13 +0,0 @@ -rules: - - id: url-sandbox - message: Unbounded URL creation - severity: WARNING - languages: - - python - pattern-either: - - patterns: - - pattern: requests.get(...) - - pattern-not: requests.get("...") - - pattern-inside: | - import requests - ... diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 3619111f..df115eb1 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -1,6 +1,7 @@ import re from typing import Any, Optional, Tuple import itertools + import libcst as cst from libcst import ( FormattedString, @@ -10,7 +11,6 @@ matchers, ) from libcst.codemod import ( - Codemod, CodemodContext, ContextAwareTransformer, ContextAwareVisitor, @@ -22,13 +22,14 @@ PositionProvider, ScopeProvider, ) -from codemodder.change import Change -from codemodder.codemods.base_codemod import ( - BaseCodemod, - CodemodMetadata, +from core_codemods.api import ( + SimpleCodemod, + Metadata, + Reference, ReviewGuidance, ) +from codemodder.change import Change from codemodder.codemods.base_visitor import UtilsMixin from codemodder.codemods.transformations.remove_empty_string_concatenation import ( RemoveEmptyStringConcatenation, @@ -41,7 +42,6 @@ infer_expression_type, ) from codemodder.codemods.utils_mixin import NameResolutionMixin -from codemodder.file_context import FileContext parameter_token = "?" @@ -49,24 +49,17 @@ raw_quote_pattern = re.compile(r"(? None: self.changed_nodes: dict[ cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel | dict[str, Any], ] = {} - BaseCodemod.__init__(self, file_context, *codemod_args) + SimpleCodemod.__init__(self, *codemod_args, **codemod_kwargs) UtilsMixin.__init__(self, []) - Codemod.__init__(self, context) def _build_param_element(self, prepend, middle, append): new_middle = ( @@ -166,7 +157,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.changed_nodes = {} line_number = self.get_metadata(PositionProvider, call).start.line self.file_context.codemod_changes.append( - Change(line_number, SQLQueryParameterization.CHANGE_DESCRIPTION) + Change(line_number, SQLQueryParameterization.change_description) ) # Normalization and cleanup result = result.visit(RemoveEmptyStringConcatenation()) diff --git a/src/core_codemods/subprocess_shell_false.py b/src/core_codemods/subprocess_shell_false.py index 536d651b..b9d40c05 100644 --- a/src/core_codemods/subprocess_shell_false.py +++ b/src/core_codemods/subprocess_shell_false.py @@ -1,29 +1,31 @@ import libcst as cst from libcst import matchers -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class SubprocessShellFalse(BaseCodemod, NameResolutionMixin): - NAME = "subprocess-shell-false" - SUMMARY = "Use `shell=False` in `subprocess` Function Calls" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - DESCRIPTION = "Set `shell` keyword argument to `False`" - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/subprocess.html#security-considerations", - "description": "", - }, - { - "url": "https://en.wikipedia.org/wiki/Code_injection#Shell_injection", - "description": "", - }, - { - "url": "https://stackoverflow.com/a/3172488", - "description": "", - }, - ] +class SubprocessShellFalse(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="subprocess-shell-false", + summary="Use `shell=False` in `subprocess` Function Calls", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/subprocess.html#security-considerations" + ), + Reference( + url="https://en.wikipedia.org/wiki/Code_injection#Shell_injection" + ), + Reference(url="https://stackoverflow.com/a/3172488"), + ], + ) + change_description = "Set `shell` keyword argument to `False`" SUBPROCESS_FUNCS = [ f"subprocess.{func}" for func in {"run", "call", "check_output", "check_call", "Popen"} diff --git a/src/core_codemods/tempfile_mktemp.py b/src/core_codemods/tempfile_mktemp.py index 8de39ac1..c5a00fdb 100644 --- a/src/core_codemods/tempfile_mktemp.py +++ b/src/core_codemods/tempfile_mktemp.py @@ -1,25 +1,27 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class TempfileMktemp(SemgrepCodemod, NameResolutionMixin): - NAME = "secure-tempfile" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Upgrade and Secure Temp File Creation" - DESCRIPTION = "Replaces `tempfile.mktemp` with `tempfile.mkstemp`." - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/tempfile.html#tempfile.mktemp", - "description": "", - } - ] +class TempfileMktemp(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="secure-tempfile", + summary="Upgrade and Secure Temp File Creation", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/tempfile.html#tempfile.mktemp" + ), + ], + ) + change_description = "Replaces `tempfile.mktemp` with `tempfile.mkstemp`." _module_name = "tempfile" - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - patterns: - pattern: tempfile.mktemp(...) diff --git a/src/core_codemods/upgrade_sslcontext_minimum_version.py b/src/core_codemods/upgrade_sslcontext_minimum_version.py index 85ee1a85..cddf1dcd 100644 --- a/src/core_codemods/upgrade_sslcontext_minimum_version.py +++ b/src/core_codemods/upgrade_sslcontext_minimum_version.py @@ -1,30 +1,29 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class UpgradeSSLContextMinimumVersion(SemgrepCodemod, NameResolutionMixin): - NAME = "upgrade-sslcontext-minimum-version" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - SUMMARY = "Upgrade SSLContext Minimum Version" - DESCRIPTION = "Replaces minimum SSL/TLS version for SSLContext." - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/ssl.html#security-considerations", - "description": "", - }, - {"url": "https://datatracker.ietf.org/doc/rfc8996/", "description": ""}, - { - "url": "https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1", - "description": "", - }, - ] +class UpgradeSSLContextMinimumVersion(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="upgrade-sslcontext-minimum-version", + summary="Upgrade SSLContext Minimum Version", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/ssl.html#security-considerations" + ), + Reference(url="https://datatracker.ietf.org/doc/rfc8996/"), + Reference(url="https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1"), + ], + ) + change_description = "Replaces minimum SSL/TLS version for SSLContext." _module_name = "ssl" - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - mode: taint pattern-sources: diff --git a/src/core_codemods/upgrade_sslcontext_tls.py b/src/core_codemods/upgrade_sslcontext_tls.py index 12187838..b3f50d33 100644 --- a/src/core_codemods/upgrade_sslcontext_tls.py +++ b/src/core_codemods/upgrade_sslcontext_tls.py @@ -1,25 +1,27 @@ -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod -from codemodder.codemods.api.helpers import NewArg +from codemodder.codemods.libcst_transformer import NewArg +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class UpgradeSSLContextTLS(SemgrepCodemod): - NAME = "upgrade-sslcontext-tls" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - SUMMARY = "Upgrade TLS Version In SSLContext" - DESCRIPTION = "Replaces known insecure TLS/SSL protocol versions in SSLContext with secure ones." - CHANGE_DESCRIPTION = "Upgrade to use a safe version of TLS in SSLContext" - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/ssl.html#security-considerations", - "description": "", - }, - {"url": "https://datatracker.ietf.org/doc/rfc8996/", "description": ""}, - { - "url": "https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1", - "description": "", - }, - ] +class UpgradeSSLContextTLS(SimpleCodemod): + metadata = Metadata( + name="upgrade-sslcontext-tls", + summary="Upgrade TLS Version In SSLContext", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/ssl.html#security-considerations" + ), + Reference(url="https://datatracker.ietf.org/doc/rfc8996/"), + Reference(url="https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1"), + ], + ) + change_description = "Replaces known insecure TLS/SSL protocol versions in SSLContext with secure ones." + change_description = "Upgrade to use a safe version of TLS in SSLContext" # TODO: in the majority of cases, using PROTOCOL_TLS_CLIENT will be the # right fix. However in some cases it will be appropriate to use @@ -27,10 +29,7 @@ class UpgradeSSLContextTLS(SemgrepCodemod): # this. Eventually, when the platform supports parameters, we want to # revisit this to provide PROTOCOL_TLS_SERVER as an alternative fix. SAFE_TLS_PROTOCOL_VERSION = "ssl.PROTOCOL_TLS_CLIENT" - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - patterns: - pattern-inside: | diff --git a/src/core_codemods/url_sandbox.py b/src/core_codemods/url_sandbox.py index 85213384..2893af4a 100644 --- a/src/core_codemods/url_sandbox.py +++ b/src/core_codemods/url_sandbox.py @@ -2,14 +2,15 @@ import libcst as cst from libcst import CSTNode, matchers -from libcst.codemod import Codemod, CodemodContext +from libcst.codemod import CodemodContext from libcst.metadata import PositionProvider, ScopeProvider from libcst.codemod.visitors import AddImportsVisitor, ImportItem from codemodder.change import Change -from codemodder.codemods.base_codemod import ( - SemgrepCodemod, - CodemodMetadata, +from core_codemods.api import ( + SimpleCodemod, + Metadata, + Reference, ReviewGuidance, ) from codemodder.codemods.base_visitor import BaseVisitor @@ -24,47 +25,47 @@ replacement_import = "safe_requests" -class UrlSandbox(SemgrepCodemod, Codemod): - METADATA = CodemodMetadata( - DESCRIPTION=( - "Replaces request.{func} with more secure safe_request library functions." - ), - NAME="url-sandbox", - REVIEW_GUIDANCE=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, - REFERENCES=[ - { - "url": "https://github.com/pixee/python-security/blob/main/src/security/safe_requests/api.py", - "description": "", - }, - {"url": "https://portswigger.net/web-security/ssrf", "description": ""}, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/Server_Side_Request_Forgery_Prevention_Cheat_Sheet.html", - "description": "", - }, - { - "url": "https://www.rapid7.com/blog/post/2021/11/23/owasp-top-10-deep-dive-defending-against-server-side-request-forgery/", - "description": "", - }, - { - "url": "https://blog.assetnote.io/2021/01/13/blind-ssrf-chains/", - "description": "", - }, +class UrlSandbox(SimpleCodemod): + metadata = Metadata( + name="url-sandbox", + summary="Sandbox URL Creation", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://github.com/pixee/python-security/blob/main/src/security/safe_requests/api.py" + ), + Reference(url="https://portswigger.net/web-security/ssrf"), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/Server_Side_Request_Forgery_Prevention_Cheat_Sheet.html" + ), + Reference( + url="https://www.rapid7.com/blog/post/2021/11/23/owasp-top-10-deep-dive-defending-against-server-side-request-forgery/" + ), + Reference(url="https://blog.assetnote.io/2021/01/13/blind-ssrf-chains/"), ], ) - SUMMARY = "Sandbox URL Creation" - CHANGE_DESCRIPTION = "Switch use of requests for security.safe_requests" - YAML_FILES = [ - "sandbox_url_creation.yaml", - ] + change_description = "Switch use of requests for security.safe_requests" + + detector_pattern = """ + rules: + - id: url-sandbox + message: Unbounded URL creation + severity: WARNING + languages: + - python + pattern-either: + - patterns: + - pattern: requests.get(...) + - pattern-not: requests.get("...") + - pattern-inside: | + import requests + ... + """ METADATA_DEPENDENCIES = (PositionProvider, ScopeProvider) adds_dependency = True - def __init__(self, codemod_context: CodemodContext, *args): - Codemod.__init__(self, codemod_context) - SemgrepCodemod.__init__(self, *args) - def transform_module_impl(self, tree: cst.Module) -> cst.Module: # we first gather all the nodes we want to change together with their replacements find_requests_visitor = FindRequestCallsAndImports( @@ -142,7 +143,7 @@ def leave_Call(self, original_node: cst.Call): } ) self.changes_in_file.append( - Change(line_number, UrlSandbox.CHANGE_DESCRIPTION) + Change(line_number, UrlSandbox.change_description) ) # case req.get(...) @@ -156,7 +157,7 @@ def leave_Call(self, original_node: cst.Call): } ) self.changes_in_file.append( - Change(line_number, UrlSandbox.CHANGE_DESCRIPTION) + Change(line_number, UrlSandbox.change_description) ) def _find_assignments(self, node: CSTNode): diff --git a/src/core_codemods/use_defused_xml.py b/src/core_codemods/use_defused_xml.py index 498ec3e2..883d7fb2 100644 --- a/src/core_codemods/use_defused_xml.py +++ b/src/core_codemods/use_defused_xml.py @@ -3,9 +3,12 @@ import libcst as cst from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor - -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import BaseCodemod +from core_codemods.api import ( + SimpleCodemod, + Metadata, + Reference, + ReviewGuidance, +) from codemodder.codemods.imported_call_modifier import ImportedCallModifier from codemodder.dependency import DefusedXML @@ -42,31 +45,26 @@ def update_simple_name(self, true_name, original_node, updated_node, new_args): # TODO: add expat methods? -class UseDefusedXml(BaseCodemod): - NAME = "use-defusedxml" - SUMMARY = "Use `defusedxml` for Parsing XML" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_REVIEW - DESCRIPTION = "Replace builtin xml method with safe defusedxml method" - REFERENCES = [ - { - "url": "https://docs.python.org/3/library/xml.html#xml-vulnerabilities", - "description": "", - }, - { - "url": "https://docs.python.org/3/library/xml.html#the-defusedxml-package", - "description": "", - }, - { - "url": "https://pypi.org/project/defusedxml/", - "description": "", - }, - { - "url": "https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html", - "description": "", - }, - ] +class UseDefusedXml(SimpleCodemod): + metadata = Metadata( + name="use-defusedxml", + summary="Use `defusedxml` for Parsing XML", + review_guidance=ReviewGuidance.MERGE_AFTER_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/xml.html#xml-vulnerabilities" + ), + Reference( + url="https://docs.python.org/3/library/xml.html#the-defusedxml-package" + ), + Reference(url="https://pypi.org/project/defusedxml/"), + Reference( + url="https://cheatsheetseries.owasp.org/cheatsheets/XML_External_Entity_Prevention_Cheat_Sheet.html" + ), + ], + ) - adds_dependency = True + change_description = "Replace builtin XML method with safe `defusedxml` method" @cached_property def matching_functions(self) -> dict[str, str]: @@ -89,7 +87,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.context, self.file_context, self.matching_functions, - self.CHANGE_DESCRIPTION, + self.change_description, ) result_tree = visitor.transform_module(tree) self.file_context.codemod_changes.extend(visitor.changes_in_file) diff --git a/src/core_codemods/use_generator.py b/src/core_codemods/use_generator.py index 6b80e196..4a451a06 100644 --- a/src/core_codemods/use_generator.py +++ b/src/core_codemods/use_generator.py @@ -1,28 +1,31 @@ import libcst as cst - -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class UseGenerator(BaseCodemod, NameResolutionMixin): - NAME = "use-generator" - SUMMARY = "Use Generator Expressions Instead of List Comprehensions" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Replace list comprehension with generator expression" - REFERENCES = [ - { - "url": "https://pylint.readthedocs.io/en/latest/user_guide/messages/refactor/use-a-generator.html", - "description": "", - }, - { - "url": "https://docs.python.org/3/glossary.html#term-generator-expression", - "description": "", - }, - { - "url": "https://docs.python.org/3/glossary.html#term-list-comprehension", - "description": "", - }, - ] +class UseGenerator(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="use-generator", + summary="Use Generator Expressions Instead of List Comprehensions", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://pylint.readthedocs.io/en/latest/user_guide/messages/refactor/use-a-generator.html" + ), + Reference( + url="https://docs.python.org/3/glossary.html#term-generator-expression" + ), + Reference( + url="https://docs.python.org/3/glossary.html#term-list-comprehension" + ), + ], + ) + change_description = "Replace list comprehension with generator expression" def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): if not self.filter_by_path_includes_or_excludes( @@ -37,7 +40,7 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): if self.is_builtin_function(original_node): match original_node.args[0].value: case cst.ListComp(elt=elt, for_in=for_in): - self.add_change(original_node, self.CHANGE_DESCRIPTION) + self.add_change(original_node, self.change_description) return updated_node.with_changes( args=[ cst.Arg( diff --git a/src/core_codemods/use_set_literal.py b/src/core_codemods/use_set_literal.py index f4ef023a..eb1dfbdc 100644 --- a/src/core_codemods/use_set_literal.py +++ b/src/core_codemods/use_set_literal.py @@ -1,14 +1,16 @@ import libcst as cst - -from codemodder.codemods.api import BaseCodemod, ReviewGuidance from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod -class UseSetLiteral(BaseCodemod, NameResolutionMixin): - NAME = "use-set-literal" - SUMMARY = "Use Set Literals Instead of Sets from Lists" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Replace sets from lists with set literals" +class UseSetLiteral(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="use-set-literal", + summary="Use Set Literals Instead of Sets from Lists", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[], + ) + change_description = "Replace sets from lists with set literals" REFERENCES: list = [] def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): diff --git a/src/core_codemods/use_walrus_if.py b/src/core_codemods/use_walrus_if.py index e2866df5..7c5e4465 100644 --- a/src/core_codemods/use_walrus_if.py +++ b/src/core_codemods/use_walrus_if.py @@ -6,9 +6,12 @@ from libcst._position import CodeRange from libcst import matchers as m from libcst.metadata import ParentNodeProvider, ScopeProvider - -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import BaseCodemod +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) FoundAssign = namedtuple("FoundAssign", ["assign", "target", "value"]) @@ -20,23 +23,24 @@ def pairwise(iterable): return zip(a, b) -class UseWalrusIf(BaseCodemod): - METADATA_DEPENDENCIES = BaseCodemod.METADATA_DEPENDENCIES + ( - ParentNodeProvider, - ScopeProvider, +class UseWalrusIf(SimpleCodemod): + metadata = Metadata( + name="use-walrus-if", + summary="Use Assignment Expression (Walrus) In Conditional", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/whatsnew/3.8.html#assignment-expressions" + ), + ], ) - NAME = "use-walrus-if" - SUMMARY = "Use Assignment Expression (Walrus) In Conditional" - REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW - DESCRIPTION = ( + change_description = ( "Replaces multiple expressions involving `if` operator with 'walrus' operator." ) - REFERENCES = [ - { - "url": "https://docs.python.org/3/whatsnew/3.8.html#assignment-expressions", - "description": "", - } - ] + METADATA_DEPENDENCIES = SimpleCodemod.METADATA_DEPENDENCIES + ( + ParentNodeProvider, + ScopeProvider, + ) _modify_next_if: List[Tuple[CodeRange, cst.NamedExpr]] _if_stack: List[Optional[Tuple[CodeRange, cst.NamedExpr]]] @@ -123,7 +127,7 @@ def leave_If(self, original_node, updated_node): if (result := self._if_stack.pop()) is not None: position, named_expr = result is_name = m.matches(updated_node.test, m.Name()) - self.add_change_from_position(position, self.CHANGE_DESCRIPTION) + self.add_change_from_position(position, self.change_description) return ( updated_node.with_changes(test=named_expr) if is_name diff --git a/src/core_codemods/with_threading_lock.py b/src/core_codemods/with_threading_lock.py index e7ad02b4..1ffb53a1 100644 --- a/src/core_codemods/with_threading_lock.py +++ b/src/core_codemods/with_threading_lock.py @@ -1,30 +1,31 @@ import libcst as cst -from codemodder.codemods.base_codemod import ReviewGuidance -from codemodder.codemods.api import SemgrepCodemod from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import ( + Metadata, + Reference, + ReviewGuidance, + SimpleCodemod, +) -class WithThreadingLock(SemgrepCodemod, NameResolutionMixin): - NAME = "bad-lock-with-statement" - SUMMARY = "Separate Lock Instantiation from `with` Call" - DESCRIPTION = ( +class WithThreadingLock(SimpleCodemod, NameResolutionMixin): + metadata = Metadata( + name="bad-lock-with-statement", + summary="Separate Lock Instantiation from `with` Call", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://pylint.pycqa.org/en/latest/user_guide/messages/warning/useless-with-lock." + ), + Reference( + url="https://docs.python.org/3/library/threading.html#using-locks-conditions-and-semaphores-in-the-with-statement" + ), + ], + ) + change_description = ( "Replace deprecated usage of threading lock classes as context managers." ) - REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - REFERENCES = [ - { - "url": "https://pylint.pycqa.org/en/latest/user_guide/messages/warning/useless-with-lock.", - "description": "", - }, - { - "url": "https://docs.python.org/3/library/threading.html#using-locks-conditions-and-semaphores-in-the-with-statement", - "description": "", - }, - ] - - @classmethod - def rule(cls): - return """ + detector_pattern = """ rules: - patterns: - pattern: | @@ -45,8 +46,8 @@ def rule(cls): - focus-metavariable: $BODY """ - def __init__(self, *args): - SemgrepCodemod.__init__(self, *args) + def __init__(self, *args, **kwargs): + SimpleCodemod.__init__(self, *args, **kwargs) NameResolutionMixin.__init__(self) self.names_in_module = self.find_used_names_in_module() @@ -82,7 +83,7 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): ) ] ) - # TODO: add result + self.add_change(original_node, self.change_description) return cst.FlattenSentinel( [ assign, diff --git a/tests/codemods/base_codemod_test.py b/tests/codemods/base_codemod_test.py index 74aa3e06..a5a6db7d 100644 --- a/tests/codemods/base_codemod_test.py +++ b/tests/codemods/base_codemod_test.py @@ -4,13 +4,10 @@ from textwrap import dedent from typing import ClassVar -import libcst as cst -from libcst.codemod import CodemodContext import mock from codemodder.context import CodemodExecutionContext -from codemodder.dependency import Dependency -from codemodder.file_context import FileContext +from codemodder.diff import create_diff from codemodder.registry import CodemodRegistry, CodemodCollection from codemodder.semgrep import run as semgrep_run @@ -19,69 +16,88 @@ class BaseCodemodTest: codemod: ClassVar = NotImplemented def setup_method(self): + if isinstance(self.codemod, type): + self.codemod = self.codemod() + self.file_context = None - def initialize_codemod(self, input_tree): - wrapper = cst.MetadataWrapper(input_tree) - codemod_instance = self.codemod( - CodemodContext(wrapper=wrapper), - self.file_context, - ) - return codemod_instance + def run_and_assert( # pylint: disable=too-many-arguments + self, + tmpdir, + input_code, + expected, + num_changes: int = 1, + root: Path | None = None, + files: list[Path] | None = None, + lines_to_exclude: list[int] | None = None, + ): + root = root or tmpdir + tmp_file_path = files[0] if files else Path(tmpdir) / "code.py" + tmp_file_path.write_text(dedent(input_code)) - def run_and_assert(self, tmpdir, input_code, expected): - tmp_file_path = Path(tmpdir / "code.py") - self.run_and_assert_filepath(tmpdir, tmp_file_path, input_code, expected) + files_to_check = files or [tmp_file_path] + + path_exclude = [f"{tmp_file_path}:{line}" for line in lines_to_exclude or []] - def assert_no_change_line_excluded( - self, tmpdir, input_code, expected, lines_to_exclude - ): - tmp_file_path = Path(tmpdir / "code.py") - input_tree = cst.parse_module(dedent(input_code)) self.execution_context = CodemodExecutionContext( - directory=tmpdir, - dry_run=True, + directory=root, + dry_run=False, verbose=False, registry=mock.MagicMock(), repo_manager=mock.MagicMock(), + path_include=[f.name for f in files_to_check], + path_exclude=path_exclude, ) - self.file_context = FileContext( + self.codemod.apply(self.execution_context, files_to_check) + changes = self.execution_context.get_results(self.codemod.id) + + if input_code == expected: + assert not changes + return + + assert len(changes) == 1 + assert len(changes[0].changes) == num_changes + + self.assert_changes( tmpdir, tmp_file_path, - lines_to_exclude, - [], - [], + input_code, + expected, + changes[0], ) - codemod_instance = self.initialize_codemod(input_tree) - output_tree = codemod_instance.transform_module(input_tree) - - assert output_tree.code == dedent(expected) - assert len(self.file_context.codemod_changes) == 0 - def run_and_assert_filepath(self, root, file_path, input_code, expected): - input_tree = cst.parse_module(dedent(input_code)) - self.execution_context = CodemodExecutionContext( - directory=root, - dry_run=True, - verbose=False, - registry=mock.MagicMock(), - repo_manager=mock.MagicMock(), - ) - self.file_context = FileContext( - root, - file_path, - [], - [], - [], + def assert_changes( # pylint: disable=too-many-arguments + self, root, file_path, input_code, expected, changes + ): + expected_diff = create_diff( + dedent(input_code).splitlines(keepends=True), + dedent(expected).splitlines(keepends=True), ) - codemod_instance = self.initialize_codemod(input_tree) - output_tree = codemod_instance.transform_module(input_tree) - assert output_tree.code == dedent(expected) + assert expected_diff == changes.diff + assert os.path.relpath(file_path, root) == changes.path + + with open(file_path, "r", encoding="utf-8") as tmp_file: + output_code = tmp_file.read() + + assert output_code == dedent(expected) - def assert_dependency(self, dependency: Dependency): - assert self.file_context and self.file_context.dependencies == set([dependency]) + def run_and_assert_filepath( # pylint: disable=too-many-arguments + self, + root: Path, + file_path: Path, + input_code: str, + expected: str, + num_changes: int = 1, + ): + self.run_and_assert( + tmpdir=root, + input_code=input_code, + expected=expected, + num_changes=num_changes, + files=[file_path], + ) class BaseSemgrepCodemodTest(BaseCodemodTest): @@ -100,35 +116,12 @@ def results_by_id_filepath(self, input_code, file_path): with open(file_path, "w", encoding="utf-8") as tmp_file: tmp_file.write(dedent(input_code)) - name = self.codemod.name() + name = self.codemod.name results = self.registry.match_codemods(codemod_include=[name]) return semgrep_run(self.execution_context, results[0].yaml_files) - def run_and_assert_filepath(self, root, file_path, input_code, expected): - self.execution_context = CodemodExecutionContext( - directory=root, - dry_run=True, - verbose=False, - registry=mock.MagicMock(), - repo_manager=mock.MagicMock(), - ) - input_tree = cst.parse_module(dedent(input_code)) - all_results = self.results_by_id_filepath(input_code, file_path) - results = all_results.results_for_rule_and_file(self.codemod.name(), file_path) - self.file_context = FileContext( - root, - file_path, - [], - [], - results, - ) - codemod_instance = self.initialize_codemod(input_tree) - output_tree = codemod_instance.transform_module(input_tree) - - assert output_tree.code == dedent(expected) - -class BaseDjangoCodemodTest(BaseSemgrepCodemodTest): +class BaseDjangoCodemodTest(BaseCodemodTest): def create_dir_structure(self, tmpdir): django_root = Path(tmpdir) / "mysite" settings_folder = django_root / "mysite" diff --git a/tests/codemods/conftest.py b/tests/codemods/conftest.py new file mode 100644 index 00000000..6aab38f6 --- /dev/null +++ b/tests/codemods/conftest.py @@ -0,0 +1,15 @@ +import pytest + + +@pytest.fixture(autouse=True) +def disable_semgrep_run(): + """ + Override the fixture defined in conftest.py + """ + + +@pytest.fixture(autouse=True) +def disable_update_code(): + """ + Override the fixture defined in conftest.py + """ diff --git a/tests/codemods/test_add_requests_timeouts.py b/tests/codemods/test_add_requests_timeouts.py index 196070f2..9382de8c 100644 --- a/tests/codemods/test_add_requests_timeouts.py +++ b/tests/codemods/test_add_requests_timeouts.py @@ -1,10 +1,13 @@ import pytest -from core_codemods.add_requests_timeouts import AddRequestsTimeouts +from core_codemods.add_requests_timeouts import ( + AddRequestsTimeouts, + TransformAddRequestsTimeouts, +) from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest METHODS = ["get", "post", "put", "delete", "head", "options", "patch"] -TIMEOUT = AddRequestsTimeouts.DEFAULT_TIMEOUT +TIMEOUT = TransformAddRequestsTimeouts.DEFAULT_TIMEOUT class TestAddRequestsTimeouts(BaseSemgrepCodemodTest): diff --git a/tests/codemods/test_base_codemod.py b/tests/codemods/test_base_codemod.py index 033e8950..662f7e0b 100644 --- a/tests/codemods/test_base_codemod.py +++ b/tests/codemods/test_base_codemod.py @@ -1,26 +1,20 @@ import libcst as cst -from libcst.codemod import Codemod, CodemodContext +from libcst.codemod import CodemodContext import mock -from codemodder.codemods.base_codemod import ( - SemgrepCodemod, - CodemodMetadata, +from codemodder.codemods.api import ( + SimpleCodemod, + Metadata, ReviewGuidance, ) -class DoNothingCodemod(SemgrepCodemod, Codemod): - METADATA = CodemodMetadata( - DESCRIPTION="An identity codemod for testing purposes.", - NAME="do-nothing", - REVIEW_GUIDANCE=ReviewGuidance.MERGE_WITHOUT_REVIEW, +class DoNothingCodemod(SimpleCodemod): + metadata = Metadata( + name="do-nothing", + summary="An identity codemod for testing purposes.", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, ) - SUMMARY = "An identity codemod for testing purposes." - YAML_FILES = [] - - def __init__(self, codemod_context: CodemodContext, *args): - Codemod.__init__(self, codemod_context) - SemgrepCodemod.__init__(self, *args) def transform_module_impl(self, tree: cst.Module) -> cst.Module: return tree @@ -32,6 +26,8 @@ def run_and_assert(self, input_code, expected_output): command_instance = DoNothingCodemod( CodemodContext(), mock.MagicMock(), + mock.MagicMock(), + _transformer=True, ) output_tree = command_instance.transform_module(input_tree) diff --git a/tests/codemods/test_combine_startswith_endswith.py b/tests/codemods/test_combine_startswith_endswith.py index 6ce21923..be8ec368 100644 --- a/tests/codemods/test_combine_startswith_endswith.py +++ b/tests/codemods/test_combine_startswith_endswith.py @@ -9,7 +9,7 @@ class TestCombineStartswithEndswith(BaseCodemodTest): codemod = CombineStartswithEndswith def test_name(self): - assert self.codemod.name() == "combine-startswith-endswith" + assert self.codemod.name == "combine-startswith-endswith" @each_func def test_combine(self, tmpdir, func): @@ -22,7 +22,6 @@ def test_combine(self, tmpdir, func): x.{func}(("foo", "f")) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 @pytest.mark.parametrize( "code", @@ -37,7 +36,6 @@ def test_combine(self, tmpdir, func): ) def test_no_change(self, tmpdir, code): self.run_and_assert(tmpdir, code, code) - assert len(self.file_context.codemod_changes) == 0 def test_exclude_line(self, tmpdir): input_code = expected = """\ @@ -45,6 +43,9 @@ def test_exclude_line(self, tmpdir): x.startswith("foo") or x.startswith("f") """ lines_to_exclude = [2] - self.assert_no_change_line_excluded( - tmpdir, input_code, expected, lines_to_exclude + self.run_and_assert( + tmpdir, + input_code, + expected, + lines_to_exclude=lines_to_exclude, ) diff --git a/tests/codemods/test_django_debug_flag_on.py b/tests/codemods/test_django_debug_flag_on.py index b66b603f..c539665d 100644 --- a/tests/codemods/test_django_debug_flag_on.py +++ b/tests/codemods/test_django_debug_flag_on.py @@ -6,7 +6,7 @@ class TestDjangoDebugFlagOn(BaseDjangoCodemodTest): codemod = DjangoDebugFlagOn def test_name(self): - assert self.codemod.name() == "django-debug-flag-on" + assert self.codemod.name == "django-debug-flag-on" def test_settings_dot_py(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -15,7 +15,6 @@ def test_settings_dot_py(self, tmpdir): input_code = """DEBUG = True""" expected = """DEBUG = False""" self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_not_settings_dot_py(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -24,7 +23,6 @@ def test_not_settings_dot_py(self, tmpdir): input_code = """DEBUG = True""" expected = input_code self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 0 def test_no_manage_dot_py(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -32,4 +30,3 @@ def test_no_manage_dot_py(self, tmpdir): input_code = """DEBUG = True""" expected = input_code self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_django_json_response_type.py b/tests/codemods/test_django_json_response_type.py index 29487000..afb99f5d 100644 --- a/tests/codemods/test_django_json_response_type.py +++ b/tests/codemods/test_django_json_response_type.py @@ -7,7 +7,7 @@ class TestDjangoJsonResponseType(BaseSemgrepCodemodTest): codemod = DjangoJsonResponseType def test_name(self): - assert self.codemod.name() == "django-json-response-type" + assert self.codemod.name == "django-json-response-type" def test_simple(self, tmpdir): input_code = """\ @@ -27,7 +27,6 @@ def foo(request): return HttpResponse(json_response, content_type="application/json") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_alias(self, tmpdir): input_code = """\ @@ -47,7 +46,6 @@ def foo(request): return response(json_response, content_type="application/json") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_direct(self, tmpdir): input_code = """\ @@ -65,7 +63,6 @@ def foo(request): return HttpResponse(json.dumps({ "user_input": request.GET.get("input") }), content_type="application/json") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_content_type_set(self, tmpdir): input_code = """\ @@ -77,7 +74,6 @@ def foo(request): return HttpResponse(json_response, content_type='application/json') """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_no_json_input(self, tmpdir): input_code = """\ @@ -89,4 +85,3 @@ def foo(request): return HttpResponse(dict_response) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_django_receiver_on_top.py b/tests/codemods/test_django_receiver_on_top.py index 95042ce4..e83f6468 100644 --- a/tests/codemods/test_django_receiver_on_top.py +++ b/tests/codemods/test_django_receiver_on_top.py @@ -7,7 +7,7 @@ class TestDjangoReceiverOnTop(BaseCodemodTest): codemod = DjangoReceiverOnTop def test_name(self): - assert self.codemod.name() == "django-receiver-on-top" + assert self.codemod.name == "django-receiver-on-top" def test_simple(self, tmpdir): input_code = """\ @@ -27,7 +27,6 @@ def foo(): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_alias(self, tmpdir): input_code = """\ @@ -47,7 +46,6 @@ def foo(): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_no_receiver(self, tmpdir): input_code = """\ @@ -56,7 +54,6 @@ def foo(): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_receiver_but_not_djangos(self, tmpdir): input_code = """\ @@ -68,7 +65,6 @@ def foo(): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_receiver_on_top(self, tmpdir): input_code = """\ @@ -80,4 +76,3 @@ def foo(): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_django_session_cookie_secure_off.py b/tests/codemods/test_django_session_cookie_secure_off.py index 821c2468..62b1cf45 100644 --- a/tests/codemods/test_django_session_cookie_secure_off.py +++ b/tests/codemods/test_django_session_cookie_secure_off.py @@ -9,7 +9,7 @@ class TestDjangoSessionSecureCookieOff(BaseDjangoCodemodTest): codemod = DjangoSessionCookieSecureOff def test_name(self): - assert self.codemod.name() == "django-session-cookie-secure-off" + assert self.codemod.name == "django-session-cookie-secure-off" def test_not_settings_dot_py(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -18,7 +18,6 @@ def test_not_settings_dot_py(self, tmpdir): input_code = """SESSION_COOKIE_SECURE = True""" expected = input_code self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 0 def test_no_manage_dot_py(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -26,7 +25,6 @@ def test_no_manage_dot_py(self, tmpdir): input_code = """SESSION_COOKIE_SECURE = True""" expected = input_code self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 0 def test_settings_dot_py_secure_true(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -36,7 +34,6 @@ def test_settings_dot_py_secure_true(self, tmpdir): SESSION_COOKIE_SECURE = True """ self.run_and_assert_filepath(django_root, file_path, input_code, input_code) - assert len(self.file_context.codemod_changes) == 0 @pytest.mark.parametrize("value", ["False", "gibberish"]) def test_settings_dot_py_secure_bad(self, tmpdir, value): @@ -50,7 +47,6 @@ def test_settings_dot_py_secure_bad(self, tmpdir, value): SESSION_COOKIE_SECURE = True """ self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_settings_dot_py_secure_missing(self, tmpdir): django_root, settings_folder = self.create_dir_structure(tmpdir) @@ -62,4 +58,3 @@ def test_settings_dot_py_secure_missing(self, tmpdir): SESSION_COOKIE_SECURE = True """ self.run_and_assert_filepath(django_root, file_path, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_enable_jinja2_autoescape.py b/tests/codemods/test_enable_jinja2_autoescape.py index e48a3640..9d0eafe4 100644 --- a/tests/codemods/test_enable_jinja2_autoescape.py +++ b/tests/codemods/test_enable_jinja2_autoescape.py @@ -6,7 +6,7 @@ class TestEnableJinja2Autoescape(BaseSemgrepCodemodTest): codemod = EnableJinja2Autoescape def test_name(self): - assert self.codemod.name() == "enable-jinja2-autoescape" + assert self.codemod.name == "enable-jinja2-autoescape" def test_import(self, tmpdir): input_code = """ diff --git a/tests/codemods/test_exception_without_raise.py b/tests/codemods/test_exception_without_raise.py index dd1f8b3a..653d062f 100644 --- a/tests/codemods/test_exception_without_raise.py +++ b/tests/codemods/test_exception_without_raise.py @@ -7,7 +7,7 @@ class TestExceptionWithoutRaise(BaseCodemodTest): codemod = ExceptionWithoutRaise def test_name(self): - assert self.codemod.name() == "exception-without-raise" + assert self.codemod.name == "exception-without-raise" def test_simple(self, tmpdir): input_code = """\ @@ -17,7 +17,6 @@ def test_simple(self, tmpdir): raise ValueError """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_call(self, tmpdir): input_code = """\ @@ -27,7 +26,6 @@ def test_simple_call(self, tmpdir): raise ValueError("Bad value!") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_alias(self, tmpdir): input_code = """\ @@ -39,21 +37,18 @@ def test_alias(self, tmpdir): raise error """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_unknown_exception(self, tmpdir): input_code = """\ Something """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_raised_exception(self, tmpdir): input_code = """\ raise ValueError """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_exclude_line(self, tmpdir): input_code = expected = """\ @@ -61,6 +56,9 @@ def test_exclude_line(self, tmpdir): ValueError("Bad value!") """ lines_to_exclude = [2] - self.assert_no_change_line_excluded( - tmpdir, input_code, expected, lines_to_exclude + self.run_and_assert( + tmpdir, + input_code, + expected, + lines_to_exclude=lines_to_exclude, ) diff --git a/tests/codemods/test_file_resource_leak.py b/tests/codemods/test_file_resource_leak.py index b67a81f2..e3c6835e 100644 --- a/tests/codemods/test_file_resource_leak.py +++ b/tests/codemods/test_file_resource_leak.py @@ -7,7 +7,7 @@ class TestFileResourceLeak(BaseCodemodTest): codemod = FileResourceLeak def test_name(self): - assert self.codemod.name() == "fix-file-resource-leak" + assert self.codemod.name == "fix-file-resource-leak" def test_simple(self, tmpdir): input_code = """\ @@ -19,7 +19,6 @@ def test_simple(self, tmpdir): file.read() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_annotated(self, tmpdir): input_code = """\ @@ -31,7 +30,6 @@ def test_simple_annotated(self, tmpdir): file.read() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_just_open(self, tmpdir): # strange as this change may be, it still leaks if left untouched @@ -43,7 +41,6 @@ def test_just_open(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_multiple_assignments(self, tmpdir): input_code = """\ @@ -55,7 +52,6 @@ def test_multiple_assignments(self, tmpdir): file.read() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_minimal_block(self, tmpdir): input_code = """\ @@ -69,7 +65,6 @@ def test_minimal_block(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 # negative tests below @@ -80,7 +75,6 @@ def test_is_closed(self, tmpdir): file.close() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_is_closed_with_exit(self, tmpdir): input_code = """\ @@ -89,7 +83,6 @@ def test_is_closed_with_exit(self, tmpdir): file.__exit__() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_is_closed_with_statement(self, tmpdir): input_code = """\ @@ -98,7 +91,6 @@ def test_is_closed_with_statement(self, tmpdir): file.read() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_is_closed_with_statement_and_contextlib(self, tmpdir): input_code = """\ @@ -108,7 +100,6 @@ def test_is_closed_with_statement_and_contextlib(self, tmpdir): file.read() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_is_closed_transitivelly(self, tmpdir): input_code = """\ @@ -117,7 +108,6 @@ def test_is_closed_transitivelly(self, tmpdir): same_file.close() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_escapes_with_assignment(self, tmpdir): input_code = """\ @@ -125,7 +115,6 @@ def test_escapes_with_assignment(self, tmpdir): Object.attribute = file """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_escapes_as_function_argument(self, tmpdir): input_code = """\ @@ -133,7 +122,6 @@ def test_escapes_as_function_argument(self, tmpdir): foo(file) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_escapes_returned(self, tmpdir): input_code = """\ @@ -142,7 +130,6 @@ def foo(): return file """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_escapes_yielded(self, tmpdir): input_code = """\ @@ -151,7 +138,6 @@ def foo(): yield file """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_escapes_outside_reference(self, tmpdir): input_code = """\ @@ -163,4 +149,3 @@ def test_escapes_outside_reference(self, tmpdir): out.read() """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_fix_deprecated_abstractproperty.py b/tests/codemods/test_fix_deprecated_abstractproperty.py index c4e7db1b..4645b729 100644 --- a/tests/codemods/test_fix_deprecated_abstractproperty.py +++ b/tests/codemods/test_fix_deprecated_abstractproperty.py @@ -132,6 +132,9 @@ def foo(self): pass """ lines_to_exclude = [4] - self.assert_no_change_line_excluded( - tmpdir, input_code, expected, lines_to_exclude + self.run_and_assert( + tmpdir, + input_code, + expected, + lines_to_exclude=lines_to_exclude, ) diff --git a/tests/codemods/test_fix_deprecated_logging_warn.py b/tests/codemods/test_fix_deprecated_logging_warn.py index 4b05f87d..41f3a1ea 100644 --- a/tests/codemods/test_fix_deprecated_logging_warn.py +++ b/tests/codemods/test_fix_deprecated_logging_warn.py @@ -28,7 +28,6 @@ def test_import(self, tmpdir, code): original_code = code.format("warn") new_code = code.format("warning") self.run_and_assert(tmpdir, original_code, new_code) - assert len(self.file_context.codemod_changes) == 1 @pytest.mark.parametrize( "code", @@ -47,7 +46,6 @@ def test_from_import(self, tmpdir, code): original_code = code.format("warn") new_code = code.format("warning") self.run_and_assert(tmpdir, original_code, new_code) - assert len(self.file_context.codemod_changes) == 1 @pytest.mark.parametrize( "input_code,expected_output", @@ -74,7 +72,6 @@ def test_from_import(self, tmpdir, code): ) def test_import_alias(self, tmpdir, input_code, expected_output): self.run_and_assert(tmpdir, input_code, expected_output) - assert len(self.file_context.codemod_changes) == 1 @pytest.mark.parametrize( "code", @@ -92,7 +89,6 @@ def test_import_alias(self, tmpdir, input_code, expected_output): ) def test_different_warn(self, tmpdir, code): self.run_and_assert(tmpdir, code, code) - assert len(self.file_context.codemod_changes) == 0 @pytest.mark.xfail(reason="Not currently supported") def test_log_as_arg(self, tmpdir): @@ -106,4 +102,3 @@ def some_function(logger): original_code = code.format("warn") new_code = code.format("warning") self.run_and_assert(tmpdir, original_code, new_code) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_flask_enable_csrf_protection.py b/tests/codemods/test_flask_enable_csrf_protection.py index 436016e3..5a922668 100644 --- a/tests/codemods/test_flask_enable_csrf_protection.py +++ b/tests/codemods/test_flask_enable_csrf_protection.py @@ -7,7 +7,7 @@ class TestFlaskEnableCSRFProtection(BaseCodemodTest): codemod = FlaskEnableCSRFProtection def test_name(self): - assert self.codemod.name() == "flask-enable-csrf-protection" + assert self.codemod.name == "flask-enable-csrf-protection" def test_simple(self, tmpdir): input_code = """\ @@ -22,7 +22,6 @@ def test_simple(self, tmpdir): csrf_app = CSRFProtect(app) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_alias(self, tmpdir): input_code = """\ @@ -37,7 +36,6 @@ def test_simple_alias(self, tmpdir): csrf_app = CSRFProtect(app) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_multiple(self, tmpdir): input_code = """\ @@ -54,8 +52,7 @@ def test_multiple(self, tmpdir): app2 = Flask(__name__) csrf_app2 = CSRFProtect(app2) """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 2 + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected), num_changes=2) def test_multiple_inline(self, tmpdir): input_code = """\ @@ -69,7 +66,6 @@ def test_multiple_inline(self, tmpdir): app = Flask(__name__); app2 = Flask(__name__); csrf_app = CSRFProtect(app); csrf_app2 = CSRFProtect(app2) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_multiple_inline_suite(self, tmpdir): input_code = """\ @@ -83,7 +79,6 @@ def test_multiple_inline_suite(self, tmpdir): if True: app = Flask(__name__); app2 = Flask(__name__); csrf_app = CSRFProtect(app); csrf_app2 = CSRFProtect(app2) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_protected(self, tmpdir): input_code = """\ @@ -94,4 +89,3 @@ def test_simple_protected(self, tmpdir): csrf_app = CSRFProtect(app) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_flask_json_response_type.py b/tests/codemods/test_flask_json_response_type.py index 77e9b1a3..6a675755 100644 --- a/tests/codemods/test_flask_json_response_type.py +++ b/tests/codemods/test_flask_json_response_type.py @@ -7,7 +7,7 @@ class TestFlaskJsonResponseType(BaseCodemodTest): codemod = FlaskJsonResponseType def test_name(self): - assert self.codemod.name() == "flask-json-response-type" + assert self.codemod.name == "flask-json-response-type" def test_simple(self, tmpdir): input_code = """\ @@ -33,7 +33,6 @@ def foo(request): return make_response(json_response, {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_indirect(self, tmpdir): input_code = """\ @@ -61,7 +60,6 @@ def foo(request): return response """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_tuple_arg(self, tmpdir): input_code = """\ @@ -87,7 +85,6 @@ def foo(request): return make_response((json_response, 404, {'Content-Type': 'application/json'})) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_return_json(self, tmpdir): input_code = """\ @@ -113,7 +110,6 @@ def foo(request): return (json_response, {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_tuple(self, tmpdir): input_code = """\ @@ -139,7 +135,6 @@ def foo(request): return (json_response, 404, {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_alias(self, tmpdir): input_code = """\ @@ -165,7 +160,6 @@ def foo(request): return response(json_response, {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_indirect_dict(self, tmpdir): input_code = """\ @@ -193,7 +187,6 @@ def foo(request): return make_response(json_response, headers) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_direct_return(self, tmpdir): input_code = """\ @@ -217,7 +210,6 @@ def foo(request): return make_response(json.dumps({ "user_input": request.GET.get("input") }), {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_tuple_dict_no_key(self, tmpdir): input_code = """\ @@ -243,7 +235,6 @@ def foo(request): return (make_response(json_response), {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_no_route_decorator(self, tmpdir): input_code = """\ @@ -257,7 +248,6 @@ def foo(request): return make_response(json_response) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_content_type_set(self, tmpdir): input_code = """\ @@ -272,7 +262,6 @@ def foo(request): return (make_response(json_response), {'Content-Type': 'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_content_type_maybe_set_star(self, tmpdir): input_code = """\ @@ -288,7 +277,6 @@ def foo(request): return (make_response(json_response), {**another_dict}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_content_type_maybe_set(self, tmpdir): input_code = """\ @@ -304,7 +292,6 @@ def foo(request): return (make_response(json_response), {key:'application/json'}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_no_json_dumps_input(self, tmpdir): input_code = """\ @@ -319,7 +306,6 @@ def foo(request): return make_response(dict_response) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_unknown_call_response(self, tmpdir): input_code = """\ @@ -335,4 +321,3 @@ def foo(request): return bar(dict_response) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_harden_pyyaml.py b/tests/codemods/test_harden_pyyaml.py index 1e30f282..cec378ba 100644 --- a/tests/codemods/test_harden_pyyaml.py +++ b/tests/codemods/test_harden_pyyaml.py @@ -12,7 +12,7 @@ class TestHardenPyyaml(BaseSemgrepCodemodTest): codemod = HardenPyyaml def test_name(self): - assert self.codemod.name() == "harden-pyyaml" + assert self.codemod.name == "harden-pyyaml" def test_safe_loader(self, tmpdir): input_code = """import yaml @@ -20,7 +20,6 @@ def test_safe_loader(self, tmpdir): deserialized_data = yaml.load(data, Loader=yaml.SafeLoader) """ self.run_and_assert(tmpdir, input_code, input_code) - assert len(self.file_context.codemod_changes) == 0 @loaders def test_all_unsafe_loaders_arg(self, tmpdir, loader): @@ -34,7 +33,6 @@ def test_all_unsafe_loaders_arg(self, tmpdir, loader): deserialized_data = yaml.load(data, yaml.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 @loaders def test_all_unsafe_loaders_kwarg(self, tmpdir, loader): @@ -48,7 +46,6 @@ def test_all_unsafe_loaders_kwarg(self, tmpdir, loader): deserialized_data = yaml.load(data, Loader=yaml.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_import_alias(self, tmpdir): input_code = """import yaml as yam @@ -64,7 +61,6 @@ def test_import_alias(self, tmpdir): deserialized_data = yam.load(data, Loader=yam.SafeLoader) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_preserve_custom_loader(self, tmpdir): expected = input_code = """ @@ -75,7 +71,6 @@ def test_preserve_custom_loader(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 0 def test_preserve_custom_loader_kwarg(self, tmpdir): expected = input_code = """ @@ -86,7 +81,6 @@ def test_preserve_custom_loader_kwarg(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 0 class TestHardenPyyamlClassInherit(BaseSemgrepCodemodTest): @@ -103,7 +97,6 @@ def __init__(self, *args, **kwargs): """ self.run_and_assert(tmpdir, input_code, input_code) - assert len(self.file_context.codemod_changes) == 0 @loaders def test_unsafe_loaders(self, tmpdir, loader): @@ -122,7 +115,6 @@ def __init__(self, *args, **kwargs): super(MyCustomLoader, self).__init__(*args, **kwargs) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_from_import(self, tmpdir): input_code = """\ @@ -140,7 +132,6 @@ def __init__(self, *args, **kwargs): super(MyCustomLoader, self).__init__(*args, **kwargs) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_import_alias(self, tmpdir): input_code = """\ @@ -158,7 +149,6 @@ def __init__(self, *args, **kwargs): super(MyCustomLoader, self).__init__(*args, **kwargs) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_multiple_bases(self, tmpdir): input_code = """\ @@ -180,7 +170,6 @@ def __init__(self, *args, **kwargs): super(MyCustomLoader, self).__init__(*args, **kwargs) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_different_yaml(self, tmpdir): input_code = """\ @@ -198,4 +187,3 @@ class MyLoader(SafeLoader, yaml.Loader): ... """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_harden_ruamel.py b/tests/codemods/test_harden_ruamel.py index 2a388a4d..3f997c9e 100644 --- a/tests/codemods/test_harden_ruamel.py +++ b/tests/codemods/test_harden_ruamel.py @@ -7,7 +7,7 @@ class TestHardenRuamel(BaseSemgrepCodemodTest): codemod = HardenRuamel def test_name(self): - assert self.codemod.name() == "harden-ruamel" + assert self.codemod.name == "harden-ruamel" @pytest.mark.parametrize("loader", ["YAML()", "YAML(typ='rt')", "YAML(typ='safe')"]) def test_safe(self, tmpdir, loader): diff --git a/tests/codemods/test_https_connection.py b/tests/codemods/test_https_connection.py index 00558658..5fbf20b8 100644 --- a/tests/codemods/test_https_connection.py +++ b/tests/codemods/test_https_connection.py @@ -11,7 +11,6 @@ def test_no_change(self, tmpdir): urllib3.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_simple(self, tmpdir): before = r"""import urllib3 @@ -23,7 +22,6 @@ def test_simple(self, tmpdir): urllib3.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_module_alias(self, tmpdir): before = r"""import urllib3 as module @@ -35,7 +33,6 @@ def test_module_alias(self, tmpdir): module.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_alias(self, tmpdir): before = r"""from urllib3 import HTTPConnectionPool as something @@ -47,7 +44,6 @@ def test_alias(self, tmpdir): urllib3.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_connectionpool(self, tmpdir): before = r"""import urllib3 @@ -59,7 +55,6 @@ def test_connectionpool(self, tmpdir): urllib3.connectionpool.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_connectionpool_alias(self, tmpdir): before = r"""import urllib3.connectionpool as pool @@ -71,7 +66,6 @@ def test_connectionpool_alias(self, tmpdir): pool.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_last_arg(self, tmpdir): before = r"""import urllib3 @@ -83,4 +77,3 @@ def test_last_arg(self, tmpdir): urllib3.HTTPSConnectionPool(None, None, None, None, None, None, None, None, None, _proxy_config = None) """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_jwt_decode_verify.py b/tests/codemods/test_jwt_decode_verify.py index 35dc1a1a..677b4bfa 100644 --- a/tests/codemods/test_jwt_decode_verify.py +++ b/tests/codemods/test_jwt_decode_verify.py @@ -7,7 +7,7 @@ class TestJwtDecodeVerify(BaseSemgrepCodemodTest): codemod = JwtDecodeVerify def test_name(self): - assert self.codemod.name() == "jwt-decode-verify" + assert self.codemod.name == "jwt-decode-verify" def test_import(self, tmpdir): input_code = """import jwt diff --git a/tests/codemods/test_limit_readline.py b/tests/codemods/test_limit_readline.py index c7d94363..bf6a7be2 100644 --- a/tests/codemods/test_limit_readline.py +++ b/tests/codemods/test_limit_readline.py @@ -6,7 +6,7 @@ class TestLimitReadline(BaseSemgrepCodemodTest): codemod = LimitReadline def test_name(self): - assert self.codemod.name() == "limit-readline" + assert self.codemod.name == "limit-readline" def test_file_readline(self, tmpdir): input_code = """file = open('some_file.txt') diff --git a/tests/codemods/test_literal_or_new_object_identity.py b/tests/codemods/test_literal_or_new_object_identity.py index fd165f15..91e2251f 100644 --- a/tests/codemods/test_literal_or_new_object_identity.py +++ b/tests/codemods/test_literal_or_new_object_identity.py @@ -7,7 +7,7 @@ class TestLiteralOrNewObjectIdentity(BaseCodemodTest): codemod = LiteralOrNewObjectIdentity def test_name(self): - assert self.codemod.name() == "literal-or-new-object-identity" + assert self.codemod.name == "literal-or-new-object-identity" def test_list(self, tmpdir): input_code = """\ @@ -17,7 +17,6 @@ def test_list(self, tmpdir): l == [1,2,3] """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_list_indirect(self, tmpdir): input_code = """\ @@ -29,7 +28,6 @@ def test_list_indirect(self, tmpdir): l == some_list """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_list_lhs(self, tmpdir): input_code = """\ @@ -39,7 +37,6 @@ def test_list_lhs(self, tmpdir): [1,2,3] == l """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_list_function(self, tmpdir): input_code = """\ @@ -49,7 +46,6 @@ def test_list_function(self, tmpdir): l == list({1,2,3}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_dict(self, tmpdir): input_code = """\ @@ -59,7 +55,6 @@ def test_dict(self, tmpdir): l == {1:2} """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_dict_function(self, tmpdir): input_code = """\ @@ -69,7 +64,6 @@ def test_dict_function(self, tmpdir): l == dict({1,2,3}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_tuple(self, tmpdir): input_code = """\ @@ -79,7 +73,6 @@ def test_tuple(self, tmpdir): l == (1,2,3) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_tuple_function(self, tmpdir): input_code = """\ @@ -89,7 +82,6 @@ def test_tuple_function(self, tmpdir): l == tuple({1,2,3}) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_set(self, tmpdir): input_code = """\ @@ -99,7 +91,6 @@ def test_set(self, tmpdir): l == {1,2,3} """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_set_function(self, tmpdir): input_code = """\ @@ -109,7 +100,6 @@ def test_set_function(self, tmpdir): l == set([1,2,3]) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_int(self, tmpdir): input_code = """\ @@ -119,7 +109,6 @@ def test_int(self, tmpdir): l == 1 """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_float(self, tmpdir): input_code = """\ @@ -129,7 +118,6 @@ def test_float(self, tmpdir): l == 1.0 """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_imaginary(self, tmpdir): input_code = """\ @@ -139,7 +127,6 @@ def test_imaginary(self, tmpdir): l == 1j """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_str(self, tmpdir): input_code = """\ @@ -149,7 +136,6 @@ def test_str(self, tmpdir): l == '1' """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_fstr(self, tmpdir): input_code = """\ @@ -159,7 +145,6 @@ def test_fstr(self, tmpdir): l == f'1' """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_concatenated_str(self, tmpdir): input_code = """\ @@ -169,7 +154,6 @@ def test_concatenated_str(self, tmpdir): l == '1' ',2' """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_negative(self, tmpdir): input_code = """\ @@ -179,18 +163,15 @@ def test_negative(self, tmpdir): l != [1,2,3] """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_do_nothing(self, tmpdir): input_code = """\ l == [1,2,3] """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_do_nothing_negative(self, tmpdir): input_code = """\ l != [1,2,3] """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_lxml_safe_parameter_defaults.py b/tests/codemods/test_lxml_safe_parameter_defaults.py index 06753904..e2884280 100644 --- a/tests/codemods/test_lxml_safe_parameter_defaults.py +++ b/tests/codemods/test_lxml_safe_parameter_defaults.py @@ -11,7 +11,7 @@ class TestLxmlSafeParserDefaults(BaseSemgrepCodemodTest): codemod = LxmlSafeParserDefaults def test_name(self): - assert self.codemod.name() == "safe-lxml-parser-defaults" + assert self.codemod.name == "safe-lxml-parser-defaults" @each_class def test_import(self, tmpdir, klass): diff --git a/tests/codemods/test_lxml_safe_parsing.py b/tests/codemods/test_lxml_safe_parsing.py index c23a26d2..3400b166 100644 --- a/tests/codemods/test_lxml_safe_parsing.py +++ b/tests/codemods/test_lxml_safe_parsing.py @@ -9,7 +9,7 @@ class TestLxmlSafeParsing(BaseSemgrepCodemodTest): codemod = LxmlSafeParsing def test_name(self): - assert self.codemod.name() == "safe-lxml-parsing" + assert self.codemod.name == "safe-lxml-parsing" @each_func def test_import(self, tmpdir, func): diff --git a/tests/codemods/test_numpy_nan_equality.py b/tests/codemods/test_numpy_nan_equality.py index f1440788..d3fd3e3d 100644 --- a/tests/codemods/test_numpy_nan_equality.py +++ b/tests/codemods/test_numpy_nan_equality.py @@ -7,7 +7,7 @@ class TestNumpyNanEquality(BaseCodemodTest): codemod = NumpyNanEquality def test_name(self): - assert self.codemod.name() == "numpy-nan-equality" + assert self.codemod.name == "numpy-nan-equality" def test_simple(self, tmpdir): input_code = """\ @@ -21,7 +21,6 @@ def test_simple(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_inequality(self, tmpdir): input_code = """\ @@ -35,7 +34,6 @@ def test_simple_inequality(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_inequality_2(self, tmpdir): input_code = """\ @@ -49,7 +47,6 @@ def test_simple_inequality_2(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_parenthesis(self, tmpdir): input_code = """\ @@ -63,7 +60,6 @@ def test_simple_parenthesis(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_conjunction(self, tmpdir): input_code = """\ @@ -76,8 +72,7 @@ def test_conjunction(self, tmpdir): if not numpy.isnan(a) and not numpy.isnan(b): pass """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 2 + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected), num_changes=2) def test_from_numpy(self, tmpdir): input_code = """\ @@ -92,7 +87,6 @@ def test_from_numpy(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_left(self, tmpdir): input_code = """\ @@ -106,7 +100,6 @@ def test_simple_left(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_alias(self, tmpdir): input_code = """\ @@ -120,7 +113,6 @@ def test_alias(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_multiple_comparisons(self, tmpdir): input_code = """\ @@ -129,7 +121,6 @@ def test_multiple_comparisons(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_not_numpy(self, tmpdir): input_code = """\ @@ -138,7 +129,6 @@ def test_not_numpy(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_numpy_other_operator(self, tmpdir): input_code = """\ @@ -147,4 +137,3 @@ def test_numpy_other_operator(self, tmpdir): pass """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_order_imports.py b/tests/codemods/test_order_imports.py index 7afcdbb2..5e503bef 100644 --- a/tests/codemods/test_order_imports.py +++ b/tests/codemods/test_order_imports.py @@ -10,7 +10,6 @@ def test_no_change(self, tmpdir): from b import c """ self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_separate_from_imports_and_regular(self, tmpdir): before = r"""import y @@ -21,7 +20,6 @@ def test_separate_from_imports_and_regular(self, tmpdir): import y from a import c""" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_consolidate_from_imports(self, tmpdir): before = r"""from a import a1 @@ -30,7 +28,6 @@ def test_consolidate_from_imports(self, tmpdir): after = r"""from a import a1, a2, a3""" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_order_blocks_separately(self, tmpdir): before = r"""import x @@ -45,8 +42,7 @@ def test_order_blocks_separately(self, tmpdir): import b import y""" - self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 2 + self.run_and_assert(tmpdir, before, after, num_changes=2) def test_preserve_comments(self, tmpdir): before = r"""# do not move @@ -69,7 +65,6 @@ def test_preserve_comments(self, tmpdir): """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_handle_star_imports(self, tmpdir): before = r"""from a import x @@ -82,7 +77,6 @@ def test_handle_star_imports(self, tmpdir): from a import b, x""" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_handle_composite_and_relative_imports(self, tmpdir): before = r"""from . import a @@ -93,7 +87,6 @@ def test_handle_composite_and_relative_imports(self, tmpdir): from . import a""" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_natural_order(self, tmpdir): before = """from a import Object11 @@ -104,7 +97,6 @@ def test_natural_order(self, tmpdir): """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_wont_remove_unused_future(self, tmpdir): before = """from __future__ import absolute_import @@ -113,7 +105,6 @@ def test_wont_remove_unused_future(self, tmpdir): after = before self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 0 def test_organize_by_sections(self, tmpdir): before = """from codemodder.codemods.transformations.clean_imports import CleanImports @@ -134,7 +125,6 @@ def test_organize_by_sections(self, tmpdir): """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_will_ignore_non_top_level(self, tmpdir): before = """import global2 @@ -159,7 +149,6 @@ def f(): import global3 """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_it_can_change_behavior(self, tmpdir): # note that c will change from b to e due to the sort @@ -175,4 +164,3 @@ def test_it_can_change_behavior(self, tmpdir): c() """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_process_creation_sandbox.py b/tests/codemods/test_process_creation_sandbox.py index 51309755..bafcad04 100644 --- a/tests/codemods/test_process_creation_sandbox.py +++ b/tests/codemods/test_process_creation_sandbox.py @@ -1,16 +1,19 @@ import pytest +import mock + from codemodder.dependency import Security from core_codemods.process_creation_sandbox import ProcessSandbox from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest +@mock.patch("codemodder.codemods.api.FileContext.add_dependency") class TestProcessCreationSandbox(BaseSemgrepCodemodTest): codemod = ProcessSandbox - def test_name(self): - assert self.codemod.name() == "sandbox-process-creation" + def test_name(self, _): + assert self.codemod.name == "sandbox-process-creation" - def test_import_subprocess(self, tmpdir): + def test_import_subprocess(self, adds_dependency, tmpdir): input_code = """ import subprocess @@ -25,9 +28,9 @@ def test_import_subprocess(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) - def test_import_alias(self, tmpdir): + def test_import_alias(self, adds_dependency, tmpdir): input_code = """ import subprocess as sub @@ -42,9 +45,9 @@ def test_import_alias(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) - def test_from_subprocess(self, tmpdir): + def test_from_subprocess(self, adds_dependency, tmpdir): input_code = """ from subprocess import run @@ -59,9 +62,9 @@ def test_from_subprocess(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) - def test_subprocess_nameerror(self, tmpdir): + def test_subprocess_nameerror(self, _, tmpdir): input_code = """ subprocess.run("echo 'hi'", shell=True) @@ -107,11 +110,13 @@ def test_subprocess_nameerror(self, tmpdir): ), ], ) - def test_other_import_untouched(self, tmpdir, input_code, expected): + def test_other_import_untouched( + self, adds_dependency, tmpdir, input_code, expected + ): self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) - def test_multifunctions(self, tmpdir): + def test_multifunctions(self, adds_dependency, tmpdir): # Test that subprocess methods that aren't part of the codemod are not changed. # If we add the function as one of # our codemods, this test would change. @@ -129,9 +134,9 @@ def test_multifunctions(self, tmpdir): subprocess.check_output(["ls", "-l"])""" self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) - def test_custom_run(self, tmpdir): + def test_custom_run(self, _, tmpdir): input_code = """ from app_funcs import run @@ -139,7 +144,7 @@ def test_custom_run(self, tmpdir): expected = input_code self.run_and_assert(tmpdir, input_code, expected) - def test_subprocess_call(self, tmpdir): + def test_subprocess_call(self, adds_dependency, tmpdir): input_code = """ import subprocess @@ -152,9 +157,9 @@ def test_subprocess_call(self, tmpdir): safe_command.run(subprocess.call, ["ls", "-l"]) """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) - def test_subprocess_popen(self, tmpdir): + def test_subprocess_popen(self, adds_dependency, tmpdir): input_code = """ import subprocess @@ -167,4 +172,4 @@ def test_subprocess_popen(self, tmpdir): safe_command.run(subprocess.Popen, ["ls", "-l"]) """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + adds_dependency.assert_called_once_with(Security) diff --git a/tests/codemods/test_remove_debug_breakpoint.py b/tests/codemods/test_remove_debug_breakpoint.py index f384c1c2..c97597ae 100644 --- a/tests/codemods/test_remove_debug_breakpoint.py +++ b/tests/codemods/test_remove_debug_breakpoint.py @@ -6,7 +6,7 @@ class TestRemoveDebugBreakpoint(BaseCodemodTest): codemod = RemoveDebugBreakpoint def test_name(self): - assert self.codemod.name() == "remove-debug-breakpoint" + assert self.codemod.name == "remove-debug-breakpoint" def test_builtin_breakpoint(self, tmpdir): input_code = """\ @@ -21,7 +21,6 @@ def something(): something() """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_builtin_breakpoint_multiple_statements(self, tmpdir): input_code = """\ @@ -37,7 +36,6 @@ def something(): something() """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_inline_pdb(self, tmpdir): input_code = """\ @@ -52,7 +50,6 @@ def something(): something() """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_pdb_import(self, tmpdir): input_code = """\ @@ -68,7 +65,6 @@ def something(): something() """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_pdb_from_import(self, tmpdir): input_code = """\ @@ -84,7 +80,6 @@ def something(): something() """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_exclude_line(self, tmpdir): input_code = expected = """\ @@ -92,6 +87,9 @@ def test_exclude_line(self, tmpdir): breakpoint() """ lines_to_exclude = [2] - self.assert_no_change_line_excluded( - tmpdir, input_code, expected, lines_to_exclude + self.run_and_assert( + tmpdir, + input_code, + expected, + lines_to_exclude=lines_to_exclude, ) diff --git a/tests/codemods/test_remove_module_global.py b/tests/codemods/test_remove_module_global.py index 12e95f79..2e9f3e02 100644 --- a/tests/codemods/test_remove_module_global.py +++ b/tests/codemods/test_remove_module_global.py @@ -7,7 +7,7 @@ class TestJwtDecodeVerify(BaseCodemodTest): codemod = RemoveModuleGlobal def test_name(self): - assert self.codemod.name() == "remove-module-global" + assert self.codemod.name == "remove-module-global" def test_simple(self, tmpdir): input_code = """\ @@ -18,7 +18,6 @@ def test_simple(self, tmpdir): price = 25 """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_reassigned(self, tmpdir): input_code = """\ @@ -33,7 +32,6 @@ def test_reassigned(self, tmpdir): price = 30 """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_attr_call(self, tmpdir): input_code = """\ @@ -50,7 +48,6 @@ class Price: PRICE.__repr__ """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_correct_scope(self, tmpdir): input_code = """\ @@ -60,4 +57,3 @@ def change_price(): price = 30 """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_remove_unnecessary_f_str.py b/tests/codemods/test_remove_unnecessary_f_str.py index 6e70eede..e2b31c91 100644 --- a/tests/codemods/test_remove_unnecessary_f_str.py +++ b/tests/codemods/test_remove_unnecessary_f_str.py @@ -7,37 +7,38 @@ class TestFStr(BaseCodemodTest): def test_no_change(self, tmpdir): before = r""" -good: str = "good" -good: str = f"with_arg{arg}" -good = "good{arg1}".format(1234) -good = "good".format() -good = "good" % {} -good = "good" % () -good = rf"good\d+{bar}" -good = f"wow i don't have args but don't mess my braces {{ up }}" -""" + good: str = "good" + good: str = f"with_arg{arg}" + good = "good{arg1}".format(1234) + good = "good".format() + good = "good" % {} + good = "good" % () + good = rf"good\d+{bar}" + good = f"wow i don't have args but don't mess my braces {{ up }}" + """ self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_change(self, tmpdir): before = r""" -bad: str = f"bad" + "bad" -bad: str = f'bad' -bad: str = rf'bad\d+' -""" + bad: str = f"bad" + "bad" + bad: str = f'bad' + bad: str = rf'bad\d+' + """ after = r""" -bad: str = "bad" + "bad" -bad: str = 'bad' -bad: str = r'bad\d+' -""" - self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 3 + bad: str = "bad" + "bad" + bad: str = 'bad' + bad: str = r'bad\d+' + """ + self.run_and_assert(tmpdir, before, after, num_changes=3) def test_exclude_line(self, tmpdir): input_code = expected = """\ bad: str = f"bad" + "bad" """ lines_to_exclude = [1] - self.assert_no_change_line_excluded( - tmpdir, input_code, expected, lines_to_exclude + self.run_and_assert( + tmpdir, + input_code, + expected, + lines_to_exclude=lines_to_exclude, ) diff --git a/tests/codemods/test_remove_unused_imports.py b/tests/codemods/test_remove_unused_imports.py index 9e675b5c..9f2e65fc 100644 --- a/tests/codemods/test_remove_unused_imports.py +++ b/tests/codemods/test_remove_unused_imports.py @@ -14,7 +14,6 @@ def test_no_change(self, tmpdir): c """ self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_change(self, tmpdir): before = r"""import a @@ -22,7 +21,6 @@ def test_change(self, tmpdir): after = r""" """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_remove_import(self, tmpdir): before = r"""import a @@ -34,7 +32,6 @@ def test_remove_import(self, tmpdir): """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_remove_single_from_import(self, tmpdir): before = r"""from b import c, d @@ -45,7 +42,6 @@ def test_remove_single_from_import(self, tmpdir): c """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_remove_from_import(self, tmpdir): before = r"""from b import c @@ -53,7 +49,6 @@ def test_remove_from_import(self, tmpdir): after = "\n" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_remove_inner_import(self, tmpdir): before = r"""import a @@ -67,7 +62,6 @@ def something(): """ self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_no_import_star_removal(self, tmpdir): before = r"""import a @@ -80,17 +74,14 @@ def test_keep_format(self, tmpdir): before = "from a import b,c,d \nprint(b)\nprint(d)" after = "from a import b,d \nprint(b)\nprint(d)" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_dont_remove_if_noqa_before(self, tmpdir): before = "import a\n# noqa\nimport b\na()" self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_dont_remove_if_noqa_trailing(self, tmpdir): before = "import a\nimport b # noqa\na()" self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_dont_remove_if_noqa_trailing_multiline(self, tmpdir): before = dedent( @@ -101,34 +92,28 @@ def test_dont_remove_if_noqa_trailing_multiline(self, tmpdir): ) self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_dont_remove_if_pylint_disable(self, tmpdir): before = "import a\nimport b # pylint: disable=W0611\na()" self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_dont_remove_if_pylint_disable_next(self, tmpdir): before = ( "import a\n# pylint: disable-next=no-member, unused-import\nimport b\na()" ) self.run_and_assert(tmpdir, before, before) - assert len(self.file_context.codemod_changes) == 0 def test_ignore_init_files(self, tmpdir): before = "import a" tmp_file_path = Path(tmpdir / "__init__.py") - self.run_and_assert_filepath(tmpdir, tmp_file_path, before, before) - assert len(self.file_context.codemod_changes) == 0 + self.run_and_assert(tmpdir, before, before, files=[tmp_file_path]) def test_no_pyling_pragma_in_comment_trailing(self, tmpdir): before = "import a # bogus: no-pragma" after = "" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 def test_no_pyling_pragma_in_comment_before(self, tmpdir): before = "#header\nprint('hello')\n# bogus: no-pragma\nimport a " after = "#header\nprint('hello')" self.run_and_assert(tmpdir, before, after) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_replace_flask_send_file.py b/tests/codemods/test_replace_flask_send_file.py index 6a40744c..d603eb26 100644 --- a/tests/codemods/test_replace_flask_send_file.py +++ b/tests/codemods/test_replace_flask_send_file.py @@ -8,7 +8,7 @@ class TestReplaceFlaskSendFile(BaseCodemodTest): codemod = ReplaceFlaskSendFile def test_name(self): - assert self.codemod.name() == "replace-flask-send-file" + assert self.codemod.name == "replace-flask-send-file" def test_direct_string(self, tmpdir): input_code = """\ @@ -32,7 +32,6 @@ def download_file(name): return flask.send_from_directory((p := Path(f'path/to/{name}.txt')).parent, p.name) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_direct_simple_string(self, tmpdir): input_code = """\ @@ -56,7 +55,6 @@ def download_file(name): return flask.send_from_directory((p := Path('path/to/file.txt')).parent, p.name) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_direct_string_convert_arguments(self, tmpdir): input_code = """\ @@ -80,7 +78,6 @@ def download_file(name): return flask.send_from_directory((p := Path(f'path/to/{name}.txt')).parent, p.name, mimetype = None, as_attachment = False, download_name = True) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_direct_path(self, tmpdir): input_code = """\ @@ -105,7 +102,6 @@ def download_file(name): return flask.send_from_directory((p := Path(f'path/to/{name}.txt')).parent, p.name) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_indirect_path(self, tmpdir): input_code = """\ @@ -132,7 +128,6 @@ def download_file(name): return flask.send_from_directory(path.parent, path.name) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_indirect_path_alias(self, tmpdir): input_code = """\ @@ -159,7 +154,6 @@ def download_file(name): return flask.send_from_directory(path.parent, path.name) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_indirect_string(self, tmpdir): input_code = """\ @@ -185,7 +179,6 @@ def download_file(name): return flask.send_from_directory((p := Path(path)).parent, p.name) """ self.run_and_assert(tmpdir, input_code, expected) - assert len(self.file_context.codemod_changes) == 1 def test_unknown_type(self, tmpdir): input_code = """\ @@ -198,4 +191,3 @@ def download_file(name): return send_file(name) """ self.run_and_assert(tmpdir, input_code, input_code) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_request_verify.py b/tests/codemods/test_request_verify.py index 28f00b1b..503c4d4e 100644 --- a/tests/codemods/test_request_verify.py +++ b/tests/codemods/test_request_verify.py @@ -9,7 +9,7 @@ class TestRequestsVerify(BaseSemgrepCodemodTest): codemod = RequestsVerify def test_name(self): - assert self.codemod.name() == "requests-verify" + assert self.codemod.name == "requests-verify" @pytest.mark.parametrize("func", REQUEST_FUNCS) def test_default_verify(self, tmpdir, func): diff --git a/tests/codemods/test_secure_flask_cookie.py b/tests/codemods/test_secure_flask_cookie.py index 923a829a..b7304aaa 100644 --- a/tests/codemods/test_secure_flask_cookie.py +++ b/tests/codemods/test_secure_flask_cookie.py @@ -10,7 +10,7 @@ class TestSecureFlaskCookie(BaseSemgrepCodemodTest): codemod = SecureFlaskCookie def test_name(self): - assert self.codemod.name() == "secure-flask-cookie" + assert self.codemod.name == "secure-flask-cookie" @each_func def test_import(self, tmpdir, func): diff --git a/tests/codemods/test_secure_flask_session_config.py b/tests/codemods/test_secure_flask_session_config.py index 60c834ad..c6038081 100644 --- a/tests/codemods/test_secure_flask_session_config.py +++ b/tests/codemods/test_secure_flask_session_config.py @@ -8,7 +8,7 @@ class TestSecureFlaskSessionConfig(BaseCodemodTest): codemod = SecureFlaskSessionConfig def test_name(self): - assert self.codemod.name() == "secure-flask-session-configuration" + assert self.codemod.name == "secure-flask-session-configuration" def test_no_flask_app(self, tmpdir): input_code = """\ @@ -19,7 +19,6 @@ def test_no_flask_app(self, tmpdir): response.set_cookie("name", "value") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_app_defined_separate_module(self, tmpdir): # TODO: test this as an integration test with two real modules @@ -29,7 +28,6 @@ def test_app_defined_separate_module(self, tmpdir): app.config["SESSION_COOKIE_SECURE"] = False """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_app_not_assigned(self, tmpdir): input_code = """\ @@ -39,7 +37,6 @@ def test_app_not_assigned(self, tmpdir): print(1) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_app_accessed_config_not_called(self, tmpdir): input_code = """\ @@ -50,7 +47,6 @@ def test_app_accessed_config_not_called(self, tmpdir): # more code """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_app_update_no_keyword(self, tmpdir): input_code = """\ @@ -63,8 +59,6 @@ def foo(test_config=None): app.config.update(test_config) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - print(self.file_context.codemod_changes) - assert len(self.file_context.codemod_changes) == 0 def test_from_import(self, tmpdir): input_code = """\ @@ -82,7 +76,6 @@ def test_from_import(self, tmpdir): app.config.update(SESSION_COOKIE_SECURE=True) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - assert len(self.file_context.codemod_changes) == 1 def test_import_alias(self, tmpdir): input_code = """\ @@ -100,7 +93,6 @@ def test_import_alias(self, tmpdir): app.config.update(SESSION_COOKIE_SECURE=True) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - assert len(self.file_context.codemod_changes) == 1 def test_annotated_assign(self, tmpdir): input_code = """\ @@ -118,7 +110,6 @@ def test_annotated_assign(self, tmpdir): app.config.update(SESSION_COOKIE_SECURE=True) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - assert len(self.file_context.codemod_changes) == 1 def test_other_assignment_type(self, tmpdir): input_code = """\ @@ -132,46 +123,54 @@ class AppStore: store.app.config.update(SESSION_COOKIE_SECURE=False) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 @pytest.mark.parametrize( - "config_lines,expected_config_lines", + "config_lines,expected_config_lines,num_changes", [ ( """app.config""", """app.config""", + 0, ), ( """app.config["TESTING"] = True""", """app.config["TESTING"] = True""", + 0, ), ( """app.config.testing = True""", """app.config.testing = True""", + 0, ), ( """app.config.update(SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax')""", """app.config.update(SESSION_COOKIE_SECURE=True, SESSION_COOKIE_SAMESITE='Lax')""", + 0, ), ( """app.config.update(SESSION_COOKIE_SECURE=True)""", """app.config.update(SESSION_COOKIE_SECURE=True)""", + 0, ), ( """app.config.update(SESSION_COOKIE_HTTPONLY=True)""", """app.config.update(SESSION_COOKIE_HTTPONLY=True)""", + 0, ), ( """app.config.update(SESSION_COOKIE_HTTPONLY=False)""", """app.config.update(SESSION_COOKIE_HTTPONLY=True)""", + 1, ), ( """app.config['SESSION_COOKIE_SECURE'] = False""", """app.config['SESSION_COOKIE_SECURE'] = True""", + 1, ), ( """app.config['SESSION_COOKIE_HTTPONLY'] = False""", """app.config['SESSION_COOKIE_HTTPONLY'] = True""", + 1, ), ( """app.config["SESSION_COOKIE_SECURE"] = True @@ -180,6 +179,7 @@ class AppStore: """app.config["SESSION_COOKIE_SECURE"] = True app.config["SESSION_COOKIE_SAMESITE"] = "Lax" """, + 0, ), ( """app.config["SESSION_COOKIE_SECURE"] = False @@ -188,6 +188,7 @@ class AppStore: """app.config["SESSION_COOKIE_SECURE"] = True app.config["SESSION_COOKIE_SAMESITE"] = "Lax" """, + 2, ), ( """app.config["SESSION_COOKIE_SECURE"] = False @@ -198,6 +199,7 @@ class AppStore: app.config["SESSION_COOKIE_HTTPONLY"] = True app.config["SESSION_COOKIE_SAMESITE"] = "Strict" """, + 2, ), ( """app.config["SESSION_COOKIE_SECURE"] = False @@ -206,11 +208,12 @@ class AppStore: """app.config["SESSION_COOKIE_SECURE"] = True app.config["SESSION_COOKIE_SECURE"] = True """, + 1, ), ], ) def test_config_accessed_variations( - self, tmpdir, config_lines, expected_config_lines + self, tmpdir, config_lines, expected_config_lines, num_changes ): input_code = f"""import flask app = flask.Flask(__name__) @@ -222,7 +225,9 @@ def test_config_accessed_variations( app.secret_key = "dev" {expected_config_lines} """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(expected_output)) + self.run_and_assert( + tmpdir, dedent(input_code), dedent(expected_output), num_changes=num_changes + ) @pytest.mark.skip() def test_func_scope(self, tmpdir): @@ -246,4 +251,3 @@ def configure(): # either within configure() call or after it's called """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output)) - assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_secure_random.py b/tests/codemods/test_secure_random.py index fc1d3d2d..6944722c 100644 --- a/tests/codemods/test_secure_random.py +++ b/tests/codemods/test_secure_random.py @@ -7,74 +7,84 @@ class TestSecureRandom(BaseSemgrepCodemodTest): codemod = SecureRandom def test_name(self): - assert self.codemod.name() == "secure-random" + assert self.codemod.name == "secure-random" def test_import_random(self, tmpdir): - input_code = """import random + input_code = """ + import random -random.random() -var = "hello" -""" - expected_output = """import secrets + random.random() + var = "hello" + """ + expected_output = """ + import secrets -secrets.SystemRandom().random() -var = "hello" -""" + secrets.SystemRandom().random() + var = "hello" + """ self.run_and_assert(tmpdir, input_code, expected_output) def test_from_random(self, tmpdir): - input_code = """from random import random - -random() -var = "hello" -""" - expected_output = """import secrets - -secrets.SystemRandom().random() -var = "hello" -""" + input_code = """ + from random import random + + random() + var = "hello" + """ + expected_output = """ + import secrets + + secrets.SystemRandom().random() + var = "hello" + """ self.run_and_assert(tmpdir, input_code, expected_output) def test_random_alias(self, tmpdir): - input_code = """import random as alleatory - -alleatory.random() -var = "hello" -""" - expected_output = """import secrets - -secrets.SystemRandom().random() -var = "hello" -""" + input_code = """ + import random as alleatory + + alleatory.random() + var = "hello" + """ + expected_output = """ + import secrets + + secrets.SystemRandom().random() + var = "hello" + """ self.run_and_assert(tmpdir, input_code, expected_output) @pytest.mark.parametrize( "input_code,expected_output", [ ( - """import random - -random.randint(0, 10) -var = "hello" -""", - """import secrets - -secrets.SystemRandom().randint(0, 10) -var = "hello" -""", + """ + import random + + random.randint(0, 10) + var = "hello" + """, + """ + import secrets + + secrets.SystemRandom().randint(0, 10) + var = "hello" + """, ), ( - """from random import randint - -randint(0, 10) -var = "hello" -""", - """import secrets - -secrets.SystemRandom().randint(0, 10) -var = "hello" -""", + """ + from random import randint + + randint(0, 10) + var = "hello" + """, + """ + import secrets + + secrets.SystemRandom().randint(0, 10) + var = "hello" + """, ), ], ) @@ -82,48 +92,54 @@ def test_random_randint(self, tmpdir, input_code, expected_output): self.run_and_assert(tmpdir, input_code, expected_output) def test_multiple_calls(self, tmpdir): - input_code = """import random - -random.random() -random.randint() -var = "hello" -""" - expected_output = """import secrets - -secrets.SystemRandom().random() -secrets.SystemRandom().randint() -var = "hello" -""" - self.run_and_assert(tmpdir, input_code, expected_output) + input_code = """ + import random + + random.random() + random.randint() + var = "hello" + """ + expected_output = """ + import secrets + + secrets.SystemRandom().random() + secrets.SystemRandom().randint() + var = "hello" + """ + self.run_and_assert(tmpdir, input_code, expected_output, num_changes=2) @pytest.mark.parametrize( "input_code,expected_output", [ ( - """import random -import csv -random.random() -csv.excel -""", - """import csv -import secrets - -secrets.SystemRandom().random() -csv.excel -""", + """ + import random + import csv + random.random() + csv.excel + """, + """ + import csv + import secrets + + secrets.SystemRandom().random() + csv.excel + """, ), ( - """import random -from csv import excel -random.random() -excel -""", - """from csv import excel -import secrets - -secrets.SystemRandom().random() -excel -""", + """ + import random + from csv import excel + random.random() + excel + """, + """ + from csv import excel + import secrets + + secrets.SystemRandom().random() + excel + """, ), ], ) @@ -131,9 +147,10 @@ def test_random_other_import_untouched(self, tmpdir, input_code, expected_output self.run_and_assert(tmpdir, input_code, expected_output) def test_random_nameerror(self, tmpdir): - input_code = """random.random() + input_code = """ + random.random() -import random""" + import random""" expected_output = input_code self.run_and_assert(tmpdir, input_code, expected_output) @@ -141,17 +158,19 @@ def test_random_multifunctions(self, tmpdir): # Test that `random` import isn't removed if code uses part of the random # library that isn't part of our codemods. If we add the function as one of # our codemods, this test would change. - input_code = """import random + input_code = """ + import random -random.random() -random.__all__ -""" + random.random() + random.__all__ + """ - expected_output = """import random -import secrets + expected_output = """ + import random + import secrets -secrets.SystemRandom().random() -random.__all__ -""" + secrets.SystemRandom().random() + random.__all__ + """ self.run_and_assert(tmpdir, input_code, expected_output) diff --git a/tests/codemods/test_sql_parameterization.py b/tests/codemods/test_sql_parameterization.py index e7c72184..5256d3ba 100644 --- a/tests/codemods/test_sql_parameterization.py +++ b/tests/codemods/test_sql_parameterization.py @@ -7,7 +7,7 @@ class TestSQLQueryParameterization(BaseCodemodTest): codemod = SQLQueryParameterization def test_name(self): - assert self.codemod.name() == "sql-parameterization" + assert self.codemod.name == "sql-parameterization" def test_simple(self, tmpdir): input_code = """\ @@ -27,7 +27,6 @@ def test_simple(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name =?", (name, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_multiple(self, tmpdir): input_code = """\ @@ -49,7 +48,6 @@ def test_multiple(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name =?" + r" AND phone =?", (name, phone, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_with_quotes_in_middle(self, tmpdir): input_code = """\ @@ -69,7 +67,6 @@ def test_simple_with_quotes_in_middle(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name =?", ('user_{0}{1}'.format(name, r"_system"), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_can_deal_with_multiple_variables(self, tmpdir): input_code = """\ @@ -94,7 +91,6 @@ def foo(self, cursor, name, phone): return cursor.execute(a + b + c, (name, phone, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_if(self, tmpdir): input_code = """\ @@ -114,7 +110,6 @@ def test_simple_if(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name =?", (('Jenny' if True else name), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_multiple_escaped_quote(self, tmpdir): input_code = """\ @@ -136,7 +131,6 @@ def test_multiple_escaped_quote(self, tmpdir): cursor.execute('SELECT * from USERS WHERE name =?' + ' AND phone =?', (name, phone, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_simple_concatenated_strings(self, tmpdir): input_code = """\ @@ -156,7 +150,6 @@ def test_simple_concatenated_strings(self, tmpdir): cursor.execute("SELECT * from USERS" "WHERE name =?", (name, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 class TestSQLQueryParameterizationFormattedString(BaseCodemodTest): @@ -180,7 +173,6 @@ def test_formatted_string_simple(self, tmpdir): cursor.execute(f"SELECT * from USERS WHERE name=?", (name, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_formatted_string_quote_in_middle(self, tmpdir): input_code = """\ @@ -200,7 +192,6 @@ def test_formatted_string_quote_in_middle(self, tmpdir): cursor.execute(f"SELECT * from USERS WHERE name=?", ('user_{0}_admin'.format(name), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_formatted_string_with_literal(self, tmpdir): input_code = """\ @@ -220,7 +211,6 @@ def test_formatted_string_with_literal(self, tmpdir): cursor.execute(f"SELECT * from USERS WHERE name=?", ('{0}_{1}'.format(name, 1+2), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_formatted_string_nested(self, tmpdir): input_code = """\ @@ -240,7 +230,6 @@ def test_formatted_string_nested(self, tmpdir): cursor.execute(f"SELECT * from USERS WHERE name={f"?"}", (name, )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_formatted_string_concat_mixed(self, tmpdir): input_code = """\ @@ -260,7 +249,6 @@ def test_formatted_string_concat_mixed(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name=?", ('{0}_{1}'.format(name, b'123'), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 def test_multiple_expressions_injection(self, tmpdir): input_code = """\ @@ -280,7 +268,6 @@ def test_multiple_expressions_injection(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name =?", ('{0}_username'.format(name), )) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) - assert len(self.file_context.codemod_changes) == 1 class TestSQLQueryParameterizationNegative(BaseCodemodTest): @@ -299,7 +286,6 @@ def foo(self, cursor, name, phone): return cursor.execute(a + b + c) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_wont_mess_with_byte_strings(self, tmpdir): input_code = """\ @@ -310,7 +296,6 @@ def test_wont_mess_with_byte_strings(self, tmpdir): cursor.execute("SELECT * from USERS WHERE " + b"name ='" + str(1234) + b"'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_wont_parameterize_literals(self, tmpdir): input_code = """\ @@ -321,7 +306,6 @@ def test_wont_parameterize_literals(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name ='" + str(1234) + "'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_wont_parameterize_literals_if(self, tmpdir): input_code = """\ @@ -332,7 +316,6 @@ def test_wont_parameterize_literals_if(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name ='" + ('Jenny' if True else 'Lorelei') + "'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_will_ignore_escaped_quote(self, tmpdir): input_code = """\ @@ -343,7 +326,6 @@ def test_will_ignore_escaped_quote(self, tmpdir): cursor.execute("SELECT * from USERS WHERE name ='Jenny\'s username'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_already_has_parameters(self, tmpdir): input_code = """\ @@ -357,7 +339,6 @@ def foo(self, cursor, name, phone): return cursor.execute(a + b + c, (phone,)) """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_wont_change_class_attribute(self, tmpdir): # query may be accesed from outside the module by importing A @@ -373,7 +354,6 @@ def foo(self, name, cursor): return cursor.execute(query + name + "'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 def test_wont_change_module_variable(self, tmpdir): # query may be accesed from outside the module by importing it @@ -386,4 +366,3 @@ def foo(name, cursor): return cursor.execute(query + name + "'") """ self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) - assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_subprocess_shell_false.py b/tests/codemods/test_subprocess_shell_false.py index 4b178a48..71a4ff2f 100644 --- a/tests/codemods/test_subprocess_shell_false.py +++ b/tests/codemods/test_subprocess_shell_false.py @@ -11,7 +11,7 @@ class TestSubprocessShellFalse(BaseCodemodTest): codemod = SubprocessShellFalse def test_name(self): - assert self.codemod.name() == "subprocess-shell-false" + assert self.codemod.name == "subprocess-shell-false" @each_func def test_import(self, tmpdir, func): @@ -24,7 +24,6 @@ def test_import(self, tmpdir, func): subprocess.{func}(args, shell=False) """ self.run_and_assert(tmpdir, input_code, expexted_output) - assert len(self.file_context.codemod_changes) == 1 @each_func def test_from_import(self, tmpdir, func): @@ -37,7 +36,6 @@ def test_from_import(self, tmpdir, func): {func}(args, shell=False) """ self.run_and_assert(tmpdir, input_code, expexted_output) - assert len(self.file_context.codemod_changes) == 1 @each_func def test_no_shell(self, tmpdir, func): @@ -46,7 +44,6 @@ def test_no_shell(self, tmpdir, func): subprocess.{func}(args, timeout=1) """ self.run_and_assert(tmpdir, input_code, input_code) - assert len(self.file_context.codemod_changes) == 0 @each_func def test_shell_False(self, tmpdir, func): @@ -55,7 +52,6 @@ def test_shell_False(self, tmpdir, func): subprocess.{func}(args, shell=False) """ self.run_and_assert(tmpdir, input_code, input_code) - assert len(self.file_context.codemod_changes) == 0 def test_exclude_line(self, tmpdir): input_code = expected = """\ @@ -63,6 +59,9 @@ def test_exclude_line(self, tmpdir): subprocess.run(args, shell=True) """ lines_to_exclude = [2] - self.assert_no_change_line_excluded( - tmpdir, input_code, expected, lines_to_exclude + self.run_and_assert( + tmpdir, + input_code, + expected, + lines_to_exclude=lines_to_exclude, ) diff --git a/tests/codemods/test_tempfile_mktemp.py b/tests/codemods/test_tempfile_mktemp.py index b4bb69e5..3d28019b 100644 --- a/tests/codemods/test_tempfile_mktemp.py +++ b/tests/codemods/test_tempfile_mktemp.py @@ -6,7 +6,7 @@ class TestTempfileMktemp(BaseSemgrepCodemodTest): codemod = TempfileMktemp def test_name(self): - assert self.codemod.name() == "secure-tempfile" + assert self.codemod.name == "secure-tempfile" def test_import(self, tmpdir): input_code = """import tempfile diff --git a/tests/codemods/test_url_sandbox.py b/tests/codemods/test_url_sandbox.py index 9978dd9a..800649ce 100644 --- a/tests/codemods/test_url_sandbox.py +++ b/tests/codemods/test_url_sandbox.py @@ -1,17 +1,19 @@ import pytest +import mock from codemodder.dependency import Security from core_codemods.url_sandbox import UrlSandbox from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest +@mock.patch("codemodder.codemods.api.FileContext.add_dependency") class TestUrlSandbox(BaseSemgrepCodemodTest): codemod = UrlSandbox - def test_name(self): - assert self.codemod.name() == "url-sandbox" + def test_name(self, _): + assert self.codemod.name == "url-sandbox" - def test_import_requests(self, tmpdir): + def test_import_requests(self, add_dependency, tmpdir): input_code = """ import requests @@ -27,9 +29,9 @@ def test_import_requests(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + add_dependency.assert_called_once_with(Security) - def test_from_requests(self, tmpdir): + def test_from_requests(self, add_dependency, tmpdir): input_code = """ from requests import get @@ -45,9 +47,9 @@ def test_from_requests(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + add_dependency.assert_called_once_with(Security) - def test_requests_nameerror(self, tmpdir): + def test_requests_nameerror(self, _, tmpdir): input_code = """ url = input() requests.get(url) @@ -96,11 +98,13 @@ def test_requests_nameerror(self, tmpdir): ), ], ) - def test_requests_other_import_untouched(self, tmpdir, input_code, expected): + def test_requests_other_import_untouched( + self, add_dependency, tmpdir, input_code, expected + ): self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + add_dependency.assert_called_once_with(Security) - def test_requests_multifunctions(self, tmpdir): + def test_requests_multifunctions(self, add_dependency, tmpdir): # Test that `requests` import isn't removed if code uses part of the requests # library that isn't part of our codemods. If we add the function as one of # our codemods, this test would change. @@ -122,9 +126,9 @@ def test_requests_multifunctions(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + add_dependency.assert_called_once_with(Security) - def test_custom_get(self, tmpdir): + def test_custom_get(self, _, tmpdir): input_code = """ from app_funcs import get @@ -134,7 +138,7 @@ def test_custom_get(self, tmpdir): expected = input_code self.run_and_assert(tmpdir, input_code, expected) - def test_ambiguous_get(self, tmpdir): + def test_ambiguous_get(self, _, tmpdir): input_code = """ from requests import get @@ -147,7 +151,7 @@ def get(url): expected = input_code self.run_and_assert(tmpdir, input_code, expected) - def test_from_requests_with_alias(self, tmpdir): + def test_from_requests_with_alias(self, add_dependency, tmpdir): input_code = """ from requests import get as got @@ -163,9 +167,9 @@ def test_from_requests_with_alias(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + add_dependency.assert_called_once_with(Security) - def test_requests_with_alias(self, tmpdir): + def test_requests_with_alias(self, add_dependency, tmpdir): input_code = """ import requests as req @@ -181,9 +185,9 @@ def test_requests_with_alias(self, tmpdir): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) - self.assert_dependency(Security) + add_dependency.assert_called_once_with(Security) - def test_ignore_hardcoded(self, tmpdir): + def test_ignore_hardcoded(self, _, tmpdir): expected = input_code = """ import requests @@ -192,7 +196,7 @@ def test_ignore_hardcoded(self, tmpdir): self.run_and_assert(tmpdir, input_code, expected) - def test_ignore_hardcoded_from_global_variable(self, tmpdir): + def test_ignore_hardcoded_from_global_variable(self, _, tmpdir): expected = input_code = """ import requests @@ -202,7 +206,7 @@ def test_ignore_hardcoded_from_global_variable(self, tmpdir): self.run_and_assert(tmpdir, input_code, expected) - def test_ignore_hardcoded_from_local_variable(self, tmpdir): + def test_ignore_hardcoded_from_local_variable(self, _, tmpdir): expected = input_code = """ import requests @@ -213,7 +217,7 @@ def foo(): self.run_and_assert(tmpdir, input_code, expected) - def test_ignore_hardcoded_from_local_variable_transitive(self, tmpdir): + def test_ignore_hardcoded_from_local_variable_transitive(self, _, tmpdir): expected = input_code = """ import requests @@ -225,7 +229,9 @@ def foo(): self.run_and_assert(tmpdir, input_code, expected) - def test_ignore_hardcoded_from_local_variable_transitive_reassigned(self, tmpdir): + def test_ignore_hardcoded_from_local_variable_transitive_reassigned( + self, _, tmpdir + ): input_code = """ import requests diff --git a/tests/codemods/test_use_defused_xml.py b/tests/codemods/test_use_defused_xml.py index 1df986a8..0c22d42a 100644 --- a/tests/codemods/test_use_defused_xml.py +++ b/tests/codemods/test_use_defused_xml.py @@ -1,4 +1,5 @@ import pytest +import mock from codemodder.dependency import DefusedXML from core_codemods.use_defused_xml import ( @@ -10,12 +11,13 @@ from tests.codemods.base_codemod_test import BaseCodemodTest +@mock.patch("codemodder.codemods.api.FileContext.add_dependency") class TestUseDefusedXml(BaseCodemodTest): codemod = UseDefusedXml @pytest.mark.parametrize("method", ETREE_METHODS) @pytest.mark.parametrize("module", ["ElementTree", "cElementTree"]) - def test_etree_simple_call(self, tmpdir, module, method): + def test_etree_simple_call(self, add_dependency, tmpdir, module, method): original_code = f""" from xml.etree.{module} import {method}, ElementPath @@ -30,10 +32,10 @@ def test_etree_simple_call(self, tmpdir, module, method): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) @pytest.mark.parametrize("method", ETREE_METHODS) - def test_etree_module_alias(self, tmpdir, method): + def test_etree_module_alias(self, add_dependency, tmpdir, method): original_code = f""" import xml.etree.ElementTree as alias import xml.etree.cElementTree as calias @@ -49,12 +51,12 @@ def test_etree_module_alias(self, tmpdir, method): cet = defusedxml.ElementTree.{method}('some.xml') """ - self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + self.run_and_assert(tmpdir, original_code, new_code, num_changes=2) + add_dependency.assert_called_once_with(DefusedXML) @pytest.mark.parametrize("method", ETREE_METHODS) @pytest.mark.parametrize("module", ["ElementTree", "cElementTree"]) - def test_etree_attribute_call(self, tmpdir, module, method): + def test_etree_attribute_call(self, add_dependency, tmpdir, module, method): original_code = f""" from xml.etree import {module} @@ -68,9 +70,9 @@ def test_etree_attribute_call(self, tmpdir, module, method): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) - def test_etree_elementtree_with_alias(self, tmpdir): + def test_etree_elementtree_with_alias(self, add_dependency, tmpdir): original_code = """ from xml.etree import ElementTree as ET @@ -84,9 +86,9 @@ def test_etree_elementtree_with_alias(self, tmpdir): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) - def test_etree_parse_with_alias(self, tmpdir): + def test_etree_parse_with_alias(self, add_dependency, tmpdir): original_code = """ from xml.etree.ElementTree import parse as parse_xml @@ -100,10 +102,10 @@ def test_etree_parse_with_alias(self, tmpdir): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) @pytest.mark.parametrize("method", SAX_METHODS) - def test_sax_simple_call(self, tmpdir, method): + def test_sax_simple_call(self, add_dependency, tmpdir, method): original_code = f""" from xml.sax import {method} @@ -117,10 +119,10 @@ def test_sax_simple_call(self, tmpdir, method): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) @pytest.mark.parametrize("method", SAX_METHODS) - def test_sax_attribute_call(self, tmpdir, method): + def test_sax_attribute_call(self, add_dependency, tmpdir, method): original_code = f""" from xml import sax @@ -134,11 +136,11 @@ def test_sax_attribute_call(self, tmpdir, method): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) @pytest.mark.parametrize("method", DOM_METHODS) @pytest.mark.parametrize("module", ["minidom", "pulldom"]) - def test_dom_simple_call(self, tmpdir, module, method): + def test_dom_simple_call(self, add_dependency, tmpdir, module, method): original_code = f""" from xml.dom.{module} import {method} @@ -152,4 +154,4 @@ def test_dom_simple_call(self, tmpdir, module, method): """ self.run_and_assert(tmpdir, original_code, new_code) - self.assert_dependency(DefusedXML) + add_dependency.assert_called_once_with(DefusedXML) diff --git a/tests/codemods/test_use_generator.py b/tests/codemods/test_use_generator.py index 709c76af..9b9947ea 100644 --- a/tests/codemods/test_use_generator.py +++ b/tests/codemods/test_use_generator.py @@ -35,6 +35,9 @@ def test_exclude_line(self, tmpdir): x = any([i for i in range(10)]) """ lines_to_exclude = [1] - self.assert_no_change_line_excluded( - tmpdir, input_code, expected, lines_to_exclude + self.run_and_assert( + tmpdir, + input_code, + expected, + lines_to_exclude=lines_to_exclude, ) diff --git a/tests/codemods/test_use_set_literal.py b/tests/codemods/test_use_set_literal.py index 29ee31aa..29b398ed 100644 --- a/tests/codemods/test_use_set_literal.py +++ b/tests/codemods/test_use_set_literal.py @@ -13,7 +13,6 @@ def test_simple(self, tmpdir): x = {1, 2, 3} """ self.run_and_assert(tmpdir, original_code, expected_code) - assert self.file_context and len(self.file_context.codemod_changes) == 1 def test_empty_list(self, tmpdir): original_code = """ @@ -23,14 +22,12 @@ def test_empty_list(self, tmpdir): x = set() """ self.run_and_assert(tmpdir, original_code, expected_code) - assert self.file_context and len(self.file_context.codemod_changes) == 1 def test_already_empty(self, tmpdir): original_code = """ x = set() """ self.run_and_assert(tmpdir, original_code, original_code) - assert self.file_context and len(self.file_context.codemod_changes) == 0 def test_not_builtin(self, tmpdir): original_code = """ @@ -38,11 +35,9 @@ def test_not_builtin(self, tmpdir): x = set([1, 2, 3]) """ self.run_and_assert(tmpdir, original_code, original_code) - assert self.file_context and len(self.file_context.codemod_changes) == 0 def test_not_list_literal(self, tmpdir): original_code = """ x = set(some_previously_defined_list) """ self.run_and_assert(tmpdir, original_code, original_code) - assert self.file_context and len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_walrus_if.py b/tests/codemods/test_walrus_if.py index add39108..c6d3c177 100644 --- a/tests/codemods/test_walrus_if.py +++ b/tests/codemods/test_walrus_if.py @@ -70,7 +70,7 @@ def test_walrus_if_multiple(self, tmpdir): if (foo := hello()) == "bar": whatever(foo) """ - self.run_and_assert(tmpdir, input_code, expected_output) + self.run_and_assert(tmpdir, input_code, expected_output, num_changes=2) def test_walrus_if_in_function(self, tmpdir): """Make sure this works inside more complex code""" @@ -101,7 +101,7 @@ def test_walrus_if_nested(self, tmpdir): if (y := do_something_else(x)) is not None: bizbaz(x, y) """ - self.run_and_assert(tmpdir, input_code, expected_output) + self.run_and_assert(tmpdir, input_code, expected_output, num_changes=2) def test_walrus_if_used_inner(self, tmpdir): """Make sure this works inside more complex code""" diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py index 7a1d56f0..7914e0be 100644 --- a/tests/codemods/test_with_threading_lock.py +++ b/tests/codemods/test_with_threading_lock.py @@ -18,7 +18,7 @@ class TestWithThreadingLock(BaseSemgrepCodemodTest): codemod = WithThreadingLock def test_rule_ids(self): - assert self.codemod.name() == "bad-lock-with-statement" + assert self.codemod.name == "bad-lock-with-statement" @each_class def test_import(self, tmpdir, klass): @@ -92,102 +92,126 @@ class TestThreadingNameResolution(BaseSemgrepCodemodTest): codemod = WithThreadingLock @pytest.mark.parametrize( - "input_code,expected_code", + "input_code,expected_code,num_changes", [ ( - """from threading import Lock -lock = 1 -with Lock(): - ... -""", - """from threading import Lock -lock = 1 -lock_1 = Lock() -with lock_1: - ... -""", + """ + from threading import Lock + + lock = 1 + with Lock(): + ... + """, + """ + from threading import Lock + + lock = 1 + lock_1 = Lock() + with lock_1: + ... + """, + 1, ), ( - """from threading import Lock -from something import lock -with Lock(): - ... -""", - """from threading import Lock -from something import lock -lock_1 = Lock() -with lock_1: - ... -""", + """ + from threading import Lock + from something import lock + with Lock(): + ... + """, + """ + from threading import Lock + from something import lock + lock_1 = Lock() + with lock_1: + ... + """, + 1, ), ( - """import threading -lock = 1 -def f(l): - with threading.Lock(): - return [lock_1 for lock_1 in l] -""", - """import threading -lock = 1 -def f(l): - lock_2 = threading.Lock() - with lock_2: - return [lock_1 for lock_1 in l] -""", + """ + import threading + + lock = 1 + def f(l): + with threading.Lock(): + return [lock_1 for lock_1 in l] + """, + """ + import threading + + lock = 1 + def f(l): + lock_2 = threading.Lock() + with lock_2: + return [lock_1 for lock_1 in l] + """, + 1, ), ( - """import threading -with threading.Lock(): - int("1") -with threading.Lock(): - print() -var = 1 -with threading.Lock(): - print() -""", - """import threading -lock = threading.Lock() -with lock: - int("1") -lock_1 = threading.Lock() -with lock_1: - print() -var = 1 -lock_2 = threading.Lock() -with lock_2: - print() -""", + """ + import threading + with threading.Lock(): + int("1") + with threading.Lock(): + print() + var = 1 + with threading.Lock(): + print() + """, + """ + import threading + lock = threading.Lock() + with lock: + int("1") + lock_1 = threading.Lock() + with lock_1: + print() + var = 1 + lock_2 = threading.Lock() + with lock_2: + print() + """, + 3, ), ( - """import threading -with threading.Lock(): - with threading.Lock(): - print() -""", - """import threading -lock_1 = threading.Lock() -with lock_1: - lock = threading.Lock() - with lock: - print() -""", + """ + import threading + with threading.Lock(): + with threading.Lock(): + print() + """, + """ + import threading + lock_1 = threading.Lock() + with lock_1: + lock = threading.Lock() + with lock: + print() + """, + 2, ), ( - """import threading -def my_func(): - lock = "whatever" - with threading.Lock(): - foo() -""", - """import threading -def my_func(): - lock = "whatever" - lock_1 = threading.Lock() - with lock_1: - foo() -""", + """ + import threading + + def my_func(): + lock = "whatever" + with threading.Lock(): + foo() + """, + """ + import threading + + def my_func(): + lock = "whatever" + lock_1 = threading.Lock() + with lock_1: + foo() + """, + 1, ), ], ) - def test_name_resolution(self, tmpdir, input_code, expected_code): - self.run_and_assert(tmpdir, input_code, expected_code) + def test_name_resolution(self, tmpdir, input_code, expected_code, num_changes): + self.run_and_assert(tmpdir, input_code, expected_code, num_changes=num_changes) diff --git a/tests/conftest.py b/tests/conftest.py index 2ef1826b..37d7cd1b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,57 +1,39 @@ import pytest -import mock -@pytest.fixture(autouse=True, scope="module") -def disable_write_report(): +@pytest.fixture(autouse=True) +def disable_write_report(mocker): """ Unit tests should not write analysis report or update any source files. """ - patch_write_report = mock.patch( - "codemodder.report.codetf_reporter.CodeTF.write_report" - ) - - patch_write_report.start() - yield - patch_write_report.stop() + mocker.patch("codemodder.report.codetf_reporter.CodeTF.write_report") -@pytest.fixture(autouse=True, scope="module") -def disable_update_code(): +@pytest.fixture(autouse=True) +def disable_update_code(mocker): """ Unit tests should not write analysis report or update any source files. """ - patch_update_code = mock.patch("codemodder.codemodder.update_code") - patch_update_code.start() - yield - patch_update_code.stop() + mocker.patch("codemodder.codemods.libcst_transformer.update_code") -@pytest.fixture(autouse=True, scope="module") -def disable_semgrep_run(): +@pytest.fixture(autouse=True) +def disable_semgrep_run(mocker): """ Semgrep run is slow so we mock them or pass hardcoded results when possible. """ - semgrep_run = mock.patch("codemodder.codemods.base_codemod.semgrep_run") + mocker.patch("codemodder.codemods.semgrep.semgrep_run") - semgrep_run.start() - yield - semgrep_run.stop() - -@pytest.fixture(autouse=True, scope="module") -def disable_write_dependencies(): +@pytest.fixture(autouse=True) +def disable_write_dependencies(mocker): """ Unit tests should not write any dependency files """ - dm_write = mock.patch( + mocker.patch( "codemodder.dependency_management.dependency_manager.DependencyManager.write" ) - dm_write.start() - yield - dm_write.stop() - @pytest.fixture(scope="module") def pkg_with_reqs_txt(tmp_path_factory): diff --git a/tests/dependency_management/test_setup_py_writer.py b/tests/dependency_management/test_setup_py_writer.py index 0fa919f8..affc5955 100644 --- a/tests/dependency_management/test_setup_py_writer.py +++ b/tests/dependency_management/test_setup_py_writer.py @@ -85,7 +85,7 @@ def test_update_setuppy_comma_single_element_inline(tmpdir): store = PackageStore( type=FileType.SETUP_PY, - file=str(dependency_file), + file=dependency_file, dependencies=set(), py_versions=[">=3.6"], ) diff --git a/tests/test_codemod_docs.py b/tests/test_codemod_docs.py index 2329a620..a8cbdae7 100644 --- a/tests/test_codemod_docs.py +++ b/tests/test_codemod_docs.py @@ -1,5 +1,6 @@ import pytest +from codemodder.codemods.api import BaseCodemod from codemodder.registry import load_registered_codemods from codemodder.scripts.generate_docs import METADATA @@ -10,10 +11,11 @@ def pytest_generate_tests(metafunc): metafunc.parametrize("codemod", registry.codemods) -def test_load_codemod_docs_info(codemod): - if codemod.name in ["order-imports"]: - pytest.xfail(reason="order-imports has no description") - assert codemod._get_description() is not None # pylint: disable=protected-access +def test_load_codemod_docs_info(codemod: BaseCodemod): + if codemod.name in ["order-imports", "refactor-new-api"]: + pytest.xfail(reason=f"{codemod.name} has no description") + + assert codemod.description is not None # pylint: disable=protected-access assert codemod.review_guidance in ( "Merge After Review", "Merge After Cursory Review", diff --git a/tests/test_codemodder.py b/tests/test_codemodder.py index ad828c52..ff5e27fe 100644 --- a/tests/test_codemodder.py +++ b/tests/test_codemodder.py @@ -61,14 +61,13 @@ def test_cst_parsing_fails(self, mock_reporting, mock_parse): requests_report = results_by_codemod[0] assert requests_report["changeset"] == [] - assert len(requests_report["failedFiles"]) == 2 + assert len(requests_report["failedFiles"]) == 1 assert sorted(requests_report["failedFiles"]) == [ - "tests/samples/make_request.py", "tests/samples/unverified_request.py", ] - @mock.patch("codemodder.codemodder.update_code") - @mock.patch("codemodder.codemods.base_codemod.semgrep_run", side_effect=semgrep_run) + @mock.patch("codemodder.codemods.libcst_transformer.update_code") + @mock.patch("codemodder.codemods.semgrep.semgrep_run", side_effect=semgrep_run) def test_dry_run(self, _, mock_update_code, tmpdir): codetf = tmpdir / "result.codetf" args = [ @@ -111,7 +110,7 @@ def test_reporting(self, mock_reporting, dry_run): assert len(results_by_codemod) == 3 - @mock.patch("codemodder.codemods.base_codemod.semgrep_run") + @mock.patch("codemodder.codemods.semgrep.semgrep_run") def test_no_codemods_to_run(self, mock_semgrep_run, tmpdir): codetf = tmpdir / "result.codetf" assert not codetf.exists() diff --git a/tests/test_context.py b/tests/test_context.py index 1c10707f..b01b45b4 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -10,7 +10,7 @@ def test_successful_dependency_description(self, mocker): repo_manager = PythonRepoManager(mocker.Mock()) codemod = registry.match_codemods(codemod_include=["url-sandbox"])[0] - context = Context(mocker.Mock(), True, False, registry, repo_manager) + context = Context(mocker.Mock(), True, False, registry, repo_manager, [], []) context.add_dependencies(codemod.id, {Security}) pkg_store_name = "pyproject.toml" @@ -39,7 +39,7 @@ def test_failed_dependency_description(self, mocker): repo_manager = PythonRepoManager(mocker.Mock()) codemod = registry.match_codemods(codemod_include=["url-sandbox"])[0] - context = Context(mocker.Mock(), True, False, registry, repo_manager) + context = Context(mocker.Mock(), True, False, registry, repo_manager, [], []) context.add_dependencies(codemod.id, {Security}) mocker.patch(