Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding return_lowest_perplexity #206

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions exllamav2/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def generate_simple(self, prompt: str or list,
encode_special_tokens = False,
decode_special_tokens = False,
loras = None,
stop_token = -1):
stop_token = -1,
return_lowest_perplexity = False):

# Default stop token

Expand All @@ -68,14 +69,20 @@ def generate_simple(self, prompt: str or list,

# Tokenize input and produce padding mask if needed

batch_size = 1 if isinstance(prompt, str) else len(prompt)
if return_lowest_perplexity:
batch_size = self.cache.batch_size
else:
batch_size = 1 if isinstance(prompt, str) else len(prompt)
assert batch_size > 1 or not return_lowest_perplexity, "When return_lowest_perplexity is set, batch_size should be greater than 1"
assert isinstance(prompt, str) or not return_lowest_perplexity, "When return_lowest_perplexity is set, the prompt should be a single string"
ids, position_offsets = self.tokenizer.encode(prompt, encode_special_tokens = encode_special_tokens, return_offsets = True)
if batch_size == 1: position_offsets = None
if batch_size == 1 or return_lowest_perplexity: position_offsets = None
if return_lowest_perplexity: ids = ids.repeat(batch_size, 1)

overflow = ids.shape[-1] + num_tokens - self.model.config.max_seq_len
if overflow > 0: ids = ids[:, overflow:]

mask = self.tokenizer.padding_mask(ids) if batch_size > 1 else None
mask = self.tokenizer.padding_mask(ids) if batch_size > 1 and not return_lowest_perplexity else None

# Prepare for healing

Expand All @@ -102,16 +109,21 @@ def generate_simple(self, prompt: str or list,
# Generate tokens

batch_eos = [False] * batch_size
if return_lowest_perplexity:
logprob_sum = torch.zeros(batch_size)
sequence_length = torch.zeros(batch_size)

for i in range(num_tokens):

logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, input_mask = mask, loras = loras, position_offsets = position_offsets).float().cpu()
token, _, _ = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token = unhealed_token)
token, output_probs, _ = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token = unhealed_token)

eos = False
if stop_token is not None:
for b in range(batch_size):
if token[b, 0].item() == stop_token:
if return_lowest_perplexity and not batch_eos[b]:
sequence_length[b] = i
batch_eos[b] = True
if all(batch_eos): eos = True
if batch_eos[b]:
Expand All @@ -120,14 +132,23 @@ def generate_simple(self, prompt: str or list,
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
gen_settings.feed_filters(token)

if return_lowest_perplexity:
logprob_sum = torch.add(logprob_sum,
torch.log(torch.squeeze(output_probs, -1)))

unhealed_token = None
if eos: break

# Decode

text = self.tokenizer.decode(self.sequence_ids, decode_special_tokens = decode_special_tokens)
if return_lowest_perplexity:
mean_log_prob = torch.div(logprob_sum, sequence_length)
lowest_perplexed_index = torch.argmin(mean_log_prob).item()
text = self.tokenizer.decode(self.sequence_ids[lowest_perplexed_index], decode_special_tokens = decode_special_tokens)
else:
text = self.tokenizer.decode(self.sequence_ids, decode_special_tokens = decode_special_tokens)
text = text[0] if isinstance(prompt, str) else text

if isinstance(prompt, str): return text[0]
return text


Expand Down