-
Notifications
You must be signed in to change notification settings - Fork 8
/
evidence_selection.py
89 lines (74 loc) · 3.39 KB
/
evidence_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import itertools
from typing import Any, Dict, List
import torch
from sentence_transformers import CrossEncoder
PASSAGE_RANKER = CrossEncoder(
"cross-encoder/ms-marco-MiniLM-L-6-v2",
max_length=512,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
def compute_score_matrix(
questions: List[str], evidences: List[str]
) -> List[List[float]]:
"""Scores the relevance of all evidence against all questions using a CrossEncoder.
Args:
questions: A list of unique questions.
evidences: A list of unique evidences.
Returns:
score_matrix: A 2D list list of question X evidence relevance scores.
"""
score_matrix = []
for q in questions:
evidence_scores = PASSAGE_RANKER.predict([(q, e) for e in evidences]).tolist()
score_matrix.append(evidence_scores)
return score_matrix
def question_coverage_objective_fn(
score_matrix: List[List[float]], evidence_indices: List[int]
) -> float:
"""Given (query, evidence) scores and a subset of evidence, return the coverage.
Given all pairwise query and evidence scores, and a subset of the evidence
specified by indices, return a value indicating how well this subset of evidence
covers (i.e., helps answer) all questions.
Args:
score_matrix: A 2D list list of question X evidence relevance scores.
evidence_indicies: A subset of the evidence to to get the coverage score of.
Returns:
total: The coverage we would get by using the subset of evidence in
`evidence_indices` over all questions.
"""
# Compute sum_{question q} max_{selected evidence e} score(q, e).
# This encourages all questions to be explained by at least one evidence.
total = 0.0
for scores_for_question in score_matrix:
total += max(scores_for_question[j] for j in evidence_indices)
return total
def select_evidences(
example: Dict[str, Any], max_selected: int = 5, prefer_fewer: bool = False
) -> List[Dict[str, Any]]:
"""Selects the set of evidence that maximizes information converage over the claim.
Args:
example: The result of running the editing pipeline on one claim.
max_selected: Maximum number of evidences to select.
prefer_fewer: If True and the maximum objective value can be achieved by
fewer evidences than `max_selected`, prefer selecting fewer evidences.
Returns:
selected_evidences: Selected evidences that serve as the attribution report.
"""
questions = sorted(set(example["questions"]))
evidences = sorted(set(e["text"] for e in example["revisions"][0]["evidences"]))
num_evidences = len(evidences)
if not num_evidences:
return []
score_matrix = compute_score_matrix(questions, evidences)
best_combo = tuple()
best_objective_value = float("-inf")
max_selected = min(max_selected, num_evidences)
min_selected = 1 if prefer_fewer else max_selected
for num_selected in range(min_selected, max_selected + 1):
for combo in itertools.combinations(range(num_evidences), num_selected):
objective_value = question_coverage_objective_fn(score_matrix, combo)
if objective_value > best_objective_value:
best_combo = combo
best_objective_value = objective_value
selected_evidences = [{"text": evidences[idx]} for idx in best_combo]
return selected_evidences