diff --git a/src/sambal/security.py b/src/sambal/security.py index 2cdb474..7949d10 100644 --- a/src/sambal/security.py +++ b/src/sambal/security.py @@ -1,6 +1,5 @@ from typing import Optional -from ldb import LdbError from pyramid.authentication import AuthTktCookieHelper from pyramid.interfaces import ISecurityPolicy from pyramid.security import forget, remember @@ -40,18 +39,19 @@ def forget(self, request, **kwargs): return self.authtkt.forget(request, **kwargs) -def login(request, username, password, host, realm): - """Log into server and put credentials in session on success only.""" - try: - samdb = connect_samdb(host, username, password, realm) - if user_sid := samdb.connecting_user_sid: - request.session["samba.username"] = username - request.session["samba.password"] = password - request.session["samba.host"] = host - request.session["samba.realm"] = realm - return remember(request, user_sid) - except LdbError: - pass +def login(request, host, username, password, realm): + """Log into server and put credentials in session on success only. + + :raises LdbError: If the login failed or host is incorrect + """ + samdb = connect_samdb(host, username, password, realm) + + if user_sid := samdb.connecting_user_sid: + request.session["samba.username"] = username + request.session["samba.password"] = password + request.session["samba.host"] = host + request.session["samba.realm"] = realm + return remember(request, user_sid) def logout(request): diff --git a/src/sambal/views/auth.py b/src/sambal/views/auth.py index ec930e4..84080e0 100644 --- a/src/sambal/views/auth.py +++ b/src/sambal/views/auth.py @@ -1,3 +1,4 @@ +from ldb import LdbError from pyramid.httpexceptions import HTTPFound from pyramid.view import forbidden_view_config, view_config @@ -21,10 +22,15 @@ def login(request): host = form.host.data realm = form.realm.data - if headers := request.login(username, password, host, realm): - return HTTPFound(location=return_url, headers=headers) - else: - request.session.flash("Login to host failed", queue="error") + try: + if headers := request.login(host, username, password, realm): + return HTTPFound(location=return_url, headers=headers) + else: + request.session.flash("Login failed", queue="error") + + except LdbError as e: + msg = e.args[1] + request.session.flash(f"Login failed: {msg}", queue="error") else: form = LoginForm() diff --git a/tests/test_login.py b/tests/test_login.py index fa4ed43..aad391d 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -61,7 +61,7 @@ def test_login_invalid_credentials(testapp, settings): } response = testapp.post("/login/", login_form, status=200) - assert "Login to host failed" in response.text + assert "Login failed" in response.text def test_login_required(testapp):