diff --git a/tests/tests.py b/tests/tests.py index c03900f..7a5d29b 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -6,6 +6,16 @@ import requests from requests.auth import HTTPBasicAuth +REQUEST_TIMEOUT = 5 + + +def get(uri, auth=None, headers=None, timeout=REQUEST_TIMEOUT): + return requests.get(uri, auth=auth, headers=headers, timeout=timeout) + + +def post(uri, data=None): + return requests.post(uri, data=data, timeout=REQUEST_TIMEOUT) + def get_uri(suffix, v=None, port=8080): host = "localhost" @@ -17,14 +27,14 @@ def get_uri(suffix, v=None, port=8080): def get_status(): - resp = requests.get(get_uri("api/status")) + resp = get(get_uri("api/status")) assert resp.status_code == 200 return resp.json() def is_responding(uri): try: - requests.get(uri) + get(uri) except requests.exceptions.ConnectionError: return False return True @@ -43,7 +53,7 @@ def send(command, arg=None, arg2=None, expect=200, status=None): for a in [arg, arg2]: if a is not None: api += f"/{a}" - resp = requests.post(get_uri(api)) + resp = post(get_uri(api)) assert resp.status_code == expect if status is not None: return get_status()[status] @@ -97,7 +107,7 @@ class TestsRequests: ], ) def test_static(mpv_instance, uri, status_code, content_type): - resp = requests.get(get_uri(uri)) + resp = get(get_uri(uri)) assert resp.status_code == status_code if status_code != 200: return @@ -113,7 +123,7 @@ def test_static(mpv_instance, uri, status_code, content_type): @staticmethod @pytest.mark.parametrize("uri", ["", "/", "//", "///"]) def test_index(mpv_instance, uri): - resp = requests.get(get_uri(uri)) + resp = get(get_uri(uri)) assert resp.status_code == 200 resp.headers.pop("Content-Length") @@ -206,7 +216,7 @@ def test_post_wrong_args(mpv_instance, snapshot, endpoint, arg, arg2): for a in [arg, arg2]: if a is not None: api += f"/{a}" - response = requests.post(get_uri(api)) + response = post(get_uri(api)) assert response.status_code == 400 snapshot.assert_match(response.json()) @@ -408,7 +418,7 @@ def test_not_allowed_methods(mpv_instance, endpoint, method, expected): ], ) def test_collections(self, endpoint, expected_status, mpv_instance, snapshot): - resp = requests.get(f"{get_uri(endpoint)}") + resp = get(get_uri(endpoint)) assert resp.status_code == expected_status @@ -493,7 +503,7 @@ def send_loadfile(url, mode=None, expect=200): indirect=["mpv_instance"], ) def test_static_dir_config(mpv_instance, status_code): - resp = requests.get(get_uri("static.json")) + resp = get(get_uri("static.json")) assert resp.status_code == status_code if status_code == 200: @@ -588,7 +598,7 @@ def test_disablers(mpv_instance, v4_works, v6_works): ) def test_auth(htpasswd, mpv_instance, auth, status_code): try: - resp = requests.get(get_uri("api/status"), auth=auth, timeout=0.5) + resp = get(get_uri("api/status"), auth=auth, timeout=0.5) except requests.exceptions.ReadTimeout: assert status_code is None return @@ -648,9 +658,7 @@ def test_logging(htpasswd, mpv_instance, use_auth, username, password, status_co auth = None if use_auth and username: auth = HTTPBasicAuth(username, password) - resp = requests.get( - get_uri("api/status"), auth=auth, headers={"Referer": "https://referer"} - ) + resp = get(get_uri("api/status"), auth=auth, headers={"Referer": "https://referer"}) assert resp.status_code == status_code # example log line