Skip to content

Commit

Permalink
updates to parse_t5 model code
Browse files Browse the repository at this point in the history
  • Loading branch information
bjascob committed Nov 27, 2021
1 parent 05e9b43 commit 998b666
Show file tree
Hide file tree
Showing 12 changed files with 321 additions and 202 deletions.
87 changes: 59 additions & 28 deletions amrlib/models/parse_t5/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,59 +14,79 @@


class Inference(STOGInferenceBase):
def __init__(self, model_dir, model_fn=None, **kwargs):
def __init__(self, model_dir=None, model_fn=None, model=None, tokenizer=None, config=None, **kwargs):
default_device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = kwargs.get('device', default_device)
self.device = torch.device(device)
self.model = T5ForConditionalGeneration.from_pretrained(model_dir).to(self.device)
self.max_sent_len = self.model.config.task_specific_params['translation_amr_to_text']['max_in_len']
self.max_graph_len = self.model.config.task_specific_params['translation_amr_to_text']['max_out_len']
tokenizer_name = kwargs.get('tokenizer_name', 't5-base') # name or path
self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_name)
self.device = torch.device(kwargs.get('device', default_device))
self.batch_size = kwargs.get('batch_size', 12)
self.num_beams = kwargs.get('num_beams', 4) # 1 => greedy
self.num_beams = kwargs.get('num_beams', 4)
self.num_ret_seq = self.num_beams
self.ret_raw_gen = kwargs.get('ret_raw_gen', False) # Use only for debug
# Load the model from file
if model_dir is not None:
model = T5ForConditionalGeneration.from_pretrained(model_dir).to(self.device)
tok_name = kwargs.get('tok_name_or_path', 't5-base')
tokenizer = T5Tokenizer.from_pretrained(tok_name)
# model_parse_t5-v0_1_0 used the key "translation_amr_to_text" (copy error from generate_t5 code)
config = model.config.task_specific_params.get('parse_amr')
if config is None:
config = model.config.task_specific_params.get('translation_amr_to_text')
# Use the passed in values
elif model is not None and tokenizer is not None and config is not None:
pass
else:
raise ValueError('Either pass in the model directory or the model, tokenizer and config.')
# Add to the class
self.model = model.to(self.device)
self.tokenizer = tokenizer
self.max_sent_len = config['max_in_len']
self.max_graph_len = config['max_out_len']

# Generate sentences from a list of sentence strings
# For generate params see https://huggingface.co/transformers/master/main_classes/model.html
@torch.no_grad()
def parse_sents(self, sents, add_metadata=True, disable_progress=True):
assert isinstance(sents, list)
# Sort by sentence length for faster batching
# Put the longest first so that inference speeds up as it progresses, instead of slowing down.
data = [(s, i) for i, s in enumerate(sents)]
data = sorted(data, key=lambda x:len(x[0]), reverse=True)
# Loop though batches
graphs_generated = []
clips = []
dataloader = torch.utils.data.DataLoader(sents, batch_size=self.batch_size)
for batch in tqdm(dataloader, disable=disable_progress):
# Form encodings and tokenize
# input_text = ['%s %s' % (sent, self.tokenizer.eos_token) for sent in batch]
input_text = ['%s' % sent for sent in batch]
input_encodings = self.tokenizer.batch_encode_plus(input_text, padding=True,
truncation=True, max_length=self.max_sent_len,
return_overflowing_tokens=True)
clips = []
graphs_generated = [None]*len(sents)*self.num_ret_seq
self.model.eval()
pbar = tqdm(total=len(sents), ncols=100, position=0, leave=True, disable=disable_progress)
for batch in self._chunk(data, self.batch_size):
input_text = [x[0] for x in batch]
sent_indxes = [x[1] for x in batch]
# Form encodings and tokenize (padding=True => pad to the longest)
input_encodings = self.tokenizer(input_text, padding=True, truncation=True,
max_length=self.max_sent_len, return_overflowing_tokens=True)
# Check if any graphs were truncated (requires return_overflowing_tokens=True)
clip = [l > 0 for l in input_encodings['num_truncated_tokens']]
clips.extend(clip)
# Convert to tensors
input_ids = torch.LongTensor(input_encodings['input_ids']).to(self.device)
attention_mask = torch.LongTensor(input_encodings['attention_mask']).to(self.device)
# Generate
# Generate the batch ids and convert to back to tokens
outs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask,
max_length=self.max_graph_len, early_stopping=True,
num_beams=self.num_beams, num_return_sequences=self.num_ret_seq)
outs = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
graphs_generated.extend(outs)
# For debugging only ...
# Note: in this mode we're returning 2 lists of num_ret_seq * len(sents) instead of
# one list of len(sents) as in the default run-time mode
# De-sort the output token data. There are self.num_ret_seq returned for each sentence
for bidx in range(len(batch)):
sidx = sent_indxes[bidx]
graphs_generated[self._group_slice(sidx)] = outs[self._group_slice(bidx)]
pbar.update(len(batch))
pbar.close()
# For debugging and sanity check
if self.ret_raw_gen:
return graphs_generated, clips
# Extract the top result that isn't clipped and will deserialize
# At this point "graphs_generated" and "clips" have num_ret_seq for each sent * len(sents)
assert not any(g is None for g in graphs_generated)
# Get the top result that properly deserializes. graphs_generated is len(sents)*num_ret_seq
graphs_final = [None]*len(sents)
for snum in range(len(sents)):
if clips[snum]:
logger.error('Sentence number %d was clipped for length' % snum)
raw_graphs = graphs_generated[snum*self.num_ret_seq:(snum+1)*self.num_ret_seq]
raw_graphs = graphs_generated[self._group_slice(snum)]
for bnum, g in enumerate(raw_graphs):
gstring = PenmanDeSerializer(g).get_graph_string()
if gstring is not None:
Expand All @@ -79,6 +99,17 @@ def parse_sents(self, sents, add_metadata=True, disable_progress=True):
graphs_final = ['# ::snt %s\n%s' % (s, g) if g is not None else None for s, g in zip(sents, graphs_final)]
return graphs_final

# Return a slice operator to extract the models ouput group based on the input index
# The model returns self.num_ret_seq * length(input) as a flat list.
def _group_slice(self, input_idx):
return slice(input_idx * self.num_ret_seq, (input_idx + 1) * self.num_ret_seq)

# Yield successive n-sized chunks from lst.
@staticmethod
def _chunk(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i + n]

# parse a list of spacy spans (ie.. span has list of tokens)
def parse_spans(self, spans, add_metadata=True):
sents = [s.text.strip() for s in spans]
Expand Down
136 changes: 58 additions & 78 deletions amrlib/models/parse_t5/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,121 +5,101 @@
import torch
from torch.utils.data import Dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer, set_seed
from transformers import TrainingArguments
from transformers import Trainer as T5Trainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from .amr_t5_trainer import AMRT5Trainer
from .penman_serializer import load_and_serialize

logger = logging.getLogger(__name__)


# Torch "DataSet" used for feeding data to the training routine
# Keys are... input_ids, attention_mask, target_ids, target_attention_mask
class AMRDataset(Dataset):
def __init__(self, encodings, sents, bad_indexes):
self.encodings = encodings
self.sents = sents
self.bad_indexes = bad_indexes # in original file's index, not same as above

def __len__(self):
return len(self.encodings['input_ids'])

def __getitem__(self, idx):
return {k:v[idx] for k, v in self.encodings.items()}


# Take a list of samples from a Dataset and collate them into a batch and returns a dict
# prepares lm_labels from target_ids, returns examples with keys as expected by the forward method
# this is necessacry because the trainer directly passes this dict as arguments to the model
# so make sure the keys match the parameter names of the forward method
class T2TDataCollator:
def __call__(self, batch):
input_ids = torch.stack([example['input_ids'] for example in batch])
lm_labels = torch.stack([example['target_ids'] for example in batch])
lm_labels[lm_labels[:, :] == 0] = -100
attention_mask = torch.stack([example['attention_mask'] for example in batch])
decoder_attention_mask = torch.stack([example['target_attention_mask'] for example in batch])
return {'input_ids': input_ids, 'attention_mask': attention_mask,
'labels': lm_labels, 'decoder_attention_mask': decoder_attention_mask }


# Note that for save_steps, steps means gradient updates (not batch) so if
# gradient_accumulation_steps=4 and save_steps=1000, then checkpoint is saved every 4000 batches.
class Trainer(object):
def __init__(self, args):
# General arguments
self.gen_args = args['gen_args']
self.model_name_or_path = self.gen_args['model_name_or_path']
self.tok_name_or_path = self.gen_args.get('tok_name_or_path', self.model_name_or_path)
self.corpus_dir = self.gen_args['corpus_dir']
self.train_fn = self.gen_args['train_fn']
self.valid_fn = self.gen_args['valid_fn']
self.eval_fn = self.gen_args.get('eval_fn')
self.max_in_len = self.gen_args['max_in_len']
self.max_out_len = self.gen_args['max_out_len']
# HuggingFace trainer arguments
# See https://github.com/huggingface/transformers/blob/master/src/transformers/training_args.py
self.training_args = TrainingArguments(**args['hf_args'])
self.training_args = TrainingArguments(**args['hf_args'])
set_seed(self.training_args.seed)

def train(self):
# Create the output directory if needed
os.makedirs(self.training_args.output_dir, exist_ok=True)
# Load pretrained model and tokenizer
print('Loading model and tokenizer')
self.tokenizer = T5Tokenizer.from_pretrained(self.model_name_or_path)
self.tokenizer = T5Tokenizer.from_pretrained(self.tok_name_or_path)
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name_or_path)
# Clear out the "task_specific_params" and add this one
self.model.config.task_specific_params = {'translation_amr_to_text':self.gen_args}
self.model.config.task_specific_params = {'parse_amr':self.gen_args}
# Save the tokenizer
if self.gen_args.get('save_tokenizer', False):
self.tokenizer.save_pretrained(self.training_args.output_dir)
# Load the datasets
print('Building datasets')
train_file_path = os.path.join(self.corpus_dir, self.train_fn)
valid_file_path = os.path.join(self.corpus_dir, self.valid_fn)
train_dataset = self.build_dataset(train_file_path)
train_dataset = self.build_dataset(train_file_path)
print('Training data is {:,} after removing {:,} long entries'.format( \
len(train_dataset), len(train_dataset.bad_indexes)))
valid_dataset = self.build_dataset(valid_file_path)
print('Validation data is {:,} after removing {:,} long entries'.format( \
len(valid_dataset), len(valid_dataset.bad_indexes)))
# Load the evaluation dataset
if self.eval_fn:
eval_file_path = os.path.join(self.corpus_dir, self.eval_fn)
eval_samples = load_and_serialize(eval_file_path)
print('Evaluation data is {:,} samples'.format(len(eval_samples['graphs'])))
else:
eval_samples = None
# Train the model
print('Training')
trainer = T5Trainer(model=self.model, args=self.training_args, train_dataset=train_dataset,
eval_dataset=valid_dataset, data_collator=T2TDataCollator())
trainer.train()
collator = DataCollatorForSeq2Seq(self.tokenizer, self.model, padding=True, max_length=self.max_out_len)
trainer = AMRT5Trainer(model=self.model, args=self.training_args, train_dataset=train_dataset,
data_collator=collator, eval_tokenizer=self.tokenizer, eval_samples=eval_samples)
# If resume_from_checkpoint is True it will look for the last checkpoint in the value of output_dir
# passed via TrainingArguments. If it's a path to a specific checkpoint it will use that saved
# checkpoint folder to resume the training from.
trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint)
# Save the results
print('Saving model')
trainer.save_model(self.training_args.output_dir)
#self.tokenizer.save_pretrained(self.training_args.output_dir)
if self.gen_args.get('save_at_end', False):
print('Saving model')
trainer.save_model(self.training_args.output_dir)

# Convert the AMR graphs into tokenized sentences
def build_dataset(self, fpath):
# Load the raw data
entries = load_and_serialize(fpath)
# Convert to input and target sentences
entries['input_text'] = ['%s' % sent for sent in entries['sents']]
entries['target_text'] = ['%s' % graph for graph in entries['serials']]
# Form the input encodings
sents = entries['sents']
print('Batch encoding')
input_encodings = self.tokenizer.batch_encode_plus(entries['input_text'],
padding=True, truncation=True, max_length=self.max_in_len,
return_overflowing_tokens=True)
target_encodings = self.tokenizer.batch_encode_plus(entries['target_text'],
padding=True, truncation=True, max_length=self.max_out_len,
return_overflowing_tokens=True)
# Remove any graphs that are greater than max length after tokenization
# Find the bad indexes
entries = load_and_serialize(fpath) # returns a dict of lists
# Tokenize the target in order to strip off any training samples that are too long.
# Set return_overflowing_tokens=True to return overflowing_tokens', 'num_truncated_tokens'
print('Tokenizing')
in_enc = self.tokenizer(entries['sents'], padding=False, truncation=True, max_length=self.max_in_len,
return_overflowing_tokens=True)
tgt_enc = self.tokenizer(entries['serials'], padding=False, truncation=True, max_length=self.max_out_len,
return_overflowing_tokens=True)
# Identify any truncated data
bi = set()
for i, (ie, te) in enumerate(zip(input_encodings['num_truncated_tokens'], target_encodings['num_truncated_tokens'])):
for i, (ie, te) in enumerate(zip(in_enc['num_truncated_tokens'], tgt_enc['num_truncated_tokens'])):
if ie > 0 or te > 0:
bi.add( i )
# Remove them
input_encodings['input_ids'] = [ie for i, ie in enumerate(input_encodings['input_ids']) if i not in bi]
target_encodings['input_ids'] = [te for i, te in enumerate(target_encodings['input_ids']) if i not in bi]
input_encodings['attention_mask'] = [ie for i, ie in enumerate(input_encodings['attention_mask']) if i not in bi]
target_encodings['attention_mask'] = [te for i, te in enumerate(target_encodings['attention_mask']) if i not in bi]
sents = [s for i, s in enumerate(sents) if i not in bi]
# Create the encodings
encodings = {'input_ids': torch.LongTensor(input_encodings['input_ids']),
'attention_mask': torch.LongTensor(input_encodings['attention_mask']),
'target_ids': torch.LongTensor(target_encodings['input_ids']),
'target_attention_mask': torch.LongTensor(target_encodings['attention_mask']) }
# Encapsulate the data and return
return AMRDataset(encodings, sents, bi)
# Compile the output encodings, stripped of bad indexes
# These will be passed directly to the model so make sure all the keys are correct, with no extras
encodings = {}
encodings['input_ids'] = [ie for i, ie in enumerate(in_enc['input_ids']) if i not in bi]
encodings['attention_mask'] = [ie for i, ie in enumerate(in_enc['attention_mask']) if i not in bi]
encodings['labels'] = [te for i, te in enumerate(tgt_enc['input_ids']) if i not in bi]
return AMRDataset(encodings, bi)


# Torch DataSet used for feeding data to the training routine
class AMRDataset(Dataset):
def __init__(self, encodings, bad_indexes):
self.encodings = encodings
self.bad_indexes = bad_indexes

def __len__(self):
return len(self.encodings['input_ids'])

def __getitem__(self, idx):
return {k:v[idx] for k, v in self.encodings.items()}
27 changes: 17 additions & 10 deletions configs/model_parse_t5.json
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
{ "gen_args" :
{
"model_name_or_path" : "t5-base",
"corpus_dir" : "amrlib/data/tdata_gsii/",
"train_fn" : "train.txt.features.nowiki",
"valid_fn" : "dev.txt.features.nowiki",
"corpus_dir" : "amrlib/data/tdata_t5/",
"train_fn" : "train.txt.nowiki",
"eval_fn" : "dev.txt.nowiki",
"save_tokenizer" : false,
"save_at_end" : false,
"eval_batch_size" : 12,
"eval_num_beams" : 1,
"max_in_len" : 100,
"max_out_len" : 512

},
"hf_args" :
{
"output_dir" : "amrlib/data/model_parse_t5",
"save_strategy" : "epoch",
"evaluation_strategy" : "epoch",
"group_by_length" : true,
"do_train" : true,
"do_eval" : false,
"do_eval" : true,
"overwrite_output_dir" : false,
"prediction_loss_only" : true,
"num_train_epochs" : 8,
"save_steps" : 1000,
"save_total_limit" : 2,
"save_total_limit" : null,
"num_train_epochs" : 16,
"per_device_train_batch_size" : 4,
"per_device_eval_batch_size" : 4,
"gradient_accumulation_steps" : 4,
"weight_decay" : 0.004,
"learning_rate" : 1e-4,
"seed" : 42
"max_grad_norm" : 1.0,
"warmup_steps" : 3448,
"seed" : 0
}
}
Loading

0 comments on commit 998b666

Please sign in to comment.