Skip to content

Commit

Permalink
This commit adds a statistical test to check the expected length of a…
Browse files Browse the repository at this point in the history
… simple regex. It has a frozen seed to avoid flakiness, but there is information in there to make it a proper probabilistic test. It's also possible to reduce the run length by taking fewer samples.
  • Loading branch information
dpsimpson committed Oct 17, 2024
1 parent e2b8831 commit 47f5215
Showing 1 changed file with 28 additions and 38 deletions.
66 changes: 28 additions & 38 deletions tests/statistical/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,78 +3,68 @@
from typing import List, Optional
from outlines_core.fsm.guide import RegexGuide


def test_generate_length():
class MockTokenizer:
vocabulary = {"0": 1, "1": 2, "eos": 3}
inverse_vocabulary = {1: "0", 2: "1", 3: ""}
special_tokens = {"eos"}
eos_token_id = 3

def length(self):
return len(self.vocabulary)

def convert_token_to_string(self, token):
return token

def decode(self, token):
return self.inverse_vocabulary[token]

class NextToken:
probs: dict[int, List[float]] = {
1: [0.3, 0.4, 0.0],
2: [0.4, 0.3, 0.1],
1: [0.2, 0.5, 0.3],
2: [0.3, 0.4, 0.3],
3: [0, 0, 0]
}
p0: List[float] = [0.2, 0.8, 0.0]
states: List[int] = [1, 2, 3]

def __call__(self, token: Optional[int], *, mask: List[int]) -> int:
if token is None:
prob = [p * m for (p, m) in zip(self.p0, mask)]
elif token in self.states:
prob = [p * m for (p, m) in zip(self.probs[token], mask)]
else:
raise ValueError("Should not be reached")
return np.random.sample(self.states, p = prob)
return np.random.choice(self.states, p=prob / np.sum(prob))

def generate(model, tokenizer, regex_str) -> str:
out_str: str = ""
n_tokens = tokenizer.length()

fsm = RegexGuide.from_regex(regex_str, tokenizer)

state = 0
state: int = fsm.initial_state
token = None
while not fsm.is_final_state(state):
allowed = fsm.get_next_instruction(state)
mask = [1 if s in allowed else 0 for s in range(1, n_tokens + 1)]
token = model(token, mask = mask)
out_str += tokenizer.inverse_vocabulary[token]
while state != -1:
allowed = fsm.get_next_instruction(state).tokens
mask: List[int] = [1 if s in allowed else 0 for s in range(1, n_tokens + 1)]
token = model(token, mask=mask)
out_str += tokenizer.decode(token)
state = fsm.get_next_state(state, token)
return out_str

n_samples: int = 1000
regex_str: str = r"11[01]+|0[01]*"
tokenizer = MockTokenizer()
model = NextToken()

tot = 0
np.random.seed(30127)

tot: int = 0
for i in range(n_samples):
out = generate(model, tokenizer, regex_str)
out: str = generate(model, tokenizer, regex_str)
# print(out)
tot += len(out)

mean = tot / n_samples
print(mean)



if __name__ == "__main__":
test_generate_length()












mean: float = tot / n_samples
# mean ~ N(4.93, (2.9/sqrt(n))^2)
assert mean == 4.88

0 comments on commit 47f5215

Please sign in to comment.