Skip to content

Commit

Permalink
ADD: complementary tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau committed Oct 23, 2024
1 parent 131faea commit 55cc305
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
4 changes: 3 additions & 1 deletion tests/unit_tests/models/test_nlogit.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ def test_fit_adam_specific_specification():
epochs=1,
batch_size=-1,
shared_gammas_over_nests=False,
regulariation="l1",
regularization_strength=0.0001,
)

model.instantiate(test_dataset)
nll_b = model.evaluate(test_dataset)
model.fit(test_dataset, get_report=False)
model.fit(test_dataset, get_report=True)
nll_a = model.evaluate(test_dataset)
assert nll_a < nll_b
60 changes: 59 additions & 1 deletion tests/unit_tests/models/test_rumnet_unit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Basic tests for the RUMnet model."""

import numpy as np
import pytest
import tensorflow as tf

from choice_learn.data import ChoiceDataset
Expand Down Expand Up @@ -121,6 +122,44 @@ def test_assortment_parallel_dense():
assert w.shape == (8, 2)


def test_paper_rumnet_errors():
"""Tests errors raisded by PaperRUMnet model."""
with pytest.raises(ValueError):
model = PaperRUMnet(
num_products_features=0,
num_customer_features=3,
width_eps_x=3,
depth_eps_x=2,
heterogeneity_x=2,
width_eps_z=3,
depth_eps_z=2,
heterogeneity_z=2,
width_u=3,
depth_u=1,
tol=1e-5,
optimizer="adam",
lr=0.001,
)
model.instantiate()
with pytest.raises(ValueError):
model = PaperRUMnet(
num_products_features=2,
num_customer_features=0,
width_eps_x=3,
depth_eps_x=2,
heterogeneity_x=2,
width_eps_z=3,
depth_eps_z=2,
heterogeneity_z=2,
width_u=3,
depth_u=1,
tol=1e-5,
optimizer="adam",
lr=0.001,
)
model.instantiate()


def test_paper_rumnet():
"""Tests the PaperRUMnet model."""
tf.config.run_functions_eagerly(True)
Expand Down Expand Up @@ -184,6 +223,14 @@ def test_cpu_rumnet():
None,
)[1].shape == (4, 3)

assert model.batch_predict(
(dataset.shared_features_by_choice[0],),
(dataset.items_features_by_choice[0],),
np.ones((4, 3)),
dataset.choices,
None,
)[1].shape == (4, 3)


def test_gpu_rumnet():
"""Tests the GPURUMNet model."""
Expand All @@ -205,7 +252,6 @@ def test_gpu_rumnet():
optimizer="adam",
lr=0.001,
)
print(dataset.items_features_by_choice[0].dtype)
model.instantiate()
assert model.batch_predict(
dataset.shared_features_by_choice[0],
Expand All @@ -214,3 +260,15 @@ def test_gpu_rumnet():
dataset.choices,
None,
)[1].shape == (4, 3)
nll_a = model.evaluate(dataset)
model.fit(dataset)
nll_b = model.evaluate(dataset)
assert nll_b < nll_a

assert model.batch_predict(
(dataset.shared_features_by_choice[0],),
(dataset.items_features_by_choice[0],),
np.ones((4, 3)),
dataset.choices,
None,
)[1].shape == (4, 3)

0 comments on commit 55cc305

Please sign in to comment.