diff --git a/tests/eval/test_token_positions.py b/tests/eval/test_token_positions.py index 826e7c04..c584b931 100644 --- a/tests/eval/test_token_positions.py +++ b/tests/eval/test_token_positions.py @@ -13,7 +13,7 @@ def mock_data(): {"tokens": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]} ).with_format("torch") selected_tokens = [2, 4, 6, 8] - metrics = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) + metrics = torch.tensor([[-1, 0.45, -0.33], [-1.31, 2.3, 0.6], [0.2, 0.8, 0.1]]) return token_ids, selected_tokens, metrics @@ -26,19 +26,29 @@ def test_get_all_tok_metrics_in_label(mock_data): ) # key: (prompt_pos, tok_pos), value: logprob expected = { - (0, 1): 0.2, - (1, 0): 0.4, + (0, 1): 0.45, + (1, 0): -1.31, (1, 2): 0.6, (2, 1): 0.8, } - # use isclose to compare floating point numbers + + # compare keys + assert result.keys() == expected.keys() + # compare values for k in result: assert isclose(cast(float, result[k]), expected[k], rel_tol=1e-6) # type: ignore # test with quantile filtering result_q = get_all_tok_metrics_in_label( - token_ids["tokens"], selected_tokens, metrics, q_start=0.3, q_end=1.0 + token_ids["tokens"], selected_tokens, metrics, q_start=0.6, q_end=1.0 ) - expected_q = {(1, 2): 0.6, (2, 1): 0.8, (1, 0): 0.4} + expected_q = { + (1, 2): 0.6, + (2, 1): 0.8, + } + + # compare keys + assert result_q.keys() == expected_q.keys() + # compare values for k in result_q: assert isclose(cast(float, result_q[k]), expected_q[k], rel_tol=1e-6) # type: ignore diff --git a/tests/eval/test_utils.py b/tests/eval/test_utils.py index ad0f54b8..a259d16b 100644 --- a/tests/eval/test_utils.py +++ b/tests/eval/test_utils.py @@ -61,7 +61,22 @@ def test_load_validation_dataset(): def test_dict_filter_quantile(): d = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4, 5: 0.5} result = dict_filter_quantile(d, 0.2, 0.6) - expected = {2: 0.2, 3: 0.3, 4: 0.4} + expected = {2: 0.2, 3: 0.3} + + # compare keys + assert result.keys() == expected.keys() + # compare values + for k in result: + assert isclose(result[k], expected[k], rel_tol=1e-6) + + # test with negative values + d = {1: -0.1, 2: -0.2, 3: -0.3, 4: -0.4, 5: -0.5} + result = dict_filter_quantile(d, 0.2, 0.6) + expected = {3: -0.3, 4: -0.4} + + # compare keys + assert result.keys() == expected.keys() + # compare values for k in result: assert isclose(result[k], expected[k], rel_tol=1e-6)