Skip to content

Commit

Permalink
Fix: Client
Browse files Browse the repository at this point in the history
  • Loading branch information
Anshul Gupta committed Jan 25, 2024
1 parent c39cff1 commit 3f20466
Showing 1 changed file with 67 additions and 67 deletions.
134 changes: 67 additions & 67 deletions tabpfn_client/tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class TabPFNConfig:
g_tabpfn_config = TabPFNConfig()


# def init(use_server=True):
def init():
use_server = True
def init(use_server=True):
# def init():
use_server = use_server
global g_tabpfn_config

if use_server:
Expand Down Expand Up @@ -76,50 +76,50 @@ def reset():


class TabPFNClassifier(BaseEstimator, ClassifierMixin):
def __init__(self):
# Configuration for TabPFNClassifier is still under development.
pass

# def __init__(
# self,
# model=None,
# device="cpu",
# base_path=Path(__file__).parent.parent.resolve(),
# model_string="",
# batch_size_inference=4,
# fp16_inference=False,
# inference_mode=True,
# c=None,
# N_ensemble_configurations=10,
# preprocess_transforms=("none", "power_all"),
# feature_shift_decoder=False,
# normalize_with_test=False,
# average_logits=False,
# categorical_features=tuple(),
# optimize_metric=None,
# seed=None,
# transformer_predict_kwargs_init=None,
# multiclass_decoder="permutation",
# ):
# # config for tabpfn
# self.model = model
# self.device = device
# self.base_path = base_path
# self.model_string = model_string
# self.batch_size_inference = batch_size_inference
# self.fp16_inference = fp16_inference
# self.inference_mode = inference_mode
# self.c = c
# self.N_ensemble_configurations = N_ensemble_configurations
# self.preprocess_transforms = preprocess_transforms
# self.feature_shift_decoder = feature_shift_decoder
# self.normalize_with_test = normalize_with_test
# self.average_logits = average_logits
# self.categorical_features = categorical_features
# self.optimize_metric = optimize_metric
# self.seed = seed
# self.transformer_predict_kwargs_init = transformer_predict_kwargs_init
# self.multiclass_decoder = multiclass_decoder
# def __init__(self):
# # Configuration for TabPFNClassifier is still under development.
# pass

def __init__(
self,
model=None,
device="cpu",
base_path=Path(__file__).parent.parent.resolve(),
model_string="",
batch_size_inference=4,
fp16_inference=False,
inference_mode=True,
c=None,
N_ensemble_configurations=10,
preprocess_transforms=("none", "power_all"),
feature_shift_decoder=False,
normalize_with_test=False,
average_logits=False,
categorical_features=tuple(),
optimize_metric=None,
seed=None,
transformer_predict_kwargs_init=None,
multiclass_decoder="permutation",
):
# config for tabpfn
self.model = model
self.device = device
self.base_path = base_path
self.model_string = model_string
self.batch_size_inference = batch_size_inference
self.fp16_inference = fp16_inference
self.inference_mode = inference_mode
self.c = c
self.N_ensemble_configurations = N_ensemble_configurations
self.preprocess_transforms = preprocess_transforms
self.feature_shift_decoder = feature_shift_decoder
self.normalize_with_test = normalize_with_test
self.average_logits = average_logits
self.categorical_features = categorical_features
self.optimize_metric = optimize_metric
self.seed = seed
self.transformer_predict_kwargs_init = transformer_predict_kwargs_init
self.multiclass_decoder = multiclass_decoder

def fit(self, X, y):
# assert init() is called
Expand All @@ -130,26 +130,26 @@ def fit(self, X, y):
if not hasattr(self, "classifier"):
# arguments that are commented out are not used at the moment
# (not supported until new TabPFN interface is released)
# classifier_cfg = {
# # "model": self.model,
# "device": self.device,
# "base_path": self.base_path,
# "model_string": self.model_string,
# "batch_size_inference": self.batch_size_inference,
# # "fp16_inference": self.fp16_inference,
# # "inference_mode": self.inference_mode,
# # "c": self.c,
# "N_ensemble_configurations": self.N_ensemble_configurations,
# # "preprocess_transforms": self.preprocess_transforms,
# "feature_shift_decoder": self.feature_shift_decoder,
# # "normalize_with_test": self.normalize_with_test,
# # "average_logits": self.average_logits,
# # "categorical_features": self.categorical_features,
# # "optimize_metric": self.optimize_metric,
# "seed": self.seed,
# # "transformer_predict_kwargs_init": self.transformer_predict_kwargs_init,
# "multiclass_decoder": self.multiclass_decoder
# }
classifier_cfg = {
# "model": self.model,
"device": self.device,
"base_path": self.base_path,
"model_string": self.model_string,
"batch_size_inference": self.batch_size_inference,
# "fp16_inference": self.fp16_inference,
# "inference_mode": self.inference_mode,
# "c": self.c,
"N_ensemble_configurations": self.N_ensemble_configurations,
# "preprocess_transforms": self.preprocess_transforms,
"feature_shift_decoder": self.feature_shift_decoder,
# "normalize_with_test": self.normalize_with_test,
# "average_logits": self.average_logits,
# "categorical_features": self.categorical_features,
# "optimize_metric": self.optimize_metric,
"seed": self.seed,
# "transformer_predict_kwargs_init": self.transformer_predict_kwargs_init,
"multiclass_decoder": self.multiclass_decoder
}
classifier_cfg = {}

if g_tabpfn_config.use_server:
Expand Down

0 comments on commit 3f20466

Please sign in to comment.