Skip to content

Commit

Permalink
Merge pull request #20 from McGill-NLP/main
Browse files Browse the repository at this point in the history
Remove dependency on tf
  • Loading branch information
xhluca authored May 29, 2024
2 parents edd5983 + 29b1e27 commit f31b698
Showing 1 changed file with 0 additions and 78 deletions.
78 changes: 0 additions & 78 deletions instruct_qa/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
AutoModelWithLMHead,
AutoModelForQuestionAnswering,
)
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text

import numpy as np
from scipy.special import softmax
Expand Down Expand Up @@ -80,81 +77,6 @@ def __call__(self, predictions, references, questions=None, ids=None):
}


class BEMScore(Metric):
def __init__(self, name, **kwargs):
super().__init__(name, **kwargs)
vocab_path = "gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/vocab.txt"
vocab_table = tf.lookup.StaticVocabularyTable(
tf.lookup.TextFileInitializer(
filename=vocab_path,
key_dtype=tf.string,
key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
value_dtype=tf.int64,
value_index=tf.lookup.TextFileIndex.LINE_NUMBER,
),
num_oov_buckets=1,
)
self.cls_id, self.sep_id = vocab_table.lookup(
tf.convert_to_tensor(["[CLS]", "[SEP]"])
)
self.tokenizer = text.BertTokenizer(
vocab_lookup_table=vocab_table,
token_out_type=tf.int64,
preserve_unused_token=True,
lower_case=True,
)
self.bem = hub.load("https://tfhub.dev/google/answer_equivalence/bem/1")

def bertify_example(self, example):
question = self.tokenizer.tokenize(example["question"]).merge_dims(1, 2)
reference = self.tokenizer.tokenize(example["reference"]).merge_dims(1, 2)
candidate = self.tokenizer.tokenize(example["candidate"]).merge_dims(1, 2)
input_ids, segment_ids = text.combine_segments(
(candidate, reference, question), self.cls_id, self.sep_id
)
return {"input_ids": input_ids.numpy(), "segment_ids": segment_ids.numpy()}

def bertify_examples(self, examples):
input_ids = []
segment_ids = []
for example in examples:
example_inputs = self.bertify_example(example)
input_ids.append(self.pad(example_inputs["input_ids"]))
segment_ids.append(self.pad(example_inputs["segment_ids"]))

return {"input_ids": np.stack(input_ids), "segment_ids": np.stack(segment_ids)}

def pad(self, a, length=512):
if a.shape[-1] >= length:
return a[0][:length]
else:
return np.append(a, np.zeros(length - a.shape[-1], np.int32))

def __call__(self, predictions, references, questions, ids=None):
assert len(predictions) == len(references)

scores = []
for i in tqdm(range(len(predictions))):
examples = [
{
"question": questions[i],
"reference": reference,
"candidate": predictions[i],
}
for reference in references[i]
]
inputs = self.bertify_examples(examples)
raw_outputs = self.bem(inputs)
score = float(max(softmax(raw_outputs, axis=1)[:, 1]))
scores.append(score)

if self.store_individual_scores:
individual_scores = [{"bem": score} for score in scores]
self.save_individual_scores(ids, individual_scores)

return {"bem": np.mean(scores)}


class Rouge(Metric):
def __init__(self, name, **kwargs):
super().__init__(name, **kwargs)
Expand Down

0 comments on commit f31b698

Please sign in to comment.