diff --git a/tests/conftest.py b/tests/conftest.py index 691ecf8..1819834 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import os + import pytest import webtest from pyramid.scripting import prepare @@ -8,7 +10,12 @@ @pytest.fixture(scope="session") def settings(): """Fixture that returns the Pyramid settings dict.""" - return sambal.SETTINGS + test_settings = dict(sambal.SETTINGS) + test_settings["samba.host"] = os.getenv("SAMBAL_SAMBA_HOST") + test_settings["samba.username"] = os.getenv("SAMBAL_SAMBA_USERNAME") + test_settings["samba.password"] = os.getenv("SAMBAL_SAMBA_PASSWORD") + test_settings["samba.realm"] = os.getenv("SAMBAL_SAMBA_REALM") + return test_settings @pytest.fixture(scope="session") diff --git a/tests/test_login.py b/tests/test_login.py index 4e4e159..9e8a1ca 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -1,3 +1,46 @@ +from html.parser import HTMLParser + + +class LoginHTMLParser(HTMLParser): + """Simple HTML parser to extract csrf token using the standard library.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.csrf_token = None + self.return_url = None + + def handle_starttag(self, tag, attrs): + if tag == "input": + tag_attrs = dict(attrs) + + if tag_attrs["name"] == "csrf_token": + self.csrf_token = tag_attrs["value"] + + if tag_attrs["name"] == "return_url": + self.return_url = tag_attrs["value"] + + +def test_login(testapp, settings): + response = testapp.get("/login/", status=200) + parser = LoginHTMLParser() + parser.feed(response.text) + + login_form = { + "host": settings["samba.host"], + "username": settings["samba.username"], + "password": settings["samba.password"], + "realm": settings["samba.realm"], + "csrf_token": parser.csrf_token, + "return_url": parser.return_url, + } + + response = testapp.post("/login/", login_form, status=302) + assert response.headers["location"] == parser.return_url + + response = testapp.get("/", status=200) + assert "Sambal Login" not in response.text + + def test_login_required(testapp): response = testapp.get("/", status=200) - assert b"Sambal Login" in response.body + assert "Sambal Login" in response.text