Skip to content

Commit

Permalink
add support for httpx for requests verify
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna committed Feb 6, 2024
1 parent 08f8fb4 commit 8eab1d7
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 8 deletions.
29 changes: 23 additions & 6 deletions src/core_codemods/requests_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,29 @@ class RequestsVerify(SimpleCodemod):
"Makes any calls to requests.{func} with `verify=False` to `verify=True`."
)
detector_pattern = """
rules:
- patterns:
- pattern: requests.$F(..., verify=False, ...)
- pattern-inside: |
import requests
...
rules:
- pattern-either:
- patterns:
- pattern: requests.$F(..., verify=False, ...)
- pattern-inside: |
import requests
...
- patterns:
- pattern: httpx.$F(..., verify=False, ...)
- pattern-inside: |
import httpx
...
- patterns:
- pattern: httpx.$CLASS(..., verify=False, ...)
- pattern-inside: |
import httpx
...
- metavariable-pattern:
metavariable: $CLASS
patterns:
- pattern-either:
- pattern: Client
- pattern: AsyncClient
"""

def on_result_found(self, original_node, updated_node):
Expand Down
130 changes: 128 additions & 2 deletions tests/codemods/test_request_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from core_codemods.requests_verify import RequestsVerify
from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest

# todo: add stream, etc specific to httpx
each_func = pytest.mark.parametrize("func", ["get", "post", "request"])
each_library = pytest.mark.parametrize("library", ["requests", "httpx"])

Expand All @@ -25,7 +24,7 @@ def test_default_verify(self, tmpdir, library, func):

@each_func
@each_library
@pytest.mark.parametrize("verify_val", ["True", "'/some/palibrary, th'"])
@pytest.mark.parametrize("verify_val", ["True", "'/some/path'"])
def test_verify(self, tmpdir, verify_val, library, func):
input_code = f"""
import {library}
Expand Down Expand Up @@ -110,3 +109,130 @@ def test_multiple_kwargs(self, tmpdir, library, func):
var = "hello"
"""
self.run_and_assert(tmpdir, input_code, expected)


class TestHttpxSpecific(BaseSemgrepCodemodTest):
codemod = RequestsVerify

def test_stream(self, tmpdir):
input_code = """
import httpx
with httpx.stream("GET", "https://www.example.com", verify=False) as r:
for data in r.iter_bytes():
print(data)
"""
expected = """
import httpx
with httpx.stream("GET", "https://www.example.com", verify=True) as r:
for data in r.iter_bytes():
print(data)
"""
self.run_and_assert(tmpdir, input_code, expected)

def test_stream_from_import(self, tmpdir):
input_code = """
from httpx import stream
with stream("GET", "https://www.example.com", verify=False) as r:
for data in r.iter_bytes():
print(data)
"""
expected = """
from httpx import stream
with stream("GET", "https://www.example.com", verify=True) as r:
for data in r.iter_bytes():
print(data)
"""
self.run_and_assert(tmpdir, input_code, expected)

def test_verify_with_sslcontext(self, tmpdir):
input_code = f"""
import ssl
import httpx
context = ssl.create_default_context()
context.load_verify_locations(cafile="/tmp/temp.pem")
httpx.get('https://google.com', verify=context)
"""
self.run_and_assert(tmpdir, input_code, input_code)

def test_client_verify(self, tmpdir):
input_code = f"""
import httpx
client = httpx.Client(verify=False)
try:
client.get('https://example.com')
finally:
client.close()
"""
expected_code = f"""
import httpx
client = httpx.Client(verify=True)
try:
client.get('https://example.com')
finally:
client.close()
"""
self.run_and_assert(tmpdir, input_code, expected_code)

def test_client_verify_from_import(self, tmpdir):
input_code = f"""
from httpx import Client
c = Client(verify=False)
try:
c.get('https://example.com')
finally:
c.close()
"""
expected_code = f"""
from httpx import Client
c = Client(verify=True)
try:
c.get('https://example.com')
finally:
c.close()
"""
self.run_and_assert(tmpdir, input_code, expected_code)

def test_client_verify_context_manager(self, tmpdir):
input_code = f"""
import httpx
with httpx.Client(verify=False) as client:
client.get('https://example.com')
"""
expected_code = f"""
import httpx
with httpx.Client(verify=True) as client:
client.get('https://example.com')
"""
self.run_and_assert(tmpdir, input_code, expected_code)

def test_async_client_verify(self, tmpdir):
input_code = f"""
import httpx
client = httpx.AsyncClient(verify=False)
try:
await client.get('https://example.com')
finally:
await client.close()
"""
expected_code = f"""
import httpx
client = httpx.AsyncClient(verify=True)
try:
await client.get('https://example.com')
finally:
await client.close()
"""
self.run_and_assert(tmpdir, input_code, expected_code)

def test_async_client_verify_context_manager(self, tmpdir):
input_code = f"""
import httpx
async with httpx.AsyncClient(verify=False) as client:
await client.get('https://example.com')
"""
expected_code = f"""
import httpx
async with httpx.AsyncClient(verify=True) as client:
await client.get('https://example.com')
"""
self.run_and_assert(tmpdir, input_code, expected_code)

0 comments on commit 8eab1d7

Please sign in to comment.