Skip to content

Commit

Permalink
B113 request without timeout (#57)
Browse files Browse the repository at this point in the history
* Update settings.py

* Update utils.py

* Update utils.py

* Update clients.py

* Update client.py

* Update client.py

* Update client.py

* Update tests.py

* Update tests.py
  • Loading branch information
mwalkowski authored Mar 13, 2023
1 parent c5b46e3 commit 63afadc
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 16 deletions.
7 changes: 4 additions & 3 deletions src/vmc/common/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vmc.common.apps import CommonConfig
from vmc.common.enum import TupleValueEnum
from vmc.common.utils import is_downloadable, get_file, handle_ranges
from vmc.config.settings import DEFAULT_REQUEST_TIMEOUT


def get_fixture_location(module, name):
Expand Down Expand Up @@ -97,7 +98,7 @@ def test_call_is_downloadable(self, content_type, verify, result, requests):

self.assertEqual(is_downloadable(UtilsTest.URL, verify), result)

requests.head.assert_called_once_with(UtilsTest.URL, allow_redirects=True, verify=verify)
requests.head.assert_called_once_with(UtilsTest.URL, allow_redirects=True, verify=verify, timeout=DEFAULT_REQUEST_TIMEOUT)

@parameterized.expand([
('test.zip', 'zip', True, b'ZIP is working\n'),
Expand All @@ -115,7 +116,7 @@ def test_call_get_file(self, filename, content_type, verify, result, requests):
)

self.assertEqual(get_file(UtilsTest.URL, verify).readline(), result)
requests.get.assert_called_once_with(UtilsTest.URL, verify=verify)
requests.get.assert_called_once_with(UtilsTest.URL, verify=verify, timeout=DEFAULT_REQUEST_TIMEOUT)

@patch('vmc.common.utils.requests')
def test_call_get_file_json(self, requests):
Expand Down Expand Up @@ -144,4 +145,4 @@ def test_handle_ranges(self):
self.assertEqual(handle_ranges([s, e2]), ["192.168.1.1", "192.168.20.10"])
self.assertEqual(handle_ranges([s, e3]), ["192.168.1.1", "192.169.1.1"])
self.assertEqual(handle_ranges([s, e4]), ["192.168.1.1", "192.168.1.10"])
self.assertEqual(handle_ranges([s, e4]), ["192.168.1.1", "192.168.1.10"])
self.assertEqual(handle_ranges([s, e4]), ["192.168.1.1", "192.168.1.10"])
5 changes: 3 additions & 2 deletions src/vmc/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from io import BytesIO
from zipfile import ZipFile
from vmc.config.celery import app as celery_app
from vmc.config.settings import DEFAULT_REQUEST_TIMEOUT


class ThreadPoolExecutor:
Expand All @@ -46,7 +47,7 @@ def wait_for_all(self):


def is_downloadable(url: str, verify: bool = True) -> bool:
h = requests.head(url, allow_redirects=True, verify=verify)
h = requests.head(url, allow_redirects=True, verify=verify, timeout=DEFAULT_REQUEST_TIMEOUT)
header = h.headers
content_type = header.get('Content-Type')
if 'text' in content_type.lower():
Expand All @@ -59,7 +60,7 @@ def is_downloadable(url: str, verify: bool = True) -> bool:
def get_file(url: str, verify: bool = True) -> [BytesIO, None]:
content = None
if is_downloadable(url, verify):
response = requests.get(url, verify=verify)
response = requests.get(url, verify=verify, timeout=DEFAULT_REQUEST_TIMEOUT)

if response.status_code == 200:

Expand Down
2 changes: 2 additions & 0 deletions src/vmc/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def get_config(key, default):
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = get_config('secret_key', 'SECRET_KEY')

DEFAULT_REQUEST_TIMEOUT = 30

# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = get_config('debug', False)

Expand Down
3 changes: 2 additions & 1 deletion src/vmc/scanners/nessus/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import Dict
from io import BytesIO

from vmc.config.settings import DEFAULT_REQUEST_TIMEOUT
from vmc.scanners.models import Config
from vmc.scanners.clients import Client

Expand Down Expand Up @@ -234,7 +235,7 @@ def get_version(self) -> str:

def _get_version(self) -> str:
try:
resp = requests.get(F'{self._url}/server/properties', verify=not self._config.insecure)
resp = requests.get(F'{self._url}/server/properties', verify=not self._config.insecure, timeout=DEFAULT_REQUEST_TIMEOUT)
version = resp.json()
return version['nessus_ui_version']

Expand Down
12 changes: 7 additions & 5 deletions src/vmc/webhook/thehive/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import logging
import requests

from vmc.config.settings import DEFAULT_REQUEST_TIMEOUT


LOGGER = logging.getLogger(__name__)

Expand All @@ -31,35 +33,35 @@ def __init__(self, url, token):

def get_alert(self, alert_id):
return TheHiveClient._log_response_if_error(
requests.get(F"{self._url}/api/alert/{alert_id}", headers=self.headers))
requests.get(F"{self._url}/api/alert/{alert_id}", headers=self.headers, timeout=DEFAULT_REQUEST_TIMEOUT))

def create_case(self, title, description):
return TheHiveClient._log_response_if_error(requests.post(F"{self._url}/api/case", headers=self.headers, data={
'title': title,
'description': description
}))['caseId']
}, timeout=DEFAULT_REQUEST_TIMEOUT))['caseId']

def update_case(self, case_id, description, tags):
return TheHiveClient._log_response_if_error(
requests.patch(F"{self._url}/api/case/{case_id}", headers=self.headers, json={
'description': description,
'tags': list(tags)
}))
}, timeout=DEFAULT_REQUEST_TIMEOUT))

def merge_alert_to_case(self, alert_id, case_id):
return TheHiveClient._log_response_if_error(
requests.post(F"{self._url}/api/alert/merge/_bulk", headers=self.headers, json={
"caseId": str(case_id),
"alertIds": [str(alert_id)]
}))
}, timeout=DEFAULT_REQUEST_TIMEOUT))

def create_task(self, case_id, title, description, group):
return TheHiveClient._log_response_if_error(
requests.post(F"{self._url}/api/case/{case_id}/task", headers=self.headers, data={
'title': title,
'description': description,
'group': group
}))['id']
}, timeout=DEFAULT_REQUEST_TIMEOUT))['id']

@staticmethod
def _log_response_if_error(resp):
Expand Down
11 changes: 6 additions & 5 deletions src/vmc/webhook/thehive/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from vmc.vulnerabilities.tests import create_vulnerability
from vmc.webhook.thehive.tasks import process_task_log
from vmc.vulnerabilities.documents import VulnerabilityDocument
from vmc.config.settings import DEFAULT_REQUEST_TIMEOUT


class TheHiveClientTest(TestCase):
Expand All @@ -47,7 +48,7 @@ def setUp(self) -> None:
@patch('vmc.webhook.thehive.client.requests')
def test_call_get_alert(self, requests):
self.uut.get_alert(1)
requests.get.assert_called_once_with('http://localhost/api/alert/1', headers={"Authorization": "Bearer token"})
requests.get.assert_called_once_with('http://localhost/api/alert/1', headers={"Authorization": "Bearer token"}, timeout=DEFAULT_REQUEST_TIMEOUT)

@patch('vmc.webhook.thehive.client.requests')
def test_call_create_case(self, requests):
Expand All @@ -58,23 +59,23 @@ def test_call_create_case(self, requests):
requests.post.assert_called_once_with('http://localhost/api/case', headers={"Authorization": "Bearer token"}, data={
'title': 'sample title',
'description': 'sample desc'
})
}, timeout=DEFAULT_REQUEST_TIMEOUT)

@patch('vmc.webhook.thehive.client.requests')
def test_call_update_case(self, requests):
self.uut.update_case(15, 'sample desc', ['tags'])
requests.patch.assert_called_once_with('http://localhost/api/case/15', headers={"Authorization": "Bearer token"}, json={
'description': 'sample desc',
'tags': ['tags']
})
}, timeout=DEFAULT_REQUEST_TIMEOUT)

@patch('vmc.webhook.thehive.client.requests')
def test_call_merge_alert_to_case(self, requests):
self.uut.merge_alert_to_case(15, 12)
requests.post.assert_called_once_with('http://localhost/api/alert/merge/_bulk', headers={"Authorization": "Bearer token"}, json={
"caseId": '12',
"alertIds": ['15']
})
}, timeout=DEFAULT_REQUEST_TIMEOUT)

@patch('vmc.webhook.thehive.client.requests')
def test_call_create_task(self, requests):
Expand All @@ -83,7 +84,7 @@ def test_call_create_task(self, requests):
'title': 'sample title',
'description': 'sample desc',
'group': 'group'
})
}, timeout=DEFAULT_REQUEST_TIMEOUT)


class TaskProcessorTests(TestCase):
Expand Down

0 comments on commit 63afadc

Please sign in to comment.