Skip to content

Commit

Permalink
Announce new features and version compatibility check (#16)
Browse files Browse the repository at this point in the history
* Fix: Test Cases in Client and Add Model to TabPFN Classifier

* Retrieve and print new messages

* Update error raising in client.py

* Remove .idea from repo

* Add further tests for client

* Fix tests and client version number retrieval

* Remove models_diff and add to .gitignore

* Fix mistake during rebase and update naming of greeting message retrieval

* Add comments and improve code for error raising

* Fix version check and quick_test.py

* Fix minor things

---------

Co-authored-by: Anshul Gupta <[email protected]>
  • Loading branch information
davidotte and Anshul Gupta authored Mar 11, 2024
1 parent 0501aba commit 87fa78f
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 36 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

models_diff/
tabpfn_client/.tabpfn/
3 changes: 1 addition & 2 deletions quick_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
tabpfn_classifier.init()
tabpfn = TabPFNClassifier(model="latest_tabpfn_hosted")
# print("checking estimator", check_estimator(tabpfn))
print(X_train.shape[0]*100)
tabpfn.fit(np.repeat(X_train, 100, axis=0), np.repeat(y_train, 100, axis=0))
tabpfn.fit(X_train[:99], y_train[:99])
print("predicting")
print(tabpfn.predict(X_test))
print("predicting_proba")
Expand Down
96 changes: 63 additions & 33 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pathlib import Path
import httpx
import logging
import copy

from importlib.metadata import version, PackageNotFoundError
import numpy as np
from omegaconf import OmegaConf
import json

from tabpfn_client.tabpfn_common_utils import utils as common_utils

Expand All @@ -15,6 +15,15 @@
SERVER_CONFIG = OmegaConf.load(SERVER_CONFIG_FILE)


def get_client_version() -> str:
try:
return version('tabpfn_client')
except PackageNotFoundError:
# Package not found, should only happen during development. Execute 'pip install -e .' to use the actual
# version number during development. Otherwise, simply return a version number that is large enough.
return '5.5.5'


@common_utils.singleton
class ServiceClient:
"""
Expand All @@ -29,7 +38,8 @@ def __init__(self):
self.httpx_timeout_s = 30 # temporary workaround for slow computation on server side
self.httpx_client = httpx.Client(
base_url=self.base_url,
timeout=self.httpx_timeout_s
timeout=self.httpx_timeout_s,
headers={"client-version": get_client_version()}
)

self._access_token = None
Expand Down Expand Up @@ -81,7 +91,7 @@ def upload_train_set(self, X, y) -> str:
])
)

self.error_raising(response, "upload_train_set")
self._validate_response(response, "upload_train_set")

train_set_uid = response.json()["train_set_uid"]
return train_set_uid
Expand Down Expand Up @@ -113,23 +123,37 @@ def predict(self, train_set_uid: str, x_test):
])
)

self.error_raising(response, "predict")
self._validate_response(response, "predict")

return np.array(response.json()["y_pred"])

def error_raising(self, response, method_name):
if response.status_code != 200:
load = None
try:
load = response.json()
except Exception:
pass
@staticmethod
def _validate_response(response, method_name, only_version_check=False):
# If status code is 200, no errors occurred on the server side.
if response.status_code == 200:
return

# Read response.
load = None
try:
load = response.json()
except json.JSONDecodeError as e:
logging.error(f"Failed to parse JSON from response in {method_name}: {e}")

# Check if the server requires a newer client version.
if response.status_code == 426:
logger.error(f"Fail to call {method_name}, response status: {response.status_code}")
raise RuntimeError(load.get("detail"))

# If we not only want to check the version compatibility, also raise other errors.
if not only_version_check:
if load is not None:
raise RuntimeError(f"Fail to call {method_name} with error: {load}")
logger.error(f"Fail to call {method_name}, response status: {response.status_code}")
if len(reponse_split_up:=response.text.split("The following exception has occurred:")) > 1:
raise RuntimeError(f"Fail to call {method_name} with error: {reponse_split_up[1]}")
raise RuntimeError(f"Fail to call {method_name} with error: {response.status_code} and reason: {response.reason_phrase}")
raise RuntimeError(f"Fail to call {method_name} with error: {response.status_code} and reason: "
f"{response.reason_phrase}")


def predict_proba(self, train_set_uid: str, x_test):
Expand Down Expand Up @@ -157,17 +181,18 @@ def predict_proba(self, train_set_uid: str, x_test):
])
)

self.error_raising(response, "predict_proba")
self._validate_response(response, "predict_proba")

return np.array(response.json()["y_pred_proba"])

def try_connection(self) -> bool:
"""
Check if server is reachable and return True if successful.
Check if server is reachable and accepts the connection.
"""
found_valid_connection = False
try:
response = self.httpx_client.get(self.server_endpoints.root.path)
self._validate_response(response, "try_connection", only_version_check=True)
if response.status_code == 200:
found_valid_connection = True

Expand All @@ -186,6 +211,8 @@ def try_authenticate(self, access_token) -> bool:
headers={"Authorization": f"Bearer {access_token}"},
)

self._validate_response(response, "try_authenticate", only_version_check=True)

if response.status_code == 200:
is_authenticated = True

Expand Down Expand Up @@ -221,6 +248,7 @@ def register(
params={"email": email, "password": password, "password_confirm": password_confirm, "validation_link": validation_link}
)

self._validate_response(response, "register", only_version_check=True)
if response.status_code == 200:
is_created = True
message = response.json()["message"]
Expand Down Expand Up @@ -251,6 +279,7 @@ def login(self, email: str, password: str) -> str | None:
data=common_utils.to_oauth_request_form(email, password)
)

self._validate_response(response, "login", only_version_check=False)
if response.status_code == 200:
access_token = response.json()["access_token"]

Expand All @@ -269,12 +298,23 @@ def get_password_policy(self) -> {}:
response = self.httpx_client.get(
self.server_endpoints.password_policy.path,
)
if response.status_code != 200:
logger.error(f"Fail to call get_password_policy(), response status: {response.status_code}")
raise RuntimeError(f"Fail to call get_password_policy()")
self._validate_response(response, "get_password_policy", only_version_check=True)

return response.json()["requirements"]

def retrieve_greeting_messages(self) -> list[str]:
"""
Retrieve greeting messages that are new for the user.
"""
response = self.httpx_client.get(self.server_endpoints.retrieve_greeting_messages.path)

self._validate_response(response, "retrieve_greeting_messages", only_version_check=True)
if response.status_code != 200:
return []

greeting_messages = response.json()["messages"]
return greeting_messages

def get_data_summary(self) -> {}:
"""
Get the data summary of the user from the server.
Expand All @@ -287,9 +327,7 @@ def get_data_summary(self) -> {}:
response = self.httpx_client.get(
self.server_endpoints.get_data_summary.path,
)
if response.status_code != 200:
logger.error(f"Fail to call get_data_summary(), response status: {response.status_code}")
raise RuntimeError(f"Fail to call get_data_summary()")
self._validate_response(response, "get_data_summary")

return response.json()

Expand All @@ -308,9 +346,7 @@ def download_all_data(self, save_dir: Path) -> Path | None:

full_url = self.base_url + self.server_endpoints.download_all_data.path
with httpx.stream("GET", full_url, headers={"Authorization": f"Bearer {self.access_token}"}) as response:
if response.status_code != 200:
logger.error(f"Fail to call download_all_data(), response status: {response.status_code}")
raise RuntimeError(f"Fail to call download_all_data()")
self._validate_response(response, "download_all_data")

filename = response.headers["Content-Disposition"].split("filename=")[1]
save_path = Path(save_dir) / filename
Expand Down Expand Up @@ -341,9 +377,7 @@ def delete_dataset(self, dataset_uid: str) -> [str]:
params={"dataset_uid": dataset_uid}
)

if response.status_code != 200:
logger.error(f"Fail to call delete_dataset(), response status: {response.status_code}")
raise RuntimeError(f"Fail to call delete_dataset()")
self._validate_response(response, "delete_dataset")

return response.json()["deleted_dataset_uids"]

Expand All @@ -360,9 +394,7 @@ def delete_all_datasets(self) -> [str]:
self.server_endpoints.delete_all_datasets.path,
)

if response.status_code != 200:
logger.error(f"Fail to call delete_all_datasets(), response status: {response.status_code}")
raise RuntimeError(f"Fail to call delete_all_datasets()")
self._validate_response(response, "delete_all_datasets")

return response.json()["deleted_dataset_uids"]

Expand All @@ -372,6 +404,4 @@ def delete_user_account(self, confirm_pass: str) -> None:
params={"confirm_password": confirm_pass}
)

if response.status_code != 200:
logger.error(f"Fail to call delete_user_account(), response status: {response.status_code}")
raise RuntimeError(f"Fail to call delete_user_account()")
self._validate_response(response, "delete_user_account")
6 changes: 6 additions & 0 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def prompt_reusing_existing_token(cls):

print(cls.indent(prompt))

@classmethod
def prompt_retrieved_greeting_messages(cls, greeting_messages: list[str]):
for message in greeting_messages:
print(cls.indent(message))


@classmethod
def prompt_confirm_password_for_user_account_deletion(cls) -> str:
print(cls.indent("You are about to delete your account."))
Expand Down
5 changes: 5 additions & 0 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ endpoints:
methods: [ "POST" ]
description: "User login"

retrieve_greeting_messages:
path: "/retrieve_greeting_messages/"
methods: [ "GET" ]
description: "Retrieve new greeting messages"

protected_root:
path: "/protected/"
methods: [ "GET" ]
Expand Down
3 changes: 3 additions & 0 deletions tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def _reset_token(self):
self.service_client.reset_authorization()
self.CACHED_TOKEN_FILE.unlink(missing_ok=True)

def retrieve_greeting_messages(self):
return self.service_client.retrieve_greeting_messages()


class UserDataClient(ServiceClientWrapper):
"""
Expand Down
3 changes: 3 additions & 0 deletions tabpfn_client/tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def init(use_server=True):
# prompt for login / register
PromptAgent.prompt_and_set_token(user_auth_handler)

# Print new greeting messages. If there are no new messages, nothing will be printed.
PromptAgent.prompt_retrieved_greeting_messages(user_auth_handler.retrieve_greeting_messages())

g_tabpfn_config.use_server = True
g_tabpfn_config.user_auth_handler = user_auth_handler
g_tabpfn_config.inference_handler = InferenceClient(service_client)
Expand Down
2 changes: 2 additions & 0 deletions tabpfn_client/tests/integration/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def test_use_remote_tabpfn_classifier(self, mock_server):
# mock connection and authentication
mock_server.router.get(mock_server.endpoints.root.path).respond(200)
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": []})
tabpfn_classifier.init(use_server=True)

tabpfn = TabPFNClassifier()
Expand Down
53 changes: 53 additions & 0 deletions tabpfn_client/tests/unit/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from unittest.mock import Mock

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -27,6 +28,14 @@ def test_try_connection_with_invalid_server(self, mock_server):
mock_server.router.get(mock_server.endpoints.root.path).respond(404)
self.assertFalse(self.client.try_connection())

@with_mock_server()
def test_try_connection_with_outdated_client_raises_runtime_error(self, mock_server):
mock_server.router.get(mock_server.endpoints.root.path).respond(
426, json={"detail": "Client version too old. ..."})
with self.assertRaises(RuntimeError) as cm:
self.client.try_connection()
self.assertTrue(str(cm.exception).startswith("Client version too old."))

@with_mock_server()
def test_register_user(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(200, json={"message": "dummy_message"})
Expand Down Expand Up @@ -57,6 +66,12 @@ def test_valid_auth_token(self, mock_server):
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
self.assertTrue(self.client.try_authenticate("true_token"))

@with_mock_server()
def test_retrieve_greeting_messages(self, mock_server):
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": ["message_1", "message_2"]})
self.assertEqual(self.client.retrieve_greeting_messages(), ["message_1", "message_2"])

@with_mock_server()
def test_predict_with_valid_train_set_and_test_set(self, mock_server):
dummy_json = {"train_set_uid": 5}
Expand All @@ -74,3 +89,41 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server):
x_test=self.X_test
)
self.assertTrue(np.array_equal(pred, dummy_result["y_pred"]))

def test_validate_response_no_error(self):
response = Mock()
response.status_code = 200
r = self.client._validate_response(response, "test")
self.assertIsNone(r)

def test_validate_response(self):
response = Mock()
# Test for "Client version too old." error
response.status_code = 426
response.json.return_value = {"detail": "Client version too old."}
with self.assertRaises(RuntimeError) as cm:
self.client._validate_response(response, "test")
self.assertEqual(str(cm.exception), "Client version too old.")

# Test for "Some other error" which is translated to a generic failure message
response.status_code = 400
response.json.return_value = {"detail": "Some other error"}
with self.assertRaises(RuntimeError) as cm:
self.client._validate_response(response, "test")
self.assertTrue(str(cm.exception).startswith("Fail to call test"))

def test_validate_response_only_version_check(self):
response = Mock()
response.status_code = 426
response.json.return_value = {"detail": "Client version too old."}
with self.assertRaises(RuntimeError) as cm:
self.client._validate_response(response, "test", only_version_check=True)
self.assertEqual(str(cm.exception), "Client version too old.")

# Errors that have nothing to do with client version should be skipped.
response = Mock()
response.status_code = 400
response.json.return_value = {"detail": "Some other error"}
r = self.client._validate_response(response, "test", only_version_check=True)
self.assertIsNone(r)

6 changes: 6 additions & 0 deletions tabpfn_client/tests/unit/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def test_init_remote_classifier(self, mock_server, mock_prompt_for_terms_and_con
mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond(
200, json={"train_set_uid": 5}
)
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": []})

tabpfn_classifier.init(use_server=True)
tabpfn = TabPFNClassifier().fit(self.X_train, self.y_train)
Expand All @@ -64,6 +66,8 @@ def test_reuse_saved_access_token(self, mock_server):
# mock connection and authentication
mock_server.router.get(mock_server.endpoints.root.path).respond(200)
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": []})

# create dummy token file
token_file = UserAuthenticationClient.CACHED_TOKEN_FILE
Expand Down Expand Up @@ -110,6 +114,8 @@ def test_reset_on_remote_classifier(self, mock_server):
# init classifier as usual
mock_server.router.get(mock_server.endpoints.root.path).respond(200)
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": []})
tabpfn_classifier.init(use_server=True)

# check if access token is saved
Expand Down

0 comments on commit 87fa78f

Please sign in to comment.