diff --git a/quick_test.py b/quick_test.py deleted file mode 100644 index 71cb385..0000000 --- a/quick_test.py +++ /dev/null @@ -1,45 +0,0 @@ -import logging - -from sklearn.datasets import load_breast_cancer, load_diabetes -from sklearn.model_selection import train_test_split - -from tabpfn_client import UserDataClient -from tabpfn_client.estimator import TabPFNClassifier, TabPFNRegressor - -logging.basicConfig(level=logging.DEBUG) - - -if __name__ == "__main__": - # set logging level to debug - # logging.basicConfig(level=logging.DEBUG) - - use_server = True - # use_server = False - - X, y = load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = train_test_split( - X, y, test_size=0.33, random_state=42 - ) - - tabpfn = TabPFNClassifier(n_estimators=3) - # print("checking estimator", check_estimator(tabpfn)) - tabpfn.fit(X_train[:99], y_train[:99]) - print("predicting") - print(tabpfn.predict(X_test)) - print("predicting_proba") - print(tabpfn.predict_proba(X_test)) - - print(UserDataClient.get_data_summary()) - - X, y = load_diabetes(return_X_y=True) - X_train, X_test, y_train, y_test = train_test_split( - X, y, test_size=0.33, random_state=42 - ) - - tabpfn = TabPFNRegressor(n_estimators=3) - # print("checking estimator", check_estimator(tabpfn)) - tabpfn.fit(X_train[:99], y_train[:99]) - print("predicting reg") - print(tabpfn.predict(X_test)) - - print(UserDataClient.get_data_summary()) diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index d22ef76..8f2997d 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -243,6 +243,7 @@ def predict( train_set_uid: str, x_test, task: Literal["classification", "regression"], + predict_params: Union[dict, None] = None, tabpfn_config: Union[dict, None] = None, X_train=None, y_train=None, @@ -265,7 +266,11 @@ def predict( x_test_serialized = common_utils.serialize_to_csv_formatted_bytes(x_test) - params = {"train_set_uid": train_set_uid, "task": task} + params = { + "train_set_uid": train_set_uid, + "task": task, + "predict_params": json.dumps(predict_params), + } if tabpfn_config is not None: paper_version = tabpfn_config.pop("paper_version") params["tabpfn_config"] = json.dumps( @@ -377,15 +382,12 @@ def run_progress(): if cached_test_set_uid is None: cls.dataset_uid_cache_manager.add_dataset_uid(dataset_hash, test_set_uid) - # The results contain different things for the different tasks - # - classification: probas_array - # - regression: {"mean": mean_array, "median": median_array, "mode": mode_array, ...} - # So, if the result is not a dictionary, we add a "probas" key to it. if not isinstance(result, dict): - result = {"probas": result} - - for k in result: - result[k] = np.array(result[k]) + result = np.array(result) + else: + for k in result: + if isinstance(result[k], list): + result[k] = np.array(result[k]) return result diff --git a/tabpfn_client/estimator.py b/tabpfn_client/estimator.py index 1d83aac..7ce65f8 100644 --- a/tabpfn_client/estimator.py +++ b/tabpfn_client/estimator.py @@ -1,6 +1,5 @@ -from typing import Optional, Tuple, Literal, Dict, Union +from typing import Optional, Literal, Dict, Union import logging -from dataclasses import dataclass, asdict import numpy as np from tabpfn_client.config import init @@ -16,110 +15,10 @@ MAX_COLS = 500 -@dataclass(eq=True, frozen=True) -class PreprocessorConfig: - """ - Configuration for data preprocessors. - - Attributes: - name (Literal): Name of the preprocessor. - categorical_name (Literal): Name of the categorical encoding method. Valid options are "none", "numeric", - "onehot", "ordinal", "ordinal_shuffled". Default is "none". - append_original (bool): Whether to append the original features to the transformed features. Default is False. - subsample_features (float): Fraction of features to subsample. -1 means no subsampling. Default is -1. - global_transformer_name (str): Name of the global transformer to use. Default is None. - """ - - name: Literal[ - "per_feature", # a different transformation for each feature - "power", # a standard sklearn power transformer - "safepower", # a power transformer that prevents some numerical issues - "power_box", - "safepower_box", - "quantile_uni_coarse", # different quantile transformations with few quantiles up to a lot - "quantile_norm_coarse", - "quantile_uni", - "quantile_norm", - "quantile_uni_fine", - "quantile_norm_fine", - "robust", # a standard sklearn robust scaler - "kdi", - "none", # no transformation (inside the transformer we anyways do a standardization) - "kdi_random_alpha", - "kdi_uni", - "kdi_random_alpha_uni", - "adaptive", - "norm_and_kdi", - # KDI with alpha collection - "kdi_alpha_0.3_uni", - "kdi_alpha_0.5_uni", - "kdi_alpha_0.8_uni", - "kdi_alpha_1.0_uni", - "kdi_alpha_1.2_uni", - "kdi_alpha_1.5_uni", - "kdi_alpha_2.0_uni", - "kdi_alpha_3.0_uni", - "kdi_alpha_5.0_uni", - "kdi_alpha_0.3", - "kdi_alpha_0.5", - "kdi_alpha_0.8", - "kdi_alpha_1.0", - "kdi_alpha_1.2", - "kdi_alpha_1.5", - "kdi_alpha_2.0", - "kdi_alpha_3.0", - "kdi_alpha_5.0", - ] - categorical_name: Literal[ - "none", - "numeric", - "onehot", - "ordinal", - "ordinal_shuffled", - "ordinal_very_common_categories_shuffled", - ] = "none" - # categorical_name meanings: - # "none": categorical features are pretty much treated as ordinal, just not resorted - # "numeric": categorical features are treated as numeric, that means they are also power transformed for example - # "onehot": categorical features are onehot encoded - # "ordinal": categorical features are sorted and encoded as integers from 0 to n_categories - 1 - # "ordinal_shuffled": categorical features are encoded as integers from 0 to n_categories - 1 in a random order - append_original: bool = False - subsample_features: Optional[float] = -1 - global_transformer_name: Optional[str] = None - # if True, the transformed features (e.g. power transformed) are appended to the original features - - def __str__(self): - return ( - f"{self.name}_cat:{self.categorical_name}" - + ("_and_none" if self.append_original else "") - + ( - "_subsample_feats_" + str(self.subsample_features) - if self.subsample_features > 0 - else "" - ) - + ( - f"_global_transformer_{self.global_transformer_name}" - if self.global_transformer_name is not None - else "" - ) - ) - - def can_be_cached(self): - return not self.subsample_features > 0 - - def to_dict(self): - return { - k: str(v) if not isinstance(v, (str, int, float, list, dict)) else v - for k, v in asdict(self).items() - } - - class TabPFNModelSelection: """Base class for TabPFN model selection and path handling.""" _AVAILABLE_MODELS: list[str] = [] - _BASE_PATH = "/home/venv/lib/python3.9/site-packages/tabpfn/model_cache/model_hans" _VALID_TASKS = {"classification", "regression"} @classmethod @@ -139,10 +38,10 @@ def _model_name_to_path( cls, task: Literal["classification", "regression"], model_name: str ) -> str: cls._validate_model_name(model_name) - + model_name_task = "classifier" if task == "classification" else "regressor" if model_name == "default": - return f"{cls._BASE_PATH}_{task}.ckpt" - return f"{cls._BASE_PATH}_{task}_{model_name}.ckpt" + return f"tabpfn-v2-{model_name_task}.ckpt" + return f"tabpfn-v2-{model_name_task}-{model_name}.ckpt" class TabPFNClassifier(BaseEstimator, ClassifierMixin, TabPFNModelSelection): @@ -157,100 +56,87 @@ class TabPFNClassifier(BaseEstimator, ClassifierMixin, TabPFNModelSelection): def __init__( self, - model="default", + model_path: str = "default", n_estimators: int = 4, - preprocess_transforms: Tuple[PreprocessorConfig, ...] = ( - PreprocessorConfig( - "quantile_uni_coarse", - append_original=True, - categorical_name="ordinal_very_common_categories_shuffled", - global_transformer_name="svd", - subsample_features=-1, - ), - PreprocessorConfig( - "none", categorical_name="numeric", subsample_features=-1 - ), - ), - feature_shift_decoder: str = "shuffle", - normalize_with_test: bool = False, - average_logits: bool = False, - optimize_metric: Literal[ - "auroc", "roc", "auroc_ovo", "balanced_acc", "acc", "log_loss", None - ] = "roc", - transformer_predict_kwargs: Optional[dict] = None, - multiclass_decoder="shuffle", - softmax_temperature: Optional[float] = -0.1, - use_poly_features=False, - max_poly_features=50, - remove_outliers=12.0, - add_fingerprint_features=True, - subsample_samples=-1, - paper_version=False, + softmax_temperature: float = 0.9, + balance_probabilities: bool = False, + average_before_softmax: bool = False, + ignore_pretraining_limits: bool = False, + inference_precision: Literal["autocast", "auto"] = "auto", + random_state: Optional[ + Union[int, np.random.RandomState, np.random.Generator] + ] = None, + inference_config: Optional[Dict] = None, + paper_version: bool = False, ): + """Initialize TabPFNClassifier. + + Parameters + ---------- + model_path: str, default="default" + The name of the model to use. + n_estimators: int, default=4 + The number of estimators in the TabPFN ensemble. We aggregate the + predictions of `n_estimators`-many forward passes of TabPFN. Each forward + pass has (slightly) different input data. Think of this as an ensemble of + `n_estimators`-many "prompts" of the input data. + softmax_temperature: float, default=0.9 + The temperature for the softmax function. This is used to control the + confidence of the model's predictions. Lower values make the model's + predictions more confident. This is only applied when predicting during a + post-processing step. Set `softmax_temperature=1.0` for no effect. + balance_probabilities: bool, default=False + Whether to balance the probabilities based on the class distribution + in the training data. This can help to improve predictive performance + when the classes are highly imbalanced. This is only applied when predicting + during a post-processing step. + average_before_softmax: bool, default=False + Only used if `n_estimators > 1`. Whether to average the predictions of the + estimators before applying the softmax function. This can help to improve + predictive performance when there are many classes or when calibrating the + model's confidence. This is only applied when predicting during a + post-processing. + ignore_pretraining_limits: bool, default=False + Whether to ignore the pre-training limits of the model. The TabPFN models + have been pre-trained on a specific range of input data. If the input data + is outside of this range, the model may not perform well. You may ignore + our limits to use the model on data outside the pre-training range. + inference_precision: "autocast" or "auto", default="auto" + The precision to use for inference. This can dramatically affect the + speed and reproducibility of the inference. + random_state: int or RandomState or RandomGenerator or None, default=None + Controls the randomness of the model. Pass an int for reproducible results. + inference_config: dict or None, default=None + Additional advanced arguments for model interface. + paper_version: bool, default=False + If True, will use the model described in the paper, instead of the newest + version available on the API, which e.g handles text features better. """ - Parameters: - model: The model string is the path to the model. - n_estimators: The number of ensemble configurations to use, the most important setting. - preprocess_transforms: A tuple of strings, specifying the preprocessing steps to use. - You can use the following strings as elements '(none|power|quantile|robust)[_all][_and_none]', where the first - part specifies the preprocessing step and the second part specifies the features to apply it to and - finally '_and_none' specifies that the original features should be added back to the features in plain. - Finally, you can combine all strings without `_all` with `_onehot` to apply one-hot encoding to the categorical - features specified with `self.fit(..., categorical_features=...)`. - feature_shift_decoder: ["shuffle", "none", "local_shuffle", "rotate", "auto_rotate"] Whether to shift features for each ensemble configuration. - normalize_with_test: If True, the test set is used to normalize the data, otherwise the training set is used only. - average_logits: Whether to average logits or probabilities for ensemble members. - optimize_metric: The optimization metric to use. - transformer_predict_kwargs: Additional keyword arguments to pass to the transformer predict method. - multiclass_decoder: The multiclass decoder to use. - softmax_temperature: A log spaced temperature, it will be applied as logits <- logits/exp(softmax_temperature). - use_poly_features: Whether to use polynomial features as the last preprocessing step. - max_poly_features: Maximum number of polynomial features to use. - remove_outliers: If not 0.0, will remove outliers from the input features, where values with a standard deviation larger than remove_outliers will be removed. - add_fingerprint_features: If True, will add one feature of random values, that will be added to the input features. This helps discern duplicated samples in the transformer model. - subsample_samples: If not None, will use a random subset of the samples for training in each ensemble configuration. If 1 or above, this will subsample to the specified number of samples. If in 0 to 1, the value is viewed as a fraction of the training set size. - paper_version: If True, will use the model described in the paper. Otherwise, will use a better model. Default is False. - """ - self.model = model + self.model_path = model_path self.n_estimators = n_estimators - self.preprocess_transforms = preprocess_transforms - self.feature_shift_decoder = feature_shift_decoder - self.normalize_with_test = normalize_with_test - self.average_logits = average_logits - self.optimize_metric = optimize_metric - self.transformer_predict_kwargs = transformer_predict_kwargs - self.multiclass_decoder = multiclass_decoder self.softmax_temperature = softmax_temperature - self.use_poly_features = use_poly_features - self.max_poly_features = max_poly_features - self.remove_outliers = remove_outliers - self.add_fingerprint_features = add_fingerprint_features - self.subsample_samples = subsample_samples + self.balance_probabilities = balance_probabilities + self.average_before_softmax = average_before_softmax + self.ignore_pretraining_limits = ignore_pretraining_limits + self.inference_precision = inference_precision + self.random_state = random_state + self.inference_config = inference_config self.paper_version = paper_version self.last_train_set_uid = None self.last_train_X = None self.last_train_y = None - def _validate_targets_and_classes(self, y) -> np.ndarray: - from sklearn.utils import column_or_1d - from sklearn.utils.multiclass import check_classification_targets - - y_ = column_or_1d(y, warn=True) - check_classification_targets(y) - - # Get classes and encode before type conversion to guarantee correct class labels. - not_nan_mask = ~np.isnan(y) - self.classes_ = np.unique(y_[not_nan_mask]) - def fit(self, X, y): # assert init() is called init() validate_data_size(X, y) - self._validate_targets_and_classes(y) _check_paper_version(self.paper_version, X) estimator_param = self.get_params() + estimator_param["model_path"] = TabPFNClassifier._model_name_to_path( + "classification", self.model_path + ) if Config.use_server: self.last_train_set_uid = InferenceClient.fit(X, y, config=estimator_param) self.last_train_X = X @@ -263,31 +149,45 @@ def fit(self, X, y): return self def predict(self, X): - probas = self.predict_proba(X) - y = np.argmax(probas, axis=1) - y = self.classes_.take(np.asarray(y, dtype=int)) - return y + """Predict class labels for samples in X. + + Args: + X: The input samples. + + Returns: + The predicted class labels. + """ + return self._predict(X, output_type="preds") def predict_proba(self, X): + """Predict class probabilities for X. + + Args: + X: The input samples. + + Returns: + The class probabilities of the input samples. + """ + return self._predict(X, output_type="probas") + + def _predict(self, X, output_type): check_is_fitted(self) validate_data_size(X) _check_paper_version(self.paper_version, X) estimator_param = self.get_params() - if "model" in estimator_param: - # replace model by model_path since in TabPFN defines model as model_path - estimator_param["model_path"] = self._model_name_to_path( - "classification", estimator_param.pop("model") - ) + estimator_param["model_path"] = TabPFNClassifier._model_name_to_path( + "classification", self.model_path + ) - return InferenceClient.predict( + res = InferenceClient.predict( X, task="classification", train_set_uid=self.last_train_set_uid, config=estimator_param, - X_train=self.last_train_X, - y_train=self.last_train_y, - )["probas"] + predict_params={"output_type": output_type}, + ) + return res class TabPFNRegressor(BaseEstimator, RegressorMixin, TabPFNModelSelection): @@ -301,110 +201,67 @@ class TabPFNRegressor(BaseEstimator, RegressorMixin, TabPFNModelSelection): def __init__( self, - model: str = "default", + model_path: str = "default", n_estimators: int = 8, - preprocess_transforms: Tuple[PreprocessorConfig, ...] = ( - PreprocessorConfig( - "quantile_uni", - append_original=True, - categorical_name="ordinal_very_common_categories_shuffled", - global_transformer_name="svd", - ), - PreprocessorConfig("safepower", categorical_name="onehot"), - ), - feature_shift_decoder: str = "shuffle", - normalize_with_test: bool = False, - average_logits: bool = False, - optimize_metric: Literal[ - "mse", "rmse", "mae", "r2", "mean", "median", "mode", "exact_match", None - ] = "rmse", - transformer_predict_kwargs: Optional[Dict] = None, - softmax_temperature: Optional[float] = -0.1, - use_poly_features=False, - max_poly_features=50, - remove_outliers=-1, - regression_y_preprocess_transforms: Optional[ - Tuple[ - Union[ - None, - Literal[ - "safepower", - "power", - "quantile_norm", - ], - ], - ..., - ] - ] = ( - None, - "safepower", - ), - add_fingerprint_features: bool = True, - cancel_nan_borders: bool = True, - super_bar_dist_averaging: bool = False, - subsample_samples: float = -1, + softmax_temperature: float = 0.9, + average_before_softmax: bool = False, + ignore_pretraining_limits: bool = False, + inference_precision: Literal["autocast", "auto"] = "auto", + random_state: Optional[ + Union[int, np.random.RandomState, np.random.Generator] + ] = None, + inference_config: Optional[Dict] = None, paper_version: bool = False, ): + """Initialize TabPFNRegressor. + + Parameters + ---------- + model_path: str, default="default" + The name to the model to use. + n_estimators: int, default=8 + The number of estimators in the TabPFN ensemble. We aggregate the + predictions of `n_estimators`-many forward passes of TabPFN. Each forward + pass has (slightly) different input data. Think of this as an ensemble of + `n_estimators`-many "prompts" of the input data. + softmax_temperature: float, default=0.9 + The temperature for the softmax function. This is used to control the + confidence of the model's predictions. Lower values make the model's + predictions more confident. This is only applied when predicting during a + post-processing step. Set `softmax_temperature=1.0` for no effect. + average_before_softmax: bool, default=False + Only used if `n_estimators > 1`. Whether to average the predictions of the + estimators before applying the softmax function. This can help to improve + predictive performance when calibrating the model's confidence. This is only + applied when predicting during a post-processing step. + ignore_pretraining_limits: bool, default=False + Whether to ignore the pre-training limits of the model. The TabPFN models + have been pre-trained on a specific range of input data. If the input data + is outside of this range, the model may not perform well. You may ignore + our limits to use the model on data outside the pre-training range. + inference_precision: "autocast" or "auto", default="auto" + The precision to use for inference. This can dramatically affect the + speed and reproducibility of the inference. + random_state: int or RandomState or RandomGenerator or None, default=None + Controls the randomness of the model. Pass an int for reproducible results. + inference_config: dict or None, default=None + Additional advanced arguments for model interface. + paper_version: bool, default=False + If True, will use the model described in the paper, instead of the newest + version available on the API, which e.g handles text features better. """ - Parameters: - model: The model string is the path to the model. - n_estimators: The number of ensemble configurations to use, the most important setting. - preprocess_transforms: A tuple of strings, specifying the preprocessing steps to use. - You can use the following strings as elements '(none|power|quantile_norm|quantile_uni|quantile_uni_coarse|robust...)[_all][_and_none]', where the first - part specifies the preprocessing step (see `.preprocessing.ReshapeFeatureDistributionsStep.get_all_preprocessors()`) and the second part specifies the features to apply it to and - finally '_and_none' specifies that the original features should be added back to the features in plain. - Finally, you can combine all strings without `_all` with `_onehot` to apply one-hot encoding to the categorical - features specified with `self.fit(..., categorical_features=...)`. - feature_shift_decoder: ["shuffle", "none", "local_shuffle", "rotate", "auto_rotate"] Whether to shift features for each ensemble configuration. - normalize_with_test: If True, the test set is used to normalize the data, otherwise the training set is used only. - average_logits: Whether to average logits or probabilities for ensemble members. - optimize_metric: The optimization metric to use. - transformer_predict_kwargs: Additional keyword arguments to pass to the transformer predict method. - softmax_temperature: A log spaced temperature, it will be applied as logits <- logits/exp(softmax_temperature). - use_poly_features: Whether to use polynomial features as the last preprocessing step. - max_poly_features: Maximum number of polynomial features to use, None means unlimited. - remove_outliers: If not 0.0, will remove outliers from the input features, where values with a standard deviation - larger than remove_outliers will be removed. - regression_y_preprocess_transforms: Preprocessing transforms for the target variable. This can be one from `.preprocessing.ReshapeFeatureDistributionsStep.get_all_preprocessors()`, e.g. "power". - This can also be None to not transform the targets, beside a simple mean/variance normalization. - add_fingerprint_features: If True, will add one feature of random values, that will be added to - the input features. This helps discern duplicated samples in the transformer model. - cancel_nan_borders: Whether to ignore buckets that are tranformed to nan values by inverting a `regression_y_preprocess_transform`. - This should be set to True, only set this to False if you know what you are doing. - super_bar_dist_averaging: If we use `regression_y_preprocess_transforms` we need to average the predictions over the different configurations. - The different configurations all come with different bar_distributions (Riemann distributions), though. - The default is for us to aggregate all bar distributions using simply scaled borders in the bar distribution, scaled by the mean and std of the target variable. - If you set this to True, a new bar distribution will be built using all the borders generated in the different configurations. - subsample_samples: If not None, will use a random subset of the samples for training in each ensemble configuration. - If 1 or above, this will subsample to the specified number of samples. - If in 0 to 1, the value is viewed as a fraction of the training set size. - paper_version: If True, will use the model described in the paper. Otherwise, will use a better model. Default is False. - """ - - if model not in self._AVAILABLE_MODELS: - raise ValueError(f"Invalid model name: {model}") - - self.model = model + self.model_path = model_path self.n_estimators = n_estimators - self.preprocess_transforms = preprocess_transforms - self.feature_shift_decoder = feature_shift_decoder - self.normalize_with_test = normalize_with_test - self.average_logits = average_logits - self.optimize_metric = optimize_metric - self.transformer_predict_kwargs = transformer_predict_kwargs self.softmax_temperature = softmax_temperature - self.use_poly_features = use_poly_features - self.max_poly_features = max_poly_features - self.remove_outliers = remove_outliers - self.regression_y_preprocess_transforms = regression_y_preprocess_transforms - self.add_fingerprint_features = add_fingerprint_features - self.cancel_nan_borders = cancel_nan_borders - self.super_bar_dist_averaging = super_bar_dist_averaging - self.subsample_samples = subsample_samples + self.average_before_softmax = average_before_softmax + self.ignore_pretraining_limits = ignore_pretraining_limits + self.inference_precision = inference_precision + self.random_state = random_state + self.inference_config = inference_config + self.paper_version = paper_version self.last_train_set_uid = None self.last_train_X = None self.last_train_y = None - self.paper_version = paper_version def fit(self, X, y): # assert init() is called @@ -414,6 +271,9 @@ def fit(self, X, y): _check_paper_version(self.paper_version, X) estimator_param = self.get_params() + estimator_param["model_path"] = TabPFNRegressor._model_name_to_path( + "regression", self.model_path + ) if Config.use_server: self.last_train_set_uid = InferenceClient.fit(X, y, config=estimator_param) self.last_train_X = X @@ -423,36 +283,59 @@ def fit(self, X, y): raise NotImplementedError( "Only server mode is supported at the moment for init(use_server=False)" ) - return self - - def predict(self, X): - full_prediction_dict = self.predict_full(X) - if self.optimize_metric in ("mse", "rmse", "r2", "mean", None): - return full_prediction_dict["mean"] - elif self.optimize_metric in ("mae", "median"): - return full_prediction_dict["median"] - elif self.optimize_metric in ("mode", "exact_match"): - return full_prediction_dict["mode"] - else: - raise ValueError(f"Optimize metric {self.optimize_metric} not supported") - def predict_full(self, X): + def predict( + self, + X: np.ndarray, + output_type: Literal[ + "mean", "median", "mode", "quantiles", "full", "main" + ] = "mean", + quantiles: Optional[list[float]] = None, + ) -> Union[np.ndarray, list[np.ndarray], dict[str, np.ndarray]]: + """Predict regression target for X. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The input samples. + output_type : str, default="mean" + The type of prediction to return: + - "mean": Return mean prediction + - "median": Return median prediction + - "mode": Return mode prediction + - "quantiles": Return predictions for specified quantiles + - "full": Return full prediction details + - "main": Return main prediction metrics + quantiles : list[float] or None, default=None + Quantiles to compute when output_type="quantiles". + Default is [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + + Returns + ------- + array-like or dict + The predicted values. + """ check_is_fitted(self) validate_data_size(X) _check_paper_version(self.paper_version, X) + # Add new parameters + predict_params = { + "output_type": output_type, + "quantiles": quantiles, + } + estimator_param = self.get_params() - if "model" in estimator_param: - # replace model by model_path since in TabPFN defines model as model_path - estimator_param["model_path"] = self._model_name_to_path( - "regression", estimator_param.pop("model") - ) + estimator_param["model_path"] = TabPFNRegressor._model_name_to_path( + "regression", self.model_path + ) return InferenceClient.predict( X, task="regression", train_set_uid=self.last_train_set_uid, config=estimator_param, + predict_params=predict_params, X_train=self.last_train_X, y_train=self.last_train_y, ) @@ -480,12 +363,4 @@ def validate_data_size(X: np.ndarray, y: Union[np.ndarray, None] = None): def _check_paper_version(paper_version, X): - if paper_version: - # check if X can be converted to numerical values - try: - np.array(X, dtype=np.float32) - except ValueError: - raise ValueError( - """X must be numerical to use the paper version of the model. - Preprocess your data or use `paper_version=False`.""" - ) + pass diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index 96661e5..74311c3 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -228,6 +228,7 @@ def predict( task: Literal["classification", "regression"], train_set_uid: str, config=None, + predict_params=None, X_train=None, y_train=None, ): @@ -235,6 +236,7 @@ def predict( train_set_uid=train_set_uid, x_test=X, tabpfn_config=config, + predict_params=predict_params, task=task, X_train=X_train, y_train=y_train, diff --git a/tabpfn_client/tests/quick_test.py b/tabpfn_client/tests/quick_test.py new file mode 100644 index 0000000..4f5cc44 --- /dev/null +++ b/tabpfn_client/tests/quick_test.py @@ -0,0 +1,58 @@ +""" +TabPFN Client Example Usage +-------------------------- +Toy script to check that the TabPFN client is working. +Use the breast cancer dataset for classification and the diabetes dataset for regression, +and try various prediction types. +""" + +import logging +from unittest.mock import patch + +from sklearn.datasets import load_breast_cancer, load_diabetes +from sklearn.model_selection import train_test_split + +from tabpfn_client import UserDataClient +from tabpfn_client.estimator import TabPFNClassifier, TabPFNRegressor + +logging.basicConfig(level=logging.DEBUG) + + +if __name__ == "__main__": + # Patch webbrowser.open to prevent browser login + with patch("webbrowser.open", return_value=False): + use_server = True + # use_server = False + + X, y = load_breast_cancer(return_X_y=True) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.33, random_state=42 + ) + + tabpfn = TabPFNClassifier(n_estimators=3) + # print("checking estimator", check_estimator(tabpfn)) + tabpfn.fit(X_train[:99], y_train[:99]) + print("predicting") + print(tabpfn.predict(X_test)) + print("predicting_proba") + print(tabpfn.predict_proba(X_test)) + + print(UserDataClient.get_data_summary()) + + X, y = load_diabetes(return_X_y=True) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.33, random_state=42 + ) + + tabpfn = TabPFNRegressor(n_estimators=3) + # print("checking estimator", check_estimator(tabpfn)) + tabpfn.fit(X_train[:99], y_train[:99]) + print("predicting reg") + print(tabpfn.predict(X_test, output_type="mean")) + + print(UserDataClient.get_data_summary()) + # test predict_full + print("predicting ") + print( + tabpfn.predict(X_test[:30], output_type="full", quantiles=[0.1, 0.5, 0.9]) + ) diff --git a/tabpfn_client/tests/unit/test_client.py b/tabpfn_client/tests/unit/test_client.py index 740bc72..35911cf 100644 --- a/tabpfn_client/tests/unit/test_client.py +++ b/tabpfn_client/tests/unit/test_client.py @@ -211,7 +211,7 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server): x_test=self.X_test, task="classification", ) - self.assertTrue(np.array_equal(pred["probas"], dummy_result["classification"])) + self.assertTrue(np.array_equal(pred, dummy_result["classification"])) def test_validate_response_no_error(self): response = Mock() @@ -343,7 +343,7 @@ def side_effect(*args, **kwargs): ) # The predictions should be the same - self.assertTrue(np.array_equal(pred1["probas"], pred2["probas"])) + self.assertTrue(np.array_equal(pred1, pred2)) # The predict endpoint should have been called twice self.assertEqual( @@ -436,7 +436,7 @@ def side_effect_counter(*args, **kwargs): ) # The predictions should be as expected - self.assertTrue(np.array_equal(pred["probas"], [1, 2, 3])) + self.assertTrue(np.array_equal(pred, [1, 2, 3])) # The predict endpoint should have been called twice due to retry self.assertEqual( diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py index c530b6d..8dfcdb0 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py @@ -3,8 +3,6 @@ import shutil import json -import pandas as pd - import numpy as np from sklearn.datasets import load_breast_cancer @@ -57,7 +55,7 @@ def test_init_remote_classifier( mock_server.endpoints.retrieve_greeting_messages.path ).respond(200, json={"messages": []}) - mock_predict_response = [[1, 0.0], [0.9, 0.1], [0.01, 0.99]] + mock_predict_response = [1, 0, 1] predict_route = mock_server.router.post(mock_server.endpoints.predict.path) predict_route.respond( 200, @@ -73,7 +71,7 @@ def test_init_remote_classifier( tabpfn.fit(self.X_train, self.y_train) self.assertTrue(mock_prompt_and_set_token.called) y_pred = tabpfn.predict(self.X_test) - self.assertTrue(np.all(np.argmax(mock_predict_response, axis=1) == y_pred)) + self.assertTrue(np.all(mock_predict_response == y_pred)) self.assertIn( "n_estimators%22%3A%2010", @@ -358,6 +356,84 @@ def test_data_check_on_predict_with_valid_data_pass(self): mock_predict.return_value = {"probas": np.random.rand(10, 2)} tabpfn.predict(test_X) + def test_only_allowed_parameters_passed_to_config(self): + """Test that only allowed parameters are passed to the config.""" + ALLOWED_PARAMS = { + "n_estimators", + # TODO: put it back + # "categorical_features_indices", + "softmax_temperature", + "average_before_softmax", + "ignore_pretraining_limits", + "inference_precision", + "random_state", + "inference_config", + "model_path", + "balance_probabilities", + "paper_version", + } + + # Create classifier with various parameters + classifier = TabPFNClassifier( + n_estimators=5, + softmax_temperature=0.8, + paper_version=True, + random_state=42, + balance_probabilities=True, + ) + + # Skip fitting + classifier.fitted_ = True + classifier.last_train_set_uid = "dummy_uid" + + test_X = np.random.randn(10, 5) + + # Mock predict and capture config + with patch.object(InferenceClient, "predict") as mock_predict: + mock_predict.return_value = np.random.rand(10, 2) + classifier.predict(test_X) + + # Get the config that was passed to predict + actual_config = mock_predict.call_args[1]["config"] + + # Check that only allowed parameters are present + config_params = set(actual_config.keys()) + unexpected_params = config_params - ALLOWED_PARAMS + missing_params = ALLOWED_PARAMS - config_params + + self.assertEqual( + unexpected_params, + set(), + f"Found unexpected parameters in config: {unexpected_params}", + ) + self.assertEqual( + missing_params, + set(), + f"Missing required parameters in config: {missing_params}", + ) + + def test_predict_params_output_type(self): + """Test that predict_params contains correct output_type.""" + classifier = TabPFNClassifier() + classifier.fitted_ = True # Skip fitting + test_X = np.random.randn(10, 5) + + # Test predict() sets output_type to "preds" + with patch.object(InferenceClient, "predict") as mock_predict: + mock_predict.return_value = np.random.rand(10) + classifier.predict(test_X) + + predict_params = mock_predict.call_args[1]["predict_params"] + self.assertEqual(predict_params, {"output_type": "preds"}) + + # Test predict_proba() sets output_type to "probas" + with patch.object(InferenceClient, "predict") as mock_predict: + mock_predict.return_value = np.random.rand(10, 2) + classifier.predict_proba(test_X) + + predict_params = mock_predict.call_args[1]["predict_params"] + self.assertEqual(predict_params, {"output_type": "probas"}) + class TestTabPFNModelSelection(unittest.TestCase): def setUp(self): @@ -390,17 +466,15 @@ def test_validate_model_name_with_invalid_model_raises_error(self): TabPFNClassifier._validate_model_name("invalid_model") def test_model_name_to_path_returns_expected_path(self): - base_path = TabPFNClassifier._BASE_PATH - # Test default model path - expected_default_path = f"{base_path}_classification.ckpt" + expected_default_path = "tabpfn-v2-classifier.ckpt" self.assertEqual( TabPFNClassifier._model_name_to_path("classification", "default"), expected_default_path, ) # Test specific model path - expected_specific_path = f"{base_path}_classification_gn2p4bpt.ckpt" + expected_specific_path = "tabpfn-v2-classifier-gn2p4bpt.ckpt" self.assertEqual( TabPFNClassifier._model_name_to_path("classification", "gn2p4bpt"), expected_specific_path, @@ -415,7 +489,7 @@ def test_predict_proba_uses_correct_model_path(self): X = np.random.rand(10, 5) y = np.random.randint(0, 2, 10) - tabpfn = TabPFNClassifier(model="gn2p4bpt") + tabpfn = TabPFNClassifier(model_path="gn2p4bpt") # Mock the inference client with patch.object(InferenceClient, "predict") as mock_predict: @@ -430,9 +504,7 @@ def test_predict_proba_uses_correct_model_path(self): # Verify the model path was correctly passed to predict predict_kwargs = mock_predict.call_args[1] - expected_model_path = ( - f"{TabPFNClassifier._BASE_PATH}_classification_gn2p4bpt.ckpt" - ) + expected_model_path = "tabpfn-v2-classifier-gn2p4bpt.ckpt" self.assertEqual( predict_kwargs["config"]["model_path"], expected_model_path @@ -461,37 +533,3 @@ def test_paper_version_behavior(self, mock_predict, mock_fit): tabpfn_false.fit(X, y) y_pred_false = tabpfn_false.predict(test_X) self.assertIsNotNone(y_pred_false) - - @patch.object(InferenceClient, "fit", return_value="dummy_uid") - @patch.object( - InferenceClient, "predict", return_value={"probas": np.random.rand(10, 2)} - ) - def test_check_paper_version_with_non_numerical_data_raises_error( - self, mock_predict, mock_fit - ): - # Create a TabPFNClassifier with paper_version=True - tabpfn = TabPFNClassifier(paper_version=True) - - # Create non-numerical data - X = pd.DataFrame({"feature1": ["a", "b", "c"], "feature2": ["d", "e", "f"]}) - y = np.array([0, 1, 0]) - - with self.assertRaises(ValueError) as context: - tabpfn.fit(X, y) - - self.assertIn( - "X must be numerical to use the paper version of the model", - str(context.exception), - ) - - # check that it works with paper_version=False - tabpfn = TabPFNClassifier(paper_version=False) - tabpfn.fit(X, y) - - # check that paper_version=True works with numerical data even with the wrong type - X = np.random.rand(10, 5).astype(str) - y = np.random.randint(0, 2, 10) - tabpfn = TabPFNClassifier(paper_version=True) - tabpfn.fit(X, y) - X = pd.DataFrame(X).astype(str) - tabpfn.predict(X) diff --git a/tabpfn_client/tests/unit/test_tabpfn_regressor.py b/tabpfn_client/tests/unit/test_tabpfn_regressor.py index 378dfb9..ea3a610 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_regressor.py +++ b/tabpfn_client/tests/unit/test_tabpfn_regressor.py @@ -14,7 +14,6 @@ from tabpfn_client.constants import CACHE_DIR from tabpfn_client import config import json -import pandas as pd class TestTabPFNRegressorInit(unittest.TestCase): @@ -54,30 +53,29 @@ def test_init_remote_regressor( mock_server.endpoints.retrieve_greeting_messages.path ).respond(200, json={"messages": []}) - mock_predict_response = { + mock_predict_responses = { "mean": [100, 200, 300], "median": [110, 210, 310], "mode": [120, 220, 320], } - predict_route = mock_server.router.post(mock_server.endpoints.predict.path) - predict_route.respond( - 200, - content=f'data: {json.dumps({"event": "result", "data": {"regression": mock_predict_response, "test_set_uid": "6"}})}\n\n', - headers={"Content-Type": "text/event-stream"}, - ) + for metric, response in mock_predict_responses.items(): + predict_route = mock_server.router.post(mock_server.endpoints.predict.path) + predict_route.respond( + 200, + content=f'data: {json.dumps({"event": "result", "data": {"regression": response, "test_set_uid": "6"}})}\n\n', + headers={"Content-Type": "text/event-stream"}, + ) - init(use_server=True) - self.assertTrue(mock_prompt_and_set_token.called) + init(use_server=True) + self.assertTrue(mock_prompt_and_set_token.called) tabpfn = TabPFNRegressor(n_estimators=10) self.assertRaises(NotFittedError, tabpfn.predict, self.X_test) tabpfn.fit(self.X_train, self.y_train) self.assertTrue(mock_prompt_and_set_token.called) - for metric in ["mean", "median", "mode"]: - tabpfn.optimize_metric = metric - y_pred = tabpfn.predict(self.X_test) - self.assertTrue(np.all(np.array(mock_predict_response[metric]) == y_pred)) + y_pred = tabpfn.predict(self.X_test, output_type=metric) + self.assertTrue(np.all(np.array(response) == y_pred)) self.assertIn( "n_estimators%22%3A%2010", @@ -85,17 +83,6 @@ def test_init_remote_regressor( "check that n_estimators is passed to the server", ) - def test_valid_model_config(self): - # Test with valid model configuration - model_name = TabPFNRegressor.list_available_models()[0] - valid_config = TabPFNRegressor(model=model_name) - self.assertEqual(valid_config.model, model_name) - - def test_invalid_model_config(self): - # Test with invalid model configuration - with self.assertRaises(ValueError): - TabPFNRegressor(model="invalid_model_name") - @with_mock_server() def test_reuse_saved_access_token(self, mock_server): # mock connection and authentication @@ -371,6 +358,85 @@ def test_data_check_on_predict_with_valid_data_pass(self): mock_predict.return_value = {"mean": np.random.randn(10)} tabpfn.predict(test_X) + def test_only_allowed_parameters_passed_to_config(self): + """Test that only allowed parameters are passed to the config.""" + ALLOWED_PARAMS = { + "n_estimators", + # TODO: put it back + # "categorical_features_indices", + "softmax_temperature", + "average_before_softmax", + "ignore_pretraining_limits", + "inference_precision", + "random_state", + "inference_config", + "model_path", + "paper_version", + } + + # Create regressor with various parameters + regressor = TabPFNRegressor( + n_estimators=8, + softmax_temperature=0.9, + paper_version=True, + random_state=42, + ) + + # Skip fitting + regressor.fitted_ = True + regressor.last_train_set_uid = "dummy_uid" + + test_X = np.random.randn(10, 5) + + # Mock predict and capture config + with patch.object(InferenceClient, "predict") as mock_predict: + mock_predict.return_value = {"mean": np.random.randn(10)} + regressor.predict(test_X) + + # Get the config that was passed to predict + actual_config = mock_predict.call_args[1]["config"] + + # Check that only allowed parameters are present + config_params = set(actual_config.keys()) + unexpected_params = config_params - ALLOWED_PARAMS + missing_params = ALLOWED_PARAMS - config_params + + self.assertEqual( + unexpected_params, + set(), + f"Found unexpected parameters in config: {unexpected_params}", + ) + self.assertEqual( + missing_params, + set(), + f"Missing required parameters in config: {missing_params}", + ) + + def test_predict_params_output_type(self): + """Test that predict_params contains correct output_type and quantiles.""" + regressor = TabPFNRegressor() + regressor.fitted_ = True # Skip fitting + test_X = np.random.randn(10, 5) + + # Test default predict() sets output_type to "mean" + with patch.object(InferenceClient, "predict") as mock_predict: + mock_predict.return_value = {"mean": np.random.randn(10)} + regressor.predict(test_X) + + predict_params = mock_predict.call_args[1]["predict_params"] + self.assertEqual(predict_params, {"output_type": "mean", "quantiles": None}) + + # Test predict() with quantiles + with patch.object(InferenceClient, "predict") as mock_predict: + mock_predict.return_value = {"quantiles": np.random.randn(10, 3)} + quantiles = [0.1, 0.5, 0.9] + regressor.predict(test_X, output_type="quantiles", quantiles=quantiles) + + predict_params = mock_predict.call_args[1]["predict_params"] + self.assertEqual( + predict_params, {"output_type": "quantiles", "quantiles": quantiles} + ) + class TestTabPFNModelSelection(unittest.TestCase): def setUp(self): @@ -396,17 +462,15 @@ def test_validate_model_name_with_invalid_model_raises_error(self): TabPFNRegressor._validate_model_name("invalid_model") def test_model_name_to_path_returns_expected_path(self): - base_path = TabPFNRegressor._BASE_PATH - # Test default model path - expected_default_path = f"{base_path}_regression.ckpt" + expected_default_path = "tabpfn-v2-regressor.ckpt" self.assertEqual( TabPFNRegressor._model_name_to_path("regression", "default"), expected_default_path, ) # Test specific model path - expected_specific_path = f"{base_path}_regression_2noar4o2.ckpt" + expected_specific_path = "tabpfn-v2-regressor-2noar4o2.ckpt" self.assertEqual( TabPFNRegressor._model_name_to_path("regression", "2noar4o2"), expected_specific_path, @@ -417,16 +481,19 @@ def test_model_name_to_path_with_invalid_model_raises_error(self): TabPFNRegressor._model_name_to_path("regression", "invalid_model") def test_predict_uses_correct_model_path(self): + # First verify available models are as expected + expected_models = ["default", "2noar4o2", "5wof9ojf", "09gpqh39", "wyl4o83o"] + self.assertEqual(TabPFNRegressor._AVAILABLE_MODELS, expected_models) + # Setup X = np.random.rand(10, 5) y = np.random.rand(10) - tabpfn = TabPFNRegressor(model="2noar4o2") + tabpfn = TabPFNRegressor(model_path="2noar4o2") # Mock the inference client with patch.object(InferenceClient, "predict") as mock_predict: mock_predict.return_value = {"mean": np.random.rand(10)} - with patch.object(InferenceClient, "fit") as mock_fit: mock_fit.return_value = "dummy_uid" @@ -436,9 +503,7 @@ def test_predict_uses_correct_model_path(self): # Verify the model path was correctly passed to predict predict_kwargs = mock_predict.call_args[1] - expected_model_path = ( - f"{TabPFNRegressor._BASE_PATH}_regression_2noar4o2.ckpt" - ) + expected_model_path = "tabpfn-v2-regressor-2noar4o2.ckpt" self.assertEqual( predict_kwargs["config"]["model_path"], expected_model_path @@ -465,35 +530,3 @@ def test_paper_version_behavior(self, mock_predict, mock_fit): tabpfn_false.fit(X, y) y_pred_false = tabpfn_false.predict(test_X) self.assertIsNotNone(y_pred_false) - - @patch.object(InferenceClient, "fit", return_value="dummy_uid") - @patch.object(InferenceClient, "predict", return_value={"mean": np.random.rand(10)}) - def test_check_paper_version_with_non_numerical_data_raises_error( - self, mock_predict, mock_fit - ): - # Create a TabPFNRegressor with paper_version=True - tabpfn = TabPFNRegressor(paper_version=True) - - # Create non-numerical data - X = pd.DataFrame({"feature1": ["a", "b", "c"], "feature2": ["d", "e", "f"]}) - y = np.array([0.1, 0.2, 0.3]) - - with self.assertRaises(ValueError) as context: - tabpfn.fit(X, y) - - self.assertIn( - "X must be numerical to use the paper version of the model", - str(context.exception), - ) - - # check that it works with paper_version=False - tabpfn = TabPFNRegressor(paper_version=False) - tabpfn.fit(X, y) - - # check that paper_version=True works with numerical data even with the wrong type - X = np.random.rand(10, 5).astype(str) - y = np.random.rand(10) # Continuous target for regression - tabpfn = TabPFNRegressor(paper_version=True) - tabpfn.fit(X, y) - X = pd.DataFrame(X).astype(str) - tabpfn.predict(X)