diff --git a/pyproject.toml b/pyproject.toml index 949958f..ed73d78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "tabpfn-client" -version = "0.0.25" +version = "0.0.26" requires-python = ">=3.10" dynamic = ["dependencies", "optional-dependencies"] diff --git a/tabpfn_client/config.py b/tabpfn_client/config.py index cae6d94..f098373 100644 --- a/tabpfn_client/config.py +++ b/tabpfn_client/config.py @@ -45,10 +45,6 @@ def init(use_server=True): PromptAgent.reverify_email(access_token) else: PromptAgent.prompt_welcome() - if not PromptAgent.prompt_terms_and_cond(): - raise RuntimeError( - "You must agree to the terms and conditions to use TabPFN" - ) # prompt for login / register PromptAgent.prompt_and_set_token() diff --git a/tabpfn_client/prompt_agent.py b/tabpfn_client/prompt_agent.py index 0f2c202..672a1d6 100644 --- a/tabpfn_client/prompt_agent.py +++ b/tabpfn_client/prompt_agent.py @@ -47,7 +47,6 @@ def prompt_welcome(cls): @classmethod def prompt_and_set_token(cls): - print("inside prompt_and_set_token") # Choose between registration and login prompt = "\n".join( [ @@ -64,6 +63,13 @@ def prompt_and_set_token(cls): # Registration if choice == "1": validation_link = "tabpfn-2023" + + agreed_terms_and_cond = cls.prompt_terms_and_cond() + if not agreed_terms_and_cond: + raise RuntimeError( + "You must agree to the terms and conditions to use TabPFN" + ) + while True: email = input(cls.indent("Please enter your email: ")) # Send request to server to check if email is valid and not already taken. @@ -102,7 +108,19 @@ def prompt_and_set_token(cls): "Entered password and confirmation password do not match, please try again.\n" ) ) + agreed_personally_identifiable_information = ( + cls.prompt_personally_identifiable_information() + ) + if not agreed_personally_identifiable_information: + raise RuntimeError( + "You must agree to not upload personally identifiable information." + ) + additional_info = cls.prompt_add_user_information() + additional_info["agreed_terms_and_cond"] = agreed_terms_and_cond + additional_info["agreed_personally_identifiable_information"] = ( + agreed_personally_identifiable_information + ) is_created, message, access_token = ( UserAuthenticationClient.set_token_by_registration( email, password, password_confirm, validation_link, additional_info @@ -180,7 +198,7 @@ def prompt_and_set_token(cls): def prompt_terms_and_cond(cls) -> bool: t_and_c = "\n".join( [ - "Please refer to our terms and conditions at: https://www.priorlabs.ai/terms-eu-en " + "\nPlease refer to our terms and conditions at: https://www.priorlabs.ai/terms " "By using TabPFN, you agree to the following terms and conditions:", "Do you agree to the above terms and conditions? (y/n): ", ] @@ -188,15 +206,45 @@ def prompt_terms_and_cond(cls) -> bool: choice = cls._choice_with_retries(t_and_c, ["y", "n"]) return choice == "y" + @classmethod + def prompt_personally_identifiable_information(cls) -> bool: + pii = "\n".join( + [ + "Do you agree to not upload personally identifiable information? (y/n): ", + ] + ) + choice = cls._choice_with_retries(pii, ["y", "n"]) + return choice == "y" + @classmethod def prompt_add_user_information(cls) -> dict: + print(cls.indent("\nPlease provide your name:")) + + # Required fields + while True: + first_name = input(cls.indent("First Name: ")).strip() + if not first_name: + print( + cls.indent("First name is required. Please enter your first name.") + ) + continue + break + + while True: + last_name = input(cls.indent("Last Name: ")).strip() + if not last_name: + print(cls.indent("Last name is required. Please enter your last name.")) + continue + break + print( cls.indent( - "To help us tailor our support and services to your needs, we have a few optional questions. " + "\nTo help us tailor our support and services to your needs, we have a few optional questions. " "Feel free to skip any question by leaving it blank." ) + "\n" ) + company = input(cls.indent("Where do you work? ")) role = input(cls.indent("What is your role? ")) use_case = input(cls.indent("What do you want to use TabPFN for? ")) @@ -207,6 +255,8 @@ def prompt_add_user_information(cls) -> dict: contact_via_email = True if choice_contact == "y" else False return { + "first_name": first_name, + "last_name": last_name, "company": company, "role": role, "use_case": use_case, diff --git a/tabpfn_client/tests/unit/test_prompt_agent.py b/tabpfn_client/tests/unit/test_prompt_agent.py index e88e5b1..f056a17 100644 --- a/tabpfn_client/tests/unit/test_prompt_agent.py +++ b/tabpfn_client/tests/unit/test_prompt_agent.py @@ -15,7 +15,19 @@ def test_password_req_to_policy(self): @patch("getpass.getpass", side_effect=["Password123!", "Password123!"]) @patch( "builtins.input", - side_effect=["1", "user@example.com", "test", "test", "test", "y", "test"], + side_effect=[ + "1", + "y", + "user@example.com", + "y", + "first", + "last", + "test", + "test", + "test", + "y", + "test", + ], ) def test_prompt_and_set_token_registration( self, mock_input, mock_getpass, mock_server @@ -50,10 +62,14 @@ def test_prompt_and_set_token_registration( "Password123!", "tabpfn-2023", { + "first_name": "first", + "last_name": "last", "company": "test", "role": "test", "use_case": "test", "contact_via_email": True, + "agreed_terms_and_cond": True, + "agreed_personally_identifiable_information": True, }, ) diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py index 70b8b6e..8abc2e5 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py @@ -36,13 +36,7 @@ def tearDown(self): @with_mock_server() @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") - @patch( - "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", - return_value=True, - ) - def test_init_remote_classifier( - self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token - ): + def test_init_remote_classifier(self, mock_server, mock_prompt_and_set_token): mock_prompt_and_set_token.side_effect = ( lambda: UserAuthenticationClient.set_token(self.dummy_token) ) @@ -67,7 +61,6 @@ def test_init_remote_classifier( init(use_server=True) self.assertTrue(mock_prompt_and_set_token.called) - self.assertTrue(mock_prompt_for_terms_and_cond.called) tabpfn = TabPFNClassifier(n_estimators=10) self.assertRaises(NotFittedError, tabpfn.predict, self.X_test) @@ -157,7 +150,15 @@ def test_reset_on_remote_classifier(self, mock_server): "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", return_value=False, ) - def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_cond): + @patch( + "builtins.input", + side_effect=[ + "1", + ], + ) + def test_decline_terms_and_cond( + self, mock_server, mock_input, mock_prompt_for_terms_and_cond + ): # mock connection mock_server.router.get(mock_server.endpoints.root.path).respond(200) diff --git a/tabpfn_client/tests/unit/test_tabpfn_regressor.py b/tabpfn_client/tests/unit/test_tabpfn_regressor.py index 1689b0e..91bcfa8 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_regressor.py +++ b/tabpfn_client/tests/unit/test_tabpfn_regressor.py @@ -34,13 +34,7 @@ def tearDown(self): @with_mock_server() @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") - @patch( - "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", - return_value=True, - ) - def test_init_remote_regressor( - self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token - ): + def test_init_remote_regressor(self, mock_server, mock_prompt_and_set_token): mock_prompt_and_set_token.side_effect = ( lambda: UserAuthenticationClient.set_token(self.dummy_token) ) @@ -68,7 +62,6 @@ def test_init_remote_regressor( init(use_server=True) self.assertTrue(mock_prompt_and_set_token.called) - self.assertTrue(mock_prompt_for_terms_and_cond.called) tabpfn = TabPFNRegressor(n_estimators=10) self.assertRaises(NotFittedError, tabpfn.predict, self.X_test) @@ -118,13 +111,7 @@ def test_reuse_saved_access_token(self, mock_server): @with_mock_server() @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") - @patch( - "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", - return_value=True, - ) - def test_invalid_saved_access_token( - self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token - ): + def test_invalid_saved_access_token(self, mock_server, mock_prompt_and_set_token): mock_prompt_and_set_token.side_effect = [RuntimeError] # mock connection and invalid authentication @@ -171,7 +158,15 @@ def test_reset_on_remote_regressor(self, mock_server): "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", return_value=False, ) - def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_cond): + @patch( + "builtins.input", + side_effect=[ + "1", + ], + ) + def test_decline_terms_and_cond( + self, mock_server, mock_input, mock_prompt_for_terms_and_cond + ): # mock connection mock_server.router.get(mock_server.endpoints.root.path).respond(200)