Skip to content

Commit

Permalink
Update test_engine.py
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS authored Dec 22, 2024
1 parent 60b0155 commit bf37d78
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3909,12 +3909,14 @@ def test_predict_regression_output_shape():
# 1-round model
bst = lgb.train(params, dtrain, num_boost_round=1)
assert bst.predict(X).shape == (n_samples,)
assert bst.predict(X, raw_score=True).shape == (n_samples,)
assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_features + 1)
assert bst.predict(X, pred_leaf=True).shape == (n_samples, 1)

# 2-round model
bst = lgb.train(params, dtrain, num_boost_round=2)
assert bst.predict(X).shape == (n_samples,)
assert bst.predict(X, raw_score=True).shape == (n_samples,)
assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_features + 1)
assert bst.predict(X, pred_leaf=True).shape == (n_samples, 2)

Expand All @@ -3929,12 +3931,14 @@ def test_predict_binary_classification_output_shape():
# 1-round model
bst = lgb.train(params, dtrain, num_boost_round=1)
assert bst.predict(X).shape == (n_samples,)
assert bst.predict(X, raw_score=True).shape == (n_samples,)
assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_features + 1)
assert bst.predict(X, pred_leaf=True).shape == (n_samples, 1)

# 2-round model
bst = lgb.train(params, dtrain, num_boost_round=2)
assert bst.predict(X).shape == (n_samples,)
assert bst.predict(X, raw_score=True).shape == (n_samples,)
assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_features + 1)
assert bst.predict(X, pred_leaf=True).shape == (n_samples, 2)

Expand All @@ -3950,12 +3954,14 @@ def test_predict_multiclass_classification_output_shape():
# 1-round model
bst = lgb.train(params, dtrain, num_boost_round=1)
assert bst.predict(X).shape == (n_samples, n_classes)
assert bst.predict(X, raw_score=True).shape == (n_samples, n_classes)
assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_classes * (n_features + 1))
assert bst.predict(X, pred_leaf=True).shape == (n_samples, n_classes)

# 2-round model
bst = lgb.train(params, dtrain, num_boost_round=2)
assert bst.predict(X).shape == (n_samples, n_classes)
assert bst.predict(X, raw_score=True).shape == (n_samples, n_classes)
assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_classes * (n_features + 1))
assert bst.predict(X, pred_leaf=True).shape == (n_samples, n_classes * 2)

Expand Down

0 comments on commit bf37d78

Please sign in to comment.