forked from lukalabs/cakechat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathranking_quality.py
80 lines (57 loc) · 3.06 KB
/
ranking_quality.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from cakechat.utils.env import init_cuda_env
init_cuda_env()
from collections import defaultdict
from cakechat.config import INPUT_SEQUENCE_LENGTH, INPUT_CONTEXT_SIZE, OUTPUT_SEQUENCE_LENGTH, TEST_CORPUS_NAME, \
TEST_DATA_DIR
from cakechat.dialog_model.inference.utils import get_sequence_score
from cakechat.dialog_model.quality import compute_retrieval_metric_mean, compute_average_precision, compute_recall_k
from cakechat.dialog_model.model_utils import transform_lines_to_token_ids, transform_contexts_to_token_ids
from cakechat.dialog_model.factory import get_trained_model
from cakechat.utils.text_processing import get_tokens_sequence
from cakechat.utils.files_utils import load_file
from cakechat.utils.data_structures import flatten
def _read_testset():
corpus_path = os.path.join(TEST_DATA_DIR, '{}.txt'.format(TEST_CORPUS_NAME))
test_lines = load_file(corpus_path)
testset = defaultdict(set)
for i in range(0, len(test_lines) - 1, 2):
context = test_lines[i].strip()
response = test_lines[i + 1].strip()
testset[context].add(response)
return testset
def _get_context_to_weighted_responses(nn_model, testset, all_utterances):
token_to_index = nn_model.token_to_index
all_utterances_ids = transform_lines_to_token_ids(
list(map(get_tokens_sequence, all_utterances)), token_to_index, OUTPUT_SEQUENCE_LENGTH, add_start_end=True)
context_to_weighted_responses = {}
for context in testset:
context_tokenized = get_tokens_sequence(context)
repeated_context_ids = transform_contexts_to_token_ids(
[[context_tokenized]] * len(all_utterances), token_to_index, INPUT_SEQUENCE_LENGTH, INPUT_CONTEXT_SIZE)
scores = get_sequence_score(nn_model, repeated_context_ids, all_utterances_ids)
context_to_weighted_responses[context] = dict(zip(all_utterances, scores))
return context_to_weighted_responses
def _compute_metrics(model, testset):
all_utterances = list(flatten(testset.values(), set)) # Get all unique responses
context_to_weighted_responses = _get_context_to_weighted_responses(model, testset, all_utterances)
test_set_size = len(all_utterances)
metrics = {
'mean_ap':
compute_retrieval_metric_mean(
compute_average_precision, testset, context_to_weighted_responses, top_count=test_set_size),
'mean_recall@10':
compute_retrieval_metric_mean(compute_recall_k, testset, context_to_weighted_responses, top_count=10),
'mean_recall@25%':
compute_retrieval_metric_mean(
compute_recall_k, testset, context_to_weighted_responses, top_count=test_set_size // 4)
}
print('Test set size = {}'.format(test_set_size))
for metric_name, metric_value in metrics.items():
print('{} = {}'.format(metric_name, metric_value))
if __name__ == '__main__':
nn_model = get_trained_model()
testset = _read_testset()
_compute_metrics(nn_model, testset)