From 8eab1d77a5656b05528543e39633c047fffa0e6e Mon Sep 17 00:00:00 2001 From: clavedeluna Date: Tue, 6 Feb 2024 14:42:45 -0300 Subject: [PATCH] add support for httpx for requests verify --- src/core_codemods/requests_verify.py | 29 ++++-- tests/codemods/test_request_verify.py | 130 +++++++++++++++++++++++++- 2 files changed, 151 insertions(+), 8 deletions(-) diff --git a/src/core_codemods/requests_verify.py b/src/core_codemods/requests_verify.py index 47380efd..827ccffa 100644 --- a/src/core_codemods/requests_verify.py +++ b/src/core_codemods/requests_verify.py @@ -23,12 +23,29 @@ class RequestsVerify(SimpleCodemod): "Makes any calls to requests.{func} with `verify=False` to `verify=True`." ) detector_pattern = """ - rules: - - patterns: - - pattern: requests.$F(..., verify=False, ...) - - pattern-inside: | - import requests - ... + rules: + - pattern-either: + - patterns: + - pattern: requests.$F(..., verify=False, ...) + - pattern-inside: | + import requests + ... + - patterns: + - pattern: httpx.$F(..., verify=False, ...) + - pattern-inside: | + import httpx + ... + - patterns: + - pattern: httpx.$CLASS(..., verify=False, ...) + - pattern-inside: | + import httpx + ... + - metavariable-pattern: + metavariable: $CLASS + patterns: + - pattern-either: + - pattern: Client + - pattern: AsyncClient """ def on_result_found(self, original_node, updated_node): diff --git a/tests/codemods/test_request_verify.py b/tests/codemods/test_request_verify.py index 0f7c7e7a..19cccd38 100644 --- a/tests/codemods/test_request_verify.py +++ b/tests/codemods/test_request_verify.py @@ -2,7 +2,6 @@ from core_codemods.requests_verify import RequestsVerify from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest -# todo: add stream, etc specific to httpx each_func = pytest.mark.parametrize("func", ["get", "post", "request"]) each_library = pytest.mark.parametrize("library", ["requests", "httpx"]) @@ -25,7 +24,7 @@ def test_default_verify(self, tmpdir, library, func): @each_func @each_library - @pytest.mark.parametrize("verify_val", ["True", "'/some/palibrary, th'"]) + @pytest.mark.parametrize("verify_val", ["True", "'/some/path'"]) def test_verify(self, tmpdir, verify_val, library, func): input_code = f""" import {library} @@ -110,3 +109,130 @@ def test_multiple_kwargs(self, tmpdir, library, func): var = "hello" """ self.run_and_assert(tmpdir, input_code, expected) + + +class TestHttpxSpecific(BaseSemgrepCodemodTest): + codemod = RequestsVerify + + def test_stream(self, tmpdir): + input_code = """ + import httpx + with httpx.stream("GET", "https://www.example.com", verify=False) as r: + for data in r.iter_bytes(): + print(data) + """ + expected = """ + import httpx + with httpx.stream("GET", "https://www.example.com", verify=True) as r: + for data in r.iter_bytes(): + print(data) + """ + self.run_and_assert(tmpdir, input_code, expected) + + def test_stream_from_import(self, tmpdir): + input_code = """ + from httpx import stream + with stream("GET", "https://www.example.com", verify=False) as r: + for data in r.iter_bytes(): + print(data) + """ + expected = """ + from httpx import stream + with stream("GET", "https://www.example.com", verify=True) as r: + for data in r.iter_bytes(): + print(data) + """ + self.run_and_assert(tmpdir, input_code, expected) + + def test_verify_with_sslcontext(self, tmpdir): + input_code = f""" + import ssl + import httpx + context = ssl.create_default_context() + context.load_verify_locations(cafile="/tmp/temp.pem") + httpx.get('https://google.com', verify=context) + """ + self.run_and_assert(tmpdir, input_code, input_code) + + def test_client_verify(self, tmpdir): + input_code = f""" + import httpx + client = httpx.Client(verify=False) + try: + client.get('https://example.com') + finally: + client.close() + """ + expected_code = f""" + import httpx + client = httpx.Client(verify=True) + try: + client.get('https://example.com') + finally: + client.close() + """ + self.run_and_assert(tmpdir, input_code, expected_code) + + def test_client_verify_from_import(self, tmpdir): + input_code = f""" + from httpx import Client + c = Client(verify=False) + try: + c.get('https://example.com') + finally: + c.close() + """ + expected_code = f""" + from httpx import Client + c = Client(verify=True) + try: + c.get('https://example.com') + finally: + c.close() + """ + self.run_and_assert(tmpdir, input_code, expected_code) + + def test_client_verify_context_manager(self, tmpdir): + input_code = f""" + import httpx + with httpx.Client(verify=False) as client: + client.get('https://example.com') + """ + expected_code = f""" + import httpx + with httpx.Client(verify=True) as client: + client.get('https://example.com') + """ + self.run_and_assert(tmpdir, input_code, expected_code) + + def test_async_client_verify(self, tmpdir): + input_code = f""" + import httpx + client = httpx.AsyncClient(verify=False) + try: + await client.get('https://example.com') + finally: + await client.close() + """ + expected_code = f""" + import httpx + client = httpx.AsyncClient(verify=True) + try: + await client.get('https://example.com') + finally: + await client.close() + """ + self.run_and_assert(tmpdir, input_code, expected_code) + + def test_async_client_verify_context_manager(self, tmpdir): + input_code = f""" + import httpx + async with httpx.AsyncClient(verify=False) as client: + await client.get('https://example.com') + """ + expected_code = f""" + import httpx + async with httpx.AsyncClient(verify=True) as client: + await client.get('https://example.com') + """ + self.run_and_assert(tmpdir, input_code, expected_code)