From b96fb7e0843fb07b8dfa756ae3c74e29edb79755 Mon Sep 17 00:00:00 2001 From: "Liam, SB Hoo" Date: Thu, 19 Oct 2023 14:40:23 +0200 Subject: [PATCH 1/7] Major refactoring --- tabpfn_client/client.py | 262 ++++++++++++++++++ tabpfn_client/constants.py | 3 + tabpfn_client/prompt_agent.py | 96 +++++++ tabpfn_client/remote_tabpfn_classifier.py | 70 +++++ tabpfn_client/service_wrapper.py | 124 +++++++++ tabpfn_client/tabpfn_classifier.py | 172 +++--------- tabpfn_client/tabpfn_service_client.py | 215 -------------- .../integration/test_tabpfn_classifier.py | 24 +- tabpfn_client/tests/mock_tabpfn_server.py | 11 +- ...abpfn_service_client.py => test_client.py} | 31 +-- .../unit/test_remote_tabpfn_classifier.py | 64 +++++ .../tests/unit/test_tabpfn_classifier.py | 100 +++---- 12 files changed, 740 insertions(+), 432 deletions(-) create mode 100644 tabpfn_client/client.py create mode 100644 tabpfn_client/constants.py create mode 100644 tabpfn_client/prompt_agent.py create mode 100644 tabpfn_client/remote_tabpfn_classifier.py create mode 100644 tabpfn_client/service_wrapper.py delete mode 100644 tabpfn_client/tabpfn_service_client.py rename tabpfn_client/tests/unit/{test_tabpfn_service_client.py => test_client.py} (70%) create mode 100644 tabpfn_client/tests/unit/test_remote_tabpfn_classifier.py diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py new file mode 100644 index 0000000..eb80b31 --- /dev/null +++ b/tabpfn_client/client.py @@ -0,0 +1,262 @@ +from pathlib import Path +import httpx +import logging + +import numpy as np +from omegaconf import OmegaConf + +from tabpfn_client.tabpfn_common_utils import utils as common_utils + + +logger = logging.getLogger(__name__) + +SERVER_CONFIG_FILE = Path(__file__).parent.resolve() / "server_config.yaml" +SERVER_CONFIG = OmegaConf.load(SERVER_CONFIG_FILE) + + +@common_utils.singleton +class ServiceClient: + """ + Singleton class for handling communication with the server. + It encapsulates all the API calls to the server. + """ + + def __init__(self): + self.server_config = SERVER_CONFIG + self.server_endpoints = SERVER_CONFIG["endpoints"] + self.httpx_timeout_s = 30 # temporary workaround for slow computation on server side + self.httpx_client = httpx.Client( + base_url=f"https://{self.server_config.host}:{self.server_config.port}", + timeout=self.httpx_timeout_s + ) + + self._access_token = None + + @property + def access_token(self): + return self._access_token + + def set_access_token(self, access_token: str): + self._access_token = access_token + + def reset_access_token(self): + self._access_token = None + + @property + def is_initialized(self): + return self.access_token is not None \ + and self.access_token != "" + + def upload_train_set(self, X, y) -> str: + """ + Upload a train set to server and return the train set UID if successful. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The training input samples. + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values. + + Returns + ------- + train_set_uid : str + The unique ID of the train set in the server. + + """ + X = common_utils.serialize_to_csv_formatted_bytes(X) + y = common_utils.serialize_to_csv_formatted_bytes(y) + + response = self.httpx_client.post( + url=self.server_endpoints.upload_train_set.path, + headers={"Authorization": f"Bearer {self.access_token}"}, + files=common_utils.to_httpx_post_file_format([ + ("x_file", "x_train_filename", X), + ("y_file", "y_train_filename", y) + ]) + ) + + if response.status_code != 200: + logger.error(f"Fail to call upload_train_set(), response status: {response.status_code}") + raise RuntimeError(f"Fail to call upload_train_set(), server response: {response.json()}") + + train_set_uid = response.json()["train_set_uid"] + return train_set_uid + + def predict(self, train_set_uid: str, x_test): + """ + Predict the class labels for the provided data (test set). + + Parameters + ---------- + train_set_uid : str + The unique ID of the train set in the server. + x_test : array-like of shape (n_samples, n_features) + The test input. + + Returns + ------- + y_pred : array-like of shape (n_samples,) + The predicted class labels. + """ + + x_test = common_utils.serialize_to_csv_formatted_bytes(x_test) + + response = self.httpx_client.post( + url=self.server_endpoints.predict.path, + headers={"Authorization": f"Bearer {self.access_token}"}, + params={"train_set_uid": train_set_uid}, + files=common_utils.to_httpx_post_file_format([ + ("x_file", "x_test_filename", x_test) + ]) + ) + + if response.status_code != 200: + logger.error(f"Fail to call predict(), response status: {response.status_code}") + raise RuntimeError(f"Fail to call predict(), server response: {response.json()}") + + return np.array(response.json()["y_pred"]) + + def predict_proba(self, train_set_uid: str, x_test): + """ + Predict the class probabilities for the provided data (test set). + + Parameters + ---------- + train_set_uid : str + The unique ID of the train set in the server. + x_test : array-like of shape (n_samples, n_features) + The test input. + + Returns + ------- + + """ + x_test = common_utils.serialize_to_csv_formatted_bytes(x_test) + + response = self.httpx_client.post( + url=self.server_endpoints.predict_proba.path, + headers={"Authorization": f"Bearer {self.access_token}"}, + params={"train_set_uid": train_set_uid}, + files=common_utils.to_httpx_post_file_format([ + ("x_file", "x_test_filename", x_test) + ]) + ) + + if response.status_code != 200: + logger.error(f"Fail to call predict_proba(), response status: {response.status_code}") + raise RuntimeError(f"Fail to call predict_proba(), server response: {response.json()}") + + return np.array(response.json()["y_pred_proba"]) + + def try_connection(self) -> bool: + """ + Check if server is reachable and return True if successful. + """ + found_valid_connection = False + try: + response = self.httpx_client.get(self.server_endpoints.root.path) + if response.status_code == 200: + found_valid_connection = True + + except httpx.ConnectError: + found_valid_connection = False + + return found_valid_connection + + def try_authenticate(self, access_token) -> bool: + """ + Check if the provided access token is valid and return True if successful. + """ + is_authenticated = False + response = self.httpx_client.get( + self.server_endpoints.protected_root.path, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + if response.status_code == 200: + is_authenticated = True + + return is_authenticated + + def register( + self, + email: str, + password: str, + password_confirm: str + ) -> (bool, str): + """ + Register a new user with the provided credentials. + + Parameters + ---------- + email : str + password : str + password_confirm : str + + Returns + ------- + is_created : bool + True if the user is created successfully. + message : str + The message returned from the server. + """ + + response = self.httpx_client.post( + self.server_endpoints.register.path, + params={"email": email, "password": password, "password_confirm": password_confirm} + ) + + if response.status_code == 200: + is_created = True + message = response.json()["message"] + else: + is_created = False + message = response.json()["detail"] + + return is_created, message + + def login(self, email: str, password: str) -> str | None: + """ + Login with the provided credentials and return the access token if successful. + + Parameters + ---------- + email : str + password : str + + Returns + ------- + access_token : str | None + The access token returned from the server. Return None if login fails. + """ + + access_token = None + response = self.httpx_client.post( + self.server_endpoints.login.path, + data=common_utils.to_oauth_request_form(email, password) + ) + + if response.status_code == 200: + access_token = response.json()["access_token"] + + return access_token + + def get_password_policy(self) -> {}: + """ + Get the password policy from the server. + + Returns + ------- + password_policy : {} + The password policy returned from the server. + """ + + response = self.httpx_client.get( + self.server_endpoints.password_policy.path, + ) + if response.status_code != 200: + logger.error(f"Fail to call get_password_policy(), response status: {response.status_code}") + raise RuntimeError(f"Fail to call get_password_policy(), server response: {response.json()}") + + return response.json()["requirements"] diff --git a/tabpfn_client/constants.py b/tabpfn_client/constants.py new file mode 100644 index 0000000..2ebea67 --- /dev/null +++ b/tabpfn_client/constants.py @@ -0,0 +1,3 @@ +from pathlib import Path + +CACHE_DIR = Path(__file__).parent.resolve() / ".tabpfn" diff --git a/tabpfn_client/prompt_agent.py b/tabpfn_client/prompt_agent.py new file mode 100644 index 0000000..0343cf8 --- /dev/null +++ b/tabpfn_client/prompt_agent.py @@ -0,0 +1,96 @@ +import textwrap +import getpass + +from tabpfn_client.service_wrapper import UserAuthenticationClient + + +class PromptAgent: + @staticmethod + def indent(text: str): + indent_factor = 2 + indent_str = " " * indent_factor + return textwrap.indent(text, indent_str) + + def prompt_welcome(self): + prompt = "\n".join([ + "Welcome to TabPFN!", + "", + ]) + + print(self.indent(prompt)) + + def prompt_and_set_token(self, user_auth_handler: UserAuthenticationClient): + prompt = "\n".join([ + "Please choose one of the following options:", + "(1) Create a TabPFN account", + "(2) Login to your TabPFN account", + "", + "Please enter your choice: ", + ]) + + choice = input(self.indent(prompt)) + + if choice == "1": + # create account + email = input(self.indent("Please enter your email: ")) + + password_req = user_auth_handler.get_password_policy() + password_req_prompt = "\n".join([ + "", + "Password requirements (minimum):", + "\n".join([f". {req}" for req in password_req]), + "", + "Please enter your password: ", + ]) + + password = getpass.getpass(self.indent(password_req_prompt)) + password_confirm = getpass.getpass(self.indent("Please confirm your password: ")) + + user_auth_handler.set_token_by_registration(email, password, password_confirm) + + print(self.indent("Account created successfully!") + "\n") + + elif choice == "2": + # login to account + email = input(self.indent("Please enter your email: ")) + password = getpass.getpass(self.indent("Please enter your password: ")) + + user_auth_handler.set_token_by_login(email, password) + + print(self.indent("Login successful!") + "\n") + + else: + raise RuntimeError("Invalid choice") + + def prompt_terms_and_cond(self) -> bool: + t_and_c = "\n".join([ + "", + "By using TabPFN, you agree to the following terms and conditions:", + "", + "...", + "", + "Do you agree to the above terms and conditions? (y/n): ", + ]) + + choice = input(self.indent(t_and_c)) + + # retry for 3 attempts until valid choice is made + is_valid_choice = False + for _ in range(3): + if choice.lower() not in ["y", "n"]: + choice = input(self.indent("Invalid choice, please enter 'y' or 'n': ")) + else: + is_valid_choice = True + break + + if not is_valid_choice: + raise RuntimeError("Invalid choice") + + return choice.lower() == "y" + + def prompt_reusing_existing_token(self): + prompt = "\n".join([ + "Found existing access token, reusing it for authentication." + ]) + + print(self.indent(prompt)) diff --git a/tabpfn_client/remote_tabpfn_classifier.py b/tabpfn_client/remote_tabpfn_classifier.py new file mode 100644 index 0000000..c6024fb --- /dev/null +++ b/tabpfn_client/remote_tabpfn_classifier.py @@ -0,0 +1,70 @@ +from sklearn.utils.validation import check_is_fitted +from sklearn.base import BaseEstimator, ClassifierMixin + +from tabpfn_client.service_wrapper import InferenceClient + + +class RemoteTabPFNClassifier(BaseEstimator, ClassifierMixin): + + def __init__( + self, + model=None, + device="cpu", + base_path=".", + model_string="", + batch_size_inference=4, + fp16_inference=False, + inference_mode=True, + c=None, + N_ensemble_configurations=10, + preprocess_transforms=("none", "power_all"), + feature_shift_decoder=False, + normalize_with_test=False, + average_logits=False, + categorical_features=tuple(), + optimize_metric=None, + seed=None, + transformer_predict_kwargs_init=None, + multiclass_decoder="permutation", + + # dependency injection (for testing) + inference_handler=InferenceClient() + ): + # TODO: + # These configs are ignored at the moment -> all clients share the same (default) on-server TabPFNClassifier. + # In the future version, these configs will be used to create per-user TabPFNClassifier, + # allowing the user to setup the desired TabPFNClassifier on the server. + # config for tabpfn + self.model = model + self.device = device + self.base_path = base_path + self.model_string = model_string + self.batch_size_inference = batch_size_inference + self.fp16_inference = fp16_inference + self.inference_mode = inference_mode + self.c = c + self.N_ensemble_configurations = N_ensemble_configurations + self.preprocess_transforms = preprocess_transforms + self.feature_shift_decoder = feature_shift_decoder + self.normalize_with_test = normalize_with_test + self.average_logits = average_logits + self.categorical_features = categorical_features + self.optimize_metric = optimize_metric + self.seed = seed + self.transformer_predict_kwargs_init = transformer_predict_kwargs_init + self.multiclass_decoder = multiclass_decoder + + self.inference_handler = inference_handler + + def fit(self, X, y): + self.inference_handler.fit(X, y) + self.fitted_ = True + return self + + def predict(self, X): + check_is_fitted(self) + return self.inference_handler.predict(X) + + def predict_proba(self, X): + check_is_fitted(self) + return self.inference_handler.predict_proba(X) diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py new file mode 100644 index 0000000..f086135 --- /dev/null +++ b/tabpfn_client/service_wrapper.py @@ -0,0 +1,124 @@ +import logging + +from tabpfn_client.client import ServiceClient +from tabpfn_client.constants import CACHE_DIR + +logger = logging.getLogger(__name__) + + +class ServiceClientWrapper: + def __init__(self, service_client: ServiceClient): + self.service_client = service_client + + +class UserAuthenticationClient(ServiceClientWrapper): + """ + Singleton class for handling user authentication, including: + - user registration and login + - access token caching + + """ + CACHED_TOKEN_FILE = CACHE_DIR / "config" + + def is_accessible_connection(self) -> bool: + return self.service_client.try_connection() + + def set_token(self, access_token: str): + self.service_client.set_access_token(access_token) + self.CACHED_TOKEN_FILE.parent.mkdir(parents=True, exist_ok=True) + self.CACHED_TOKEN_FILE.write_text(access_token) + + def set_token_by_registration( + self, + email: str, + password: str, + password_confirm: str + ) -> None: + if password != password_confirm: + raise ValueError("Password and password_confirm must be the same.") + + is_created, message = self.service_client.register(email, password, password_confirm) + if not is_created: + raise RuntimeError(f"Failed to register user: {message}") + + # login after registration + self.set_token_by_login(email, password) + + def set_token_by_login(self, email: str, password: str) -> None: + access_token = self.service_client.login(email, password) + + if access_token is None: + raise RuntimeError("Failed to login, please check your email and password.") + + self.set_token(access_token) + + def try_reuse_existing_token(self) -> bool: + if self.service_client.access_token is None: + if not self.CACHED_TOKEN_FILE.exists(): + return False + + access_token = self.CACHED_TOKEN_FILE.read_text() + + else: + access_token = self.service_client.access_token + + is_valid = self.service_client.try_authenticate(access_token) + if not is_valid: + self._reset_token() + return False + + logger.debug(f"Reusing existing access token? {is_valid}") + self.set_token(access_token) + + return True + + def get_password_policy(self): + return self.service_client.get_password_policy() + + def reset_cache(self): + self._reset_token() + + def _reset_token(self): + self.service_client.reset_access_token() + self.CACHED_TOKEN_FILE.unlink() + + +class UserDataClient(ServiceClientWrapper): + """ + Singleton class for handling user data, including: + - query, or delete user account data + - query, download, or delete uploaded data + """ + pass + + +class InferenceClient(ServiceClientWrapper): + """ + Singleton class for handling inference, including: + - fitting + - prediction + """ + + def __init__(self, service_client = ServiceClient()): + super().__init__(service_client) + self.last_train_set_uid = None + + def fit(self, X, y) -> None: + if not self.service_client.is_initialized: + raise RuntimeError("Service client is not initialized.") + + self.last_train_set_uid = self.service_client.upload_train_set(X, y) + + def predict(self, X): + return self.service_client.predict( + train_set_uid=self.last_train_set_uid, + x_test=X + ) + + def predict_proba(self, X): + return self.service_client.predict_proba( + train_set_uid=self.last_train_set_uid, + x_test=X + ) + + diff --git a/tabpfn_client/tabpfn_classifier.py b/tabpfn_client/tabpfn_classifier.py index 6911252..3b7f613 100644 --- a/tabpfn_client/tabpfn_classifier.py +++ b/tabpfn_client/tabpfn_classifier.py @@ -1,18 +1,17 @@ import logging -from typing import Union from pathlib import Path -import getpass -import textwrap +import shutil from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.utils.validation import check_is_fitted -from tabpfn import TabPFNClassifier as TabPFNClassifierLocal -from tabpfn_client import tabpfn_service_client -from tabpfn_client.tabpfn_service_client import TabPFNServiceClient +from tabpfn import TabPFNClassifier as LocalTabPFNClassifier +from tabpfn_client.remote_tabpfn_classifier import RemoteTabPFNClassifier +from tabpfn_client.service_wrapper import UserAuthenticationClient, InferenceClient +from tabpfn_client.client import ServiceClient +from tabpfn_client.constants import CACHE_DIR +from tabpfn_client.prompt_agent import PromptAgent -CACHE_DIR = Path(__file__).parent.resolve() / ".tabpfn" -TOKEN_FILE = CACHE_DIR / "config" logger = logging.getLogger(__name__) @@ -20,50 +19,43 @@ class TabPFNConfig: is_initialized = None use_server = None + user_auth_handler = None + inference_handler = None g_tabpfn_config = TabPFNConfig() -def init( - use_server=True, -): +def init(use_server=True): global g_tabpfn_config + prompt_agent = PromptAgent() if use_server: + prompt_agent.prompt_welcome() + + service_client = ServiceClient() + user_auth_handler = UserAuthenticationClient(service_client) + # check connection to server - if not TabPFNServiceClient.try_connection(): - raise RuntimeError("TabPFN is unaccessible at the moment, please try again later.") - - token = None - - # check previously saved token file (if exists) - if Path.exists(TOKEN_FILE): - print(f"Using previously saved access token from {str(TOKEN_FILE)}") - token = Path(TOKEN_FILE).read_text() - if not TabPFNServiceClient.try_authenticate(token): - # invalidate token and delete token file - logger.debug("Previously saved access token is invalid, deleting token file") - token = None - Path.unlink(TOKEN_FILE, missing_ok=True) - - if token is None: - # prompt for terms and conditions - if not prompt_for_terms_and_cond(): - raise RuntimeError("You must agree to the terms and conditions to use TabPFN") + if not user_auth_handler.is_accessible_connection(): + raise RuntimeError("TabPFN is inaccessible at the moment, please try again later.") - # prompt for token - token = prompt_for_token() - if not TabPFNServiceClient.try_authenticate(token): - raise RuntimeError("Invalid access token") - print(f"API key is saved to {str(TOKEN_FILE)} for future use.") - TOKEN_FILE.parent.mkdir(parents=True, exist_ok=True) - TOKEN_FILE.write_text(token) + is_valid_token_set = user_auth_handler.try_reuse_existing_token() - assert token is not None + if is_valid_token_set: + prompt_agent.prompt_reusing_existing_token() + else: + prompt_agent.prompt_welcome() + + if not prompt_agent.prompt_terms_and_cond(): + raise RuntimeError("You must agree to the terms and conditions to use TabPFN") + + # prompt for login / register + prompt_agent.prompt_and_set_token(user_auth_handler) g_tabpfn_config.use_server = True - tabpfn_service_client.init(token) + g_tabpfn_config.user_auth_handler = user_auth_handler + g_tabpfn_config.inference_handler = InferenceClient(service_client) else: g_tabpfn_config.use_server = False @@ -76,8 +68,12 @@ def reset(): global g_tabpfn_config g_tabpfn_config = TabPFNConfig() - # remove token file if exists - Path.unlink(TOKEN_FILE, missing_ok=True) + # reset user auth handler + if g_tabpfn_config.use_server: + g_tabpfn_config.user_auth_handler.reset_cache() + + # remove cache dir + shutil.rmtree(CACHE_DIR, ignore_errors=True) class TabPFNClassifier(BaseEstimator, ClassifierMixin): @@ -153,9 +149,12 @@ def fit(self, X, y): } if g_tabpfn_config.use_server: - self.classifier_ = TabPFNServiceClient(**classifier_cfg) + self.classifier_ = RemoteTabPFNClassifier( + **classifier_cfg, + inference_handler=g_tabpfn_config.inference_handler + ) else: - self.classifier_ = TabPFNClassifierLocal(**classifier_cfg) + self.classifier_ = LocalTabPFNClassifier(**classifier_cfg) self.classifier_.fit(X, y) return self @@ -169,90 +168,3 @@ def predict_proba(self, X): return self.classifier_.predict_proba(X) -def indent(text: str) -> str: - indent_factor = 2 - indent_str = " " * indent_factor - return textwrap.indent(text, indent_str) - - -def prompt_for_token(): - prompt = "\n".join([ - "", - "Welcome to TabPFN!", - "", - "You are not logged in yet.", - "", - "Please choose one of the following options:", - "(1) Create a TabPFN account", - "(2) Login to your TabPFN account", - "", - "Please enter your choice: ", - ]) - - choice = input(indent(prompt)) - - if choice == "1": - # create account - email = input(indent("Please enter your email: ")) - - password_req = TabPFNServiceClient.get_password_policy()["requirements"] - password_req_prompt = "\n".join([ - "", - "Password requirements (minimum):", - "\n".join([f". {req}" for req in password_req]), - "", - "Please enter your password: ", - ]) - - password = getpass.getpass(indent(password_req_prompt)) - password_confirm = getpass.getpass(indent("Please confirm your password: ")) - - if password != password_confirm: - raise RuntimeError("Fail to register account, mismatched password") - - success, message = TabPFNServiceClient.register(email, password, password_confirm) - if not success: - raise RuntimeError(f"Fail to register account, {message}") - - elif choice == "2": - # login to account - email = input(indent("Please enter your email: ")) - password = getpass.getpass(indent("Please enter your password: ")) - - else: - raise RuntimeError("Invalid choice") - - token = TabPFNServiceClient.login(email, password) - if token is None: - raise RuntimeError(f"Fail to login with the given email and password") - - return token - - -def prompt_for_terms_and_cond(): - t_and_c = "\n".join([ - "", - "By using TabPFN, you agree to the following terms and conditions:", - "", - "...", - "", - "Do you agree to the above terms and conditions? (y/n): ", - ]) - - choice = input(indent(t_and_c)) - - # retry for 3 attempts until valid choice is made - is_valid_choice = False - for _ in range(3): - if choice.lower() not in ["y", "n"]: - choice = input(indent("Invalid choice, please enter 'y' or 'n': ")) - else: - is_valid_choice = True - break - - if not is_valid_choice: - raise RuntimeError("Invalid choice") - - return choice.lower() == "y" - - diff --git a/tabpfn_client/tabpfn_service_client.py b/tabpfn_client/tabpfn_service_client.py deleted file mode 100644 index d08e892..0000000 --- a/tabpfn_client/tabpfn_service_client.py +++ /dev/null @@ -1,215 +0,0 @@ -import httpx -import logging -from pathlib import Path -from omegaconf import OmegaConf - -import numpy as np -from sklearn.utils.validation import check_is_fitted -from sklearn.base import BaseEstimator, ClassifierMixin - -from tabpfn_client.tabpfn_common_utils import utils as common_utils - -g_access_token = None - -SERVER_CONFIG_FILE = Path(__file__).parent.resolve() / "server_config.yaml" -SERVER_CONFIG = OmegaConf.load(SERVER_CONFIG_FILE) - - -def init(access_token: str): - if access_token is None or access_token == "": - raise RuntimeError("access_token must be provided") - TabPFNServiceClient.access_token = access_token - - -class TabPFNServiceClient(BaseEstimator, ClassifierMixin): - - server_config = SERVER_CONFIG - server_endpoints = SERVER_CONFIG["endpoints"] - HTTPX_TIMEOUT_S = 15 # temporary workaround for slow computation on server side - - httpx_client = httpx.Client( - base_url=f"https://{server_config.host}:{server_config.port}", - timeout=HTTPX_TIMEOUT_S, - ) - access_token = None - - def __init__( - self, - model=None, - device="cpu", - base_path=".", - model_string="", - batch_size_inference=4, - fp16_inference=False, - inference_mode=True, - c=None, - N_ensemble_configurations=10, - preprocess_transforms=("none", "power_all"), - feature_shift_decoder=False, - normalize_with_test=False, - average_logits=False, - categorical_features=tuple(), - optimize_metric=None, - seed=None, - transformer_predict_kwargs_init=None, - multiclass_decoder="permutation", - ): - if self.access_token is None or self.access_token == "": - raise RuntimeError("tabpfn_service_client.init() must be called before instantiating TabPFNServiceClient") - - # TODO: - # These configs are ignored at the moment -> all clients share the same (default) on-server TabPFNClassifier. - # In the future version, these configs will be used to create per-user TabPFNClassifier, - # allowing the user to setup the desired TabPFNClassifier on the server. - # config for tabpfn - self.model = model - self.device = device - self.base_path = base_path - self.model_string = model_string - self.batch_size_inference = batch_size_inference - self.fp16_inference = fp16_inference - self.inference_mode = inference_mode - self.c = c - self.N_ensemble_configurations = N_ensemble_configurations - self.preprocess_transforms = preprocess_transforms - self.feature_shift_decoder = feature_shift_decoder - self.normalize_with_test = normalize_with_test - self.average_logits = average_logits - self.categorical_features = categorical_features - self.optimize_metric = optimize_metric - self.seed = seed - self.transformer_predict_kwargs_init = transformer_predict_kwargs_init - self.multiclass_decoder = multiclass_decoder - - def fit(self, X, y): - self.last_per_user_train_set_id_ = None - - # TODO: (in the coming version) - # create a per-client TabPFN on the server (referred by self.tabpfn_id) if it doesn't exist yet - self.tabpfn_id_ = None - - X = common_utils.serialize_to_csv_formatted_bytes(X) - y = common_utils.serialize_to_csv_formatted_bytes(y) - - response = self.httpx_client.post( - url=self.server_endpoints.upload_train_set.path, - headers={"Authorization": f"Bearer {self.access_token}"}, - files=common_utils.to_httpx_post_file_format([ - ("x_file", X), - ("y_file", y) - ]) - ) - - if response.status_code != 200: - logging.error(f"Fail to call upload_train_set(), response status: {response.status_code}") - # TODO: error probably doesn't have json() method, check in unit test - logging.error(f"Fail to call fit(), server response: {response.json()}") - raise RuntimeError(f"Fail to call fit(), server response: {response.json()}") - - self.last_per_user_train_set_id_ = response.json()["per_user_train_set_id"] - - return self - - def predict(self, X): - check_is_fitted(self) - - X = common_utils.serialize_to_csv_formatted_bytes(X) - - response = self.httpx_client.post( - url=self.server_endpoints.predict.path, - headers={"Authorization": f"Bearer {self.access_token}"}, - params={"per_user_train_set_id": self.last_per_user_train_set_id_}, - files=common_utils.to_httpx_post_file_format([ - ("x_file", X) - ]) - ) - - if response.status_code != 200: - logging.error(f"Fail to call predict(), response status: {response.status_code}") - raise RuntimeError(f"Fail to call predict(), server response: {response.json()}") - - return np.array(response.json()["y_pred"]) - - def predict_proba(self, X): - check_is_fitted(self) - - X = common_utils.serialize_to_csv_formatted_bytes(X) - - response = self.httpx_client.post( - url=self.server_endpoints.predict_proba.path, - headers={"Authorization": f"Bearer {self.access_token}"}, - params={"per_user_train_set_id": self.last_per_user_train_set_id_}, - files=common_utils.to_httpx_post_file_format([ - ("x_file", X) - ]) - ) - - if response.status_code != 200: - logging.error(f"Fail to call predict_proba(), response status: {response.status_code}") - raise RuntimeError(f"Fail to call predict_proba(), server response: {response.json()}") - - return np.array(response.json()["y_pred_proba"]) - - @classmethod - def try_connection(cls) -> bool: - found_valid_connection = False - try: - response = cls.httpx_client.get(cls.server_endpoints.root.path) - if response.status_code == 200: - found_valid_connection = True - - except httpx.ConnectError: - found_valid_connection = False - - return found_valid_connection - - @classmethod - def try_authenticate(cls, access_token) -> bool: - is_authenticated = False - response = cls.httpx_client.get( - cls.server_endpoints.protected_root.path, - headers={"Authorization": f"Bearer {access_token}"}, - ) - - if response.status_code == 200: - is_authenticated = True - - return is_authenticated - - @classmethod - def register(cls, email, password, password_confirm) -> (bool, str): - is_created = False - response = cls.httpx_client.post( - cls.server_endpoints.register.path, - params={"email": email, "password": password, "password_confirm": password_confirm} - ) - if response.status_code == 200: - is_created = True - message = response.json()["message"] - - else: - message = response.json()["detail"] - - return is_created, message - - @classmethod - def login(cls, email, password) -> str: - access_token = None - response = cls.httpx_client.post( - cls.server_endpoints.login.path, - data=common_utils.to_oauth_request_form(email, password) - ) - if response.status_code == 200: - access_token = response.json()["access_token"] - - return access_token - - @classmethod - def get_password_policy(cls) -> {}: - response = cls.httpx_client.get( - cls.server_endpoints.password_policy.path, - ) - if response.status_code == 200: - return response.json() - else: - raise RuntimeError(f"Fail to call get_password_policy(), server response: {response.json()}") diff --git a/tabpfn_client/tests/integration/test_tabpfn_classifier.py b/tabpfn_client/tests/integration/test_tabpfn_classifier.py index 4b3638c..f4bea90 100644 --- a/tabpfn_client/tests/integration/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/integration/test_tabpfn_classifier.py @@ -1,13 +1,13 @@ import unittest -from unittest.mock import patch from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split -from tabpfn import TabPFNClassifier as TabPFNClassifierLocal +from tabpfn import TabPFNClassifier as LocalTabPFNClassifier -from tabpfn_client.tabpfn_classifier import TabPFNClassifier -from tabpfn_client import tabpfn_classifier +from tabpfn_client import tabpfn_classifier, TabPFNClassifier from tabpfn_client.tests.mock_tabpfn_server import with_mock_server +from tabpfn_client.service_wrapper import UserAuthenticationClient +from tabpfn_client.client import ServiceClient class TestTabPFNClassifier(unittest.TestCase): @@ -17,20 +17,24 @@ def setUp(self): def tearDown(self): tabpfn_classifier.reset() + ServiceClient().delete_instance() def test_use_local_tabpfn_classifier(self): tabpfn_classifier.init(use_server=False) tabpfn = TabPFNClassifier(device="cpu") tabpfn.fit(self.X_train, self.y_train) - self.assertTrue(isinstance(tabpfn.classifier_, TabPFNClassifierLocal)) + self.assertTrue(isinstance(tabpfn.classifier_, LocalTabPFNClassifier)) pred = tabpfn.predict(self.X_test) self.assertEqual(pred.shape[0], self.X_test.shape[0]) @with_mock_server() - @patch("tabpfn_client.tabpfn_classifier.prompt_for_token", side_effect=["dummy_token"]) - @patch("tabpfn_client.tabpfn_classifier.prompt_for_terms_and_cond", side_effect=[True]) - def test_use_remote_tabpfn_classifier(self, mock_server, mock_prompt_for_token, mock_prompt_for_terms_and_cond): + def test_use_remote_tabpfn_classifier(self, mock_server): + # create dummy token file + token_file = UserAuthenticationClient.CACHED_TOKEN_FILE + token_file.parent.mkdir(parents=True, exist_ok=True) + token_file.write_text("dummy token") + # mock connection and authentication mock_server.router.get(mock_server.endpoints.root.path).respond(200) mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) @@ -40,13 +44,13 @@ def test_use_remote_tabpfn_classifier(self, mock_server, mock_prompt_for_token, # mock fitting mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond( - 200, json={"per_user_train_set_id": 5}) + 200, json={"train_set_uid": 5}) tabpfn.fit(self.X_train, self.y_train) # mock prediction mock_server.router.post(mock_server.endpoints.predict.path).respond( 200, - json={"y_pred": TabPFNClassifierLocal().fit(self.X_train, self.y_train).predict(self.X_test).tolist()} + json={"y_pred": LocalTabPFNClassifier().fit(self.X_train, self.y_train).predict(self.X_test).tolist()} ) pred = tabpfn.predict(self.X_test) self.assertEqual(pred.shape[0], self.X_test.shape[0]) diff --git a/tabpfn_client/tests/mock_tabpfn_server.py b/tabpfn_client/tests/mock_tabpfn_server.py index 8ec53d4..4aa15c0 100644 --- a/tabpfn_client/tests/mock_tabpfn_server.py +++ b/tabpfn_client/tests/mock_tabpfn_server.py @@ -1,17 +1,9 @@ import respx from contextlib import AbstractContextManager -from tabpfn_client.tabpfn_service_client import SERVER_CONFIG +from tabpfn_client.client import SERVER_CONFIG -# class MockTabPFNServer(respx.MockRouter): -# def __init__(self): -# self.server_config = SERVER_CONFIG -# self.endpoints = self.server_config.endpoints -# self.base_url = f"http://{self.server_config.host}:{self.server_config.port}" -# -# super().__init__(base_url=self.base_url, assert_all_called=True) - class MockTabPFNServer(AbstractContextManager): def __init__(self): self.server_config = SERVER_CONFIG @@ -27,6 +19,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.router.stop() + def with_mock_server(): def decorator(func): def wrapper(test_class, *args, **kwargs): diff --git a/tabpfn_client/tests/unit/test_tabpfn_service_client.py b/tabpfn_client/tests/unit/test_client.py similarity index 70% rename from tabpfn_client/tests/unit/test_tabpfn_service_client.py rename to tabpfn_client/tests/unit/test_client.py index 7ecbf94..f3fd810 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_service_client.py +++ b/tabpfn_client/tests/unit/test_client.py @@ -2,22 +2,20 @@ from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split -from sklearn.exceptions import NotFittedError import numpy as np -from tabpfn_client import tabpfn_service_client -from tabpfn_client.tabpfn_service_client import TabPFNServiceClient +from tabpfn_client.client import ServiceClient from tabpfn_client.tests.mock_tabpfn_server import with_mock_server -class TestTabPFNServiceClient(unittest.TestCase): +class TestServiceClient(unittest.TestCase): def setUp(self): # setup data X, y = load_breast_cancer(return_X_y=True) - self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size=0.33, random_state=42) + self.X_train, self.X_test, self.y_train, self.y_test = \ + train_test_split(X, y, test_size=0.33, random_state=42) - tabpfn_service_client.init("dummy_token") - self.client = TabPFNServiceClient() + self.client = ServiceClient() @with_mock_server() def test_try_connection(self, mock_server): @@ -51,25 +49,18 @@ def test_valid_auth_token(self, mock_server): @with_mock_server() def test_predict_with_valid_train_set_and_test_set(self, mock_server): - dummy_json = {"per_user_train_set_id": 5} + dummy_json = {"train_set_uid": 5} mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond( 200, json=dummy_json) - self.client.fit(self.X_train, self.y_train) + self.client.upload_train_set(self.X_train, self.y_train) dummy_result = {"y_pred": [1, 2, 3]} mock_server.router.post(mock_server.endpoints.predict.path).respond( 200, json=dummy_result) - pred = self.client.predict(self.X_test) + pred = self.client.predict( + train_set_uid=dummy_json["train_set_uid"], + x_test=self.X_test + ) self.assertTrue(np.array_equal(pred, dummy_result["y_pred"])) - - def test_predict_with_conflicting_test_set(self): - # TODO: implement this - pass - - def test_call_predict_without_calling_fit_before(self): - self.assertRaises(NotFittedError, self.client.predict, self.X_test) - - def test_call_predict_proba_without_calling_fit_before(self): - self.assertRaises(NotFittedError, self.client.predict_proba, self.X_test) diff --git a/tabpfn_client/tests/unit/test_remote_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_remote_tabpfn_classifier.py new file mode 100644 index 0000000..5c76b55 --- /dev/null +++ b/tabpfn_client/tests/unit/test_remote_tabpfn_classifier.py @@ -0,0 +1,64 @@ +import unittest +from unittest.mock import MagicMock, patch +import shutil + +from sklearn.datasets import load_breast_cancer +from sklearn.model_selection import train_test_split +from sklearn.exceptions import NotFittedError + +from tabpfn_client.remote_tabpfn_classifier import RemoteTabPFNClassifier +from tabpfn_client.client import ServiceClient +from tabpfn_client.service_wrapper import InferenceClient +from tabpfn_client.constants import CACHE_DIR + + +class TestRemoteTabPFNClassifier(unittest.TestCase): + + def setUp(self): + self.dummy_token = "dummy_token" + X, y = load_breast_cancer(return_X_y=True) + self.X_train, self.X_test, self.y_train, self.y_test = \ + train_test_split(X, y, test_size=0.33) + + # mock service client + self.mock_client = MagicMock(spec=ServiceClient) + self.mock_client.is_initialized.return_value = True + inference_handler = InferenceClient(service_client=self.mock_client) + + self.remote_tabpfn = RemoteTabPFNClassifier(inference_handler=inference_handler) + + def tearDown(self): + patch.stopall() + shutil.rmtree(CACHE_DIR, ignore_errors=True) + + def test_fit_and_predict_with_valid_datasets(self): + # mock responses + self.mock_client.upload_train_set.return_value = "dummy_train_set_uid" + + mock_predict_response = [1, 1, 0] + self.mock_client.predict.return_value = mock_predict_response + + self.remote_tabpfn.fit(self.X_train, self.y_train) + y_pred = self.remote_tabpfn.predict(self.X_test) + + self.assertEqual(mock_predict_response, y_pred) + self.mock_client.upload_train_set.called_once_with(self.X_train, self.y_train) + self.mock_client.predict.called_once_with(self.X_test) + + def test_call_predict_without_calling_fit_before(self): + self.assertRaises( + NotFittedError, + self.remote_tabpfn.predict, + self.X_test + ) + + def test_call_predict_proba_without_calling_fit_before(self): + self.assertRaises( + NotFittedError, + self.remote_tabpfn.predict_proba, + self.X_test + ) + + def test_predict_with_conflicting_test_set(self): + # TODO: implement this + pass diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py index d72bb1f..b4c57ef 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py @@ -1,15 +1,18 @@ import unittest from unittest.mock import patch -from pathlib import Path +import shutil from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split +from tabpfn import TabPFNClassifier as LocalTabPFNClassifier from tabpfn_client import tabpfn_classifier from tabpfn_client.tabpfn_classifier import TabPFNClassifier -from tabpfn_client.tabpfn_service_client import TabPFNServiceClient -from tabpfn import TabPFNClassifier as TabPFNClassifierLocal +from tabpfn_client.remote_tabpfn_classifier import RemoteTabPFNClassifier +from tabpfn_client.service_wrapper import UserAuthenticationClient +from tabpfn_client.client import ServiceClient from tabpfn_client.tests.mock_tabpfn_server import with_mock_server +from tabpfn_client.constants import CACHE_DIR class TestTabPFNClassifierInit(unittest.TestCase): @@ -19,52 +22,42 @@ class TestTabPFNClassifierInit(unittest.TestCase): def setUp(self): # set up dummy data X, y = load_breast_cancer(return_X_y=True) - self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size=0.33) + self.X_train, self.X_test, self.y_train, self.y_test = \ + train_test_split(X, y, test_size=0.33) def tearDown(self): - # remove token file if exists - Path.unlink(tabpfn_classifier.TOKEN_FILE, missing_ok=True) + tabpfn_classifier.reset() + + # remove singleton instance of ServiceClient + ServiceClient().delete_instance() + + # remove cache dir + shutil.rmtree(CACHE_DIR, ignore_errors=True) def test_init_local_classifier(self): tabpfn_classifier.init(use_server=False) tabpfn = TabPFNClassifier().fit(self.X_train, self.y_train) - self.assertTrue(isinstance(tabpfn.classifier_, TabPFNClassifierLocal)) + self.assertTrue(isinstance(tabpfn.classifier_, LocalTabPFNClassifier)) @with_mock_server() - @patch("tabpfn_client.tabpfn_classifier.prompt_for_token", side_effect=[dummy_token]) - @patch("tabpfn_client.tabpfn_classifier.prompt_for_terms_and_cond", side_effect=[True]) - def test_init_remote_classifier(self, mock_server, mock_prompt_for_token, mock_prompt_for_terms_and_cond): - # mock connection, authentication, and fitting + @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") + @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", + return_value=True) + def test_init_remote_classifier(self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token): + mock_prompt_and_set_token.side_effect = \ + lambda user_auth_handler: user_auth_handler.set_token(self.dummy_token) + + # mock server connection mock_server.router.get(mock_server.endpoints.root.path).respond(200) - mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond( - 200, json={"per_user_train_set_id": 5} + 200, json={"train_set_uid": 5} ) tabpfn_classifier.init(use_server=True) tabpfn = TabPFNClassifier().fit(self.X_train, self.y_train) - self.assertTrue(isinstance(tabpfn.classifier_, TabPFNServiceClient)) - - # check if access token is saved - token_file = tabpfn_classifier.TOKEN_FILE - self.assertTrue(token_file.exists()) - self.assertEqual(token_file.read_text(), self.dummy_token) - - @with_mock_server() - @patch("tabpfn_client.tabpfn_classifier.prompt_for_token", side_effect=[dummy_token]) - @patch("tabpfn_client.tabpfn_classifier.prompt_for_terms_and_cond", side_effect=[True]) - def test_init_remote_classifier_with_invalid_token( - self, mock_server, mock_prompt_for_token, mock_prompt_for_terms_and_cond - ): - # mock connection and invalid authentication - mock_server.router.get(mock_server.endpoints.root.path).respond(200) - mock_server.router.get(mock_server.endpoints.protected_root.path).respond(401) - - self.assertRaises(RuntimeError, tabpfn_classifier.init, use_server=True) - - # check if access token is not saved - token_file = tabpfn_classifier.TOKEN_FILE - self.assertFalse(token_file.exists()) + self.assertTrue(isinstance(tabpfn.classifier_, RemoteTabPFNClassifier)) + self.assertTrue(mock_prompt_and_set_token.called) + self.assertTrue(mock_prompt_for_terms_and_cond.called) @with_mock_server() def test_reuse_saved_access_token(self, mock_server): @@ -73,26 +66,34 @@ def test_reuse_saved_access_token(self, mock_server): mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) # create dummy token file - token_file = tabpfn_classifier.TOKEN_FILE + token_file = UserAuthenticationClient.CACHED_TOKEN_FILE + token_file.parent.mkdir(parents=True, exist_ok=True) token_file.write_text(self.dummy_token) # init is called without error tabpfn_classifier.init(use_server=True) + # check if access token still exists + self.assertTrue(UserAuthenticationClient.CACHED_TOKEN_FILE.exists()) + @with_mock_server() - @patch("tabpfn_client.tabpfn_classifier.prompt_for_token", side_effect=[RuntimeError("Invalid token")]) - @patch("tabpfn_client.tabpfn_classifier.prompt_for_terms_and_cond", side_effect=[True]) - def test_invalid_saved_access_token(self, mock_server, mock_prompt_for_token, mock_prompt_for_terms_and_cond): + @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") + @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", + return_value=True) + def test_invalid_saved_access_token(self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token): + mock_prompt_and_set_token.side_effect = [RuntimeError] + # mock connection and invalid authentication mock_server.router.get(mock_server.endpoints.root.path).respond(200) mock_server.router.get(mock_server.endpoints.protected_root.path).respond(401) # create dummy token file - token_file = tabpfn_classifier.TOKEN_FILE + token_file = UserAuthenticationClient.CACHED_TOKEN_FILE + token_file.parent.mkdir(parents=True, exist_ok=True) token_file.write_text("invalid_token") self.assertRaises(RuntimeError, tabpfn_classifier.init, use_server=True) - self.assertTrue(mock_prompt_for_token.called) + self.assertTrue(mock_prompt_and_set_token.called) def test_reset_on_local_classifier(self): tabpfn_classifier.init(use_server=False) @@ -100,29 +101,32 @@ def test_reset_on_local_classifier(self): self.assertFalse(tabpfn_classifier.g_tabpfn_config.is_initialized) @with_mock_server() - @patch("tabpfn_client.tabpfn_classifier.prompt_for_token", side_effect=[dummy_token]) - @patch("tabpfn_client.tabpfn_classifier.prompt_for_terms_and_cond", side_effect=[True]) - def test_reset_on_remote_classifier(self, mock_server, mock_prompt_for_token, mock_prompt_for_terms_and_cond): + def test_reset_on_remote_classifier(self, mock_server): + # create dummy token file + token_file = UserAuthenticationClient.CACHED_TOKEN_FILE + token_file.parent.mkdir(parents=True, exist_ok=True) + token_file.write_text(self.dummy_token) + # init classifier as usual mock_server.router.get(mock_server.endpoints.root.path).respond(200) mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) tabpfn_classifier.init(use_server=True) # check if access token is saved - token_file = tabpfn_classifier.TOKEN_FILE - self.assertTrue(token_file.exists()) + self.assertTrue(UserAuthenticationClient.CACHED_TOKEN_FILE.exists()) # reset tabpfn_classifier.reset() # check if access token is deleted - self.assertFalse(token_file.exists()) + self.assertFalse(UserAuthenticationClient.CACHED_TOKEN_FILE.exists()) # check if config is reset self.assertFalse(tabpfn_classifier.g_tabpfn_config.is_initialized) @with_mock_server() - @patch("tabpfn_client.tabpfn_classifier.prompt_for_terms_and_cond", side_effect=[False]) + @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", + return_value=False) def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_cond): # mock connection mock_server.router.get(mock_server.endpoints.root.path).respond(200) From f1f45f828247287056a62ce16d50be3bb1fa67ab Mon Sep 17 00:00:00 2001 From: "Liam, SB Hoo" Date: Fri, 20 Oct 2023 11:18:03 +0200 Subject: [PATCH 2/7] Extend client for GDPR-related APIs --- tabpfn_client/client.py | 116 +++++++++++++++++++++++++++-- tabpfn_client/server_config.yaml | 53 +++++++++++-- tabpfn_client/service_wrapper.py | 19 ++++- tabpfn_client/tabpfn_classifier.py | 2 - 4 files changed, 171 insertions(+), 19 deletions(-) diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index eb80b31..7ec4885 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -1,6 +1,7 @@ from pathlib import Path import httpx import logging +import copy import numpy as np from omegaconf import OmegaConf @@ -24,9 +25,10 @@ class ServiceClient: def __init__(self): self.server_config = SERVER_CONFIG self.server_endpoints = SERVER_CONFIG["endpoints"] + self.base_url = f"{self.server_config.protocol}://{self.server_config.host}:{self.server_config.port}" self.httpx_timeout_s = 30 # temporary workaround for slow computation on server side self.httpx_client = httpx.Client( - base_url=f"https://{self.server_config.host}:{self.server_config.port}", + base_url=self.base_url, timeout=self.httpx_timeout_s ) @@ -36,11 +38,15 @@ def __init__(self): def access_token(self): return self._access_token - def set_access_token(self, access_token: str): + def authorize(self, access_token: str): self._access_token = access_token + self.httpx_client.headers.update( + {"Authorization": f"Bearer {self.access_token}"} + ) - def reset_access_token(self): + def reset_authorization(self): self._access_token = None + self.httpx_client.headers.pop("Authorization", None) @property def is_initialized(self): @@ -69,7 +75,6 @@ def upload_train_set(self, X, y) -> str: response = self.httpx_client.post( url=self.server_endpoints.upload_train_set.path, - headers={"Authorization": f"Bearer {self.access_token}"}, files=common_utils.to_httpx_post_file_format([ ("x_file", "x_train_filename", X), ("y_file", "y_train_filename", y) @@ -104,7 +109,6 @@ def predict(self, train_set_uid: str, x_test): response = self.httpx_client.post( url=self.server_endpoints.predict.path, - headers={"Authorization": f"Bearer {self.access_token}"}, params={"train_set_uid": train_set_uid}, files=common_utils.to_httpx_post_file_format([ ("x_file", "x_test_filename", x_test) @@ -136,7 +140,6 @@ def predict_proba(self, train_set_uid: str, x_test): response = self.httpx_client.post( url=self.server_endpoints.predict_proba.path, - headers={"Authorization": f"Bearer {self.access_token}"}, params={"train_set_uid": train_set_uid}, files=common_utils.to_httpx_post_file_format([ ("x_file", "x_test_filename", x_test) @@ -260,3 +263,104 @@ def get_password_policy(self) -> {}: raise RuntimeError(f"Fail to call get_password_policy(), server response: {response.json()}") return response.json()["requirements"] + + def get_data_summary(self) -> {}: + """ + Get the data summary of the user from the server. + + Returns + ------- + data_summary : {} + The data summary returned from the server. + """ + response = self.httpx_client.get( + self.server_endpoints.get_data_summary.path, + ) + if response.status_code != 200: + logger.error(f"Fail to call get_data_summary(), response status: {response.status_code}") + raise RuntimeError(f"Fail to call get_data_summary(), server response: {response.json()}") + + return response.json() + + def download_all_data(self, save_dir: Path) -> Path | None: + """ + Download all data uploaded by the user from the server. + + Returns + ------- + save_path : Path | None + The path to the downloaded file. Return None if download fails. + + """ + + save_path = None + + full_url = self.base_url + self.server_endpoints.download_all_data.path + with httpx.stream("GET", full_url, headers={"Authorization": f"Bearer {self.access_token}"}) as response: + if response.status_code != 200: + logger.error(f"Fail to call download_all_data(), response status: {response.status_code}") + raise RuntimeError(f"Fail to call download_all_data(), server response: {response.json()}") + + filename = response.headers["Content-Disposition"].split("filename=")[1] + save_path = Path(save_dir) / filename + with open(save_path, "wb") as f: + for data in response.iter_bytes(): + f.write(data) + + return save_path + + def delete_dataset(self, dataset_uid: str) -> [str]: + """ + Delete the dataset with the provided UID from the server. + Note that deleting a train set with lead to deleting all associated test sets. + + Parameters + ---------- + dataset_uid : str + The UID of the dataset to be deleted. + + Returns + ------- + deleted_dataset_uids : [str] + The list of deleted dataset UIDs. + + """ + response = self.httpx_client.delete( + self.server_endpoints.delete_dataset.path, + params={"dataset_uid": dataset_uid} + ) + + if response.status_code != 200: + logger.error(f"Fail to call delete_dataset(), response status: {response.status_code}") + raise RuntimeError(f"Fail to call delete_dataset(), server response: {response.json()}") + + return response.json()["deleted_dataset_uids"] + + def delete_all_datasets(self) -> [str]: + """ + Delete all datasets uploaded by the user from the server. + + Returns + ------- + deleted_dataset_uids : [str] + The list of deleted dataset UIDs. + """ + response = self.httpx_client.delete( + self.server_endpoints.delete_all_datasets.path, + ) + + if response.status_code != 200: + logger.error(f"Fail to call delete_all_datasets(), response status: {response.status_code}") + raise RuntimeError(f"Fail to call delete_all_datasets(), server response: {response.json()}") + + return response.json()["deleted_dataset_uids"] + + def delete_user_account(self, confirm_pass: str) -> None: + response = self.httpx_client.delete( + self.server_endpoints.delete_user_account.path, + params={"confirm_password": confirm_pass} + ) + + if response.status_code != 200: + logger.error(f"Fail to call delete_user_account(), response status: {response.status_code}") + raise RuntimeError(f"Fail to call delete_user_account(), server response: {response.json()}") diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml index a024e76..982acdb 100644 --- a/tabpfn_client/server_config.yaml +++ b/tabpfn_client/server_config.yaml @@ -1,10 +1,16 @@ -host: "tabpfn.priorlabs.ai" +## testing +#protocol: "http" #host: "0.0.0.0" +#port: "8000" + +# production +protocol: "https" +host: "tabpfn.priorlabs.ai" port: "443" endpoints: root: path: "/" - methods: ["GET"] + methods: [ "GET" ] description: "Root endpoint" password_policy: @@ -14,30 +20,61 @@ endpoints: register: path: "/auth/register/" - methods: ["POST"] + methods: [ "POST" ] description: "User registration" login: path: "/auth/login/" - methods: ["POST"] + methods: [ "POST" ] description: "User login" protected_root: path: "/protected/" - methods: ["GET"] + methods: [ "GET" ] description: "Protected root" + upload_test_set: + path: "/upload/test_set/" + methods: [ "POST" ] + description: "Upload test set (for testing purpose)" + upload_train_set: path: "/upload/train_set/" - methods: ["POST"] + methods: [ "POST" ] description: "Upload train set" predict: path: "/predict/" - methods: ["POST"] + methods: [ "POST" ] description: "Predict" predict_proba: path: "/predict_proba/" - methods: ["POST"] + methods: [ "POST" ] description: "Predict probability" + + get_data_summary: + path: "/get_data_summary/" + methods: [ "GET" ] + description: "Get a summary of all uploaded data" + + download_all_data: + path: "/download_all_data/" + methods: [ "GET" ] + description: "Download all uploaded data" + + delete_dataset: + path: "/delete_dataset/" + methods: [ "DELETE" ] + description: "Delete dataset (can be train set or test set)" + + delete_all_datasets: + path: "/delete_all_datasets/" + methods: [ "DELETE" ] + description: "Delete all datasets (both train set and test set)" + + delete_user_account: + path: "/delete_user_account/" + methods: [ "DELETE" ] + description: "Delete user account" + diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index f086135..2bd5211 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -24,7 +24,7 @@ def is_accessible_connection(self) -> bool: return self.service_client.try_connection() def set_token(self, access_token: str): - self.service_client.set_access_token(access_token) + self.service_client.authorize(access_token) self.CACHED_TOKEN_FILE.parent.mkdir(parents=True, exist_ok=True) self.CACHED_TOKEN_FILE.write_text(access_token) @@ -79,7 +79,7 @@ def reset_cache(self): self._reset_token() def _reset_token(self): - self.service_client.reset_access_token() + self.service_client.reset_authorization() self.CACHED_TOKEN_FILE.unlink() @@ -89,7 +89,20 @@ class UserDataClient(ServiceClientWrapper): - query, or delete user account data - query, download, or delete uploaded data """ - pass + def get_data_summary(self): + pass + + def download_all_data(self): + pass + + def delete_all_dataset(self): + pass + + def delete_dataset(self): + pass + + def delete_user_account(self): + pass class InferenceClient(ServiceClientWrapper): diff --git a/tabpfn_client/tabpfn_classifier.py b/tabpfn_client/tabpfn_classifier.py index 3b7f613..54ea506 100644 --- a/tabpfn_client/tabpfn_classifier.py +++ b/tabpfn_client/tabpfn_classifier.py @@ -45,8 +45,6 @@ def init(use_server=True): if is_valid_token_set: prompt_agent.prompt_reusing_existing_token() else: - prompt_agent.prompt_welcome() - if not prompt_agent.prompt_terms_and_cond(): raise RuntimeError("You must agree to the terms and conditions to use TabPFN") From f964a43b26c650c31f20840d5be2fa411f92b080 Mon Sep 17 00:00:00 2001 From: "Liam, SB Hoo" Date: Fri, 20 Oct 2023 12:40:22 +0200 Subject: [PATCH 3/7] Add APIs to UserDataClient --- tabpfn_client/prompt_agent.py | 49 ++++++++++++++--------- tabpfn_client/service_wrapper.py | 62 +++++++++++++++++++++++++----- tabpfn_client/tabpfn_classifier.py | 9 ++--- 3 files changed, 88 insertions(+), 32 deletions(-) diff --git a/tabpfn_client/prompt_agent.py b/tabpfn_client/prompt_agent.py index 0343cf8..a73563e 100644 --- a/tabpfn_client/prompt_agent.py +++ b/tabpfn_client/prompt_agent.py @@ -1,8 +1,6 @@ import textwrap import getpass -from tabpfn_client.service_wrapper import UserAuthenticationClient - class PromptAgent: @staticmethod @@ -11,15 +9,17 @@ def indent(text: str): indent_str = " " * indent_factor return textwrap.indent(text, indent_str) - def prompt_welcome(self): + @classmethod + def prompt_welcome(cls): prompt = "\n".join([ "Welcome to TabPFN!", "", ]) - print(self.indent(prompt)) + print(cls.indent(prompt)) - def prompt_and_set_token(self, user_auth_handler: UserAuthenticationClient): + @classmethod + def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): prompt = "\n".join([ "Please choose one of the following options:", "(1) Create a TabPFN account", @@ -28,11 +28,11 @@ def prompt_and_set_token(self, user_auth_handler: UserAuthenticationClient): "Please enter your choice: ", ]) - choice = input(self.indent(prompt)) + choice = input(cls.indent(prompt)) if choice == "1": # create account - email = input(self.indent("Please enter your email: ")) + email = input(cls.indent("Please enter your email: ")) password_req = user_auth_handler.get_password_policy() password_req_prompt = "\n".join([ @@ -43,26 +43,27 @@ def prompt_and_set_token(self, user_auth_handler: UserAuthenticationClient): "Please enter your password: ", ]) - password = getpass.getpass(self.indent(password_req_prompt)) - password_confirm = getpass.getpass(self.indent("Please confirm your password: ")) + password = getpass.getpass(cls.indent(password_req_prompt)) + password_confirm = getpass.getpass(cls.indent("Please confirm your password: ")) user_auth_handler.set_token_by_registration(email, password, password_confirm) - print(self.indent("Account created successfully!") + "\n") + print(cls.indent("Account created successfully!") + "\n") elif choice == "2": # login to account - email = input(self.indent("Please enter your email: ")) - password = getpass.getpass(self.indent("Please enter your password: ")) + email = input(cls.indent("Please enter your email: ")) + password = getpass.getpass(cls.indent("Please enter your password: ")) user_auth_handler.set_token_by_login(email, password) - print(self.indent("Login successful!") + "\n") + print(cls.indent("Login successful!") + "\n") else: raise RuntimeError("Invalid choice") - def prompt_terms_and_cond(self) -> bool: + @classmethod + def prompt_terms_and_cond(cls) -> bool: t_and_c = "\n".join([ "", "By using TabPFN, you agree to the following terms and conditions:", @@ -72,13 +73,13 @@ def prompt_terms_and_cond(self) -> bool: "Do you agree to the above terms and conditions? (y/n): ", ]) - choice = input(self.indent(t_and_c)) + choice = input(cls.indent(t_and_c)) # retry for 3 attempts until valid choice is made is_valid_choice = False for _ in range(3): if choice.lower() not in ["y", "n"]: - choice = input(self.indent("Invalid choice, please enter 'y' or 'n': ")) + choice = input(cls.indent("Invalid choice, please enter 'y' or 'n': ")) else: is_valid_choice = True break @@ -88,9 +89,21 @@ def prompt_terms_and_cond(self) -> bool: return choice.lower() == "y" - def prompt_reusing_existing_token(self): + @classmethod + def prompt_reusing_existing_token(cls): prompt = "\n".join([ "Found existing access token, reusing it for authentication." ]) - print(self.indent(prompt)) + print(cls.indent(prompt)) + + @classmethod + def prompt_confirm_password_for_user_account_deletion(cls) -> str: + print(cls.indent("You are about to delete your account.")) + confirm_pass = getpass.getpass(cls.indent("Please confirm by entering your password: ")) + + return confirm_pass + + @classmethod + def prompt_account_deleted(cls): + print(cls.indent("Your account has been deleted.")) diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index 2bd5211..f7777ff 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -1,7 +1,9 @@ import logging +from pathlib import Path from tabpfn_client.client import ServiceClient from tabpfn_client.constants import CACHE_DIR +from tabpfn_client.prompt_agent import PromptAgent logger = logging.getLogger(__name__) @@ -89,20 +91,62 @@ class UserDataClient(ServiceClientWrapper): - query, or delete user account data - query, download, or delete uploaded data """ - def get_data_summary(self): - pass + def __init__(self, service_client = ServiceClient()): + super().__init__(service_client) + + def get_data_summary(self) -> {}: + try: + summary = self.service_client.get_data_summary() + except RuntimeError as e: + logging.error(f"Failed to get data summary: {e}") + raise e + + return summary + + def download_all_data(self, save_dir: Path = Path(".")) -> Path: + try: + saved_path = self.service_client.download_all_data(save_dir) + except RuntimeError as e: + logging.error(f"Failed to download data: {e}") + raise e + + if saved_path is None: + raise RuntimeError("Failed to download data.") + + logging.info(f"Data saved to {saved_path}") + return saved_path + + def delete_dataset(self, dataset_uid: str) -> [str]: + try: + deleted_datasets = self.service_client.delete_dataset(dataset_uid) + except RuntimeError as e: + logging.error(f"Failed to delete dataset: {e}") + raise e + + logging.info(f"Deleted datasets: {deleted_datasets}") + + return deleted_datasets - def download_all_data(self): - pass + def delete_all_dataset(self) -> [str]: + try: + deleted_datasets = self.service_client.delete_all_datasets() + except RuntimeError as e: + logging.error(f"Failed to delete all datasets: {e}") + raise e - def delete_all_dataset(self): - pass + logging.info(f"Deleted datasets: {deleted_datasets}") - def delete_dataset(self): - pass + return deleted_datasets def delete_user_account(self): - pass + confirm_pass = PromptAgent.prompt_confirm_password_for_user_account_deletion() + try: + self.service_client.delete_user_account(confirm_pass) + except RuntimeError as e: + logging.error(f"Failed to delete user account: {e}") + raise e + + PromptAgent.prompt_account_deleted() class InferenceClient(ServiceClientWrapper): diff --git a/tabpfn_client/tabpfn_classifier.py b/tabpfn_client/tabpfn_classifier.py index 54ea506..459b189 100644 --- a/tabpfn_client/tabpfn_classifier.py +++ b/tabpfn_client/tabpfn_classifier.py @@ -28,10 +28,9 @@ class TabPFNConfig: def init(use_server=True): global g_tabpfn_config - prompt_agent = PromptAgent() if use_server: - prompt_agent.prompt_welcome() + PromptAgent.prompt_welcome() service_client = ServiceClient() user_auth_handler = UserAuthenticationClient(service_client) @@ -43,13 +42,13 @@ def init(use_server=True): is_valid_token_set = user_auth_handler.try_reuse_existing_token() if is_valid_token_set: - prompt_agent.prompt_reusing_existing_token() + PromptAgent.prompt_reusing_existing_token() else: - if not prompt_agent.prompt_terms_and_cond(): + if not PromptAgent.prompt_terms_and_cond(): raise RuntimeError("You must agree to the terms and conditions to use TabPFN") # prompt for login / register - prompt_agent.prompt_and_set_token(user_auth_handler) + PromptAgent.prompt_and_set_token(user_auth_handler) g_tabpfn_config.use_server = True g_tabpfn_config.user_auth_handler = user_auth_handler From b456bd3d52ebaecd25f31eac6289f7bf847fb676 Mon Sep 17 00:00:00 2001 From: "Liam, SB Hoo" Date: Fri, 20 Oct 2023 12:41:18 +0200 Subject: [PATCH 4/7] Update __init__.py --- tabpfn_client/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tabpfn_client/__init__.py b/tabpfn_client/__init__.py index 4552571..9d65ea2 100644 --- a/tabpfn_client/__init__.py +++ b/tabpfn_client/__init__.py @@ -1,2 +1,3 @@ from tabpfn_client import tabpfn_classifier from tabpfn_client.tabpfn_classifier import TabPFNClassifier +from tabpfn_client.service_wrapper import UserDataClient From 379f69e16b0c32b9c90cc67f4e6f99923044dcb0 Mon Sep 17 00:00:00 2001 From: "Liam, SB Hoo" Date: Wed, 25 Oct 2023 23:03:05 +0200 Subject: [PATCH 5/7] Add test and minor fix --- tabpfn_client/client.py | 18 +- tabpfn_client/service_wrapper.py | 10 +- tabpfn_client/tests/mock_tabpfn_server.py | 2 +- .../tests/unit/test_service_wrapper.py | 274 ++++++++++++++++++ 4 files changed, 289 insertions(+), 15 deletions(-) create mode 100644 tabpfn_client/tests/unit/test_service_wrapper.py diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index 7ec4885..b12c3c9 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -83,7 +83,7 @@ def upload_train_set(self, X, y) -> str: if response.status_code != 200: logger.error(f"Fail to call upload_train_set(), response status: {response.status_code}") - raise RuntimeError(f"Fail to call upload_train_set(), server response: {response.json()}") + raise RuntimeError(f"Fail to call upload_train_set()") train_set_uid = response.json()["train_set_uid"] return train_set_uid @@ -117,7 +117,7 @@ def predict(self, train_set_uid: str, x_test): if response.status_code != 200: logger.error(f"Fail to call predict(), response status: {response.status_code}") - raise RuntimeError(f"Fail to call predict(), server response: {response.json()}") + raise RuntimeError(f"Fail to call predict()") return np.array(response.json()["y_pred"]) @@ -148,7 +148,7 @@ def predict_proba(self, train_set_uid: str, x_test): if response.status_code != 200: logger.error(f"Fail to call predict_proba(), response status: {response.status_code}") - raise RuntimeError(f"Fail to call predict_proba(), server response: {response.json()}") + raise RuntimeError(f"Fail to call predict_proba()") return np.array(response.json()["y_pred_proba"]) @@ -260,7 +260,7 @@ def get_password_policy(self) -> {}: ) if response.status_code != 200: logger.error(f"Fail to call get_password_policy(), response status: {response.status_code}") - raise RuntimeError(f"Fail to call get_password_policy(), server response: {response.json()}") + raise RuntimeError(f"Fail to call get_password_policy()") return response.json()["requirements"] @@ -278,7 +278,7 @@ def get_data_summary(self) -> {}: ) if response.status_code != 200: logger.error(f"Fail to call get_data_summary(), response status: {response.status_code}") - raise RuntimeError(f"Fail to call get_data_summary(), server response: {response.json()}") + raise RuntimeError(f"Fail to call get_data_summary()") return response.json() @@ -299,7 +299,7 @@ def download_all_data(self, save_dir: Path) -> Path | None: with httpx.stream("GET", full_url, headers={"Authorization": f"Bearer {self.access_token}"}) as response: if response.status_code != 200: logger.error(f"Fail to call download_all_data(), response status: {response.status_code}") - raise RuntimeError(f"Fail to call download_all_data(), server response: {response.json()}") + raise RuntimeError(f"Fail to call download_all_data()") filename = response.headers["Content-Disposition"].split("filename=")[1] save_path = Path(save_dir) / filename @@ -332,7 +332,7 @@ def delete_dataset(self, dataset_uid: str) -> [str]: if response.status_code != 200: logger.error(f"Fail to call delete_dataset(), response status: {response.status_code}") - raise RuntimeError(f"Fail to call delete_dataset(), server response: {response.json()}") + raise RuntimeError(f"Fail to call delete_dataset()") return response.json()["deleted_dataset_uids"] @@ -351,7 +351,7 @@ def delete_all_datasets(self) -> [str]: if response.status_code != 200: logger.error(f"Fail to call delete_all_datasets(), response status: {response.status_code}") - raise RuntimeError(f"Fail to call delete_all_datasets(), server response: {response.json()}") + raise RuntimeError(f"Fail to call delete_all_datasets()") return response.json()["deleted_dataset_uids"] @@ -363,4 +363,4 @@ def delete_user_account(self, confirm_pass: str) -> None: if response.status_code != 200: logger.error(f"Fail to call delete_user_account(), response status: {response.status_code}") - raise RuntimeError(f"Fail to call delete_user_account(), server response: {response.json()}") + raise RuntimeError(f"Fail to call delete_user_account()") diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index f7777ff..731c256 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -15,7 +15,7 @@ def __init__(self, service_client: ServiceClient): class UserAuthenticationClient(ServiceClientWrapper): """ - Singleton class for handling user authentication, including: + Wrapper of ServiceClient to handle user authentication, including: - user registration and login - access token caching @@ -82,12 +82,12 @@ def reset_cache(self): def _reset_token(self): self.service_client.reset_authorization() - self.CACHED_TOKEN_FILE.unlink() + self.CACHED_TOKEN_FILE.unlink(missing_ok=True) class UserDataClient(ServiceClientWrapper): """ - Singleton class for handling user data, including: + Wrapper of ServiceClient to handle user data, including: - query, or delete user account data - query, download, or delete uploaded data """ @@ -127,7 +127,7 @@ def delete_dataset(self, dataset_uid: str) -> [str]: return deleted_datasets - def delete_all_dataset(self) -> [str]: + def delete_all_datasets(self) -> [str]: try: deleted_datasets = self.service_client.delete_all_datasets() except RuntimeError as e: @@ -151,7 +151,7 @@ def delete_user_account(self): class InferenceClient(ServiceClientWrapper): """ - Singleton class for handling inference, including: + Wrapper of ServiceClient to handle inference, including: - fitting - prediction """ diff --git a/tabpfn_client/tests/mock_tabpfn_server.py b/tabpfn_client/tests/mock_tabpfn_server.py index 4aa15c0..786539b 100644 --- a/tabpfn_client/tests/mock_tabpfn_server.py +++ b/tabpfn_client/tests/mock_tabpfn_server.py @@ -8,7 +8,7 @@ class MockTabPFNServer(AbstractContextManager): def __init__(self): self.server_config = SERVER_CONFIG self.endpoints = self.server_config.endpoints - self.base_url = f"https://{self.server_config.host}:{self.server_config.port}" + self.base_url = f"{self.server_config.protocol}://{self.server_config.host}:{self.server_config.port}" self.router = None def __enter__(self): diff --git a/tabpfn_client/tests/unit/test_service_wrapper.py b/tabpfn_client/tests/unit/test_service_wrapper.py new file mode 100644 index 0000000..c875c23 --- /dev/null +++ b/tabpfn_client/tests/unit/test_service_wrapper.py @@ -0,0 +1,274 @@ +import unittest +import zipfile +from unittest.mock import patch +from io import BytesIO +from pathlib import Path + +from tabpfn_client.tests.mock_tabpfn_server import with_mock_server +from tabpfn_client.service_wrapper import UserAuthenticationClient, UserDataClient +from tabpfn_client.client import ServiceClient + + +class TestUserAuthClient(unittest.TestCase): + """ + These test cases are meant to validate the interface between the client and the server. + They do not guarantee if the response from the server is correct. + """ + + def tearDown(self): + ServiceClient().delete_instance() + + UserAuthenticationClient.CACHED_TOKEN_FILE.unlink(missing_ok=True) + + @with_mock_server() + def test_set_token_by_valid_login(self, mock_server): + # mock valid login response + dummy_token = "dummy_token" + mock_server.router.post(mock_server.endpoints.login.path).respond( + 200, + json={"access_token": dummy_token} + ) + + # assert no exception is raised + UserAuthenticationClient(ServiceClient()).set_token_by_login("dummy_email", "dummy_password") + + # assert token is set + self.assertEqual(dummy_token, ServiceClient().access_token) + + @with_mock_server() + def test_set_token_by_invalid_login(self, mock_server): + # mock invalid login response + mock_server.router.post(mock_server.endpoints.login.path).respond(400) + + # assert exception is raised + self.assertRaises( + RuntimeError, + UserAuthenticationClient(ServiceClient()).set_token_by_login, + "dummy_email", "dummy_password" + ) + + # assert token is not set + self.assertIsNone(ServiceClient().access_token) + + @with_mock_server() + def test_try_reusing_existing_token(self, mock_server): + # create dummy token file + dummy_token = "dummy_token" + token_file = UserAuthenticationClient.CACHED_TOKEN_FILE + token_file.parent.mkdir(parents=True, exist_ok=True) + token_file.write_text(dummy_token) + + # mock authentication + mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) + + # assert no exception is raised + UserAuthenticationClient(ServiceClient()).try_reuse_existing_token() + + # assert token is set + self.assertEqual(dummy_token, ServiceClient().access_token) + + def test_try_reusing_non_existing_token(self): + # assert no exception is raised + UserAuthenticationClient(ServiceClient()).try_reuse_existing_token() + + # assert token is not set + self.assertIsNone(ServiceClient().access_token) + + @with_mock_server() + def test_set_token_by_valid_registration(self, mock_server): + # mock valid registration response, and valid login response + dummy_token = "dummy_token" + mock_server.router.post(mock_server.endpoints.register.path).respond( + 200, + json={"message": "doesn't matter"} + ) + mock_server.router.post(mock_server.endpoints.login.path).respond( + 200, + json={"access_token": dummy_token} + ) + + # assert no exception is raised + UserAuthenticationClient(ServiceClient()).set_token_by_registration( + "dummy_email", "dummy_password", "dummy_password" + ) + + # assert token is set + self.assertEqual(dummy_token, ServiceClient().access_token) + + @with_mock_server() + def test_set_token_by_invalid_registration(self, mock_server): + # mock invalid registration response + mock_server.router.post(mock_server.endpoints.register.path).respond( + 400, + json={"detail": "doesn't matter"} + ) + + # assert exception is raised + self.assertRaises( + RuntimeError, + UserAuthenticationClient(ServiceClient()).set_token_by_registration, + "dummy_email", "dummy_password", "dummy_password" + ) + + # assert token is not set + self.assertIsNone(ServiceClient().access_token) + + @with_mock_server() + def test_reset_cache_after_token_set(self, mock_server): + # set token from a dummy file + dummy_token = "dummy_token" + token_file = UserAuthenticationClient.CACHED_TOKEN_FILE + token_file.parent.mkdir(parents=True, exist_ok=True) + token_file.write_text(dummy_token) + + # mock authentication + mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) + self.assertTrue(UserAuthenticationClient(ServiceClient()).try_reuse_existing_token()) + + # assert token is set + self.assertEqual(dummy_token, ServiceClient().access_token) + + # reset cache + UserAuthenticationClient(ServiceClient()).reset_cache() + + # assert token is not set + self.assertIsNone(ServiceClient().access_token) + + def test_reset_cache_without_token_set(self): + # assert no exception is raised + UserAuthenticationClient(ServiceClient()).reset_cache() + + # assert token is not set + self.assertIsNone(ServiceClient().access_token) + + +class TestUserDataClient(unittest.TestCase): + """ + These test cases are meant to validate the interface between the client and the server. + They do not guarantee if the response from the server is correct. + """ + + @staticmethod + def _is_zip_file_empty(zip_file_path: Path): + return not zipfile.ZipFile(zip_file_path, "r").namelist() + + @with_mock_server() + def test_get_data_summary_accepts_dict(self, mock_server): + # mock get_data_summary response + mock_summary = {"content": "does not matter as long as this is returned by the server"} + mock_server.router.get(mock_server.endpoints.get_data_summary.path).respond( + 200, + json=mock_summary + ) + + self.assertEqual(mock_summary, UserDataClient().get_data_summary()) + + @with_mock_server() + def test_download_all_data_accepts_empty_zip(self, mock_server): + # mock download_all_data response (with empty zip file) + zip_buffer = BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zip_file: + pass + zip_buffer.seek(0) + + mock_server.router.get(mock_server.endpoints.download_all_data.path).respond( + 200, + stream=zip_buffer, + headers={"Content-Disposition": "attachment; filename=all_data.zip"} + ) + + # assert no exception is raised, and zip file is empty + zip_file_path = UserDataClient().download_all_data(Path(".")) + self.assertTrue(self._is_zip_file_empty(zip_file_path)) + + # delete the zip file + zip_file_path.unlink() + + @with_mock_server() + def test_download_all_data_accepts_non_empty_zip(self, mock_server): + # mock download_all_data response (with non-empty zip file) + zip_buffer = BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zip_file: + zip_file.writestr("dummy_file.txt", "dummy content") + zip_buffer.seek(0) + + mock_server.router.get(mock_server.endpoints.download_all_data.path).respond( + 200, + stream=zip_buffer, + headers={"Content-Disposition": "attachment; filename=all_data.zip"} + ) + + # assert no exception is raised, and zip file is not empty + zip_file_path = UserDataClient().download_all_data(Path(".")) + self.assertFalse(self._is_zip_file_empty(zip_file_path)) + + # delete the zip file + zip_file_path.unlink() + + @with_mock_server() + def test_delete_datasets_accepts_empty_uid_list(self, mock_server): + # mock delete_dataset response (with empty list) + mock_server.router.delete(mock_server.endpoints.delete_dataset.path).respond( + 200, + json={"deleted_dataset_uids": []} + ) + + # assert no exception is raised + self.assertEqual([], UserDataClient().delete_dataset("dummy_uid")) + + @with_mock_server() + def test_delete_datasets_accepts_uid_list(self, mock_server): + # mock delete_dataset response (with non-empty list) + mock_server.router.delete(mock_server.endpoints.delete_dataset.path).respond( + 200, + json={"deleted_dataset_uids": ["dummy_uid"]} + ) + + # assert no exception is raised + self.assertEqual(["dummy_uid"], UserDataClient().delete_dataset("dummy_uid")) + + @with_mock_server() + def test_delete_all_datasets_accepts_empty_uid_list(self, mock_server): + # mock delete_all_datasets response (with empty list) + mock_server.router.delete(mock_server.endpoints.delete_all_datasets.path).respond( + 200, + json={"deleted_dataset_uids": []} + ) + + # assert no exception is raised + self.assertEqual([], UserDataClient().delete_all_datasets()) + + @with_mock_server() + def test_delete_all_datasets_accepts_uid_list(self, mock_server): + # mock delete_all_datasets response (with non-empty list) + mock_server.router.delete(mock_server.endpoints.delete_all_datasets.path).respond( + 200, + json={"deleted_dataset_uids": ["dummy_uid"]} + ) + + # assert no exception is raised + self.assertEqual(["dummy_uid"], UserDataClient().delete_all_datasets()) + + @with_mock_server() + @patch("tabpfn_client.service_wrapper.PromptAgent.prompt_confirm_password_for_user_account_deletion") + def test_delete_user_account_with_valid_password(self, mock_server, mock_prompt_confirm_password): + # mock delete_user_account response + mock_server.router.delete(mock_server.endpoints.delete_user_account.path).respond(200) + + # mock password prompting + mock_prompt_confirm_password.return_value = "dummy_password" + + # assert no exception is raised + UserDataClient().delete_user_account() + + @with_mock_server() + @patch("tabpfn_client.service_wrapper.PromptAgent.prompt_confirm_password_for_user_account_deletion") + def test_delete_user_account_with_invalid_password(self, mock_server, mock_prompt_confirm_password): + # mock delete_user_account response + mock_server.router.delete(mock_server.endpoints.delete_user_account.path).respond(400) + + # mock password prompting + mock_prompt_confirm_password.return_value = "dummy_password" + + # assert exception is raised + self.assertRaises(RuntimeError, UserDataClient().delete_user_account) From ba4b4b039cbfcab47f905bf4c92c1045cde0dfb7 Mon Sep 17 00:00:00 2001 From: "Liam, SB Hoo" Date: Wed, 25 Oct 2023 23:24:57 +0200 Subject: [PATCH 6/7] Update submodule --- tabpfn_client/tabpfn_common_utils | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tabpfn_client/tabpfn_common_utils b/tabpfn_client/tabpfn_common_utils index 8af9e62..a2df122 160000 --- a/tabpfn_client/tabpfn_common_utils +++ b/tabpfn_client/tabpfn_common_utils @@ -1 +1 @@ -Subproject commit 8af9e62f08a91628285c48b98bba8d656429d8bb +Subproject commit a2df122f2894369a444eb2335776d7dd5eade5d9 From 4fa419b654eb7aa747a871eddefa9292277340b4 Mon Sep 17 00:00:00 2001 From: Liam Hoo <44376667+liam-sbhoo@users.noreply.github.com> Date: Thu, 2 Nov 2023 15:02:14 +0100 Subject: [PATCH 7/7] Update server_config.yaml --- tabpfn_client/server_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml index 982acdb..09aba38 100644 --- a/tabpfn_client/server_config.yaml +++ b/tabpfn_client/server_config.yaml @@ -76,5 +76,5 @@ endpoints: delete_user_account: path: "/delete_user_account/" methods: [ "DELETE" ] - description: "Delete user account" + description: "Delete user account, alongside all associated data"