Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Announce new features and version compatibility check #16

Merged
merged 11 commits into from
Mar 11, 2024
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

models_diff/
tabpfn_client/.tabpfn/
3 changes: 1 addition & 2 deletions quick_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
tabpfn_classifier.init()
tabpfn = TabPFNClassifier(model="latest_tabpfn_hosted")
# print("checking estimator", check_estimator(tabpfn))
print(X_train.shape[0]*100)
tabpfn.fit(np.repeat(X_train, 100, axis=0), np.repeat(y_train, 100, axis=0))
tabpfn.fit(X_train[:99], y_train[:99])
print("predicting")
print(tabpfn.predict(X_test))
print("predicting_proba")
Expand Down
96 changes: 63 additions & 33 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pathlib import Path
import httpx
import logging
import copy

from importlib.metadata import version, PackageNotFoundError
import numpy as np
from omegaconf import OmegaConf
import json

from tabpfn_client.tabpfn_common_utils import utils as common_utils

Expand All @@ -15,6 +15,15 @@
SERVER_CONFIG = OmegaConf.load(SERVER_CONFIG_FILE)


def get_client_version() -> str:
try:
return version('tabpfn_client')
except PackageNotFoundError:
# Package not found, should only happen during development. Execute 'pip install -e .' to use the actual
# version number during development. Otherwise, simply return a version number that is large enough.
return '5.5.5'
davidotte marked this conversation as resolved.
Show resolved Hide resolved


@common_utils.singleton
class ServiceClient:
"""
Expand All @@ -29,7 +38,8 @@ def __init__(self):
self.httpx_timeout_s = 30 # temporary workaround for slow computation on server side
self.httpx_client = httpx.Client(
base_url=self.base_url,
timeout=self.httpx_timeout_s
timeout=self.httpx_timeout_s,
headers={"client-version": get_client_version()}
)

self._access_token = None
Expand Down Expand Up @@ -81,7 +91,7 @@ def upload_train_set(self, X, y) -> str:
])
)

self.error_raising(response, "upload_train_set")
self._validate_response(response, "upload_train_set")

train_set_uid = response.json()["train_set_uid"]
return train_set_uid
Expand Down Expand Up @@ -113,23 +123,37 @@ def predict(self, train_set_uid: str, x_test):
])
)

self.error_raising(response, "predict")
self._validate_response(response, "predict")

return np.array(response.json()["y_pred"])

def error_raising(self, response, method_name):
if response.status_code != 200:
load = None
try:
load = response.json()
except Exception:
pass
@staticmethod
def _validate_response(response, method_name, only_version_check=False):
# If status code is 200, no errors occurred on the server side.
if response.status_code == 200:
return

# Read response.
load = None
try:
load = response.json()
except json.JSONDecodeError as e:
logging.error(f"Failed to parse JSON from response in {method_name}: {e}")

# Check if the server requires a newer client version.
if response.status_code == 426:
logger.error(f"Fail to call {method_name}, response status: {response.status_code}")
raise RuntimeError(load.get("detail"))

# If we not only want to check the version compatibility, also raise other errors.
if not only_version_check:
if load is not None:
raise RuntimeError(f"Fail to call {method_name} with error: {load}")
logger.error(f"Fail to call {method_name}, response status: {response.status_code}")
if len(reponse_split_up:=response.text.split("The following exception has occurred:")) > 1:
raise RuntimeError(f"Fail to call {method_name} with error: {reponse_split_up[1]}")
raise RuntimeError(f"Fail to call {method_name} with error: {response.status_code} and reason: {response.reason_phrase}")
raise RuntimeError(f"Fail to call {method_name} with error: {response.status_code} and reason: "
f"{response.reason_phrase}")


def predict_proba(self, train_set_uid: str, x_test):
Expand Down Expand Up @@ -157,17 +181,18 @@ def predict_proba(self, train_set_uid: str, x_test):
])
)

self.error_raising(response, "predict_proba")
self._validate_response(response, "predict_proba")

return np.array(response.json()["y_pred_proba"])

def try_connection(self) -> bool:
"""
Check if server is reachable and return True if successful.
Check if server is reachable and accepts the connection.
"""
found_valid_connection = False
try:
response = self.httpx_client.get(self.server_endpoints.root.path)
self._validate_response(response, "try_connection", only_version_check=True)
if response.status_code == 200:
found_valid_connection = True

Expand All @@ -186,6 +211,8 @@ def try_authenticate(self, access_token) -> bool:
headers={"Authorization": f"Bearer {access_token}"},
)

self._validate_response(response, "try_authenticate", only_version_check=True)

if response.status_code == 200:
is_authenticated = True

Expand Down Expand Up @@ -221,6 +248,7 @@ def register(
params={"email": email, "password": password, "password_confirm": password_confirm, "validation_link": validation_link}
)

self._validate_response(response, "register", only_version_check=True)
if response.status_code == 200:
is_created = True
message = response.json()["message"]
Expand Down Expand Up @@ -251,6 +279,7 @@ def login(self, email: str, password: str) -> str | None:
data=common_utils.to_oauth_request_form(email, password)
)

self._validate_response(response, "login", only_version_check=False)
if response.status_code == 200:
access_token = response.json()["access_token"]

Expand All @@ -269,12 +298,23 @@ def get_password_policy(self) -> {}:
response = self.httpx_client.get(
self.server_endpoints.password_policy.path,
)
if response.status_code != 200:
logger.error(f"Fail to call get_password_policy(), response status: {response.status_code}")
raise RuntimeError(f"Fail to call get_password_policy()")
self._validate_response(response, "get_password_policy", only_version_check=True)

return response.json()["requirements"]

def retrieve_greeting_messages(self) -> list[str]:
"""
Retrieve greeting messages that are new for the user.
"""
response = self.httpx_client.get(self.server_endpoints.retrieve_greeting_messages.path)

self._validate_response(response, "retrieve_greeting_messages", only_version_check=True)
if response.status_code != 200:
return []

greeting_messages = response.json()["messages"]
return greeting_messages

def get_data_summary(self) -> {}:
"""
Get the data summary of the user from the server.
Expand All @@ -287,9 +327,7 @@ def get_data_summary(self) -> {}:
response = self.httpx_client.get(
self.server_endpoints.get_data_summary.path,
)
if response.status_code != 200:
logger.error(f"Fail to call get_data_summary(), response status: {response.status_code}")
raise RuntimeError(f"Fail to call get_data_summary()")
self._validate_response(response, "get_data_summary")

return response.json()

Expand All @@ -308,9 +346,7 @@ def download_all_data(self, save_dir: Path) -> Path | None:

full_url = self.base_url + self.server_endpoints.download_all_data.path
with httpx.stream("GET", full_url, headers={"Authorization": f"Bearer {self.access_token}"}) as response:
if response.status_code != 200:
logger.error(f"Fail to call download_all_data(), response status: {response.status_code}")
raise RuntimeError(f"Fail to call download_all_data()")
self._validate_response(response, "download_all_data")

filename = response.headers["Content-Disposition"].split("filename=")[1]
save_path = Path(save_dir) / filename
Expand Down Expand Up @@ -341,9 +377,7 @@ def delete_dataset(self, dataset_uid: str) -> [str]:
params={"dataset_uid": dataset_uid}
)

if response.status_code != 200:
logger.error(f"Fail to call delete_dataset(), response status: {response.status_code}")
raise RuntimeError(f"Fail to call delete_dataset()")
self._validate_response(response, "delete_dataset")

return response.json()["deleted_dataset_uids"]

Expand All @@ -360,9 +394,7 @@ def delete_all_datasets(self) -> [str]:
self.server_endpoints.delete_all_datasets.path,
)

if response.status_code != 200:
logger.error(f"Fail to call delete_all_datasets(), response status: {response.status_code}")
raise RuntimeError(f"Fail to call delete_all_datasets()")
self._validate_response(response, "delete_all_datasets")

return response.json()["deleted_dataset_uids"]

Expand All @@ -372,6 +404,4 @@ def delete_user_account(self, confirm_pass: str) -> None:
params={"confirm_password": confirm_pass}
)

if response.status_code != 200:
logger.error(f"Fail to call delete_user_account(), response status: {response.status_code}")
raise RuntimeError(f"Fail to call delete_user_account()")
self._validate_response(response, "delete_user_account")
6 changes: 6 additions & 0 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def prompt_reusing_existing_token(cls):

print(cls.indent(prompt))

@classmethod
def prompt_retrieved_greeting_messages(cls, greeting_messages: list[str]):
for message in greeting_messages:
print(cls.indent(message))


@classmethod
def prompt_confirm_password_for_user_account_deletion(cls) -> str:
print(cls.indent("You are about to delete your account."))
Expand Down
5 changes: 5 additions & 0 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ endpoints:
methods: [ "POST" ]
description: "User login"

retrieve_greeting_messages:
path: "/retrieve_greeting_messages/"
methods: [ "GET" ]
description: "Retrieve new greeting messages"

protected_root:
path: "/protected/"
methods: [ "GET" ]
Expand Down
3 changes: 3 additions & 0 deletions tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def _reset_token(self):
self.service_client.reset_authorization()
self.CACHED_TOKEN_FILE.unlink(missing_ok=True)

def retrieve_greeting_messages(self):
return self.service_client.retrieve_greeting_messages()


class UserDataClient(ServiceClientWrapper):
"""
Expand Down
3 changes: 3 additions & 0 deletions tabpfn_client/tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def init(use_server=True):
# prompt for login / register
PromptAgent.prompt_and_set_token(user_auth_handler)

# Print new greeting messages. If there are no new messages, nothing will be printed.
PromptAgent.prompt_retrieved_greeting_messages(user_auth_handler.retrieve_greeting_messages())

g_tabpfn_config.use_server = True
g_tabpfn_config.user_auth_handler = user_auth_handler
g_tabpfn_config.inference_handler = InferenceClient(service_client)
Expand Down
2 changes: 2 additions & 0 deletions tabpfn_client/tests/integration/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def test_use_remote_tabpfn_classifier(self, mock_server):
# mock connection and authentication
mock_server.router.get(mock_server.endpoints.root.path).respond(200)
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": []})
tabpfn_classifier.init(use_server=True)

tabpfn = TabPFNClassifier()
Expand Down
53 changes: 53 additions & 0 deletions tabpfn_client/tests/unit/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from unittest.mock import Mock

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -27,6 +28,14 @@ def test_try_connection_with_invalid_server(self, mock_server):
mock_server.router.get(mock_server.endpoints.root.path).respond(404)
self.assertFalse(self.client.try_connection())

@with_mock_server()
def test_try_connection_with_outdated_client_raises_runtime_error(self, mock_server):
mock_server.router.get(mock_server.endpoints.root.path).respond(
426, json={"detail": "Client version too old. ..."})
with self.assertRaises(RuntimeError) as cm:
self.client.try_connection()
self.assertTrue(str(cm.exception).startswith("Client version too old."))

@with_mock_server()
def test_register_user(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(200, json={"message": "dummy_message"})
Expand Down Expand Up @@ -57,6 +66,12 @@ def test_valid_auth_token(self, mock_server):
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
self.assertTrue(self.client.try_authenticate("true_token"))

@with_mock_server()
def test_retrieve_greeting_messages(self, mock_server):
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": ["message_1", "message_2"]})
self.assertEqual(self.client.retrieve_greeting_messages(), ["message_1", "message_2"])

@with_mock_server()
def test_predict_with_valid_train_set_and_test_set(self, mock_server):
dummy_json = {"train_set_uid": 5}
Expand All @@ -74,3 +89,41 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server):
x_test=self.X_test
)
self.assertTrue(np.array_equal(pred, dummy_result["y_pred"]))

def test_validate_response_no_error(self):
response = Mock()
response.status_code = 200
r = self.client._validate_response(response, "test")
self.assertIsNone(r)

def test_validate_response(self):
response = Mock()
# Test for "Client version too old." error
response.status_code = 426
response.json.return_value = {"detail": "Client version too old."}
with self.assertRaises(RuntimeError) as cm:
self.client._validate_response(response, "test")
self.assertEqual(str(cm.exception), "Client version too old.")

# Test for "Some other error" which is translated to a generic failure message
response.status_code = 400
response.json.return_value = {"detail": "Some other error"}
with self.assertRaises(RuntimeError) as cm:
self.client._validate_response(response, "test")
self.assertTrue(str(cm.exception).startswith("Fail to call test"))

def test_validate_response_only_version_check(self):
response = Mock()
response.status_code = 426
response.json.return_value = {"detail": "Client version too old."}
with self.assertRaises(RuntimeError) as cm:
self.client._validate_response(response, "test", only_version_check=True)
self.assertEqual(str(cm.exception), "Client version too old.")

# Errors that have nothing to do with client version should be skipped.
response = Mock()
response.status_code = 400
response.json.return_value = {"detail": "Some other error"}
r = self.client._validate_response(response, "test", only_version_check=True)
self.assertIsNone(r)

6 changes: 6 additions & 0 deletions tabpfn_client/tests/unit/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def test_init_remote_classifier(self, mock_server, mock_prompt_for_terms_and_con
mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond(
200, json={"train_set_uid": 5}
)
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": []})

tabpfn_classifier.init(use_server=True)
tabpfn = TabPFNClassifier().fit(self.X_train, self.y_train)
Expand All @@ -64,6 +66,8 @@ def test_reuse_saved_access_token(self, mock_server):
# mock connection and authentication
mock_server.router.get(mock_server.endpoints.root.path).respond(200)
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": []})

# create dummy token file
token_file = UserAuthenticationClient.CACHED_TOKEN_FILE
Expand Down Expand Up @@ -110,6 +114,8 @@ def test_reset_on_remote_classifier(self, mock_server):
# init classifier as usual
mock_server.router.get(mock_server.endpoints.root.path).respond(200)
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": []})
tabpfn_classifier.init(use_server=True)

# check if access token is saved
Expand Down
Loading