Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ruff #25

Merged
merged 5 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ on:
- main

jobs:
check_python_linting:
name: Ruff Linting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
liam-sbhoo marked this conversation as resolved.
Show resolved Hide resolved
- uses: chartboost/ruff-action@v1
with:
liam-sbhoo marked this conversation as resolved.
Show resolved Hide resolved
src: "./"
version: 0.3.3

test:
name: Run unit and integration tests
runs-on: ubuntu-latest
Expand Down
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.3
hooks:
# Run the linter.
- id: ruff
# Run the formatter.
- id: ruff-format
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,10 @@ tabpfn.fit(X_train, y_train)
tabpfn.predict(X_test)
# or you can also use tabpfn.predict_proba(X_test)
```

# Development
liam-sbhoo marked this conversation as resolved.
Show resolved Hide resolved

To encourage better coding practices, `ruff` has been added to the pre-commit hooks. This will ensure that the code is formatted properly before being committed. To enable pre-commit (if you haven't), run the following command:
```sh
pre-commit install
```
16 changes: 10 additions & 6 deletions quick_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
import numpy as np

logging.basicConfig(level=logging.DEBUG)

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

from tabpfn_client import tabpfn_classifier, UserDataClient
from tabpfn_client.tabpfn_classifier import TabPFNClassifier

logging.basicConfig(level=logging.DEBUG)


if __name__ == "__main__":
# set logging level to debug
Expand All @@ -18,13 +18,17 @@
# use_server = False

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)

if not use_server:
tabpfn_classifier.init(use_server=False)
tabpfn = TabPFNClassifier(device="cpu", N_ensemble_configurations=4, model="tabpfn_1_local")
tabpfn = TabPFNClassifier(
device="cpu", N_ensemble_configurations=4, model="tabpfn_1_local"
)
# check_estimator(tabpfn)
tabpfn.fit(np.repeat(X_train,100,axis=0), np.repeat(y_train,100,axis=0))
tabpfn.fit(np.repeat(X_train, 100, axis=0), np.repeat(y_train, 100, axis=0))
print("predicting")
print(tabpfn.predict(X_test))
print("predicting_proba")
Expand All @@ -40,4 +44,4 @@
print("predicting_proba")
print(tabpfn.predict_proba(X_test))

print(UserDataClient().get_data_summary())
print(UserDataClient().get_data_summary())
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@ scikit-learn
torch

# for testing
respx
respx

# development tool
pre-commit
ruff == 0.3.3
3 changes: 3 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
line-length = 88
indent-width = 4
target-version = "py310"
2 changes: 2 additions & 0 deletions tabpfn_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from tabpfn_client.tabpfn_classifier import init, TabPFNClassifier
from tabpfn_client.service_wrapper import UserDataClient

__all__ = ["init", "TabPFNClassifier", "UserDataClient"]
115 changes: 69 additions & 46 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

def get_client_version() -> str:
try:
return version('tabpfn_client')
return version("tabpfn_client")
except PackageNotFoundError:
# Package not found, should only happen during development. Execute 'pip install -e .' to use the actual
# version number during development. Otherwise, simply return a version number that is large enough.
return '5.5.5'
return "5.5.5"


@common_utils.singleton
Expand All @@ -38,11 +38,13 @@ def __init__(self):
self.server_config = SERVER_CONFIG
self.server_endpoints = SERVER_CONFIG["endpoints"]
self.base_url = f"{self.server_config.protocol}://{self.server_config.host}:{self.server_config.port}"
self.httpx_timeout_s = 30 # temporary workaround for slow computation on server side
self.httpx_timeout_s = (
30 # temporary workaround for slow computation on server side
)
self.httpx_client = httpx.Client(
base_url=self.base_url,
timeout=self.httpx_timeout_s,
headers={"client-version": get_client_version()}
headers={"client-version": get_client_version()},
)

self._access_token = None
Expand All @@ -63,8 +65,7 @@ def reset_authorization(self):

@property
def is_initialized(self):
return self.access_token is not None \
and self.access_token != ""
return self.access_token is not None and self.access_token != ""

def upload_train_set(self, X, y) -> str:
"""
Expand All @@ -88,18 +89,17 @@ def upload_train_set(self, X, y) -> str:

response = self.httpx_client.post(
url=self.server_endpoints.upload_train_set.path,
files=common_utils.to_httpx_post_file_format([
("x_file", "x_train_filename", X),
("y_file", "y_train_filename", y)
])
files=common_utils.to_httpx_post_file_format(
[("x_file", "x_train_filename", X), ("y_file", "y_train_filename", y)]
),
)

self._validate_response(response, "upload_train_set")

train_set_uid = response.json()["train_set_uid"]
return train_set_uid

def predict(self, train_set_uid: str, x_test, tabpfn_config: dict | None=None):
def predict(self, train_set_uid: str, x_test, tabpfn_config: dict | None = None):
"""
Predict the class labels for the provided data (test set).

Expand All @@ -121,14 +121,16 @@ def predict(self, train_set_uid: str, x_test, tabpfn_config: dict | None=None):
params = {"train_set_uid": train_set_uid}

if tabpfn_config is not None:
params["tabpfn_config"] = json.dumps(tabpfn_config, default=lambda x: x.to_dict())
params["tabpfn_config"] = json.dumps(
tabpfn_config, default=lambda x: x.to_dict()
)

response = self.httpx_client.post(
url=self.server_endpoints.predict.path,
params=params,
files=common_utils.to_httpx_post_file_format([
("x_file", "x_test_filename", x_test)
])
files=common_utils.to_httpx_post_file_format(
[("x_file", "x_test_filename", x_test)]
),
)

self._validate_response(response, "predict")
Expand All @@ -150,19 +152,33 @@ def _validate_response(response, method_name, only_version_check=False):

# Check if the server requires a newer client version.
if response.status_code == 426:
logger.error(f"Fail to call {method_name}, response status: {response.status_code}")
logger.error(
f"Fail to call {method_name}, response status: {response.status_code}"
)
raise RuntimeError(load.get("detail"))

# If we not only want to check the version compatibility, also raise other errors.
if not only_version_check:
if load is not None:
raise RuntimeError(f"Fail to call {method_name} with error: {load}")
logger.error(f"Fail to call {method_name}, response status: {response.status_code}")
if len(reponse_split_up:=response.text.split("The following exception has occurred:")) > 1:
raise RuntimeError(f"Fail to call {method_name} with error: {reponse_split_up[1]}")
raise RuntimeError(f"Fail to call {method_name} with error: {response.status_code} and reason: "
f"{response.reason_phrase}")

logger.error(
f"Fail to call {method_name}, response status: {response.status_code}"
)
if (
len(
reponse_split_up := response.text.split(
"The following exception has occurred:"
)
)
> 1
):
raise RuntimeError(
f"Fail to call {method_name} with error: {reponse_split_up[1]}"
)
raise RuntimeError(
f"Fail to call {method_name} with error: {response.status_code} and reason: "
f"{response.reason_phrase}"
)

def predict_proba(self, train_set_uid: str, x_test):
"""
Expand All @@ -184,9 +200,9 @@ def predict_proba(self, train_set_uid: str, x_test):
response = self.httpx_client.post(
url=self.server_endpoints.predict_proba.path,
params={"train_set_uid": train_set_uid},
files=common_utils.to_httpx_post_file_format([
("x_file", "x_test_filename", x_test)
])
files=common_utils.to_httpx_post_file_format(
[("x_file", "x_test_filename", x_test)]
),
)

self._validate_response(response, "predict_proba")
Expand Down Expand Up @@ -244,8 +260,7 @@ def validate_email(self, email: str) -> tuple[bool, str]:
The message returned from the server.
"""
response = self.httpx_client.post(
self.server_endpoints.validate_email.path,
params={"email": email}
self.server_endpoints.validate_email.path, params={"email": email}
)

self._validate_response(response, "validate_email", only_version_check=True)
Expand All @@ -259,12 +274,12 @@ def validate_email(self, email: str) -> tuple[bool, str]:
return is_valid, message

def register(
self,
email: str,
password: str,
password_confirm: str,
validation_link: str,
additional_info: dict
self,
email: str,
password: str,
password_confirm: str,
validation_link: str,
additional_info: dict,
) -> tuple[bool, str]:
"""
Register a new user with the provided credentials.
Expand All @@ -288,12 +303,12 @@ def register(
response = self.httpx_client.post(
self.server_endpoints.register.path,
params={
"email": email,
"password": password,
"email": email,
"password": password,
"password_confirm": password_confirm,
"validation_link": validation_link,
**additional_info
}
"validation_link": validation_link,
**additional_info,
},
)

self._validate_response(response, "register", only_version_check=True)
Expand Down Expand Up @@ -326,7 +341,7 @@ def login(self, email: str, password: str) -> tuple[str, str]:
access_token = None
response = self.httpx_client.post(
self.server_endpoints.login.path,
data=common_utils.to_oauth_request_form(email, password)
data=common_utils.to_oauth_request_form(email, password),
)

self._validate_response(response, "login", only_version_check=True)
Expand All @@ -351,7 +366,9 @@ def get_password_policy(self) -> {}:
response = self.httpx_client.get(
self.server_endpoints.password_policy.path,
)
self._validate_response(response, "get_password_policy", only_version_check=True)
self._validate_response(
response, "get_password_policy", only_version_check=True
)

return response.json()["requirements"]

Expand All @@ -361,7 +378,7 @@ def send_reset_password_email(self, email: str) -> tuple[bool, str]:
"""
response = self.httpx_client.post(
self.server_endpoints.send_reset_password_email.path,
params={"email": email}
params={"email": email},
)
if response.status_code == 200:
sent = True
Expand All @@ -375,9 +392,13 @@ def retrieve_greeting_messages(self) -> list[str]:
"""
Retrieve greeting messages that are new for the user.
"""
response = self.httpx_client.get(self.server_endpoints.retrieve_greeting_messages.path)
response = self.httpx_client.get(
self.server_endpoints.retrieve_greeting_messages.path
)

self._validate_response(response, "retrieve_greeting_messages", only_version_check=True)
self._validate_response(
response, "retrieve_greeting_messages", only_version_check=True
)
if response.status_code != 200:
return []

Expand Down Expand Up @@ -414,7 +435,9 @@ def download_all_data(self, save_dir: Path) -> Path | None:
save_path = None

full_url = self.base_url + self.server_endpoints.download_all_data.path
with httpx.stream("GET", full_url, headers={"Authorization": f"Bearer {self.access_token}"}) as response:
with httpx.stream(
"GET", full_url, headers={"Authorization": f"Bearer {self.access_token}"}
) as response:
self._validate_response(response, "download_all_data")

filename = response.headers["Content-Disposition"].split("filename=")[1]
Expand Down Expand Up @@ -443,7 +466,7 @@ def delete_dataset(self, dataset_uid: str) -> [str]:
"""
response = self.httpx_client.delete(
self.server_endpoints.delete_dataset.path,
params={"dataset_uid": dataset_uid}
params={"dataset_uid": dataset_uid},
)

self._validate_response(response, "delete_dataset")
Expand All @@ -470,7 +493,7 @@ def delete_all_datasets(self) -> [str]:
def delete_user_account(self, confirm_pass: str) -> None:
response = self.httpx_client.delete(
self.server_endpoints.delete_user_account.path,
params={"confirm_password": confirm_pass}
params={"confirm_password": confirm_pass},
)

self._validate_response(response, "delete_user_account")
Loading
Loading