Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store additional user info #60

Merged
merged 5 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "tabpfn-client"
version = "0.0.25"
version = "0.0.26"
requires-python = ">=3.10"
dynamic = ["dependencies", "optional-dependencies"]

Expand Down
4 changes: 0 additions & 4 deletions tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ def init(use_server=True):
PromptAgent.reverify_email(access_token)
else:
PromptAgent.prompt_welcome()
if not PromptAgent.prompt_terms_and_cond():
raise RuntimeError(
"You must agree to the terms and conditions to use TabPFN"
)

# prompt for login / register
PromptAgent.prompt_and_set_token()
Expand Down
56 changes: 53 additions & 3 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def prompt_welcome(cls):

@classmethod
def prompt_and_set_token(cls):
print("inside prompt_and_set_token")
# Choose between registration and login
prompt = "\n".join(
[
Expand All @@ -64,6 +63,13 @@ def prompt_and_set_token(cls):
# Registration
if choice == "1":
validation_link = "tabpfn-2023"

agreed_terms_and_cond = cls.prompt_terms_and_cond()
if not agreed_terms_and_cond:
raise RuntimeError(
"You must agree to the terms and conditions to use TabPFN"
)

while True:
email = input(cls.indent("Please enter your email: "))
# Send request to server to check if email is valid and not already taken.
Expand Down Expand Up @@ -102,7 +108,19 @@ def prompt_and_set_token(cls):
"Entered password and confirmation password do not match, please try again.\n"
)
)
agreed_personally_identifiable_information = (
cls.prompt_personally_identifiable_information()
)
if not agreed_personally_identifiable_information:
raise RuntimeError(
"You must agree to not upload personally identifiable information."
)

additional_info = cls.prompt_add_user_information()
additional_info["agreed_terms_and_cond"] = agreed_terms_and_cond
additional_info["agreed_personally_identifiable_information"] = (
agreed_personally_identifiable_information
)
is_created, message, access_token = (
UserAuthenticationClient.set_token_by_registration(
email, password, password_confirm, validation_link, additional_info
Expand Down Expand Up @@ -180,23 +198,53 @@ def prompt_and_set_token(cls):
def prompt_terms_and_cond(cls) -> bool:
t_and_c = "\n".join(
[
"Please refer to our terms and conditions at: https://www.priorlabs.ai/terms-eu-en "
"\nPlease refer to our terms and conditions at: https://www.priorlabs.ai/terms-eu-en "
davidotte marked this conversation as resolved.
Show resolved Hide resolved
"By using TabPFN, you agree to the following terms and conditions:",
"Do you agree to the above terms and conditions? (y/n): ",
]
)
choice = cls._choice_with_retries(t_and_c, ["y", "n"])
return choice == "y"

@classmethod
def prompt_personally_identifiable_information(cls) -> bool:
pii = "\n".join(
[
"Do you agree to not upload personally identifiable information? (y/n): ",
]
)
choice = cls._choice_with_retries(pii, ["y", "n"])
return choice == "y"

@classmethod
def prompt_add_user_information(cls) -> dict:
print(cls.indent("\nPlease provide your name:"))

# Required fields
while True:
first_name = input(cls.indent("First Name: ")).strip()
if not first_name:
print(
cls.indent("First name is required. Please enter your first name.")
)
continue
break

while True:
last_name = input(cls.indent("Last Name: ")).strip()
if not last_name:
print(cls.indent("Last name is required. Please enter your last name."))
continue
break

print(
cls.indent(
"To help us tailor our support and services to your needs, we have a few optional questions. "
"\nTo help us tailor our support and services to your needs, we have a few optional questions. "
"Feel free to skip any question by leaving it blank."
)
+ "\n"
)

company = input(cls.indent("Where do you work? "))
role = input(cls.indent("What is your role? "))
use_case = input(cls.indent("What do you want to use TabPFN for? "))
Expand All @@ -207,6 +255,8 @@ def prompt_add_user_information(cls) -> dict:
contact_via_email = True if choice_contact == "y" else False

return {
"first_name": first_name,
"last_name": last_name,
"company": company,
"role": role,
"use_case": use_case,
Expand Down
18 changes: 17 additions & 1 deletion tabpfn_client/tests/unit/test_prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,19 @@ def test_password_req_to_policy(self):
@patch("getpass.getpass", side_effect=["Password123!", "Password123!"])
@patch(
"builtins.input",
side_effect=["1", "[email protected]", "test", "test", "test", "y", "test"],
side_effect=[
"1",
"y",
"[email protected]",
"y",
"first",
"last",
"test",
"test",
"test",
"y",
"test",
],
)
def test_prompt_and_set_token_registration(
self, mock_input, mock_getpass, mock_server
Expand Down Expand Up @@ -50,10 +62,14 @@ def test_prompt_and_set_token_registration(
"Password123!",
"tabpfn-2023",
{
"first_name": "first",
"last_name": "last",
"company": "test",
"role": "test",
"use_case": "test",
"contact_via_email": True,
"agreed_terms_and_cond": True,
"agreed_personally_identifiable_information": True,
},
)

Expand Down
19 changes: 10 additions & 9 deletions tabpfn_client/tests/unit/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,7 @@ def tearDown(self):

@with_mock_server()
@patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token")
@patch(
"tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond",
return_value=True,
)
def test_init_remote_classifier(
self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token
):
def test_init_remote_classifier(self, mock_server, mock_prompt_and_set_token):
mock_prompt_and_set_token.side_effect = (
lambda: UserAuthenticationClient.set_token(self.dummy_token)
)
Expand All @@ -67,7 +61,6 @@ def test_init_remote_classifier(

init(use_server=True)
self.assertTrue(mock_prompt_and_set_token.called)
self.assertTrue(mock_prompt_for_terms_and_cond.called)

tabpfn = TabPFNClassifier(n_estimators=10)
self.assertRaises(NotFittedError, tabpfn.predict, self.X_test)
Expand Down Expand Up @@ -157,7 +150,15 @@ def test_reset_on_remote_classifier(self, mock_server):
"tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond",
return_value=False,
)
def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_cond):
@patch(
"builtins.input",
side_effect=[
"1",
],
)
def test_decline_terms_and_cond(
self, mock_server, mock_input, mock_prompt_for_terms_and_cond
):
# mock connection
mock_server.router.get(mock_server.endpoints.root.path).respond(200)

Expand Down
27 changes: 11 additions & 16 deletions tabpfn_client/tests/unit/test_tabpfn_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,7 @@ def tearDown(self):

@with_mock_server()
@patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token")
@patch(
"tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond",
return_value=True,
)
def test_init_remote_regressor(
self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token
):
def test_init_remote_regressor(self, mock_server, mock_prompt_and_set_token):
mock_prompt_and_set_token.side_effect = (
lambda: UserAuthenticationClient.set_token(self.dummy_token)
)
Expand Down Expand Up @@ -68,7 +62,6 @@ def test_init_remote_regressor(

init(use_server=True)
self.assertTrue(mock_prompt_and_set_token.called)
self.assertTrue(mock_prompt_for_terms_and_cond.called)

tabpfn = TabPFNRegressor(n_estimators=10)
self.assertRaises(NotFittedError, tabpfn.predict, self.X_test)
Expand Down Expand Up @@ -118,13 +111,7 @@ def test_reuse_saved_access_token(self, mock_server):

@with_mock_server()
@patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token")
@patch(
"tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond",
return_value=True,
)
def test_invalid_saved_access_token(
self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token
):
def test_invalid_saved_access_token(self, mock_server, mock_prompt_and_set_token):
mock_prompt_and_set_token.side_effect = [RuntimeError]

# mock connection and invalid authentication
Expand Down Expand Up @@ -171,7 +158,15 @@ def test_reset_on_remote_regressor(self, mock_server):
"tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond",
return_value=False,
)
def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_cond):
@patch(
"builtins.input",
side_effect=[
"1",
],
)
def test_decline_terms_and_cond(
self, mock_server, mock_input, mock_prompt_for_terms_and_cond
):
# mock connection
mock_server.router.get(mock_server.endpoints.root.path).respond(200)

Expand Down
Loading