Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for custom replacement_fn #687

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

ByrdOfAFeather
Copy link

@ByrdOfAFeather ByrdOfAFeather commented Aug 7, 2022

Rather than removing text which can create oddities, we may want to consider ways to replace tokens that would otherwise be removed. I added a support for a custom replacement_fn, which is similar to the classifier_fn. My particular use case was using T5, as such, I modified the generation of perturbed data to be in batch style rather than going one at a time.

This solves partially #648

Example replacement_fn:

def t5_wrapper(text_as_list: List[str], masks: list[list[bool]]):
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    out_refs = []
    masker_idxs = []
    outs = []
    for mask in masks:
        local_out = ""
        local_out_ref = ""
        local_masker_idx = 0
        for idx in range(len(mask)):
            if mask[idx]:
                local_out += text_as_list[idx]
                local_out_ref += text_as_list[idx]
            else:
                try:
                    local_out += tokenizer.additional_special_tokens[local_masker_idx]
                    local_masker_idx += 1
                except IndexError:
                    continue
        masker_idxs.append(local_masker_idx)
        outs.append(local_out)
        out_refs.append(local_out_ref)

    model.cuda()
    batch_size = 50
    if len(outs) > batch_size:
        input_ids = tokenizer(outs, return_tensors="pt", padding=True, max_length=512, truncation=True)
        model_suggestions = []
        for idx in range(0, len(input_ids.input_ids), batch_size):
            local_inputs = {}
            for key, value in input_ids.items():
                local_inputs[key] = value[idx: idx+batch_size]
            for key, value in local_inputs.items():
                local_inputs[key] = value.cuda()
            outputs = model.generate(**local_inputs)
            model_suggestions.extend(tokenizer.batch_decode(outputs, skip_special_tokens=False))
    else:
        input_ids = tokenizer(outs, return_tensors="pt", padding=True)
        for key, value in input_ids.items():
            input_ids[key] = value.cuda()
        outputs = model.generate(**input_ids)
        model_suggestions = tokenizer.batch_decode(outputs, skip_special_tokens=False)

    inversed_data = []
    for idx, suggestion in enumerate(model_suggestions):
        local_out = outs[idx]
        local_masker_idx = masker_idxs[idx]
        present_tokens = [tokenizer.additional_special_tokens[idx] for idx in range(local_masker_idx) if
                          tokenizer.additional_special_tokens[idx] in suggestion]
        for idx, present in enumerate(present_tokens):
            if idx == len(present_tokens) - 1:
                index = suggestion.find(present)
                start_idx = index + len(present)
                local_out = local_out.replace(present, suggestion[start_idx:])
            else:
                base_index = suggestion.find(present)
                start_idx = base_index + len(present)
                upper_index = suggestion.find(present_tokens[idx + 1])
                local_out = local_out.replace(present, suggestion[start_idx:upper_index])
        for item in tokenizer.additional_special_tokens:
            local_out = local_out.replace(item, "")
        inversed_data.append(local_out)
    return inversed_data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant