Skip to content

Commit

Permalink
move quick_test in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoGrin committed Jan 6, 2025
1 parent c68c33e commit 2327933
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions quick_test.py → tabpfn_client/tests/quick_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""
TabPFN Client Example Usage
--------------------------
Toy script to check that the TabPFN client is working.
Use the breast cancer dataset for classification and the diabetes dataset for regression,
and try various prediction types.
"""

import logging

from sklearn.datasets import load_breast_cancer, load_diabetes
Expand All @@ -20,9 +28,8 @@
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)
# make y text
y_train = y_train.astype(str) + "a"
y_test = y_test.astype(str) + "a"
y_train = y_train
y_test = y_test

tabpfn = TabPFNClassifier(n_estimators=3)
# print("checking estimator", check_estimator(tabpfn))
Expand All @@ -47,4 +54,5 @@

print(UserDataClient.get_data_summary())
# test predict_full
print("predicting ")
print(tabpfn.predict(X_test[:30], output_type="full", quantiles=[0.1, 0.5, 0.9]))

0 comments on commit 2327933

Please sign in to comment.