From 4d9b11286557f0c1c32719d96760c16bc655f9f8 Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Fri, 6 Dec 2024 21:51:04 +0100 Subject: [PATCH 1/3] first working version of using the browser to login --- tabpfn_client/browser_auth.py | 79 +++++++++++++++++++++++ tabpfn_client/client.py | 15 +++++ tabpfn_client/config.py | 5 -- tabpfn_client/prompt_agent.py | 15 ++++- tabpfn_client/server_config.yaml | 1 + tabpfn_client/service_wrapper.py | 7 ++ tabpfn_client/tests/mock_tabpfn_server.py | 1 + 7 files changed, 117 insertions(+), 6 deletions(-) create mode 100644 tabpfn_client/browser_auth.py diff --git a/tabpfn_client/browser_auth.py b/tabpfn_client/browser_auth.py new file mode 100644 index 0000000..58ff14c --- /dev/null +++ b/tabpfn_client/browser_auth.py @@ -0,0 +1,79 @@ +from threading import Event +import http.server +import socketserver +import webbrowser +import logging +import urllib.parse +from typing import Optional, Tuple + +logger = logging.getLogger(__name__) + + +class BrowserAuthHandler: + def __init__(self, gui_url: str): + self.gui_url = gui_url + + def try_browser_login(self) -> Tuple[bool, Optional[str]]: + """ + Attempts to perform browser-based login + Returns (success: bool, token: Optional[str]) + """ + auth_event = Event() + received_token = None + + class CallbackHandler(http.server.SimpleHTTPRequestHandler): + def do_GET(self): + nonlocal received_token + + parsed = urllib.parse.urlparse(self.path) + query = urllib.parse.parse_qs(parsed.query) + + if "token" in query: + received_token = query["token"][0] + logger.debug("Received auth token from callback") + + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + success_html = """ + + +

Login successful!

+

You can close this window and return to your application.

+ + + """ + self.wfile.write(success_html.encode()) + auth_event.set() + + def log_message(self, format, *args): + pass + + try: + with socketserver.TCPServer(("", 0), CallbackHandler) as httpd: + port = httpd.server_address[1] + callback_url = f"http://localhost:{port}" + + login_url = f"{self.gui_url}/login?callback={callback_url}" + logger.debug(f"Opening browser for login at: {login_url}") + + print( + "\nOpening browser for login. Please complete the login/registration process in your browser and return here.\n" + ) + + if not webbrowser.open(login_url): + logger.debug("Failed to open browser") + print( + "\nCould not open browser automatically. Falling back to command-line login...\n" + ) + return False, None + + logger.info("Waiting for browser login completion...") + while not auth_event.is_set(): + httpd.handle_request() + + return received_token is not None, received_token + + except Exception as e: + logger.debug(f"Browser auth failed: {str(e)}") + return False, None diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index f2398c9..be85d74 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -16,6 +16,7 @@ from tabpfn_client.tabpfn_common_utils import utils as common_utils from tabpfn_client.constants import CACHE_DIR +from tabpfn_client.browser_auth import BrowserAuthHandler logger = logging.getLogger(__name__) @@ -704,3 +705,17 @@ def delete_user_account(self, confirm_pass: str) -> None: ) self._validate_response(response, "delete_user_account") + + def try_browser_login(self) -> tuple[bool, str]: + """ + Attempts browser-based login flow + Returns (success: bool, message: str) + """ + browser_auth = BrowserAuthHandler(self.server_config.gui_url) + success, token = browser_auth.try_browser_login() + + if success and token: + # Don't authorize directly, let UserAuthenticationClient handle it + return True, token + + return False, "Browser login failed or was cancelled" diff --git a/tabpfn_client/config.py b/tabpfn_client/config.py index 259432e..ceab63d 100644 --- a/tabpfn_client/config.py +++ b/tabpfn_client/config.py @@ -46,11 +46,6 @@ def init(use_server=True): PromptAgent.reverify_email(is_valid_token_set[1], user_auth_handler) else: PromptAgent.prompt_welcome() - if not PromptAgent.prompt_terms_and_cond(): - raise RuntimeError( - "You must agree to the terms and conditions to use TabPFN" - ) - # prompt for login / register PromptAgent.prompt_and_set_token(user_auth_handler) diff --git a/tabpfn_client/prompt_agent.py b/tabpfn_client/prompt_agent.py index f909afa..640b65f 100644 --- a/tabpfn_client/prompt_agent.py +++ b/tabpfn_client/prompt_agent.py @@ -44,7 +44,20 @@ def prompt_welcome(cls): @classmethod def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): - # Choose between registration and login + # Try browser login first + success, message = user_auth_handler.try_browser_login() + if success: + print(cls.indent("Login via browser successful!")) + return + + # Fall back to CLI login if browser login failed + # Show terms and conditions for CLI login + if not cls.prompt_terms_and_cond(): + raise RuntimeError( + "You must agree to the terms and conditions to use TabPFN" + ) + + # Rest of the existing CLI login code prompt = "\n".join( [ "Please choose one of the following options:", diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml index 3440cbf..77f5fe3 100644 --- a/tabpfn_client/server_config.yaml +++ b/tabpfn_client/server_config.yaml @@ -8,6 +8,7 @@ protocol: "https" host: "tabpfn-server-wjedmz7r5a-ez.a.run.app" # host: tabpfn-server-preprod-wjedmz7r5a-ez.a.run.app # preprod port: "443" +gui_url: "https://frontend-1039906307296.europe-west4.run.app" endpoints: root: path: "/" diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index a61a7d7..46d5733 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -103,6 +103,13 @@ def send_verification_email(self, access_token: str) -> tuple[bool, str]: sent, message = self.service_client.send_verification_email(access_token) return sent, message + def try_browser_login(self) -> tuple[bool, str]: + """Try to authenticate using browser-based login""" + success, token_or_message = self.service_client.try_browser_login() + if success: + self.set_token(token_or_message) + return success, token_or_message + class UserDataClient(ServiceClientWrapper): """ diff --git a/tabpfn_client/tests/mock_tabpfn_server.py b/tabpfn_client/tests/mock_tabpfn_server.py index f9ac292..935f13d 100644 --- a/tabpfn_client/tests/mock_tabpfn_server.py +++ b/tabpfn_client/tests/mock_tabpfn_server.py @@ -12,6 +12,7 @@ def __init__(self): self.router = None def __enter__(self): + print(self.base_url) self.router = respx.mock(base_url=self.base_url, assert_all_called=True) self.router.start() return self From 885921b79ba842fbdf5a3eb7398c3402b96d3b78 Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Tue, 10 Dec 2024 22:37:18 +0100 Subject: [PATCH 2/3] add and fix tests --- tabpfn_client/browser_auth.py | 11 +- tabpfn_client/tests/mock_tabpfn_server.py | 1 - tabpfn_client/tests/unit/test_browser_auth.py | 248 ++++++++++++++++++ tabpfn_client/tests/unit/test_prompt_agent.py | 52 +++- .../tests/unit/test_tabpfn_classifier.py | 34 ++- .../tests/unit/test_tabpfn_regressor.py | 28 +- 6 files changed, 321 insertions(+), 53 deletions(-) create mode 100644 tabpfn_client/tests/unit/test_browser_auth.py diff --git a/tabpfn_client/browser_auth.py b/tabpfn_client/browser_auth.py index 58ff14c..f158526 100644 --- a/tabpfn_client/browser_auth.py +++ b/tabpfn_client/browser_auth.py @@ -2,12 +2,9 @@ import http.server import socketserver import webbrowser -import logging import urllib.parse from typing import Optional, Tuple -logger = logging.getLogger(__name__) - class BrowserAuthHandler: def __init__(self, gui_url: str): @@ -30,7 +27,6 @@ def do_GET(self): if "token" in query: received_token = query["token"][0] - logger.debug("Received auth token from callback") self.send_response(200) self.send_header("Content-type", "text/html") @@ -55,25 +51,22 @@ def log_message(self, format, *args): callback_url = f"http://localhost:{port}" login_url = f"{self.gui_url}/login?callback={callback_url}" - logger.debug(f"Opening browser for login at: {login_url}") print( "\nOpening browser for login. Please complete the login/registration process in your browser and return here.\n" ) if not webbrowser.open(login_url): - logger.debug("Failed to open browser") print( "\nCould not open browser automatically. Falling back to command-line login...\n" ) return False, None - logger.info("Waiting for browser login completion...") while not auth_event.is_set(): httpd.handle_request() return received_token is not None, received_token - except Exception as e: - logger.debug(f"Browser auth failed: {str(e)}") + except Exception: + print("\n Browser auth failed. Falling back to command-line login...\n") return False, None diff --git a/tabpfn_client/tests/mock_tabpfn_server.py b/tabpfn_client/tests/mock_tabpfn_server.py index 935f13d..f9ac292 100644 --- a/tabpfn_client/tests/mock_tabpfn_server.py +++ b/tabpfn_client/tests/mock_tabpfn_server.py @@ -12,7 +12,6 @@ def __init__(self): self.router = None def __enter__(self): - print(self.base_url) self.router = respx.mock(base_url=self.base_url, assert_all_called=True) self.router.start() return self diff --git a/tabpfn_client/tests/unit/test_browser_auth.py b/tabpfn_client/tests/unit/test_browser_auth.py new file mode 100644 index 0000000..81efe45 --- /dev/null +++ b/tabpfn_client/tests/unit/test_browser_auth.py @@ -0,0 +1,248 @@ +import unittest +from unittest.mock import patch, MagicMock +from tabpfn_client.browser_auth import BrowserAuthHandler +import time +import threading +import urllib.parse +import http.client +import socketserver + + +class TestBrowserAuthHandler(unittest.TestCase): + FALLBACK_MESSAGE = "\nCould not open browser automatically. Falling back to command-line login...\n" + + @patch("tabpfn_client.browser_auth.socketserver.TCPServer") + @patch("builtins.print") + def test_server_setup_exception(self, mock_print, mock_tcp_server): + # Simulate exception during callback server setup + mock_tcp_server.side_effect = Exception("Server setup failed") + + # Instantiate BrowserAuthHandler + browser_auth = BrowserAuthHandler(gui_url="http://example.com") + + # Call try_browser_login + success, token = browser_auth.try_browser_login() + + # Assert that the method returned failure + self.assertFalse(success) + self.assertIsNone(token) + + @patch("tabpfn_client.browser_auth.webbrowser.open", return_value=False) + @patch("builtins.print") + def test_browser_open_failure(self, mock_print, mock_webbrowser_open): + # Instantiate BrowserAuthHandler + browser_auth = BrowserAuthHandler(gui_url="http://example.com") + + # Call try_browser_login + success, token = browser_auth.try_browser_login() + + # Assert that browser open was attempted + mock_webbrowser_open.assert_called_once() + + # Assert that the method returned failure + self.assertFalse(success) + self.assertIsNone(token) + + # Verify that fallback message was printed + mock_print.assert_any_call(self.FALLBACK_MESSAGE) + + @patch("tabpfn_client.browser_auth.webbrowser.open", return_value=True) + @patch("tabpfn_client.browser_auth.socketserver.TCPServer") + @patch("tabpfn_client.browser_auth.Event") + @patch("builtins.print") + def test_user_cancels_login( + self, mock_event, mock_tcp_server, mock_webbrowser_open, mock_print + ): + # Mock HTTP server and auth event + mock_httpd = MagicMock() + mock_tcp_server.return_value.__enter__.return_value = mock_httpd + + # Simulate the auth event never being set + mock_auth_event = MagicMock() + mock_auth_event.is_set.return_value = False + mock_event.return_value = mock_auth_event + + # Set up `handle_request` to raise an exception after a few calls to break the loop + handle_request_call_count = {"count": 0} + + def handle_request_side_effect(): + handle_request_call_count["count"] += 1 + if handle_request_call_count["count"] > 2: + raise TimeoutError("Simulated timeout") + + mock_httpd.handle_request.side_effect = handle_request_side_effect + + # Instantiate BrowserAuthHandler + browser_auth = BrowserAuthHandler(gui_url="http://example.com") + + # Call try_browser_login + success, token = browser_auth.try_browser_login() + + # Assert that the method returned failure + self.assertFalse(success) + self.assertIsNone(token) + + def simulate_callback(self, mock_webbrowser, test_token): + """Helper method to simulate the callback response""" + time.sleep(0.1) # Give the server time to start + calls = mock_webbrowser.call_args_list + login_url = calls[0][0][0] + callback_url = urllib.parse.parse_qs(urllib.parse.urlparse(login_url).query)[ + "callback" + ][0] + port = int(callback_url.split(":")[-1]) + + conn = http.client.HTTPConnection("localhost", port) + conn.request("GET", f"/?token={test_token}") + response = conn.getresponse() + html = response.read().decode() + conn.close() + return login_url, html + + @patch("webbrowser.open") + def test_successful_browser_login(self, mock_webbrowser): + mock_webbrowser.return_value = True + test_token = "test_token_123" + + # Start callback simulation thread + callback_thread = threading.Thread( + target=lambda: self.simulate_callback(mock_webbrowser, test_token) + ) + callback_thread.daemon = True + callback_thread.start() + + browser_auth = BrowserAuthHandler(gui_url="http://example.com") + + # Call try_browser_login + success, token = browser_auth.try_browser_login() + + # Verify results + self.assertTrue(success) + self.assertEqual(token, test_token) + self.assertTrue(mock_webbrowser.called) + + # Verify URL format + login_url = mock_webbrowser.call_args[0][0] + parsed_url = urllib.parse.urlparse(login_url) + self.assertEqual(parsed_url.scheme, "http") + self.assertEqual(parsed_url.netloc, "example.com") + self.assertEqual(parsed_url.path, "/login") + self.assertTrue("callback" in urllib.parse.parse_qs(parsed_url.query)) + + # Let callback thread finish + callback_thread.join(timeout=1) + + @patch("webbrowser.open") + def test_invalid_token_response(self, mock_webbrowser): + """Test handling of invalid/empty token in callback""" + mock_webbrowser.return_value = True + + # Start callback simulation with empty token + callback_thread = threading.Thread( + target=lambda: self.simulate_callback(mock_webbrowser, "") + ) + callback_thread.daemon = True + callback_thread.start() + + browser_auth = BrowserAuthHandler(gui_url="http://example.com") + + success, token = browser_auth.try_browser_login() + self.assertFalse(success) + self.assertIsNone(token) + callback_thread.join(timeout=1) + + @patch("webbrowser.open") + def test_multiple_callback_requests(self, mock_webbrowser): + """Test handling of multiple callback requests""" + mock_webbrowser.return_value = True + test_token = "test_token_123" + + def simulate_multiple_callbacks(): + try: + # First callback with one token + self.simulate_callback(mock_webbrowser, test_token) + # Second callback with different token - should be ignored + # This will fail with ConnectionRefused, which is expected + self.simulate_callback(mock_webbrowser, "different_token") + except ConnectionRefusedError: + # This is expected as the server stops after first successful callback + pass + + callback_thread = threading.Thread(target=simulate_multiple_callbacks) + callback_thread.daemon = True + callback_thread.start() + + browser_auth = BrowserAuthHandler(gui_url="http://example.com") + + success, token = browser_auth.try_browser_login() + self.assertTrue(success) + self.assertEqual(token, test_token) # Should use first token + callback_thread.join(timeout=1) + + @patch("webbrowser.open") + @patch("tabpfn_client.browser_auth.socketserver.TCPServer") + def test_timeout_handling(self, mock_tcp_server, mock_webbrowser): + """Test handling of timeout during login""" + mock_webbrowser.return_value = True + + # Mock HTTP server + mock_httpd = MagicMock() + mock_tcp_server.return_value.__enter__.return_value = mock_httpd + + # Simulate timeout by making handle_request raise TimeoutError + mock_httpd.handle_request.side_effect = TimeoutError("Request timed out") + + browser_auth = BrowserAuthHandler(gui_url="http://example.com") + success, token = browser_auth.try_browser_login() + + self.assertFalse(success) + self.assertIsNone(token) + + @patch("webbrowser.open") + def test_malformed_callback_url(self, mock_webbrowser): + """Test handling of malformed callback URL""" + mock_webbrowser.return_value = True + + def simulate_malformed_request(): + time.sleep(0.1) + calls = mock_webbrowser.call_args_list + callback_url = urllib.parse.parse_qs( + urllib.parse.urlparse(calls[0][0][0]).query + )["callback"][0] + port = int(callback_url.split(":")[-1]) + + # Send malformed request + conn = http.client.HTTPConnection("localhost", port) + conn.request("GET", "/malformed?not_a_token=123") + conn.getresponse() + conn.close() + + callback_thread = threading.Thread(target=simulate_malformed_request) + callback_thread.daemon = True + callback_thread.start() + + browser_auth = BrowserAuthHandler(gui_url="http://example.com") + success, token = browser_auth.try_browser_login() + self.assertFalse(success) + self.assertIsNone(token) + callback_thread.join(timeout=1) + + @patch("webbrowser.open") + def test_server_port_in_use(self, mock_webbrowser): + """Test handling when preferred port is in use""" + mock_webbrowser.return_value = True + + # Create a server to occupy a port + with socketserver.TCPServer( + ("", 0), http.server.SimpleHTTPRequestHandler + ) as blocking_server: + port = blocking_server.server_address[1] + + # Try to start auth server on same port + with patch("socketserver.TCPServer.__init__") as mock_server: + mock_server.side_effect = OSError(f"Port {port} already in use") + + browser_auth = BrowserAuthHandler(gui_url="http://example.com") + success, token = browser_auth.try_browser_login() + self.assertFalse(success) + self.assertIsNone(token) diff --git a/tabpfn_client/tests/unit/test_prompt_agent.py b/tabpfn_client/tests/unit/test_prompt_agent.py index 93c4e40..cb66f73 100644 --- a/tabpfn_client/tests/unit/test_prompt_agent.py +++ b/tabpfn_client/tests/unit/test_prompt_agent.py @@ -1,7 +1,6 @@ import unittest from unittest.mock import patch, MagicMock from tabpfn_client.prompt_agent import PromptAgent -from tabpfn_client.tests.mock_tabpfn_server import with_mock_server class TestPromptAgent(unittest.TestCase): @@ -11,16 +10,34 @@ def test_password_req_to_policy(self): requirements = [repr(req) for req in password_policy.test("")] self.assertEqual(password_req, requirements) - @with_mock_server() - @patch("getpass.getpass", side_effect=["Password123!", "Password123!"]) + @patch( + "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", + return_value=True, + ) + @patch( + "tabpfn_client.prompt_agent.getpass.getpass", + side_effect=["Password123!", "Password123!"], + ) @patch( "builtins.input", - side_effect=["1", "user@example.com", "test", "test", "test", "y"], + side_effect=[ + "1", + "user@example.com", + "Acme Corp", + "Data Analysis", + "Data Scientist", + "y", + ], ) def test_prompt_and_set_token_registration( - self, mock_input, mock_getpass, mock_server + self, + mock_input, + mock_getpass, + mock_prompt_terms_and_cond, ): mock_auth_client = MagicMock() + + mock_auth_client.try_browser_login.return_value = (False, None) mock_auth_client.get_password_policy.return_value = [ "Length(8)", "Uppercase(1)", @@ -32,15 +49,32 @@ def test_prompt_and_set_token_registration( "Registration successful", ) mock_auth_client.validate_email.return_value = (True, "") - PromptAgent.prompt_and_set_token(user_auth_handler=mock_auth_client) + + with patch("builtins.print"): + PromptAgent.prompt_and_set_token(user_auth_handler=mock_auth_client) + + mock_auth_client.try_browser_login.assert_called_once() + mock_auth_client.validate_email.assert_called_once_with("user@example.com") mock_auth_client.set_token_by_registration.assert_called_once() - @patch("getpass.getpass", side_effect=["password123"]) + @patch( + "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", + return_value=True, + ) + @patch("getpass.getpass", return_value="password123") @patch("builtins.input", side_effect=["2", "user@example.com"]) - def test_prompt_and_set_token_login(self, mock_input, mock_getpass): + def test_prompt_and_set_token_login( + self, mock_input, mock_getpass, mock_prompt_terms_and_cond + ): mock_auth_client = MagicMock() + # Simulate browser login failure + mock_auth_client.try_browser_login.return_value = (False, None) mock_auth_client.set_token_by_login.return_value = (True, "Login successful") - PromptAgent.prompt_and_set_token(user_auth_handler=mock_auth_client) + + # Call prompt_and_set_token + with patch("builtins.print"): + PromptAgent.prompt_and_set_token(user_auth_handler=mock_auth_client) + mock_auth_client.set_token_by_login.assert_called_once() @patch("builtins.input", return_value="y") diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py index 4e0430d..cafe136 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py @@ -36,14 +36,14 @@ def tearDown(self): # remove cache dir shutil.rmtree(CACHE_DIR, ignore_errors=True) - @with_mock_server() + @patch("tabpfn_client.browser_auth.webbrowser.open", return_value=False) @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") - @patch( - "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", - return_value=True, - ) + @with_mock_server() def test_init_remote_classifier( - self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token + self, + mock_server, + mock_prompt_and_set_token, + mock_webbrowser_open, ): mock_prompt_and_set_token.side_effect = ( lambda user_auth_handler: user_auth_handler.set_token(self.dummy_token) @@ -71,8 +71,6 @@ def test_init_remote_classifier( self.assertRaises(NotFittedError, tabpfn.predict, self.X_test) tabpfn.fit(self.X_train, self.y_train) self.assertTrue(mock_prompt_and_set_token.called) - self.assertTrue(mock_prompt_for_terms_and_cond.called) - y_pred = tabpfn.predict(self.X_test) self.assertTrue(np.all(np.argmax(mock_predict_response, axis=1) == y_pred)) @@ -104,13 +102,7 @@ def test_reuse_saved_access_token(self, mock_server): @with_mock_server() @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") - @patch( - "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", - return_value=True, - ) - def test_invalid_saved_access_token( - self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token - ): + def test_invalid_saved_access_token(self, mock_server, mock_prompt_and_set_token): mock_prompt_and_set_token.side_effect = [RuntimeError] # mock connection and invalid authentication @@ -152,13 +144,19 @@ def test_reset_on_remote_classifier(self, mock_server): # check if config is reset self.assertFalse(estimator.config.g_tabpfn_config.is_initialized) - @with_mock_server() @patch( "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", return_value=False, ) - def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_cond): - # mock connection + @patch("tabpfn_client.browser_auth.webbrowser.open", return_value=False) + @with_mock_server() # TODO (leo): investigate why this needs to be the last decorator + def test_decline_terms_and_cond( + self, + mock_server, + mock_webbrowser_open, + mock_prompt_for_terms_and_cond, + ): + # Mock connection mock_server.router.get(mock_server.endpoints.root.path).respond(200) self.assertRaises(RuntimeError, init, use_server=True) diff --git a/tabpfn_client/tests/unit/test_tabpfn_regressor.py b/tabpfn_client/tests/unit/test_tabpfn_regressor.py index c07aa30..f45318b 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_regressor.py +++ b/tabpfn_client/tests/unit/test_tabpfn_regressor.py @@ -36,14 +36,14 @@ def tearDown(self): # remove cache dir shutil.rmtree(CACHE_DIR, ignore_errors=True) - @with_mock_server() + @patch("tabpfn_client.browser_auth.webbrowser.open", return_value=False) @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") - @patch( - "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", - return_value=True, - ) + @with_mock_server() def test_init_remote_regressor( - self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token + self, + mock_server, + mock_prompt_and_set_token, + mock_webbrowser_open, ): mock_prompt_and_set_token.side_effect = ( lambda user_auth_handler: user_auth_handler.set_token(self.dummy_token) @@ -74,7 +74,6 @@ def test_init_remote_regressor( self.assertRaises(NotFittedError, tabpfn.predict, self.X_test) tabpfn.fit(self.X_train, self.y_train) self.assertTrue(mock_prompt_and_set_token.called) - self.assertTrue(mock_prompt_for_terms_and_cond.called) for metric in ["mean", "median", "mode"]: tabpfn.optimize_metric = metric @@ -120,13 +119,7 @@ def test_reuse_saved_access_token(self, mock_server): @with_mock_server() @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") - @patch( - "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", - return_value=True, - ) - def test_invalid_saved_access_token( - self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token - ): + def test_invalid_saved_access_token(self, mock_server, mock_prompt_and_set_token): mock_prompt_and_set_token.side_effect = [RuntimeError] # mock connection and invalid authentication @@ -168,12 +161,15 @@ def test_reset_on_remote_regressor(self, mock_server): # check if config is reset self.assertFalse(estimator.config.g_tabpfn_config.is_initialized) - @with_mock_server() @patch( "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", return_value=False, ) - def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_cond): + @patch("tabpfn_client.browser_auth.webbrowser.open", return_value=False) + @with_mock_server() + def test_decline_terms_and_cond( + self, mock_server, mock_webbrowser_open, mock_prompt_for_terms_and_cond + ): # mock connection mock_server.router.get(mock_server.endpoints.root.path).respond(200) From fa6426108e9242ed790d9a2b71d09935032c1920 Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Thu, 12 Dec 2024 18:10:21 +0100 Subject: [PATCH 3/3] change gui url to new adress --- tabpfn_client/server_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml index 77f5fe3..2fa9c31 100644 --- a/tabpfn_client/server_config.yaml +++ b/tabpfn_client/server_config.yaml @@ -8,7 +8,7 @@ protocol: "https" host: "tabpfn-server-wjedmz7r5a-ez.a.run.app" # host: tabpfn-server-preprod-wjedmz7r5a-ez.a.run.app # preprod port: "443" -gui_url: "https://frontend-1039906307296.europe-west4.run.app" +gui_url: "https://ux.priorlabs.ai" endpoints: root: path: "/"