From 2327933c3203029902dbfbce6cdf0126ca28c552 Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Mon, 6 Jan 2025 22:14:35 +0100 Subject: [PATCH] move quick_test in tests --- quick_test.py => tabpfn_client/tests/quick_test.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) rename quick_test.py => tabpfn_client/tests/quick_test.py (82%) diff --git a/quick_test.py b/tabpfn_client/tests/quick_test.py similarity index 82% rename from quick_test.py rename to tabpfn_client/tests/quick_test.py index 9e45e6f..f5dcf8f 100644 --- a/quick_test.py +++ b/tabpfn_client/tests/quick_test.py @@ -1,3 +1,11 @@ +""" +TabPFN Client Example Usage +-------------------------- +Toy script to check that the TabPFN client is working. +Use the breast cancer dataset for classification and the diabetes dataset for regression, +and try various prediction types. +""" + import logging from sklearn.datasets import load_breast_cancer, load_diabetes @@ -20,9 +28,8 @@ X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.33, random_state=42 ) - # make y text - y_train = y_train.astype(str) + "a" - y_test = y_test.astype(str) + "a" + y_train = y_train + y_test = y_test tabpfn = TabPFNClassifier(n_estimators=3) # print("checking estimator", check_estimator(tabpfn)) @@ -47,4 +54,5 @@ print(UserDataClient.get_data_summary()) # test predict_full + print("predicting ") print(tabpfn.predict(X_test[:30], output_type="full", quantiles=[0.1, 0.5, 0.9]))