Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary pad token #2428

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def __len__(self):

def collate_fn(self, data):
inputs, answers = [], []

max_batch_len = 1
for sample in data:
preamble, context, aliases = (sample['preamble'], sample['context'], sample['aliases'])
context_enc = preamble['input_ids'] + context['input_ids']
Expand All @@ -247,9 +247,13 @@ def collate_fn(self, data):

inputs.append(inp)
answers.append(aliases)
max_batch_len = max(max_batch_len, len(context_enc))

max_batch_len = min(max_batch_len, self.max_seq_len)
# Truncate pad token from left padding
input_ids = torch.stack(inputs)[..., -max_batch_len:]
batch = {
'input_ids': torch.stack(inputs),
'input_ids': input_ids,
'mode': 'generate',
'labels': answers,
'generation_length': self.max_answer_length,
Expand Down Expand Up @@ -415,6 +419,7 @@ def __len__(self):
def collate_fn(self, data):
inputs = []
continuation_indices = []
max_batch_len = 1
for data_pair in data:
preamble, context, continuation = (data_pair['preamble'], data_pair['context'], data_pair['continuation'])

Expand All @@ -426,12 +431,15 @@ def collate_fn(self, data):

inputs.append(inp)
continuation_indices.append(continuation_span)
max_batch_len = max(max_batch_len, len(context_enc)+len(continuation_enc))

max_batch_len = min(max_batch_len, self.max_seq_len)
input_ids = torch.stack(inputs)[..., :max_batch_len] # right padding
batch = {
'input_ids': torch.stack(inputs),
'input_ids': input_ids,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the test error you are getting is because input_ids and labels should not be the same tensor, because labels gets modified to put a -100 when the labels get rolled so they are aligned for the next token objective.

'continuation_indices': continuation_indices,
'mode': 'icl_task',
'labels': torch.stack(inputs),
'labels': input_ids,
}

batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id)
Expand Down Expand Up @@ -592,6 +600,7 @@ def collate_fn(self, data):
continuation_indices = []
gold_idxs = []
choice_groupings = []
max_batch_len = 1
for data_pair in data:

choice_start_idx = len(continuation_indices)
Expand All @@ -606,6 +615,7 @@ def collate_fn(self, data):

inputs.append(inp)
continuation_indices.append(continuation_span)
max_batch_len = max(max_batch_len, len(context_enc)+len(continuation_enc))

gold_idxs.append(gold_idx)
choice_end_idx = len(continuation_indices)
Expand All @@ -618,11 +628,14 @@ def collate_fn(self, data):
# since the batch may consist of multiple questions, the choice_groupings indicates
# which contiguous sequences of elements in the batch correspond to which question
# gold_indices indicates which of the [0, N-1] choices is the correct one for each question.

max_batch_len = min(max_batch_len, self.max_seq_len)
input_ids = torch.stack(inputs)[..., :max_batch_len] # right padding
batch = {
'input_ids': torch.stack(inputs),
'input_ids': input_ids,
'continuation_indices': continuation_indices,
'mode': 'icl_task',
'labels': torch.stack(inputs),
'labels': input_ids,
'gold_indices': gold_idxs,
'choice_groupings': choice_groupings
}
Expand Down Expand Up @@ -811,6 +824,7 @@ def collate_fn(self, data):
continuation_indices = []
gold_idxs = []
choice_groupings = []
max_batch_len = 1
for data_pair in data:

continuation_start_idx = len(continuation_indices)
Expand All @@ -825,6 +839,7 @@ def collate_fn(self, data):

inputs.append(inp)
continuation_indices.append(continuation_span)
max_batch_len = max(max_batch_len, len(context_enc)+len(continuation_enc))

gold_idxs.append(gold_idx)
continuation_end_idx = len(continuation_indices)
Expand All @@ -837,11 +852,14 @@ def collate_fn(self, data):
# since the batch may consist of multiple questions, the choice_groupings indicates
# which contiguous sequences of elements in the batch correspond to which question
# gold_indices indicates which of the [0, N-1] choices is the correct one for each question.

max_batch_len = min(max_batch_len, self.max_seq_len)
input_ids = torch.stack(inputs)[..., :max_batch_len] # right padding
batch = {
'input_ids': torch.stack(inputs),
'input_ids': input_ids,
'continuation_indices': continuation_indices,
'mode': 'icl_task',
'labels': torch.stack(inputs),
'labels': input_ids,
'gold_indices': gold_idxs,
'choice_groupings': choice_groupings
}
Expand Down Expand Up @@ -1001,6 +1019,7 @@ def __len__(self):
return len(self.encoded_dataset)

def collate_fn(self, data):
max_batch_len = 1
inputs, prompts, tests, canonical_solutions, entry_points, test_inputs, test_outputs, languages = [], [], [], [], [], [], [], []
for sample in data:
preamble, prompt, text_prompt, canonical_solution, test, entry_point, test_input, test_output, language = (
Expand All @@ -1027,10 +1046,14 @@ def collate_fn(self, data):
entry_points.append(entry_point)
test_inputs.append(test_input)
test_outputs.append(test_output)
max_batch_len = max(max_batch_len, len(context_enc))
languages.append(language)

max_batch_len = min(max_batch_len, self.max_seq_len)
# Truncate pad token from left padding
input_ids = torch.stack(inputs)[..., -max_batch_len:]
batch = {
'input_ids': torch.stack(inputs),
'input_ids': input_ids,
'mode': 'generate',
'labels': canonical_solutions,
'prompts': prompts, # list of prompts
Expand Down Expand Up @@ -1335,4 +1358,4 @@ def get_icl_task_dataloader(
question_prelimiter,
fewshot_random_seed,
generations_per_sample,
)
)
Loading