Skip to content

Commit

Permalink
Store additional user info (#60)
Browse files Browse the repository at this point in the history
* Ask user to provide name and to agree to PII

* Fix test

* Move terms and cond to registration and send agreements to server

* Add new line print before toc

* Update terms link
  • Loading branch information
davidotte authored Jan 5, 2025
1 parent 30901db commit f407788
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "tabpfn-client"
version = "0.0.25"
version = "0.0.26"
requires-python = ">=3.10"
dynamic = ["dependencies", "optional-dependencies"]

Expand Down
4 changes: 0 additions & 4 deletions tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ def init(use_server=True):
PromptAgent.reverify_email(access_token)
else:
PromptAgent.prompt_welcome()
if not PromptAgent.prompt_terms_and_cond():
raise RuntimeError(
"You must agree to the terms and conditions to use TabPFN"
)

# prompt for login / register
PromptAgent.prompt_and_set_token()
Expand Down
56 changes: 53 additions & 3 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def prompt_welcome(cls):

@classmethod
def prompt_and_set_token(cls):
print("inside prompt_and_set_token")
# Choose between registration and login
prompt = "\n".join(
[
Expand All @@ -64,6 +63,13 @@ def prompt_and_set_token(cls):
# Registration
if choice == "1":
validation_link = "tabpfn-2023"

agreed_terms_and_cond = cls.prompt_terms_and_cond()
if not agreed_terms_and_cond:
raise RuntimeError(
"You must agree to the terms and conditions to use TabPFN"
)

while True:
email = input(cls.indent("Please enter your email: "))
# Send request to server to check if email is valid and not already taken.
Expand Down Expand Up @@ -102,7 +108,19 @@ def prompt_and_set_token(cls):
"Entered password and confirmation password do not match, please try again.\n"
)
)
agreed_personally_identifiable_information = (
cls.prompt_personally_identifiable_information()
)
if not agreed_personally_identifiable_information:
raise RuntimeError(
"You must agree to not upload personally identifiable information."
)

additional_info = cls.prompt_add_user_information()
additional_info["agreed_terms_and_cond"] = agreed_terms_and_cond
additional_info["agreed_personally_identifiable_information"] = (
agreed_personally_identifiable_information
)
is_created, message, access_token = (
UserAuthenticationClient.set_token_by_registration(
email, password, password_confirm, validation_link, additional_info
Expand Down Expand Up @@ -180,23 +198,53 @@ def prompt_and_set_token(cls):
def prompt_terms_and_cond(cls) -> bool:
t_and_c = "\n".join(
[
"Please refer to our terms and conditions at: https://www.priorlabs.ai/terms-eu-en "
"\nPlease refer to our terms and conditions at: https://www.priorlabs.ai/terms "
"By using TabPFN, you agree to the following terms and conditions:",
"Do you agree to the above terms and conditions? (y/n): ",
]
)
choice = cls._choice_with_retries(t_and_c, ["y", "n"])
return choice == "y"

@classmethod
def prompt_personally_identifiable_information(cls) -> bool:
pii = "\n".join(
[
"Do you agree to not upload personally identifiable information? (y/n): ",
]
)
choice = cls._choice_with_retries(pii, ["y", "n"])
return choice == "y"

@classmethod
def prompt_add_user_information(cls) -> dict:
print(cls.indent("\nPlease provide your name:"))

# Required fields
while True:
first_name = input(cls.indent("First Name: ")).strip()
if not first_name:
print(
cls.indent("First name is required. Please enter your first name.")
)
continue
break

while True:
last_name = input(cls.indent("Last Name: ")).strip()
if not last_name:
print(cls.indent("Last name is required. Please enter your last name."))
continue
break

print(
cls.indent(
"To help us tailor our support and services to your needs, we have a few optional questions. "
"\nTo help us tailor our support and services to your needs, we have a few optional questions. "
"Feel free to skip any question by leaving it blank."
)
+ "\n"
)

company = input(cls.indent("Where do you work? "))
role = input(cls.indent("What is your role? "))
use_case = input(cls.indent("What do you want to use TabPFN for? "))
Expand All @@ -207,6 +255,8 @@ def prompt_add_user_information(cls) -> dict:
contact_via_email = True if choice_contact == "y" else False

return {
"first_name": first_name,
"last_name": last_name,
"company": company,
"role": role,
"use_case": use_case,
Expand Down
18 changes: 17 additions & 1 deletion tabpfn_client/tests/unit/test_prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,19 @@ def test_password_req_to_policy(self):
@patch("getpass.getpass", side_effect=["Password123!", "Password123!"])
@patch(
"builtins.input",
side_effect=["1", "[email protected]", "test", "test", "test", "y", "test"],
side_effect=[
"1",
"y",
"[email protected]",
"y",
"first",
"last",
"test",
"test",
"test",
"y",
"test",
],
)
def test_prompt_and_set_token_registration(
self, mock_input, mock_getpass, mock_server
Expand Down Expand Up @@ -50,10 +62,14 @@ def test_prompt_and_set_token_registration(
"Password123!",
"tabpfn-2023",
{
"first_name": "first",
"last_name": "last",
"company": "test",
"role": "test",
"use_case": "test",
"contact_via_email": True,
"agreed_terms_and_cond": True,
"agreed_personally_identifiable_information": True,
},
)

Expand Down
19 changes: 10 additions & 9 deletions tabpfn_client/tests/unit/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,7 @@ def tearDown(self):

@with_mock_server()
@patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token")
@patch(
"tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond",
return_value=True,
)
def test_init_remote_classifier(
self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token
):
def test_init_remote_classifier(self, mock_server, mock_prompt_and_set_token):
mock_prompt_and_set_token.side_effect = (
lambda: UserAuthenticationClient.set_token(self.dummy_token)
)
Expand All @@ -67,7 +61,6 @@ def test_init_remote_classifier(

init(use_server=True)
self.assertTrue(mock_prompt_and_set_token.called)
self.assertTrue(mock_prompt_for_terms_and_cond.called)

tabpfn = TabPFNClassifier(n_estimators=10)
self.assertRaises(NotFittedError, tabpfn.predict, self.X_test)
Expand Down Expand Up @@ -157,7 +150,15 @@ def test_reset_on_remote_classifier(self, mock_server):
"tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond",
return_value=False,
)
def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_cond):
@patch(
"builtins.input",
side_effect=[
"1",
],
)
def test_decline_terms_and_cond(
self, mock_server, mock_input, mock_prompt_for_terms_and_cond
):
# mock connection
mock_server.router.get(mock_server.endpoints.root.path).respond(200)

Expand Down
27 changes: 11 additions & 16 deletions tabpfn_client/tests/unit/test_tabpfn_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,7 @@ def tearDown(self):

@with_mock_server()
@patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token")
@patch(
"tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond",
return_value=True,
)
def test_init_remote_regressor(
self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token
):
def test_init_remote_regressor(self, mock_server, mock_prompt_and_set_token):
mock_prompt_and_set_token.side_effect = (
lambda: UserAuthenticationClient.set_token(self.dummy_token)
)
Expand Down Expand Up @@ -68,7 +62,6 @@ def test_init_remote_regressor(

init(use_server=True)
self.assertTrue(mock_prompt_and_set_token.called)
self.assertTrue(mock_prompt_for_terms_and_cond.called)

tabpfn = TabPFNRegressor(n_estimators=10)
self.assertRaises(NotFittedError, tabpfn.predict, self.X_test)
Expand Down Expand Up @@ -118,13 +111,7 @@ def test_reuse_saved_access_token(self, mock_server):

@with_mock_server()
@patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token")
@patch(
"tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond",
return_value=True,
)
def test_invalid_saved_access_token(
self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token
):
def test_invalid_saved_access_token(self, mock_server, mock_prompt_and_set_token):
mock_prompt_and_set_token.side_effect = [RuntimeError]

# mock connection and invalid authentication
Expand Down Expand Up @@ -171,7 +158,15 @@ def test_reset_on_remote_regressor(self, mock_server):
"tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond",
return_value=False,
)
def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_cond):
@patch(
"builtins.input",
side_effect=[
"1",
],
)
def test_decline_terms_and_cond(
self, mock_server, mock_input, mock_prompt_for_terms_and_cond
):
# mock connection
mock_server.router.get(mock_server.endpoints.root.path).respond(200)

Expand Down

0 comments on commit f407788

Please sign in to comment.