diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index cf8897f..9eb3f6e 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -5,6 +5,21 @@ on: - main jobs: + check_python_linting: + name: Ruff Linting & Formatting + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: chartboost/ruff-action@v1 + with: + src: "./" + version: 0.3.3 + - uses: chartboost/ruff-action@v1 + with: + src: "./" + version: 0.3.3 + args: 'format --check' + test: name: Run unit and integration tests runs-on: ubuntu-latest @@ -13,10 +28,10 @@ jobs: python-version: ["3.10"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} architecture: x64 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..3c10b0e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.3.3 + hooks: + # Run the linter. + - id: ruff + # Run the formatter. + - id: ruff-format \ No newline at end of file diff --git a/README.md b/README.md index ef2d8fc..ad644d7 100644 --- a/README.md +++ b/README.md @@ -43,3 +43,11 @@ tabpfn.fit(X_train, y_train) tabpfn.predict(X_test) # or you can also use tabpfn.predict_proba(X_test) ``` + +# Development + +To encourage better coding practices, `ruff` has been added to the pre-commit hooks. This will ensure that the code is formatted properly before being committed. To enable pre-commit (if you haven't), run the following command: +```sh +pre-commit install +``` +Additionally, it is recommended that developers install the ruff extension in their preferred editor. For installation instructions, refer to the [Ruff Integrations Documentation](https://docs.astral.sh/ruff/integrations/). diff --git a/quick_test.py b/quick_test.py index 128b1af..0bded35 100644 --- a/quick_test.py +++ b/quick_test.py @@ -1,14 +1,14 @@ import logging import numpy as np -logging.basicConfig(level=logging.DEBUG) - from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split from tabpfn_client import tabpfn_classifier, UserDataClient from tabpfn_client.tabpfn_classifier import TabPFNClassifier +logging.basicConfig(level=logging.DEBUG) + if __name__ == "__main__": # set logging level to debug @@ -18,13 +18,17 @@ # use_server = False X, y = load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.33, random_state=42 + ) if not use_server: tabpfn_classifier.init(use_server=False) - tabpfn = TabPFNClassifier(device="cpu", N_ensemble_configurations=4, model="tabpfn_1_local") + tabpfn = TabPFNClassifier( + device="cpu", N_ensemble_configurations=4, model="tabpfn_1_local" + ) # check_estimator(tabpfn) - tabpfn.fit(np.repeat(X_train,100,axis=0), np.repeat(y_train,100,axis=0)) + tabpfn.fit(np.repeat(X_train, 100, axis=0), np.repeat(y_train, 100, axis=0)) print("predicting") print(tabpfn.predict(X_test)) print("predicting_proba") @@ -40,4 +44,4 @@ print("predicting_proba") print(tabpfn.predict_proba(X_test)) - print(UserDataClient().get_data_summary()) \ No newline at end of file + print(UserDataClient().get_data_summary()) diff --git a/requirements.txt b/requirements.txt index 6909bd3..bf57ffc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,8 @@ scikit-learn torch # for testing -respx \ No newline at end of file +respx + +# development tool +pre-commit +ruff == 0.3.3 \ No newline at end of file diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..d0cbe53 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,3 @@ +line-length = 88 +indent-width = 4 +target-version = "py310" \ No newline at end of file diff --git a/tabpfn_client/__init__.py b/tabpfn_client/__init__.py index 4f163df..25439a2 100644 --- a/tabpfn_client/__init__.py +++ b/tabpfn_client/__init__.py @@ -1,2 +1,4 @@ from tabpfn_client.tabpfn_classifier import init, TabPFNClassifier from tabpfn_client.service_wrapper import UserDataClient + +__all__ = ["init", "TabPFNClassifier", "UserDataClient"] diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index 105059c..0836f6a 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -20,11 +20,11 @@ def get_client_version() -> str: try: - return version('tabpfn_client') + return version("tabpfn_client") except PackageNotFoundError: # Package not found, should only happen during development. Execute 'pip install -e .' to use the actual # version number during development. Otherwise, simply return a version number that is large enough. - return '5.5.5' + return "5.5.5" @common_utils.singleton @@ -38,11 +38,13 @@ def __init__(self): self.server_config = SERVER_CONFIG self.server_endpoints = SERVER_CONFIG["endpoints"] self.base_url = f"{self.server_config.protocol}://{self.server_config.host}:{self.server_config.port}" - self.httpx_timeout_s = 30 # temporary workaround for slow computation on server side + self.httpx_timeout_s = ( + 30 # temporary workaround for slow computation on server side + ) self.httpx_client = httpx.Client( base_url=self.base_url, timeout=self.httpx_timeout_s, - headers={"client-version": get_client_version()} + headers={"client-version": get_client_version()}, ) self._access_token = None @@ -63,8 +65,7 @@ def reset_authorization(self): @property def is_initialized(self): - return self.access_token is not None \ - and self.access_token != "" + return self.access_token is not None and self.access_token != "" def upload_train_set(self, X, y) -> str: """ @@ -88,10 +89,9 @@ def upload_train_set(self, X, y) -> str: response = self.httpx_client.post( url=self.server_endpoints.upload_train_set.path, - files=common_utils.to_httpx_post_file_format([ - ("x_file", "x_train_filename", X), - ("y_file", "y_train_filename", y) - ]) + files=common_utils.to_httpx_post_file_format( + [("x_file", "x_train_filename", X), ("y_file", "y_train_filename", y)] + ), ) self._validate_response(response, "upload_train_set") @@ -99,7 +99,7 @@ def upload_train_set(self, X, y) -> str: train_set_uid = response.json()["train_set_uid"] return train_set_uid - def predict(self, train_set_uid: str, x_test, tabpfn_config: dict | None=None): + def predict(self, train_set_uid: str, x_test, tabpfn_config: dict | None = None): """ Predict the class labels for the provided data (test set). @@ -121,14 +121,16 @@ def predict(self, train_set_uid: str, x_test, tabpfn_config: dict | None=None): params = {"train_set_uid": train_set_uid} if tabpfn_config is not None: - params["tabpfn_config"] = json.dumps(tabpfn_config, default=lambda x: x.to_dict()) + params["tabpfn_config"] = json.dumps( + tabpfn_config, default=lambda x: x.to_dict() + ) response = self.httpx_client.post( url=self.server_endpoints.predict.path, params=params, - files=common_utils.to_httpx_post_file_format([ - ("x_file", "x_test_filename", x_test) - ]) + files=common_utils.to_httpx_post_file_format( + [("x_file", "x_test_filename", x_test)] + ), ) self._validate_response(response, "predict") @@ -150,19 +152,33 @@ def _validate_response(response, method_name, only_version_check=False): # Check if the server requires a newer client version. if response.status_code == 426: - logger.error(f"Fail to call {method_name}, response status: {response.status_code}") + logger.error( + f"Fail to call {method_name}, response status: {response.status_code}" + ) raise RuntimeError(load.get("detail")) # If we not only want to check the version compatibility, also raise other errors. if not only_version_check: if load is not None: raise RuntimeError(f"Fail to call {method_name} with error: {load}") - logger.error(f"Fail to call {method_name}, response status: {response.status_code}") - if len(reponse_split_up:=response.text.split("The following exception has occurred:")) > 1: - raise RuntimeError(f"Fail to call {method_name} with error: {reponse_split_up[1]}") - raise RuntimeError(f"Fail to call {method_name} with error: {response.status_code} and reason: " - f"{response.reason_phrase}") - + logger.error( + f"Fail to call {method_name}, response status: {response.status_code}" + ) + if ( + len( + reponse_split_up := response.text.split( + "The following exception has occurred:" + ) + ) + > 1 + ): + raise RuntimeError( + f"Fail to call {method_name} with error: {reponse_split_up[1]}" + ) + raise RuntimeError( + f"Fail to call {method_name} with error: {response.status_code} and reason: " + f"{response.reason_phrase}" + ) def predict_proba(self, train_set_uid: str, x_test): """ @@ -184,9 +200,9 @@ def predict_proba(self, train_set_uid: str, x_test): response = self.httpx_client.post( url=self.server_endpoints.predict_proba.path, params={"train_set_uid": train_set_uid}, - files=common_utils.to_httpx_post_file_format([ - ("x_file", "x_test_filename", x_test) - ]) + files=common_utils.to_httpx_post_file_format( + [("x_file", "x_test_filename", x_test)] + ), ) self._validate_response(response, "predict_proba") @@ -244,8 +260,7 @@ def validate_email(self, email: str) -> tuple[bool, str]: The message returned from the server. """ response = self.httpx_client.post( - self.server_endpoints.validate_email.path, - params={"email": email} + self.server_endpoints.validate_email.path, params={"email": email} ) self._validate_response(response, "validate_email", only_version_check=True) @@ -259,12 +274,12 @@ def validate_email(self, email: str) -> tuple[bool, str]: return is_valid, message def register( - self, - email: str, - password: str, - password_confirm: str, - validation_link: str, - additional_info: dict + self, + email: str, + password: str, + password_confirm: str, + validation_link: str, + additional_info: dict, ) -> tuple[bool, str]: """ Register a new user with the provided credentials. @@ -288,12 +303,12 @@ def register( response = self.httpx_client.post( self.server_endpoints.register.path, params={ - "email": email, - "password": password, + "email": email, + "password": password, "password_confirm": password_confirm, - "validation_link": validation_link, - **additional_info - } + "validation_link": validation_link, + **additional_info, + }, ) self._validate_response(response, "register", only_version_check=True) @@ -326,7 +341,7 @@ def login(self, email: str, password: str) -> tuple[str, str]: access_token = None response = self.httpx_client.post( self.server_endpoints.login.path, - data=common_utils.to_oauth_request_form(email, password) + data=common_utils.to_oauth_request_form(email, password), ) self._validate_response(response, "login", only_version_check=True) @@ -351,7 +366,9 @@ def get_password_policy(self) -> {}: response = self.httpx_client.get( self.server_endpoints.password_policy.path, ) - self._validate_response(response, "get_password_policy", only_version_check=True) + self._validate_response( + response, "get_password_policy", only_version_check=True + ) return response.json()["requirements"] @@ -361,7 +378,7 @@ def send_reset_password_email(self, email: str) -> tuple[bool, str]: """ response = self.httpx_client.post( self.server_endpoints.send_reset_password_email.path, - params={"email": email} + params={"email": email}, ) if response.status_code == 200: sent = True @@ -375,9 +392,13 @@ def retrieve_greeting_messages(self) -> list[str]: """ Retrieve greeting messages that are new for the user. """ - response = self.httpx_client.get(self.server_endpoints.retrieve_greeting_messages.path) + response = self.httpx_client.get( + self.server_endpoints.retrieve_greeting_messages.path + ) - self._validate_response(response, "retrieve_greeting_messages", only_version_check=True) + self._validate_response( + response, "retrieve_greeting_messages", only_version_check=True + ) if response.status_code != 200: return [] @@ -414,7 +435,9 @@ def download_all_data(self, save_dir: Path) -> Path | None: save_path = None full_url = self.base_url + self.server_endpoints.download_all_data.path - with httpx.stream("GET", full_url, headers={"Authorization": f"Bearer {self.access_token}"}) as response: + with httpx.stream( + "GET", full_url, headers={"Authorization": f"Bearer {self.access_token}"} + ) as response: self._validate_response(response, "download_all_data") filename = response.headers["Content-Disposition"].split("filename=")[1] @@ -443,7 +466,7 @@ def delete_dataset(self, dataset_uid: str) -> [str]: """ response = self.httpx_client.delete( self.server_endpoints.delete_dataset.path, - params={"dataset_uid": dataset_uid} + params={"dataset_uid": dataset_uid}, ) self._validate_response(response, "delete_dataset") @@ -470,7 +493,7 @@ def delete_all_datasets(self) -> [str]: def delete_user_account(self, confirm_pass: str) -> None: response = self.httpx_client.delete( self.server_endpoints.delete_user_account.path, - params={"confirm_password": confirm_pass} + params={"confirm_password": confirm_pass}, ) self._validate_response(response, "delete_user_account") diff --git a/tabpfn_client/prompt_agent.py b/tabpfn_client/prompt_agent.py index 527b3bb..714332d 100644 --- a/tabpfn_client/prompt_agent.py +++ b/tabpfn_client/prompt_agent.py @@ -1,6 +1,10 @@ import textwrap import getpass from password_strength import PasswordPolicy +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from tabpfn_client.tabpfn_classifier import UserAuthenticationClient class PromptAgent: @@ -19,33 +23,37 @@ def password_req_to_policy(password_req: list[str]): """ requirements = {} for req in password_req: - word_part, number_part = req.split('(') + word_part, number_part = req.split("(") number = int(number_part[:-1]) requirements[word_part.lower()] = number return PasswordPolicy.from_names(**requirements) @classmethod def prompt_welcome(cls): - prompt = "\n".join([ - "Welcome to TabPFN!", - "", - "TabPFN is still under active development, and we are working hard to make it better.", - "Please bear with us if you encounter any issues.", - "" - ]) + prompt = "\n".join( + [ + "Welcome to TabPFN!", + "", + "TabPFN is still under active development, and we are working hard to make it better.", + "Please bear with us if you encounter any issues.", + "", + ] + ) print(cls.indent(prompt)) @classmethod def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): # Choose between registration and login - prompt = "\n".join([ - "Please choose one of the following options:", - "(1) Create a TabPFN account", - "(2) Login to your TabPFN account", - "", - "Please enter your choice: ", - ]) + prompt = "\n".join( + [ + "Please choose one of the following options:", + "(1) Create a TabPFN account", + "(2) Login to your TabPFN account", + "", + "Please enter your choice: ", + ] + ) choice = cls._choice_with_retries(prompt, ["1", "2"]) # Registration @@ -63,13 +71,15 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): password_req = user_auth_handler.get_password_policy() password_policy = cls.password_req_to_policy(password_req) - password_req_prompt = "\n".join([ - "", - "Password requirements (minimum):", - "\n".join([f". {req}" for req in password_req]), - "", - "Please enter your password: ", - ]) + password_req_prompt = "\n".join( + [ + "", + "Password requirements (minimum):", + "\n".join([f". {req}" for req in password_req]), + "", + "Please enter your password: ", + ] + ) while True: password = getpass.getpass(cls.indent(password_req_prompt)) password_req_prompt = "Please enter your password: " @@ -77,18 +87,30 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): print(cls.indent("Password requirements not satisfied.\n")) continue - password_confirm = getpass.getpass(cls.indent("Please confirm your password: ")) + password_confirm = getpass.getpass( + cls.indent("Please confirm your password: ") + ) if password == password_confirm: break else: - print(cls.indent("Entered password and confirmation password do not match, please try again.\n")) + print( + cls.indent( + "Entered password and confirmation password do not match, please try again.\n" + ) + ) additional_info = cls.prompt_add_user_information() is_created, message = user_auth_handler.set_token_by_registration( - email, password, password_confirm, validation_link, additional_info) + email, password, password_confirm, validation_link, additional_info + ) if not is_created: raise RuntimeError("User registration failed: " + str(message) + "\n") - print(cls.indent("Account created successfully! To start using TabPFN please click on the link in the verification email we sent you.") + "\n") + print( + cls.indent( + "Account created successfully! To start using TabPFN please click on the link in the verification email we sent you." + ) + + "\n" + ) # Login elif choice == "2": @@ -97,49 +119,70 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): email = input(cls.indent("Please enter your email: ")) password = getpass.getpass(cls.indent("Please enter your password: ")) - successful, message = user_auth_handler.set_token_by_login(email, password) + successful, message = user_auth_handler.set_token_by_login( + email, password + ) if successful: break print(cls.indent("Login failed: " + message) + "\n") - prompt = "\n".join([ - "Please choose one of the following options:", - "(1) Retry login", - "(2) Reset your password", - "", - "Please enter your choice: ", - ]) + prompt = "\n".join( + [ + "Please choose one of the following options:", + "(1) Retry login", + "(2) Reset your password", + "", + "Please enter your choice: ", + ] + ) choice = cls._choice_with_retries(prompt, ["1", "2"]) if choice == "1": continue elif choice == "2": sent = False - print(cls.indent("We will send you an email with a link " - "that allows you to reset your password. \n")) + print( + cls.indent( + "We will send you an email with a link " + "that allows you to reset your password. \n" + ) + ) while not sent: email = input(cls.indent("Please enter your email address: ")) - sent, message = user_auth_handler.send_reset_password_email(email) + sent, message = user_auth_handler.send_reset_password_email( + email + ) print("\n" + cls.indent(message)) - print(cls.indent("Once you have reset your password, you will be able to login here: ")) + print( + cls.indent( + "Once you have reset your password, you will be able to login here: " + ) + ) print(cls.indent("Login successful!") + "\n") @classmethod def prompt_terms_and_cond(cls) -> bool: - t_and_c = "\n".join([ - "Please refer to our terms and conditions at: https://www.priorlabs.ai/terms-eu-en " - "By using TabPFN, you agree to the following terms and conditions:", - "Do you agree to the above terms and conditions? (y/n): ", - ]) + t_and_c = "\n".join( + [ + "Please refer to our terms and conditions at: https://www.priorlabs.ai/terms-eu-en " + "By using TabPFN, you agree to the following terms and conditions:", + "Do you agree to the above terms and conditions? (y/n): ", + ] + ) choice = cls._choice_with_retries(t_and_c, ["y", "n"]) return choice == "y" @classmethod def prompt_add_user_information(cls) -> dict: - print(cls.indent("To help us tailor our support and services to your needs, we have a few optional questions. " - "Feel free to skip any question by leaving it blank.") + "\n") + print( + cls.indent( + "To help us tailor our support and services to your needs, we have a few optional questions. " + "Feel free to skip any question by leaving it blank." + ) + + "\n" + ) company = input(cls.indent("Where do you work? ")) role = input(cls.indent("What is your role? ")) use_case = input(cls.indent("What do you want to use TabPFN for? ")) @@ -153,14 +196,14 @@ def prompt_add_user_information(cls) -> dict: "company": company, "role": role, "use_case": use_case, - "contact_via_email": contact_via_email + "contact_via_email": contact_via_email, } @classmethod def prompt_reusing_existing_token(cls): - prompt = "\n".join([ - "Found existing access token, reusing it for authentication." - ]) + prompt = "\n".join( + ["Found existing access token, reusing it for authentication."] + ) print(cls.indent(prompt)) @@ -169,11 +212,12 @@ def prompt_retrieved_greeting_messages(cls, greeting_messages: list[str]): for message in greeting_messages: print(cls.indent(message)) - @classmethod def prompt_confirm_password_for_user_account_deletion(cls) -> str: print(cls.indent("You are about to delete your account.")) - confirm_pass = getpass.getpass(cls.indent("Please confirm by entering your password: ")) + confirm_pass = getpass.getpass( + cls.indent("Please confirm by entering your password: ") + ) return confirm_pass @@ -193,8 +237,13 @@ def _choice_with_retries(cls, prompt: str, choices: list) -> str: # retry until valid choice is made while True: if choice.lower() not in choices: - choices_str = ", ".join(f"'{item}'" for item in choices[:-1]) + f" or '{choices[-1]}'" - choice = input(cls.indent(f"Invalid choice, please enter {choices_str}: ")) + choices_str = ( + ", ".join(f"'{item}'" for item in choices[:-1]) + + f" or '{choices[-1]}'" + ) + choice = input( + cls.indent(f"Invalid choice, please enter {choices_str}: ") + ) else: break diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index e7c7ec5..902d972 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -20,6 +20,7 @@ class UserAuthenticationClient(ServiceClientWrapper): - access token caching """ + CACHED_TOKEN_FILE = CACHE_DIR / "config" def is_accessible_connection(self) -> bool: @@ -35,15 +36,16 @@ def validate_email(self, email: str) -> tuple[bool, str]: return is_valid, message def set_token_by_registration( - self, - email: str, - password: str, - password_confirm: str, - validation_link: str, - additional_info: dict + self, + email: str, + password: str, + password_confirm: str, + validation_link: str, + additional_info: dict, ) -> tuple[bool, str]: - - is_created, message = self.service_client.register(email, password, password_confirm, validation_link, additional_info) + is_created, message = self.service_client.register( + email, password, password_confirm, validation_link, additional_info + ) return is_created, message def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]: @@ -92,13 +94,15 @@ def send_reset_password_email(self, email: str) -> tuple[bool, str]: sent, message = self.service_client.send_reset_password_email(email) return sent, message + class UserDataClient(ServiceClientWrapper): """ Wrapper of ServiceClient to handle user data, including: - query, or delete user account data - query, download, or delete uploaded data """ - def __init__(self, service_client = ServiceClient()): + + def __init__(self, service_client=ServiceClient()): super().__init__(service_client) def get_data_summary(self) -> {}: @@ -163,19 +167,19 @@ class InferenceClient(ServiceClientWrapper): - prediction """ - def __init__(self, service_client = ServiceClient()): + def __init__(self, service_client=ServiceClient()): super().__init__(service_client) self.last_train_set_uid = None def fit(self, X, y) -> None: if not self.service_client.is_initialized: - raise RuntimeError("Either email is not verified or Service client is not initialized. Please Verify your email and try again!") + raise RuntimeError( + "Either email is not verified or Service client is not initialized. Please Verify your email and try again!" + ) self.last_train_set_uid = self.service_client.upload_train_set(X, y) def predict(self, X, config=None): return self.service_client.predict( - train_set_uid=self.last_train_set_uid, - x_test=X, - tabpfn_config=config + train_set_uid=self.last_train_set_uid, x_test=X, tabpfn_config=config ) diff --git a/tabpfn_client/tabpfn_classifier.py b/tabpfn_client/tabpfn_classifier.py index edce3db..e6252b6 100644 --- a/tabpfn_client/tabpfn_classifier.py +++ b/tabpfn_client/tabpfn_classifier.py @@ -39,7 +39,9 @@ def init(use_server=True): # check connection to server if not user_auth_handler.is_accessible_connection(): - raise RuntimeError("TabPFN is inaccessible at the moment, please try again later.") + raise RuntimeError( + "TabPFN is inaccessible at the moment, please try again later." + ) is_valid_token_set = user_auth_handler.try_reuse_existing_token() @@ -47,13 +49,17 @@ def init(use_server=True): PromptAgent.prompt_reusing_existing_token() else: if not PromptAgent.prompt_terms_and_cond(): - raise RuntimeError("You must agree to the terms and conditions to use TabPFN") + raise RuntimeError( + "You must agree to the terms and conditions to use TabPFN" + ) # prompt for login / register PromptAgent.prompt_and_set_token(user_auth_handler) # Print new greeting messages. If there are no new messages, nothing will be printed. - PromptAgent.prompt_retrieved_greeting_messages(user_auth_handler.retrieve_greeting_messages()) + PromptAgent.prompt_retrieved_greeting_messages( + user_auth_handler.retrieve_greeting_messages() + ) g_tabpfn_config.use_server = True g_tabpfn_config.user_auth_handler = user_auth_handler @@ -171,7 +177,10 @@ def can_be_cached(self): return not self.subsample_features > 0 def to_dict(self): - return {k: str(v) if not isinstance(v, (str, int, float, list, dict)) else v for k, v in asdict(self).items()} + return { + k: str(v) if not isinstance(v, (str, int, float, list, dict)) else v + for k, v in asdict(self).items() + } ClassificationOptimizationMetricType = Literal[ @@ -228,17 +237,23 @@ def __init__( def fit(self, X, y): # assert init() is called if not g_tabpfn_config.is_initialized: - raise RuntimeError("tabpfn_client.init() must be called before using TabPFNClassifier") + raise RuntimeError( + "tabpfn_client.init() must be called before using TabPFNClassifier" + ) if g_tabpfn_config.use_server: try: - assert self.model == "latest_tabpfn_hosted", "Only 'latest_tabpfn_hosted' model is supported at the moment for tabpfn_classifier.init(use_server=True)" + assert ( + self.model == "latest_tabpfn_hosted" + ), "Only 'latest_tabpfn_hosted' model is supported at the moment for tabpfn_classifier.init(use_server=True)" except AssertionError as e: print(e) g_tabpfn_config.inference_handler.fit(X, y) self.fitted_ = True else: - raise NotImplementedError("Only server mode is supported at the moment for tabpfn_classifier.init(use_server=False)") + raise NotImplementedError( + "Only server mode is supported at the moment for tabpfn_classifier.init(use_server=False)" + ) return self def predict(self, X): @@ -248,5 +263,3 @@ def predict(self, X): def predict_proba(self, X): check_is_fitted(self) return g_tabpfn_config.inference_handler.predict(X, config=self.get_params()) - - diff --git a/tabpfn_client/tests/integration/test_tabpfn_classifier.py b/tabpfn_client/tests/integration/test_tabpfn_classifier.py index 5255111..616fd86 100644 --- a/tabpfn_client/tests/integration/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/integration/test_tabpfn_classifier.py @@ -13,7 +13,9 @@ class TestTabPFNClassifier(unittest.TestCase): def setUp(self): X, y = load_breast_cancer(return_X_y=True) - self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size=0.33, random_state=42) + self.X_train, self.X_test, self.y_train, self.y_test = train_test_split( + X, y, test_size=0.33, random_state=42 + ) def tearDown(self): tabpfn_classifier.reset() @@ -29,21 +31,27 @@ def test_use_remote_tabpfn_classifier(self, mock_server): # mock connection and authentication mock_server.router.get(mock_server.endpoints.root.path).respond(200) mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) - mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond( - 200, json={"messages": []}) + mock_server.router.get( + mock_server.endpoints.retrieve_greeting_messages.path + ).respond(200, json={"messages": []}) tabpfn_classifier.init(use_server=True) tabpfn = TabPFNClassifier() # mock fitting mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond( - 200, json={"train_set_uid": 5}) + 200, json={"train_set_uid": 5} + ) tabpfn.fit(self.X_train, self.y_train) # mock prediction mock_server.router.post(mock_server.endpoints.predict.path).respond( 200, - json={"y_pred_proba": np.random.rand(len(self.X_test), len(np.unique(self.y_train))).tolist()} + json={ + "y_pred_proba": np.random.rand( + len(self.X_test), len(np.unique(self.y_train)) + ).tolist() + }, ) pred = tabpfn.predict(self.X_test) self.assertEqual(pred.shape[0], self.X_test.shape[0]) diff --git a/tabpfn_client/tests/mock_tabpfn_server.py b/tabpfn_client/tests/mock_tabpfn_server.py index 786539b..f9ac292 100644 --- a/tabpfn_client/tests/mock_tabpfn_server.py +++ b/tabpfn_client/tests/mock_tabpfn_server.py @@ -25,5 +25,7 @@ def decorator(func): def wrapper(test_class, *args, **kwargs): with MockTabPFNServer() as mock_server: return func(test_class, mock_server, *args, **kwargs) + return wrapper - return decorator \ No newline at end of file + + return decorator diff --git a/tabpfn_client/tests/unit/test_client.py b/tabpfn_client/tests/unit/test_client.py index 6b82dd0..b04cbcd 100644 --- a/tabpfn_client/tests/unit/test_client.py +++ b/tabpfn_client/tests/unit/test_client.py @@ -13,8 +13,9 @@ class TestServiceClient(unittest.TestCase): def setUp(self): # setup data X, y = load_breast_cancer(return_X_y=True) - self.X_train, self.X_test, self.y_train, self.y_test = \ - train_test_split(X, y, test_size=0.33, random_state=42) + self.X_train, self.X_test, self.y_train, self.y_test = train_test_split( + X, y, test_size=0.33, random_state=42 + ) self.client = ServiceClient() @@ -29,63 +30,110 @@ def test_try_connection_with_invalid_server(self, mock_server): self.assertFalse(self.client.try_connection()) @with_mock_server() - def test_try_connection_with_outdated_client_raises_runtime_error(self, mock_server): + def test_try_connection_with_outdated_client_raises_runtime_error( + self, mock_server + ): mock_server.router.get(mock_server.endpoints.root.path).respond( - 426, json={"detail": "Client version too old. ..."}) + 426, json={"detail": "Client version too old. ..."} + ) with self.assertRaises(RuntimeError) as cm: self.client.try_connection() self.assertTrue(str(cm.exception).startswith("Client version too old.")) @with_mock_server() def test_validate_email(self, mock_server): - mock_server.router.post(mock_server.endpoints.validate_email.path).respond(200, json={"message": "dummy_message"}) + mock_server.router.post(mock_server.endpoints.validate_email.path).respond( + 200, json={"message": "dummy_message"} + ) self.assertTrue(self.client.validate_email("dummy_email")[0]) @with_mock_server() def test_validate_email_invalid(self, mock_server): - mock_server.router.post(mock_server.endpoints.validate_email.path).respond(401, json={"detail": "dummy_message"}) + mock_server.router.post(mock_server.endpoints.validate_email.path).respond( + 401, json={"detail": "dummy_message"} + ) self.assertFalse(self.client.validate_email("dummy_email")[0]) self.assertEqual("dummy_message", self.client.validate_email("dummy_email")[1]) @with_mock_server() def test_register_user(self, mock_server): - mock_server.router.post(mock_server.endpoints.register.path).respond(200, json={"message": "dummy_message"}) - self.assertTrue(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation", { - "company": "dummy_company", - "use_case": "dummy_usecase", - "role": "dummy_role", - "contact_via_email": False - })[0]) + mock_server.router.post(mock_server.endpoints.register.path).respond( + 200, json={"message": "dummy_message"} + ) + self.assertTrue( + self.client.register( + "dummy_email", + "dummy_password", + "dummy_password", + "dummy_validation", + { + "company": "dummy_company", + "use_case": "dummy_usecase", + "role": "dummy_role", + "contact_via_email": False, + }, + )[0] + ) @with_mock_server() def test_register_user_with_invalid_email(self, mock_server): - mock_server.router.post(mock_server.endpoints.register.path).respond(401, json={"detail": "dummy_message"}) - self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation", { - "company": "dummy_company", - "use_case": "dummy_usecase", - "role": "dummy_role", - "contact_via_email": False - })[0]) + mock_server.router.post(mock_server.endpoints.register.path).respond( + 401, json={"detail": "dummy_message"} + ) + self.assertFalse( + self.client.register( + "dummy_email", + "dummy_password", + "dummy_password", + "dummy_validation", + { + "company": "dummy_company", + "use_case": "dummy_usecase", + "role": "dummy_role", + "contact_via_email": False, + }, + )[0] + ) @with_mock_server() def test_register_user_with_invalid_validation_link(self, mock_server): - mock_server.router.post(mock_server.endpoints.register.path).respond(401, json={"detail": "dummy_message"}) - self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation", { - "company": "dummy_company", - "use_case": "dummy_usecase", - "role": "dummy_role", - "contact_via_email": False - })[0]) + mock_server.router.post(mock_server.endpoints.register.path).respond( + 401, json={"detail": "dummy_message"} + ) + self.assertFalse( + self.client.register( + "dummy_email", + "dummy_password", + "dummy_password", + "dummy_validation", + { + "company": "dummy_company", + "use_case": "dummy_usecase", + "role": "dummy_role", + "contact_via_email": False, + }, + )[0] + ) @with_mock_server() def test_register_user_with_limit_reached(self, mock_server): - mock_server.router.post(mock_server.endpoints.register.path).respond(401, json={"detail": "dummy_message"}) - self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation", { - "company": "dummy_company", - "use_case": "dummy_usecase", - "role": "dummy_role", - "contact_via_email": False - })[0]) + mock_server.router.post(mock_server.endpoints.register.path).respond( + 401, json={"detail": "dummy_message"} + ) + self.assertFalse( + self.client.register( + "dummy_email", + "dummy_password", + "dummy_password", + "dummy_validation", + { + "company": "dummy_company", + "use_case": "dummy_usecase", + "role": "dummy_role", + "contact_via_email": False, + }, + )[0] + ) @with_mock_server() def test_invalid_auth_token(self, mock_server): @@ -99,31 +147,39 @@ def test_valid_auth_token(self, mock_server): @with_mock_server() def test_send_reset_password_email(self, mock_server): - mock_server.router.post(mock_server.endpoints.send_reset_password_email.path).respond( - 200, json={"message": "Password reset email sent!"}) - self.assertEqual(self.client.send_reset_password_email("test"), (True, "Password reset email sent!")) + mock_server.router.post( + mock_server.endpoints.send_reset_password_email.path + ).respond(200, json={"message": "Password reset email sent!"}) + self.assertEqual( + self.client.send_reset_password_email("test"), + (True, "Password reset email sent!"), + ) @with_mock_server() def test_retrieve_greeting_messages(self, mock_server): - mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond( - 200, json={"messages": ["message_1", "message_2"]}) - self.assertEqual(self.client.retrieve_greeting_messages(), ["message_1", "message_2"]) + mock_server.router.get( + mock_server.endpoints.retrieve_greeting_messages.path + ).respond(200, json={"messages": ["message_1", "message_2"]}) + self.assertEqual( + self.client.retrieve_greeting_messages(), ["message_1", "message_2"] + ) @with_mock_server() def test_predict_with_valid_train_set_and_test_set(self, mock_server): dummy_json = {"train_set_uid": 5} mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond( - 200, json=dummy_json) + 200, json=dummy_json + ) self.client.upload_train_set(self.X_train, self.y_train) dummy_result = {"y_pred_proba": [1, 2, 3]} mock_server.router.post(mock_server.endpoints.predict.path).respond( - 200, json=dummy_result) + 200, json=dummy_result + ) pred = self.client.predict( - train_set_uid=dummy_json["train_set_uid"], - x_test=self.X_test + train_set_uid=dummy_json["train_set_uid"], x_test=self.X_test ) self.assertTrue(np.array_equal(pred, dummy_result["y_pred_proba"])) @@ -163,4 +219,3 @@ def test_validate_response_only_version_check(self): response.json.return_value = {"detail": "Some other error"} r = self.client._validate_response(response, "test", only_version_check=True) self.assertIsNone(r) - diff --git a/tabpfn_client/tests/unit/test_prompt_agent.py b/tabpfn_client/tests/unit/test_prompt_agent.py index 7c17370..93c4e40 100644 --- a/tabpfn_client/tests/unit/test_prompt_agent.py +++ b/tabpfn_client/tests/unit/test_prompt_agent.py @@ -12,40 +12,57 @@ def test_password_req_to_policy(self): self.assertEqual(password_req, requirements) @with_mock_server() - @patch('getpass.getpass', side_effect=['Password123!', 'Password123!']) - @patch('builtins.input', side_effect=['1', 'user@example.com', 'test', 'test', 'test', 'y']) - def test_prompt_and_set_token_registration(self, mock_input, mock_getpass, mock_server): + @patch("getpass.getpass", side_effect=["Password123!", "Password123!"]) + @patch( + "builtins.input", + side_effect=["1", "user@example.com", "test", "test", "test", "y"], + ) + def test_prompt_and_set_token_registration( + self, mock_input, mock_getpass, mock_server + ): mock_auth_client = MagicMock() - mock_auth_client.get_password_policy.return_value = ['Length(8)', 'Uppercase(1)', 'Numbers(1)', 'Special(1)'] - mock_auth_client.set_token_by_registration.return_value = (True, 'Registration successful') - mock_auth_client.validate_email.return_value = (True, '') + mock_auth_client.get_password_policy.return_value = [ + "Length(8)", + "Uppercase(1)", + "Numbers(1)", + "Special(1)", + ] + mock_auth_client.set_token_by_registration.return_value = ( + True, + "Registration successful", + ) + mock_auth_client.validate_email.return_value = (True, "") PromptAgent.prompt_and_set_token(user_auth_handler=mock_auth_client) mock_auth_client.set_token_by_registration.assert_called_once() - @patch('getpass.getpass', side_effect = ['password123']) - @patch('builtins.input', side_effect=['2', 'user@example.com']) + @patch("getpass.getpass", side_effect=["password123"]) + @patch("builtins.input", side_effect=["2", "user@example.com"]) def test_prompt_and_set_token_login(self, mock_input, mock_getpass): mock_auth_client = MagicMock() - mock_auth_client.set_token_by_login.return_value = (True, 'Login successful') + mock_auth_client.set_token_by_login.return_value = (True, "Login successful") PromptAgent.prompt_and_set_token(user_auth_handler=mock_auth_client) mock_auth_client.set_token_by_login.assert_called_once() - @patch('builtins.input', return_value='y') + @patch("builtins.input", return_value="y") def test_prompt_terms_and_cond_returns_true(self, mock_input): result = PromptAgent.prompt_terms_and_cond() self.assertTrue(result) - @patch('builtins.input', return_value='n') + @patch("builtins.input", return_value="n") def test_prompt_terms_and_cond_returns_false(self, mock_input): result = PromptAgent.prompt_terms_and_cond() self.assertFalse(result) - @patch('builtins.input', return_value='1') + @patch("builtins.input", return_value="1") def test_choice_with_retries_valid_first_try(self, mock_input): - result = PromptAgent._choice_with_retries("Please enter your choice: ", ["1", "2"]) - self.assertEqual(result, '1') + result = PromptAgent._choice_with_retries( + "Please enter your choice: ", ["1", "2"] + ) + self.assertEqual(result, "1") - @patch('builtins.input', side_effect=['3', '3', '1']) + @patch("builtins.input", side_effect=["3", "3", "1"]) def test_choice_with_retries_valid_third_try(self, mock_input): - result = PromptAgent._choice_with_retries("Please enter your choice: ", ["1", "2"]) - self.assertEqual(result, '1') + result = PromptAgent._choice_with_retries( + "Please enter your choice: ", ["1", "2"] + ) + self.assertEqual(result, "1") diff --git a/tabpfn_client/tests/unit/test_service_wrapper.py b/tabpfn_client/tests/unit/test_service_wrapper.py index ee44800..7ddd584 100644 --- a/tabpfn_client/tests/unit/test_service_wrapper.py +++ b/tabpfn_client/tests/unit/test_service_wrapper.py @@ -25,12 +25,14 @@ def test_set_token_by_valid_login(self, mock_server): # mock valid login response dummy_token = "dummy_token" mock_server.router.post(mock_server.endpoints.login.path).respond( - 200, - json={"access_token": dummy_token} + 200, json={"access_token": dummy_token} ) - self.assertTrue(UserAuthenticationClient(ServiceClient()).set_token_by_login( - "dummy_email", "dummy_password")[0]) + self.assertTrue( + UserAuthenticationClient(ServiceClient()).set_token_by_login( + "dummy_email", "dummy_password" + )[0] + ) # assert token is set self.assertEqual(dummy_token, ServiceClient().access_token) @@ -38,11 +40,14 @@ def test_set_token_by_valid_login(self, mock_server): @with_mock_server() def test_set_token_by_invalid_login(self, mock_server): # mock invalid login response - mock_server.router.post(mock_server.endpoints.login.path).respond(401, json={ - "detail": "Incorrect email or password"}) + mock_server.router.post(mock_server.endpoints.login.path).respond( + 401, json={"detail": "Incorrect email or password"} + ) self.assertEqual( (False, "Incorrect email or password"), - UserAuthenticationClient(ServiceClient()).set_token_by_login("dummy_email", "dummy_password") + UserAuthenticationClient(ServiceClient()).set_token_by_login( + "dummy_email", "dummy_password" + ), ) # assert token is not set @@ -75,18 +80,23 @@ def test_try_reusing_non_existing_token(self): @with_mock_server() def test_set_token_by_invalid_registration(self, mock_server): # mock invalid registration response - mock_server.router.post(mock_server.endpoints.register.path).respond(401, json={ - "detail": "Password mismatch"}) + mock_server.router.post(mock_server.endpoints.register.path).respond( + 401, json={"detail": "Password mismatch"} + ) self.assertEqual( (False, "Password mismatch"), UserAuthenticationClient(ServiceClient()).set_token_by_registration( - "dummy_email", "dummy_password", "dummy_password", - "dummy_validation", { - "company": "dummy_company", - "use_case": "dummy_usecase", - "role": "dummy_role", - "contact_via_email": False - }) + "dummy_email", + "dummy_password", + "dummy_password", + "dummy_validation", + { + "company": "dummy_company", + "use_case": "dummy_usecase", + "role": "dummy_role", + "contact_via_email": False, + }, + ), ) # assert token is not set @@ -102,7 +112,9 @@ def test_reset_cache_after_token_set(self, mock_server): # mock authentication mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) - self.assertTrue(UserAuthenticationClient(ServiceClient()).try_reuse_existing_token()) + self.assertTrue( + UserAuthenticationClient(ServiceClient()).try_reuse_existing_token() + ) # assert token is set self.assertEqual(dummy_token, ServiceClient().access_token) @@ -134,10 +146,11 @@ def _is_zip_file_empty(zip_file_path: Path): @with_mock_server() def test_get_data_summary_accepts_dict(self, mock_server): # mock get_data_summary response - mock_summary = {"content": "does not matter as long as this is returned by the server"} + mock_summary = { + "content": "does not matter as long as this is returned by the server" + } mock_server.router.get(mock_server.endpoints.get_data_summary.path).respond( - 200, - json=mock_summary + 200, json=mock_summary ) self.assertEqual(mock_summary, UserDataClient().get_data_summary()) @@ -146,14 +159,14 @@ def test_get_data_summary_accepts_dict(self, mock_server): def test_download_all_data_accepts_empty_zip(self, mock_server): # mock download_all_data response (with empty zip file) zip_buffer = BytesIO() - with zipfile.ZipFile(zip_buffer, "w") as zip_file: + with zipfile.ZipFile(zip_buffer, "w"): pass zip_buffer.seek(0) mock_server.router.get(mock_server.endpoints.download_all_data.path).respond( 200, stream=zip_buffer, - headers={"Content-Disposition": "attachment; filename=all_data.zip"} + headers={"Content-Disposition": "attachment; filename=all_data.zip"}, ) # assert no exception is raised, and zip file is empty @@ -174,7 +187,7 @@ def test_download_all_data_accepts_non_empty_zip(self, mock_server): mock_server.router.get(mock_server.endpoints.download_all_data.path).respond( 200, stream=zip_buffer, - headers={"Content-Disposition": "attachment; filename=all_data.zip"} + headers={"Content-Disposition": "attachment; filename=all_data.zip"}, ) # assert no exception is raised, and zip file is not empty @@ -188,8 +201,7 @@ def test_download_all_data_accepts_non_empty_zip(self, mock_server): def test_delete_datasets_accepts_empty_uid_list(self, mock_server): # mock delete_dataset response (with empty list) mock_server.router.delete(mock_server.endpoints.delete_dataset.path).respond( - 200, - json={"deleted_dataset_uids": []} + 200, json={"deleted_dataset_uids": []} ) # assert no exception is raised @@ -199,8 +211,7 @@ def test_delete_datasets_accepts_empty_uid_list(self, mock_server): def test_delete_datasets_accepts_uid_list(self, mock_server): # mock delete_dataset response (with non-empty list) mock_server.router.delete(mock_server.endpoints.delete_dataset.path).respond( - 200, - json={"deleted_dataset_uids": ["dummy_uid"]} + 200, json={"deleted_dataset_uids": ["dummy_uid"]} ) # assert no exception is raised @@ -209,10 +220,9 @@ def test_delete_datasets_accepts_uid_list(self, mock_server): @with_mock_server() def test_delete_all_datasets_accepts_empty_uid_list(self, mock_server): # mock delete_all_datasets response (with empty list) - mock_server.router.delete(mock_server.endpoints.delete_all_datasets.path).respond( - 200, - json={"deleted_dataset_uids": []} - ) + mock_server.router.delete( + mock_server.endpoints.delete_all_datasets.path + ).respond(200, json={"deleted_dataset_uids": []}) # assert no exception is raised self.assertEqual([], UserDataClient().delete_all_datasets()) @@ -220,19 +230,24 @@ def test_delete_all_datasets_accepts_empty_uid_list(self, mock_server): @with_mock_server() def test_delete_all_datasets_accepts_uid_list(self, mock_server): # mock delete_all_datasets response (with non-empty list) - mock_server.router.delete(mock_server.endpoints.delete_all_datasets.path).respond( - 200, - json={"deleted_dataset_uids": ["dummy_uid"]} - ) + mock_server.router.delete( + mock_server.endpoints.delete_all_datasets.path + ).respond(200, json={"deleted_dataset_uids": ["dummy_uid"]}) # assert no exception is raised self.assertEqual(["dummy_uid"], UserDataClient().delete_all_datasets()) @with_mock_server() - @patch("tabpfn_client.service_wrapper.PromptAgent.prompt_confirm_password_for_user_account_deletion") - def test_delete_user_account_with_valid_password(self, mock_server, mock_prompt_confirm_password): + @patch( + "tabpfn_client.service_wrapper.PromptAgent.prompt_confirm_password_for_user_account_deletion" + ) + def test_delete_user_account_with_valid_password( + self, mock_server, mock_prompt_confirm_password + ): # mock delete_user_account response - mock_server.router.delete(mock_server.endpoints.delete_user_account.path).respond(200) + mock_server.router.delete( + mock_server.endpoints.delete_user_account.path + ).respond(200) # mock password prompting mock_prompt_confirm_password.return_value = "dummy_password" @@ -241,10 +256,16 @@ def test_delete_user_account_with_valid_password(self, mock_server, mock_prompt_ UserDataClient().delete_user_account() @with_mock_server() - @patch("tabpfn_client.service_wrapper.PromptAgent.prompt_confirm_password_for_user_account_deletion") - def test_delete_user_account_with_invalid_password(self, mock_server, mock_prompt_confirm_password): + @patch( + "tabpfn_client.service_wrapper.PromptAgent.prompt_confirm_password_for_user_account_deletion" + ) + def test_delete_user_account_with_invalid_password( + self, mock_server, mock_prompt_confirm_password + ): # mock delete_user_account response - mock_server.router.delete(mock_server.endpoints.delete_user_account.path).respond(400) + mock_server.router.delete( + mock_server.endpoints.delete_user_account.path + ).respond(400) # mock password prompting mock_prompt_confirm_password.return_value = "dummy_password" diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py index 5653a25..657fcb9 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py @@ -16,14 +16,14 @@ class TestTabPFNClassifierInit(unittest.TestCase): - dummy_token = "dummy_token" def setUp(self): # set up dummy data X, y = load_breast_cancer(return_X_y=True) - self.X_train, self.X_test, self.y_train, self.y_test = \ - train_test_split(X, y, test_size=0.33) + self.X_train, self.X_test, self.y_train, self.y_test = train_test_split( + X, y, test_size=0.33 + ) def tearDown(self): tabpfn_classifier.reset() @@ -36,33 +36,33 @@ def tearDown(self): @with_mock_server() @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") - @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", - return_value=True) - def test_init_remote_classifier(self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token): - mock_prompt_and_set_token.side_effect = \ + @patch( + "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", + return_value=True, + ) + def test_init_remote_classifier( + self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token + ): + mock_prompt_and_set_token.side_effect = ( lambda user_auth_handler: user_auth_handler.set_token(self.dummy_token) + ) # mock server connection mock_server.router.get(mock_server.endpoints.root.path).respond(200) mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond( 200, json={"train_set_uid": 5} ) - mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond( - 200, json={"messages": []}) - mock_predict_response = [[1, 0.],[.9, .1],[0.01, 0.99]] + mock_server.router.get( + mock_server.endpoints.retrieve_greeting_messages.path + ).respond(200, json={"messages": []}) + mock_predict_response = [[1, 0.0], [0.9, 0.1], [0.01, 0.99]] predict_route = mock_server.router.post(mock_server.endpoints.predict.path) - predict_route.respond( - 200, json={"y_pred_proba": mock_predict_response} - ) + predict_route.respond(200, json={"y_pred_proba": mock_predict_response}) tabpfn_classifier.init(use_server=True) tabpfn = TabPFNClassifier(n_estimators=10) - self.assertRaises( - NotFittedError, - tabpfn.predict, - self.X_test - ) + self.assertRaises(NotFittedError, tabpfn.predict, self.X_test) tabpfn.fit(self.X_train, self.y_train) self.assertTrue(mock_prompt_and_set_token.called) self.assertTrue(mock_prompt_for_terms_and_cond.called) @@ -70,15 +70,20 @@ def test_init_remote_classifier(self, mock_server, mock_prompt_for_terms_and_con y_pred = tabpfn.predict(self.X_test) self.assertTrue(np.all(np.argmax(mock_predict_response, axis=1) == y_pred)) - self.assertIn('n_estimators%22%3A%2010', str(predict_route.calls.last.request.url), "check that n_estimators is passed to the server") + self.assertIn( + "n_estimators%22%3A%2010", + str(predict_route.calls.last.request.url), + "check that n_estimators is passed to the server", + ) @with_mock_server() def test_reuse_saved_access_token(self, mock_server): # mock connection and authentication mock_server.router.get(mock_server.endpoints.root.path).respond(200) mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) - mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond( - 200, json={"messages": []}) + mock_server.router.get( + mock_server.endpoints.retrieve_greeting_messages.path + ).respond(200, json={"messages": []}) # create dummy token file token_file = UserAuthenticationClient.CACHED_TOKEN_FILE @@ -93,9 +98,13 @@ def test_reuse_saved_access_token(self, mock_server): @with_mock_server() @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") - @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", - return_value=True) - def test_invalid_saved_access_token(self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token): + @patch( + "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", + return_value=True, + ) + def test_invalid_saved_access_token( + self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token + ): mock_prompt_and_set_token.side_effect = [RuntimeError] # mock connection and invalid authentication @@ -120,8 +129,9 @@ def test_reset_on_remote_classifier(self, mock_server): # init classifier as usual mock_server.router.get(mock_server.endpoints.root.path).respond(200) mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) - mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond( - 200, json={"messages": []}) + mock_server.router.get( + mock_server.endpoints.retrieve_greeting_messages.path + ).respond(200, json={"messages": []}) tabpfn_classifier.init(use_server=True) # check if access token is saved @@ -137,8 +147,10 @@ def test_reset_on_remote_classifier(self, mock_server): self.assertFalse(tabpfn_classifier.g_tabpfn_config.is_initialized) @with_mock_server() - @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", - return_value=False) + @patch( + "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", + return_value=False, + ) def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_cond): # mock connection mock_server.router.get(mock_server.endpoints.root.path).respond(200)