Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ancalita committed Sep 19, 2023
1 parent 138bd07 commit 12973fa
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
2 changes: 2 additions & 0 deletions tests/cdu/stack/test_dialogue_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def test_dialogue_stack_as_dict():
"frame_id": "some-other-id",
"step_id": "__start__",
"flow_id": "pattern_ask_collect_information",
"rejections": None,
},
]

Expand Down Expand Up @@ -201,6 +202,7 @@ def test_get_current_context():
"step_id": "first_step",
"type": "flow",
"collect_information": "foo",
"rejections": None,
}


Expand Down
14 changes: 7 additions & 7 deletions tests/core/featurizers/test_tracker_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_featurize_trackers_with_full_dialogue_tracker_featurizer(
for actual, expected in zip(actual_features, expected_features):
assert compare_featurized_states(actual, expected)

expected_labels = np.array([[0, 20, 0, 17, 18, 0, 19]])
expected_labels = np.array([[0, 21, 0, 18, 19, 0, 20]])
assert actual_labels is not None
assert len(actual_labels) == 1
for actual, expected in zip(actual_labels, expected_labels):
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_trackers_ignore_action_unlikely_intent_with_full_dialogue_tracker_featu
for actual, expected in zip(actual_features, expected_features):
assert compare_featurized_states(actual, expected)

expected_labels = np.array([[0, 20, 0, 17, 18, 0, 19]])
expected_labels = np.array([[0, 21, 0, 18, 19, 0, 20]])
assert actual_labels is not None
assert len(actual_labels) == 1
for actual, expected in zip(actual_labels, expected_labels):
Expand Down Expand Up @@ -324,7 +324,7 @@ def test_trackers_keep_action_unlikely_intent_with_full_dialogue_tracker_featuri
for actual, expected in zip(actual_features, expected_features):
assert compare_featurized_states(actual, expected)

expected_labels = np.array([[0, 9, 20, 0, 9, 17, 18, 0, 9, 19]])
expected_labels = np.array([[0, 9, 21, 0, 9, 18, 19, 0, 9, 20]])
assert actual_labels is not None
assert len(actual_labels) == 1
for actual, expected in zip(actual_labels, expected_labels):
Expand Down Expand Up @@ -832,7 +832,7 @@ def test_featurize_trackers_with_max_history_tracker_featurizer(
for actual, expected in zip(actual_features, expected_features):
assert compare_featurized_states(actual, expected)

expected_labels = np.array([[0, 20, 0, 17, 18, 0, 19]]).T
expected_labels = np.array([[0, 21, 0, 18, 19, 0, 20]]).T

assert actual_labels is not None
assert actual_labels.shape == expected_labels.shape
Expand Down Expand Up @@ -899,7 +899,7 @@ def test_featurize_trackers_ignore_action_unlikely_intent_max_history_featurizer
for actual, expected in zip(actual_features, expected_features):
assert compare_featurized_states(actual, expected)

expected_labels = np.array([[0, 20, 0]]).T
expected_labels = np.array([[0, 21, 0]]).T
assert actual_labels.shape == expected_labels.shape
for actual, expected in zip(actual_labels, expected_labels):
assert np.all(actual == expected)
Expand Down Expand Up @@ -971,7 +971,7 @@ def test_featurize_trackers_keep_action_unlikely_intent_max_history_featurizer(
for actual, expected in zip(actual_features, expected_features):
assert compare_featurized_states(actual, expected)

expected_labels = np.array([[0, 9, 20, 0]]).T
expected_labels = np.array([[0, 9, 21, 0]]).T
assert actual_labels is not None
assert actual_labels.shape == expected_labels.shape
for actual, expected in zip(actual_labels, expected_labels):
Expand Down Expand Up @@ -1088,7 +1088,7 @@ def test_deduplicate_featurize_trackers_with_max_history_tracker_featurizer(
for actual, expected in zip(actual_features, expected_features):
assert compare_featurized_states(actual, expected)

expected_labels = np.array([[0, 20, 0, 17, 18, 0, 19]]).T
expected_labels = np.array([[0, 21, 0, 18, 19, 0, 20]]).T
if not remove_duplicates:
expected_labels = np.vstack([expected_labels] * 2)

Expand Down
10 changes: 6 additions & 4 deletions tests/core/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
ACTION_CANCEL_FLOW,
ACTION_CLARIFY_FLOWS,
ACTION_CORRECT_FLOW_SLOT,
ACTION_EVALUATE_PREDICATE_REJECTIONS_NAME,
USER_INTENT_SESSION_START,
ACTION_LISTEN_NAME,
ACTION_RESTART_NAME,
Expand Down Expand Up @@ -145,7 +146,7 @@ def test_domain_action_instantiation():
for action_name in domain.action_names_or_texts
]

assert len(instantiated_actions) == 20
assert len(instantiated_actions) == 21
assert instantiated_actions[0].name() == ACTION_LISTEN_NAME
assert instantiated_actions[1].name() == ACTION_RESTART_NAME
assert instantiated_actions[2].name() == ACTION_SESSION_START_NAME
Expand All @@ -163,9 +164,10 @@ def test_domain_action_instantiation():
assert instantiated_actions[14].name() == ACTION_CANCEL_FLOW
assert instantiated_actions[15].name() == ACTION_CORRECT_FLOW_SLOT
assert instantiated_actions[16].name() == ACTION_CLARIFY_FLOWS
assert instantiated_actions[17].name() == "my_module.ActionTest"
assert instantiated_actions[18].name() == "utter_test"
assert instantiated_actions[19].name() == "utter_chitchat"
assert instantiated_actions[17].name() == ACTION_EVALUATE_PREDICATE_REJECTIONS_NAME
assert instantiated_actions[18].name() == "my_module.ActionTest"
assert instantiated_actions[19].name() == "utter_test"
assert instantiated_actions[20].name() == "utter_chitchat"


@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def test_generate_training_data_with_cycles(domain: Domain):
# if new default actions are added the keys of the actions will be changed

all_label_ids = [id for ids in label_ids for id in ids]
assert Counter(all_label_ids) == {0: 6, 20: 1, 18: num_tens, 1: 2, 19: 3}
assert Counter(all_label_ids) == {0: 6, 20: 3, 19: num_tens, 1: 2, 21: 1}


def test_generate_training_data_with_unused_checkpoints(domain: Domain):
Expand Down

0 comments on commit 12973fa

Please sign in to comment.