Skip to content

Commit

Permalink
Rotating Salt based Verification Token
Browse files Browse the repository at this point in the history
  • Loading branch information
anshulg954 committed Dec 19, 2024
1 parent e2fe6c1 commit 621a767
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 30 deletions.
40 changes: 38 additions & 2 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def is_auth_token_outdated(self, access_token) -> bool | None:
if response.status_code == 200:
is_authenticated = True
elif response.status_code == 403:
# 403 means user is not verified
is_authenticated = None
return is_authenticated

Expand Down Expand Up @@ -507,6 +508,37 @@ def register(
access_token = response.json()["token"] if is_created else None
return is_created, message, access_token

def verify_email(self, token: str, access_token: str) -> tuple[bool, str]:
"""
Verify the email with the provided token.
Parameters
----------
token : str
access_token : str
Returns
-------
is_verified : bool
True if the email is verified successfully.
message : str
The message returned from the server.
"""

response = self.httpx_client.get(
self.server_endpoints.verify_email.path,
params={"token": token, "access_token": access_token},
)
self._validate_response(response, "verify_email", only_version_check=True)
if response.status_code == 200:
is_verified = True
message = response.json()["message"]
else:
is_verified = False
message = response.json()["detail"]

return is_verified, message

def login(self, email: str, password: str) -> tuple[str, str]:
"""
Login with the provided credentials and return the access token if successful.
Expand Down Expand Up @@ -534,10 +566,14 @@ def login(self, email: str, password: str) -> tuple[str, str]:
if response.status_code == 200:
access_token = response.json()["access_token"]
message = ""
elif response.status_code == 403:
access_token = response.headers["access_token"]
message = response.json()["detail"]
else:
message = response.json()["detail"]

return access_token, message
# status code signifies the success of the login, issues with password, and email verification
# 200 : success, 401 : wrong password, 403 : email not verified yet
return access_token, message, response.status_code

def get_password_policy(self) -> {}:
"""
Expand Down
11 changes: 5 additions & 6 deletions tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ def init(use_server=True):
"TabPFN is inaccessible at the moment, please try again later."
)

is_valid_token_set = user_auth_handler.try_reuse_existing_token()
is_valid_token, access_token = user_auth_handler.try_reuse_existing_token()

if isinstance(is_valid_token_set, bool) and is_valid_token_set:
if is_valid_token:
PromptAgent.prompt_reusing_existing_token()
elif (
isinstance(is_valid_token_set, tuple) and is_valid_token_set[1] is not None
):
elif access_token is not None:
# token holds invalid due to user email verification
print("Your email is not verified. Please verify your email to continue...")
PromptAgent.reverify_email(is_valid_token_set[1], user_auth_handler)
PromptAgent.reverify_email(access_token, user_auth_handler)
else:
PromptAgent.prompt_welcome()
if not PromptAgent.prompt_terms_and_cond():
Expand Down
45 changes: 38 additions & 7 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,22 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"):
)
)
additional_info = cls.prompt_add_user_information()
is_created, message = user_auth_handler.set_token_by_registration(
email, password, password_confirm, validation_link, additional_info
is_created, message, access_token = (
user_auth_handler.set_token_by_registration(
email, password, password_confirm, validation_link, additional_info
)
)
if not is_created:
raise RuntimeError("User registration failed: " + str(message) + "\n")

print(
cls.indent(
"Account created successfully! To start using TabPFN please click on the link in the verification email we sent you."
"Account created successfully! To start using TabPFN please enter the verification code we sent you by mail."
)
+ "\n"
)
# verify token from email
cls._verify_user_email(user_auth_handler, access_token=access_token)

# Login
elif choice == "2":
Expand All @@ -120,12 +124,17 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"):
email = input(cls.indent("Please enter your email: "))
password = getpass.getpass(cls.indent("Please enter your password: "))

successful, message = user_auth_handler.set_token_by_login(
email, password
access_token, message, status_code = (
user_auth_handler.set_token_by_login(email, password)
)
if successful:
if status_code == 200 and access_token is not None:
break
print(cls.indent("Login failed: " + str(message)) + "\n")
if status_code == 403:
# 403 implies that the email is not verified
cls._verify_user_email(user_auth_handler, access_token=access_token)
user_auth_handler.set_token_by_login(email, password)
break
print(cls.indent("Login failed: " + message) + "\n")

prompt = "\n".join(
[
Expand Down Expand Up @@ -239,6 +248,9 @@ def reverify_email(
)
+ "\n"
)
# verify token from email
cls._verify_user_email(user_auth_handler, access_token=access_token)
user_auth_handler.set_token(access_token)
return

@classmethod
Expand Down Expand Up @@ -282,3 +294,22 @@ def _choice_with_retries(cls, prompt: str, choices: list) -> str:
break

return choice.lower()

@classmethod
def _verify_user_email(
cls, user_auth_handler: "UserAuthenticationClient", access_token: str
):
verified = False
while not verified:
token = input(
cls.indent(
"Please enter the correct verification code sent to your email: "
)
)
verified, message = user_auth_handler.verify_email(token, access_token)
if not verified:
print("\n" + cls.indent(str(message) + "Please try again!") + "\n")
else:
print(cls.indent("Email verified successfully!") + "\n")
break
return
5 changes: 5 additions & 0 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ endpoints:
methods: [ "POST" ]
description: "Send verifiaction email or for reverification"

verify_email:
path: "/auth/verify_email/"
methods: [ "GET" ]
description: "Verify email"

send_reset_password_email:
path: "/auth/send_reset_password_email/"
methods: [ "POST" ]
Expand Down
25 changes: 15 additions & 10 deletions tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,22 @@ def set_token_by_registration(
)
if access_token is not None:
self.set_token(access_token)
return is_created, message
return is_created, message, access_token

def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]:
access_token, message = self.service_client.login(email, password)
def set_token_by_login(self, email: str, password: str) -> tuple[bool, str, int]:
access_token, message, status_code = self.service_client.login(email, password)

if access_token is None:
return False, message
return False, message, status_code
elif status_code == 200:
self.set_token(access_token)

self.set_token(access_token)
return True, message
return access_token, message, status_code

def try_reuse_existing_token(self) -> bool | tuple[bool, str]:
def try_reuse_existing_token(self) -> tuple[bool, str or None]:
if self.service_client.access_token is None:
if not self.CACHED_TOKEN_FILE.exists():
return False
return False, None

access_token = self.CACHED_TOKEN_FILE.read_text()

Expand All @@ -73,14 +74,14 @@ def try_reuse_existing_token(self) -> bool | tuple[bool, str]:
is_valid = self.service_client.is_auth_token_outdated(access_token)
if is_valid is False:
self._reset_token()
return False
return False, None
elif is_valid is None:
return False, access_token

logger.debug(f"Reusing existing access token? {is_valid}")
self.set_token(access_token)

return True
return True, access_token

def get_password_policy(self):
return self.service_client.get_password_policy()
Expand All @@ -103,6 +104,10 @@ def send_verification_email(self, access_token: str) -> tuple[bool, str]:
sent, message = self.service_client.send_verification_email(access_token)
return sent, message

def verify_email(self, token: str, access_token: str) -> tuple[bool, str]:
verified, message = self.service_client.verify_email(token, access_token)
return verified, message


class UserDataClient(ServiceClientWrapper):
"""
Expand Down
10 changes: 8 additions & 2 deletions tabpfn_client/tests/unit/test_prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_password_req_to_policy(self):
@patch("getpass.getpass", side_effect=["Password123!", "Password123!"])
@patch(
"builtins.input",
side_effect=["1", "[email protected]", "test", "test", "test", "y"],
side_effect=["1", "[email protected]", "test", "test", "test", "y", "test"],
)
def test_prompt_and_set_token_registration(
self, mock_input, mock_getpass, mock_server
Expand All @@ -30,16 +30,22 @@ def test_prompt_and_set_token_registration(
mock_auth_client.set_token_by_registration.return_value = (
True,
"Registration successful",
"dummy_token",
)
mock_auth_client.validate_email.return_value = (True, "")
mock_auth_client.verify_email.return_value = (True, "Verification successful")
PromptAgent.prompt_and_set_token(user_auth_handler=mock_auth_client)
mock_auth_client.set_token_by_registration.assert_called_once()

@patch("getpass.getpass", side_effect=["password123"])
@patch("builtins.input", side_effect=["2", "[email protected]"])
def test_prompt_and_set_token_login(self, mock_input, mock_getpass):
mock_auth_client = MagicMock()
mock_auth_client.set_token_by_login.return_value = (True, "Login successful")
mock_auth_client.set_token_by_login.return_value = (
True,
"Login successful",
200,
)
PromptAgent.prompt_and_set_token(user_auth_handler=mock_auth_client)
mock_auth_client.set_token_by_login.assert_called_once()

Expand Down
6 changes: 3 additions & 3 deletions tabpfn_client/tests/unit/test_service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_set_token_by_invalid_login(self, mock_server):
401, json={"detail": "Incorrect email or password"}
)
self.assertEqual(
(False, "Incorrect email or password"),
(False, "Incorrect email or password", 401),
UserAuthenticationClient(ServiceClient()).set_token_by_login(
"dummy_email", "dummy_password"
),
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_set_token_by_invalid_registration(self, mock_server):
401, json={"detail": "Password mismatch"}
)
self.assertEqual(
(False, "Password mismatch"),
(False, "Password mismatch", None),
UserAuthenticationClient(ServiceClient()).set_token_by_registration(
"dummy_email",
"dummy_password",
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_reset_cache_after_token_set(self, mock_server):
# mock authentication
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
self.assertTrue(
UserAuthenticationClient(ServiceClient()).try_reuse_existing_token()
UserAuthenticationClient(ServiceClient()).try_reuse_existing_token()[0]
)

# assert token is set
Expand Down

0 comments on commit 621a767

Please sign in to comment.