Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

login: bubble up LdbError so that it can be handled better #26

Merged
merged 4 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ build :
lint :
@ruff check

.PHONY : lint-fix
lint-fix :
@ruff check --fix

.PHONY : format
format :
@ruff format

.PHONY : test
test :
@pytest
74 changes: 42 additions & 32 deletions src/sambal/client.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,62 @@
from typing import Optional

from ldb import LdbError
from samba.auth import system_session
from samba.credentials import Credentials
from samba.param import LoadParm
from samba.samdb import SamDB


def connect_samdb(username, password, host, realm=None) -> Optional[SamDB]:
"""Connect to Samba or Windows host and return SamDB on success."""
if host and username and password:
if host.startswith(("ldap://", "ldaps://")):
url = host
else:
url = f"ldap://{host}"
def connect_samdb(host, username, password, realm=None) -> SamDB:
"""Connect to Samba or Windows host and return SamDB on success.

lp = LoadParm()
lp.load_default()
:param host: Host name or URL
:param username: Account name
:param password: Account password
:param realm: Optional realm
:raises LdbError: on failure, caller should handle error.
"""
if host.startswith(("ldap://", "ldaps://")):
url = host
else:
url = f"ldap://{host}"

lp = LoadParm()
lp.load_default()

creds = Credentials()
creds.set_username(username)
creds.set_password(password)
creds = Credentials()
creds.set_username(username)
creds.set_password(password)

if realm:
creds.set_realm(realm)
if realm:
creds.set_realm(realm)

try:
return SamDB(
url=url,
session_info=system_session(),
credentials=creds,
lp=lp,
)
except LdbError:
return None
return SamDB(
url=url,
session_info=system_session(),
credentials=creds,
lp=lp,
)


def get_samdb(request) -> Optional[SamDB]:
"""Returns a SamDB connection to be used via the request.samdb property.

Fetch credentials out of the session after user has logged in.
All keys except for realm must be present.

For this to be secure the session MUST be a backend session only,
with a password on Redis and a unique session secret different to the
authtkt cookie secret.

:param request: Pyramid request object
:return: SamDB or None if no credentials in session
:raises LdbError: On connection error or if the credentials no longer work
"""
# Fetch credentials out of the session after user logs in.
# For this to be secure the session MUST be a backend session only.
username = request.session.get("samba.username")
password = request.session.get("samba.password")
host = request.session.get("samba.host")
realm = request.session.get("samba.realm")

return connect_samdb(username, password, host, realm)
try:
host = request.session["samba.host"]
username = request.session["samba.username"]
password = request.session["samba.password"]
realm = request.session.get("samba.realm")
return connect_samdb(host, username, password, realm)
except KeyError:
return None
4 changes: 1 addition & 3 deletions src/sambal/forms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from .login import LoginForm

__all__ = (
"LoginForm",
)
__all__ = ("LoginForm",)
12 changes: 8 additions & 4 deletions src/sambal/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ def forget(self, request, **kwargs):
return self.authtkt.forget(request, **kwargs)


def login(request, username, password, host, realm):
"""Log into server and put credentials in session on success only."""
samdb = connect_samdb(username, password, host, realm)
if samdb and (user_sid := samdb.connecting_user_sid):
def login(request, host, username, password, realm):
"""Log into server and put credentials in session on success only.

:raises LdbError: If the login failed or host is incorrect
"""
samdb = connect_samdb(host, username, password, realm)

if user_sid := samdb.connecting_user_sid:
request.session["samba.username"] = username
request.session["samba.password"] = password
request.session["samba.host"] = host
Expand Down
4 changes: 1 addition & 3 deletions src/sambal/tweens/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from .headers import SecurityHeaders

__all__ = (
"SecurityHeaders",
)
__all__ = ("SecurityHeaders",)
14 changes: 10 additions & 4 deletions src/sambal/views/auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ldb import LdbError
from pyramid.httpexceptions import HTTPFound
from pyramid.view import forbidden_view_config, view_config

Expand All @@ -21,10 +22,15 @@ def login(request):
host = form.host.data
realm = form.realm.data

if headers := request.login(username, password, host, realm):
return HTTPFound(location=return_url, headers=headers)
else:
request.session.flash("Login to host failed", queue="error")
try:
if headers := request.login(host, username, password, realm):
return HTTPFound(location=return_url, headers=headers)
else:
request.session.flash("Login failed", queue="error")

except LdbError as e:
msg = e.args[1]
request.session.flash(f"Login failed: {msg}", queue="error")
else:
form = LoginForm()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_login_invalid_credentials(testapp, settings):
}

response = testapp.post("/login/", login_form, status=200)
assert "Login to host failed" in response.text
assert "Login failed" in response.text


def test_login_required(testapp):
Expand Down
Loading