Skip to content

Commit

Permalink
Update test
Browse files Browse the repository at this point in the history
  • Loading branch information
liam-sbhoo committed Dec 3, 2023
1 parent ec002d3 commit d0a567b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions tabpfn_client/tests/integration/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_use_remote_tabpfn_classifier(self, mock_server):
# mock prediction
mock_server.router.post(mock_server.endpoints.predict.path).respond(
200,
json={"y_pred": LocalTabPFNClassifier().fit(self.X_train, self.y_train).predict(self.X_test).tolist()}
json={"res": LocalTabPFNClassifier().fit(self.X_train, self.y_train).predict(self.X_test).tolist()}
)
pred = tabpfn.predict(self.X_test)
self.assertEqual(pred.shape[0], self.X_test.shape[0])
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_use_remote_tabpfn_classifier_with_explicit_tabpfn_config(self, mock_ser
dummy_pred = "content doesn't matter"
mock_server.router.post(mock_server.endpoints.predict.path).respond(
200,
json={"y_pred": dummy_pred}
json={"res": dummy_pred}
)
pred = tabpfn.predict(self.X_test)

Expand Down
4 changes: 2 additions & 2 deletions tabpfn_client/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server):

self.client.upload_train_set(self.X_train, self.y_train)

dummy_result = {"y_pred": [1, 2, 3]}
dummy_result = {"res": [1, 2, 3]}
mock_server.router.post(mock_server.endpoints.predict.path).respond(
200, json=dummy_result)

pred = self.client.predict(
train_set_uid=dummy_json["train_set_uid"],
x_test=self.X_test
)
self.assertTrue(np.array_equal(pred, dummy_result["y_pred"]))
self.assertTrue(np.array_equal(pred, dummy_result["res"]))
24 changes: 12 additions & 12 deletions tabpfn_client/tests/unit/test_service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,45 +290,45 @@ def test_predict_with_explicit_tabpfn_config(self, mock_server):
# mock predict response
mock_server.router.post(mock_server.endpoints.predict.path).respond(
200,
json={"y_pred": [0]}
json={"res": [0]}
)

# assert no exception is raised
tabpfn_config = {"dummy_config": "dummy_value"}
self.inference_client.config_tabpfn(**tabpfn_config)
self.assertEqual([0], self.inference_client.predict(self.dummy_x))
self.assertEqual([0], self.inference_client.predict(
self.dummy_x, with_proba=False, tabpfn_config=tabpfn_config)
)

@with_mock_server()
def test_predict_with_default_tabpfn_config(self, mock_server):
# mock predict response
mock_server.router.post(mock_server.endpoints.predict.path).respond(
200,
json={"y_pred": [0]}
json={"res": [0]}
)

# assert no exception is raised
self.assertEqual([0],self.inference_client.predict(self.dummy_x))
self.assertEqual([0], self.inference_client.predict(self.dummy_x, with_proba=False))

@with_mock_server()
def test_predict_proba_with_explicit_tabpfn_config(self, mock_server):
# mock predict_proba response
mock_server.router.post(mock_server.endpoints.predict_proba.path).respond(
mock_server.router.post(mock_server.endpoints.predict.path).respond(
200,
json={"y_pred_proba": [[0]]}
json={"res": [[0]]}
)

# assert no exception is raised
tabpfn_config = {"dummy_config": "dummy_value"}
self.inference_client.config_tabpfn(**tabpfn_config)
self.assertEqual([[0]], self.inference_client.predict_proba(self.dummy_x))
self.assertEqual([[0]], self.inference_client.predict(self.dummy_x, tabpfn_config=tabpfn_config))

@with_mock_server()
def test_predict_proba_with_default_tabpfn_config(self, mock_server):
# mock predict_proba response
mock_server.router.post(mock_server.endpoints.predict_proba.path).respond(
mock_server.router.post(mock_server.endpoints.predict.path).respond(
200,
json={"y_pred_proba": [[0]]}
json={"res": [[0]]}
)

# assert no exception is raised
self.assertEqual([[0]], self.inference_client.predict_proba(self.dummy_x))
self.assertEqual([[0]], self.inference_client.predict(self.dummy_x))

0 comments on commit d0a567b

Please sign in to comment.