diff --git a/benchmarks/__init__.py b/bindings/python/benchmarks/__init__.py similarity index 100% rename from benchmarks/__init__.py rename to bindings/python/benchmarks/__init__.py diff --git a/benchmarks/asv.conf.json b/bindings/python/benchmarks/asv.conf.json similarity index 100% rename from benchmarks/asv.conf.json rename to bindings/python/benchmarks/asv.conf.json diff --git a/benchmarks/bench_json_schema.py b/bindings/python/benchmarks/bench_json_schema.py similarity index 100% rename from benchmarks/bench_json_schema.py rename to bindings/python/benchmarks/bench_json_schema.py diff --git a/benchmarks/bench_numba_compile.py b/bindings/python/benchmarks/bench_numba_compile.py similarity index 100% rename from benchmarks/bench_numba_compile.py rename to bindings/python/benchmarks/bench_numba_compile.py diff --git a/benchmarks/bench_regex_guide.py b/bindings/python/benchmarks/bench_regex_guide.py similarity index 100% rename from benchmarks/bench_regex_guide.py rename to bindings/python/benchmarks/bench_regex_guide.py diff --git a/benchmarks/common.py b/bindings/python/benchmarks/common.py similarity index 100% rename from benchmarks/common.py rename to bindings/python/benchmarks/common.py diff --git a/tests/__init__.py b/bindings/python/tests/__init__.py similarity index 100% rename from tests/__init__.py rename to bindings/python/tests/__init__.py diff --git a/tests/fsm/partial_python.lark b/bindings/python/tests/fsm/partial_python.lark similarity index 100% rename from tests/fsm/partial_python.lark rename to bindings/python/tests/fsm/partial_python.lark diff --git a/bindings/python/tests/fsm/test_fsm.py b/bindings/python/tests/fsm/test_fsm.py new file mode 100644 index 00000000..aeb7060c --- /dev/null +++ b/bindings/python/tests/fsm/test_fsm.py @@ -0,0 +1,91 @@ +import pytest +from outlines_core.fsm.fsm import RegexFSM, StopAtEosFSM + + +def assert_expected_tensor_ids(tensor, ids): + assert len(tensor) == len(ids) + norm_tensor = sorted(map(int, tensor)) + norm_ids = sorted(map(int, tensor)) + assert norm_tensor == norm_ids, (norm_tensor, norm_ids) + + +def test_stop_at_eos(): + class MockTokenizer: + vocabulary = {"a": 1, "eos": 2} + eos_token_id = 2 + + with pytest.warns(UserWarning): + fsm = StopAtEosFSM(MockTokenizer()) + + assert fsm.allowed_token_ids(fsm.start_state) is None + assert fsm.allowed_token_ids(fsm.final_state) == [2] + assert fsm.next_state(fsm.start_state, 2) == fsm.final_state + assert fsm.next_state(fsm.start_state, 1) == fsm.start_state + assert fsm.is_final_state(fsm.start_state) is False + assert fsm.is_final_state(fsm.final_state) is True + + +def test_regex_vocabulary_error(): + class MockTokenizer: + vocabulary = {"a": 1} + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + regex_str = "[1-9]" + + with pytest.raises(ValueError, match="The vocabulary"): + RegexFSM(regex_str, MockTokenizer()) + + +def test_regex(): + class MockTokenizer: + vocabulary = {"1": 1, "a": 2, "eos": 3} + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + regex_str = "[1-9]" + tokenizer = MockTokenizer() + + with pytest.warns(UserWarning): + fsm = RegexFSM(regex_str, tokenizer) + + assert fsm.states_to_token_maps == {0: {1: 1}} + assert_expected_tensor_ids(fsm.allowed_token_ids(state=0), [1]) + assert fsm.next_state(state=0, token_id=1) == 1 + assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == -1 + + assert fsm.is_final_state(0) is False + + for state in fsm.final_states: + assert fsm.is_final_state(state) is True + + +def test_regex_final_state(): + """Make sure that the FSM stays in the final state as we keep generating""" + + class MockTokenizer: + vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104} + special_tokens = {"eos"} + eos_token_id = 104 + + def convert_token_to_string(self, token): + return token + + regex_str = r"`\n(\.\n)?`\n" + tokenizer = MockTokenizer() + + with pytest.warns(UserWarning): + fsm = RegexFSM(regex_str, tokenizer) + + state = fsm.next_state(state=4, token_id=103) + assert state == 5 + assert fsm.is_final_state(state) + + state = fsm.next_state(state=5, token_id=103) + assert fsm.is_final_state(state) diff --git a/bindings/python/tests/fsm/test_guide.py b/bindings/python/tests/fsm/test_guide.py new file mode 100644 index 00000000..0bd28d4f --- /dev/null +++ b/bindings/python/tests/fsm/test_guide.py @@ -0,0 +1,189 @@ +import pytest +from outlines_core.fsm.guide import Generate, RegexGuide, StopAtEOSGuide, Write + + +def assert_expected_tensor_ids(tensor, ids): + assert len(tensor) == len(ids) + norm_tensor = sorted(map(int, tensor)) + norm_ids = sorted(map(int, tensor)) + assert norm_tensor == norm_ids, (norm_tensor, norm_ids) + + +def test_stop_at_eos(): + class MockTokenizer: + vocabulary = {"a": 1, "eos": 2} + eos_token_id = 2 + + fsm = StopAtEOSGuide(MockTokenizer()) + + instruction = fsm.get_next_instruction(fsm.start_state) + assert isinstance(instruction, Generate) + assert instruction.tokens is None + + instruction = fsm.get_next_instruction(fsm.final_state) + assert isinstance(instruction, Write) + assert instruction.tokens == [2] + + assert fsm.get_next_state(fsm.start_state, 2) == fsm.final_state + assert fsm.get_next_state(fsm.start_state, 1) == fsm.start_state + assert fsm.is_final_state(fsm.start_state) is False + assert fsm.is_final_state(fsm.final_state) is True + + +def test_regex_vocabulary_error(): + class MockTokenizer: + vocabulary = {"a": 1} + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + regex_str = "[1-9]" + + with pytest.raises(ValueError, match="The vocabulary"): + RegexGuide(regex_str, MockTokenizer()) + + +def test_regex(): + class MockTokenizer: + vocabulary = {"1": 1, "a": 2, "eos": 3} + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + regex_str = "[1-9]" + tokenizer = MockTokenizer() + fsm = RegexGuide(regex_str, tokenizer) + + assert fsm.states_to_token_maps == {0: {1: 1}} + + instruction = fsm.get_next_instruction(0) + assert isinstance(instruction, Generate) + assert_expected_tensor_ids(instruction.tokens, [1]) + + assert fsm.get_next_state(state=0, token_id=1) == 1 + assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 + + assert fsm.is_final_state(0) is False + + for state in fsm.final_states: + assert fsm.is_final_state(state) is True + + +def test_regex_multi_byte_llama_like(): + class MockTokenizer: + vocabulary = { + "1": 1, + "a": 2, + "eos": 3, + "😍": 4, + "<0xF0>": 5, + "<0x9F>": 6, + "<0x98>": 7, + "<0x88>": 8, # 😈 + "\ufffd": 9, + "\ufffd\ufffd": 10, + } + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + if token[0] == "<": + return "\ufffd" + return token + + regex_str = "[😁-😎]" + tokenizer = MockTokenizer() + fsm = RegexGuide(regex_str, tokenizer) + + assert fsm.states_to_token_maps == { + 0: {5: 1, 4: 2}, + 1: {6: 3}, + 3: {7: 4}, + 4: {8: 2}, + } + + instruction = fsm.get_next_instruction(0) + assert isinstance(instruction, Generate) + assert_expected_tensor_ids(instruction.tokens, [5, 4]) + + assert fsm.get_next_state(state=0, token_id=5) == 1 + assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 + + assert fsm.is_final_state(0) is False + + for state in fsm.final_states: + assert fsm.is_final_state(state) is True + + +def test_regex_multi_byte_gpt2_like(): + class MockTokenizer: + vocabulary = { + "1": 1, + "a": 2, + "eos": 3, + "😍": 4, + " ": 5, + "\ufffd": 6, + "\ufffd\ufffd": 7, + "ðŁĺ": 8, + "Δͺ": 9, # '😈' + "Δ Γ°": 10, + "ŁĺΔͺ": 11, # ' 😈' + } + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + if self.vocabulary[token] >= 8: + return "\ufffd" + return token + + regex_str = " [😁-😎]" + tokenizer = MockTokenizer() + fsm = RegexGuide(regex_str, tokenizer) + + assert fsm.states_to_token_maps == { + 0: {5: 1, 10: 2}, + 1: {8: 5, 4: 3}, + 2: {11: 3}, + 5: {9: 3}, + } + + instruction = fsm.get_next_instruction(0) + assert isinstance(instruction, Generate) + assert_expected_tensor_ids(instruction.tokens, [5, 10]) + + assert fsm.get_next_state(state=0, token_id=5) == 1 + assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 + + assert fsm.is_final_state(0) is False + + for state in fsm.final_states: + assert fsm.is_final_state(state) is True + + +def test_regex_final_state(): + """Make sure that the FSM stays in the final state as we keep generating""" + + class MockTokenizer: + vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104} + special_tokens = {"eos"} + eos_token_id = 104 + + def convert_token_to_string(self, token): + return token + + regex_str = r"`\n(\.\n)?`\n" + tokenizer = MockTokenizer() + fsm = RegexGuide(regex_str, tokenizer) + + state = fsm.get_next_state(state=4, token_id=103) + assert state == 5 + assert fsm.is_final_state(state) + + state = fsm.get_next_state(state=5, token_id=103) + assert fsm.is_final_state(state) diff --git a/bindings/python/tests/fsm/test_json_schema.py b/bindings/python/tests/fsm/test_json_schema.py new file mode 100644 index 00000000..3fa3d79c --- /dev/null +++ b/bindings/python/tests/fsm/test_json_schema.py @@ -0,0 +1,1040 @@ +import json +import re +from typing import List, Literal, Union + +import interegular +import pytest +from outlines_core.fsm.json_schema import ( + BOOLEAN, + DATE, + DATE_TIME, + INTEGER, + NULL, + NUMBER, + STRING, + STRING_INNER, + TIME, + UUID, + WHITESPACE, + build_regex_from_schema, + get_schema_from_signature, + to_regex, +) +from pydantic import BaseModel, Field, constr + + +def test_function_basic(): + def test_function(foo: str, bar: List[int]): + pass + + result = get_schema_from_signature(test_function) + assert result["type"] == "object" + assert list(result["properties"].keys()) == ["foo", "bar"] + assert result["properties"]["foo"]["type"] == "string" + assert result["properties"]["bar"]["type"] == "array" + assert result["properties"]["bar"]["items"]["type"] == "integer" + + +def test_function_no_type(): + def test_function(foo, bar: List[int]): + pass + + with pytest.raises(ValueError): + get_schema_from_signature(test_function) + + +def test_from_pydantic(): + class User(BaseModel): + user_id: int + name: str + maxlength_name: constr(max_length=10) + minlength_name: constr(min_length=10) + value: float + is_true: bool + + schema = json.dumps(User.model_json_schema()) + schedule = build_regex_from_schema(schema) + assert isinstance(schedule, str) + + +@pytest.mark.parametrize( + "pattern,does_match", + [ + ({"integer": "0"}, True), + ({"integer": "1"}, True), + ({"integer": "-1"}, True), + ({"integer": "01"}, False), + ({"integer": "1.3"}, False), + ({"integer": "t"}, False), + ], +) +def test_match_integer(pattern, does_match): + step = {"title": "Foo", "type": "integer"} + regex = to_regex(None, step) + assert regex == INTEGER + + value = pattern["integer"] + match = re.fullmatch(regex, value) + if does_match: + assert match[0] == value + assert match.span() == (0, len(value)) + else: + assert match is None + + +@pytest.mark.parametrize( + "pattern,does_match", + [ + ({"number": "1"}, True), + ({"number": "0"}, True), + ({"number": "01"}, False), + ({"number": ".3"}, False), + ({"number": "1.3"}, True), + ({"number": "-1.3"}, True), + ({"number": "1.3e9"}, False), + ({"number": "1.3e+9"}, True), + ], +) +def test_match_number(pattern, does_match): + step = {"title": "Foo", "type": "number"} + regex = to_regex(None, step) + assert regex == NUMBER + + value = pattern["number"] + match = re.fullmatch(regex, value) + if does_match: + assert match[0] == value + assert match.span() == (0, len(value)) + else: + assert match is None + + +@pytest.mark.parametrize( + "schema,regex,examples", + [ + # String + ( + {"title": "Foo", "type": "string"}, + STRING, + [ + ("unquotedstring", False), + ('"(parenthesized_string)"', True), + ('"malformed) parenthesis (((() string"', True), + ('"quoted_string"', True), + (r'"escape_\character"', False), + (r'"double_\\escape"', True), + (r'"\n"', False), + (r'"\\n"', True), + (r'"unescaped " quote"', False), + (r'"escaped \" quote"', True), + ], + ), + # String with maximum length + ( + {"title": "Foo", "type": "string", "maxLength": 3}, + f'"{STRING_INNER}{{,3}}"', + [('"ab"', True), ('"a""', False), ('"abcd"', False)], + ), + # String with minimum length + ( + {"title": "Foo", "type": "string", "minLength": 3}, + f'"{STRING_INNER}{{3,}}"', + [('"ab"', False), ('"abcd"', True), ('"abc""', False)], + ), + # String with both minimum and maximum length + ( + {"title": "Foo", "type": "string", "minLength": 3, "maxLength": 5}, + f'"{STRING_INNER}{{3,5}}"', + [('"ab"', False), ('"abcd"', True), ('"abcdef""', False)], + ), + # String defined by a regular expression + ( + {"title": "Foo", "type": "string", "pattern": r"^[a-z]$"}, + r'("[a-z]")', + [('"a"', True), ('"1"', False)], + ), + # Boolean + ( + {"title": "Foo", "type": "boolean"}, + BOOLEAN, + [ + ("true", True), + ("false", True), + ("null", False), + ("0", False), + ], + ), + # Null + ( + {"title": "Foo", "type": "null"}, + NULL, + [ + ("null", True), + ("true", False), + ("0", False), + ], + ), + # Const string + ( + {"title": "Foo", "const": "Marc", "type": "string"}, + '"Marc"', + [('"Marc"', True), ('"Jean"', False), ('"John"', False)], + ), + # Make sure strings are escaped with regex escaping + ( + {"title": "Foo", "const": ".*", "type": "string"}, + r'"\.\*"', + [('".*"', True), (r'"\s*"', False), (r'"\.\*"', False)], + ), + # Make sure strings are escaped with JSON escaping + ( + {"title": "Foo", "const": '"', "type": "string"}, + r'"\\""', + [('"\\""', True), ('"""', False)], + ), + # Const integer + ( + {"title": "Foo", "const": 0, "type": "integer"}, + "0", + [("0", True), ("1", False), ("a", False)], + ), + # Const float + ( + {"title": "Foo", "const": 0.2, "type": "float"}, + r"0\.2", + [("0.2", True), ("032", False)], + ), + # Const boolean + ( + {"title": "Foo", "const": True, "type": "boolean"}, + "true", + [("true", True), ("True", False)], + ), + # Const null + ( + {"title": "Foo", "const": None, "type": "null"}, + "null", + [("null", True), ("None", False), ("", False)], + ), + # Enum string + ( + {"title": "Foo", "enum": ["Marc", "Jean"], "type": "string"}, + '("Marc"|"Jean")', + [('"Marc"', True), ('"Jean"', True), ('"John"', False)], + ), + # Make sure strings are escaped with regex and JSON escaping + ( + {"title": "Foo", "enum": [".*", r"\s*"], "type": "string"}, + r'("\.\*"|"\\\\s\*")', + [('".*"', True), (r'"\\s*"', True), (r'"\.\*"', False)], + ), + # Enum integer + ( + {"title": "Foo", "enum": [0, 1], "type": "integer"}, + "(0|1)", + [("0", True), ("1", True), ("a", False)], + ), + # Enum mix of types + ( + {"title": "Foo", "enum": [6, 5.3, "potato", True, None]}, + r'(6|5\.3|"potato"|true|null)', + [ + ("6", True), + ("5.3", True), + ('"potato"', True), + ("true", True), + ("null", True), + ("523", False), + ("True", False), + ("None", False), + ], + ), + # integer + ( + { + "title": "Foo", + "type": "object", + "properties": {"count": {"title": "Count", "type": "integer"}}, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?\\}', + [('{ "count": 100 }', True)], + ), + # integer with minimum digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": {"title": "Count", "type": "integer", "minDigits": 3} + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,})[ ]?\\}', + [('{ "count": 10 }', False), ('{ "count": 100 }', True)], + ), + # integer with maximum digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": {"title": "Count", "type": "integer", "maxDigits": 3} + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{,2})[ ]?\\}', + [('{ "count": 100 }', True), ('{ "count": 1000 }', False)], + ), + # integer with minimum and maximum digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": { + "title": "Count", + "type": "integer", + "minDigits": 3, + "maxDigits": 5, + } + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,4})[ ]?\\}', + [ + ('{ "count": 10 }', False), + ('{ "count": 100 }', True), + ('{ "count": 10000 }', True), + ('{ "count": 100000 }', False), + ], + ), + # number + ( + { + "title": "Foo", + "type": "object", + "properties": {"count": {"title": "Count", "type": "number"}}, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\\}', + [('{ "count": 100 }', True), ('{ "count": 100.5 }', True)], + ), + # number with min and max integer digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": { + "title": "Count", + "type": "number", + "minDigitsInteger": 3, + "maxDigitsInteger": 5, + } + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]{2,4}))(\\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\\}', + [ + ('{ "count": 10.005 }', False), + ('{ "count": 100.005 }', True), + ('{ "count": 10000.005 }', True), + ('{ "count": 100000.005 }', False), + ], + ), + # number with min and max fraction digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": { + "title": "Count", + "type": "number", + "minDigitsFraction": 3, + "maxDigitsFraction": 5, + } + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]{3,5})?([eE][+-][0-9]+)?[ ]?\\}', + [ + ('{ "count": 1.05 }', False), + ('{ "count": 1.005 }', True), + ('{ "count": 1.00005 }', True), + ('{ "count": 1.000005 }', False), + ], + ), + # number with min and max exponent digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": { + "title": "Count", + "type": "number", + "minDigitsExponent": 3, + "maxDigitsExponent": 5, + } + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]+)?([eE][+-][0-9]{3,5})?[ ]?\\}', + [ + ('{ "count": 1.05e1 }', False), + ('{ "count": 1.05e+001 }', True), + ('{ "count": 1.05e-00001 }', True), + ('{ "count": 1.05e0000001 }', False), + ], + ), + # number with min and max integer, fraction and exponent digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": { + "title": "Count", + "type": "number", + "minDigitsInteger": 3, + "maxDigitsInteger": 5, + "minDigitsFraction": 3, + "maxDigitsFraction": 5, + "minDigitsExponent": 3, + "maxDigitsExponent": 5, + } + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]{2,4}))(\\.[0-9]{3,5})?([eE][+-][0-9]{3,5})?[ ]?\\}', + [ + ('{ "count": 1.05e1 }', False), + ('{ "count": 100.005e+001 }', True), + ('{ "count": 10000.00005e-00001 }', True), + ('{ "count": 100000.000005e0000001 }', False), + ], + ), + # array + ( + {"title": "Foo", "type": "array", "items": {"type": "number"}}, + rf"\[{WHITESPACE}(({NUMBER})(,{WHITESPACE}({NUMBER})){{0,}})?{WHITESPACE}\]", + [("[1e+9,1.3]", True), ("[]", True), ("[1", False)], + ), + # array with a set length of 1 + ( + { + "title": "Foo", + "type": "array", + "items": {"type": "integer"}, + "minItems": 1, + "maxItems": 1, + }, + rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{0,0}}){WHITESPACE}\]", + [("[1]", True), ("[1,2]", False), ('["a"]', False), ("[]", False)], + ), + # array with a set length greather than 1 + ( + { + "title": "Foo", + "type": "array", + "items": {"type": "integer"}, + "minItems": 3, + "maxItems": 3, + }, + rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{2,2}}){WHITESPACE}\]", + [("[1]", False), ("[]", False), ("[1,2,3]", True), ("[1,2,3,4]", False)], + ), + # array with length 0 + ( + { + "title": "Foo", + "type": "array", + "items": {"type": "integer"}, + "minItems": 0, + "maxItems": 0, + }, + rf"\[{WHITESPACE}\]", + [("[1]", False), ("[]", True), ("[1,2,3]", False), ("[1,2,3,4]", False)], + ), + # object + ( + { + "title": "TestSchema", + "type": "object", + "properties": { + "test_dict": { + "title": "Test Dict", + "additionalProperties": {"type": "string"}, + "type": "object", + } + }, + "required": ["test_dict"], + }, + rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", + [ + ("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True), + ("""{ "test_dict":{"foo":"bar" }}""", True), + ("""{ "test_dict":{}}""", True), + ("""{ "WRONG_KEY":{}}""", False), + ("""{ "test_dict":{"wrong_type" 1}}""", False), + ], + ), + # object containing object + ( + { + "title": "TestSchema", + "type": "object", + "properties": { + "test_dict": { + "title": "Test Dict", + "additionalProperties": { + "additionalProperties": {"type": "integer"}, + "type": "object", + }, + "type": "object", + } + }, + "required": ["test_dict"], + }, + rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", + [ + ( + """{"test_dict": {"foo": {"bar": 123, "apple": 99}, "baz": {"bif": 456}}}""", + True, + ), + ( + """{"test_dict": {"anykey": {"anykey": 123}, "anykey2": {"bif": 456}}}""", + True, + ), + ("""{"test_dict": {}}""", True), + ("""{"test_dict": {"dict of empty dicts are ok": {} }}""", True), + ( + """{"test_dict": {"anykey": {"ONLY Dict[Dict]": 123}, "No Dict[int]" 1: }}""", + False, + ), + ], + ), + # oneOf + ( + { + "title": "Foo", + "oneOf": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}], + }, + rf'((?:"{STRING_INNER}*")|(?:{NUMBER})|(?:{BOOLEAN}))', + [ + ("12.3", True), + ("true", True), + ('"a"', True), + ("null", False), + ("", False), + ("12true", False), + ('1.3"a"', False), + ('12.3true"a"', False), + ], + ), + # anyOf + ( + { + "title": "Foo", + "anyOf": [{"type": "string"}, {"type": "integer"}], + }, + rf"({STRING}|{INTEGER})", + [("12", True), ('"a"', True), ('1"a"', False)], + ), + # allOf + ( + { + "title": "Foo", + "allOf": [{"type": "string"}, {"type": "integer"}], + }, + rf"({STRING}{INTEGER})", + [('"a"1', True), ('"a"', False), ('"1"', False)], + ), + # Tuple / prefixItems + ( + { + "title": "Foo", + "prefixItems": [{"type": "string"}, {"type": "integer"}], + }, + rf"\[{WHITESPACE}{STRING}{WHITESPACE},{WHITESPACE}{INTEGER}{WHITESPACE}\]", + [('["a", 1]', True), ('["a", 1, 1]', False), ("[]", False)], + ), + # Nested schema + ( + { + "title": "Bar", + "type": "object", + "properties": { + "fuzz": { + "title": "Foo", + "type": "object", + "properties": {"spam": {"title": "Spam", "type": "integer"}}, + "required": ["spam"], + } + }, + "required": ["fuzz"], + }, + f'\\{{[ ]?"fuzz"[ ]?:[ ]?\\{{[ ]?"spam"[ ]?:[ ]?{INTEGER}[ ]?\\}}[ ]?\\}}', + [('{ "fuzz": { "spam": 100 }}', True)], + ), + # Schema with a reference + ( + { + "title": "User", + "type": "object", + "properties": { + "user_id": {"title": "User Id", "type": "integer"}, + "name": {"title": "Name", "type": "string"}, + "a": {"$ref": "#/properties/name"}, + }, + "required": ["user_id", "name", "a"], + }, + f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"a"[ ]?:[ ]?{STRING}[ ]?\\}}', + [('{"user_id": 100, "name": "John", "a": "Marc"}', True)], + ), + ( + { + "title": "User", + "type": "object", + "$defs": {"name": {"title": "Name2", "type": "string"}}, + "properties": { + "user_id": {"title": "User Id", "type": "integer"}, + "name": {"title": "Name", "type": "string"}, + "name2": {"$ref": "#/$defs/name"}, + }, + "required": ["user_id", "name", "name2"], + }, + f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"name2"[ ]?:[ ]?{STRING}[ ]?\\}}', + [('{"user_id": 100, "name": "John", "name2": "Marc"}', True)], + ), + ( + { + "$id": "customer", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "Customer", + "type": "object", + "properties": { + "name": {"type": "string"}, + "last_name": {"type": "string"}, + "address": {"$ref": "customer#/$defs/address"}, + }, + "required": [ + "name", + "first_name", + "last_name", + "address", + "shipping_address", + "billing_address", + ], + "$defs": { + "address": { + "title": "Address", + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "city": {"type": "string"}, + }, + "required": ["street_address", "city", "state"], + "definitions": { + "state": { + "type": "object", + "title": "State", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + }, + } + }, + }, + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"last_name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"address"[ ]?:[ ]?\\{{[ ]?"city"[ ]?:[ ]?{STRING}[ ]?\\}}[ ]?\\}}', + [ + ( + '{"name": "John", "last_name": "Doe", "address": {"city": "Paris"}}', + True, + ) + ], + ), + # Optional properties + # Last required property in first position + ( + { + "properties": { + "name": {"type": "string"}, + "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + "weapon": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + }, + "required": ["name"], + "title": "Character", + "type": "object", + }, + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"weapon"[ ]?:[ ]?({STRING}|null))?[ ]?\\}}', + [ + ('{ "name" : "Player" }', True), + ('{ "name" : "Player", "weapon" : "sword" }', True), + ('{ "age" : 10, "weapon" : "sword" }', False), + ], + ), + # Last required property in middle position + ( + { + "properties": { + "name": {"type": "string"}, + "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + "weapon": {"type": "string"}, + "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + }, + "required": ["name", "weapon"], + "title": "Character", + "type": "object", + }, + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', + [ + ('{ "name" : "Player" , "weapon" : "sword" }', True), + ( + '{ "name" : "Player", "age" : 10, "weapon" : "sword" , "strength" : 10 }', + True, + ), + ('{ "weapon" : "sword" }', False), + ], + ), + # Last required property in last position + ( + { + "properties": { + "name": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + "age": {"type": "integer"}, + "armor": {"type": "string"}, + "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + "weapon": {"title": "Weapon", "type": "string"}, + }, + "required": ["age", "armor", "weapon"], + "title": "Character", + "type": "object", + }, + f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"armor"[ ]?:[ ]?{STRING}[ ]?,([ ]?"strength"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}[ ]?\\}}', + [ + ( + '{ "name" : "Player", "age" : 10, "armor" : "plate", "strength" : 11, "weapon" : "sword" }', + True, + ), + ('{ "age" : 10, "armor" : "plate", "weapon" : "sword" }', True), + ( + '{ "name" : "Kahlhanbeh", "armor" : "plate", "weapon" : "sword" }', + False, + ), + ], + ), + # All properties are optional + ( + { + "properties": { + "name": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + }, + "title": "Character", + "type": "object", + }, + f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?({INTEGER}|null)([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', + [ + ('{ "name" : "Player" }', True), + ('{ "name" : "Player", "age" : 10, "strength" : 10 }', True), + ('{ "age" : 10, "strength" : 10 }', True), + ("{ }", True), + ], + ), + ], +) +def test_match(schema, regex, examples): + interegular.parse_pattern(regex) + schema = json.dumps(schema) + test_regex = build_regex_from_schema(schema) + assert test_regex == regex + + for string, does_match in examples: + match = re.fullmatch(test_regex, string) + if does_match: + if match is None: + raise ValueError(f"Expected match for '{string}'") + assert match[0] == string + assert match.span() == (0, len(string)) + else: + assert match is None + + +@pytest.mark.parametrize( + "schema,regex,examples", + [ + # UUID + ( + {"title": "Foo", "type": "string", "format": "uuid"}, + UUID, + [ + ("123e4567-e89b-12d3-a456-426614174000", False), + ('"123e4567-e89b-12d3-a456-426614174000"', True), + ('"123e4567-e89b-12d3-a456-42661417400"', False), + ('"123e4567-e89b-12d3-a456-42661417400g"', False), + ('"123e4567-e89b-12d3-a456-42661417400-"', False), + ('""', False), + ], + ), + # DATE-TIME + ( + {"title": "Foo", "type": "string", "format": "date-time"}, + DATE_TIME, + [ + ("2018-11-13T20:20:39Z", False), + ('"2018-11-13T20:20:39Z"', True), + ('"2016-09-18T17:34:02.666Z"', True), + ('"2008-05-11T15:30:00Z"', True), + ('"2021-01-01T00:00:00"', True), + ('"2022-01-10 07:19:30"', False), # missing T + ('"2022-12-10T10-04-29"', False), # incorrect separator + ('"2023-01-01"', False), + ], + ), + # DATE + ( + {"title": "Foo", "type": "string", "format": "date"}, + DATE, + [ + ("2018-11-13", False), + ('"2018-11-13"', True), + ('"2016-09-18"', True), + ('"2008-05-11"', True), + ('"2015-13-01"', False), # incorrect month + ('"2022-01"', False), # missing day + ('"2022/12/01"', False), # incorrect separator" + ], + ), + # TIME + ( + {"title": "Foo", "type": "string", "format": "time"}, + TIME, + [ + ("20:20:39Z", False), + ('"20:20:39Z"', True), + ('"15:30:00Z"', True), + ('"25:30:00"', False), # incorrect hour + ('"15:30"', False), # missing seconds + ('"15:30:00.000"', False), # missing Z + ('"15-30-00"', False), # incorrect separator + ('"15:30:00+01:00"', False), # incorrect separator + ], + ), + ], +) +def test_format(schema, regex, examples): + interegular.parse_pattern(regex) + schema = json.dumps(schema) + test_regex = build_regex_from_schema(schema) + assert test_regex == regex + + for string, does_match in examples: + match = re.fullmatch(test_regex, string) + if does_match: + assert match[0] == string + assert match.span() == (0, len(string)) + else: + assert match is None + + +@pytest.mark.parametrize( + "schema,examples", + [ + # NESTED UUID + ( + { + "title": "Foo", + "type": "object", + "properties": {"uuid": {"type": "string", "format": "uuid"}}, + }, + [ + ('{"uuid": "123e4567-e89b-12d3-a456-426614174000"}', True), + ('{"uuid":"123e4567-e89b-12d3-a456-42661417400"}', False), + ('{"uuid":"123e4567-e89b-12d3-a456-42661417400g"}', False), + ('{"uuid":"123e4567-e89b-12d3-a456-42661417400-"}', False), + ( + '{"uuid":123e4567-e89b-12d3-a456-426614174000}', + False, + ), # missing quotes for value + ('{"uuid":""}', False), + ], + ), + # NESTED DATE-TIME + ( + { + "title": "Foo", + "type": "object", + "properties": {"dateTime": {"type": "string", "format": "date-time"}}, + }, + [ + ('{"dateTime": "2018-11-13T20:20:39Z"}', True), + ('{"dateTime":"2016-09-18T17:34:02.666Z"}', True), + ('{"dateTime":"2008-05-11T15:30:00Z"}', True), + ('{"dateTime":"2021-01-01T00:00:00"}', True), + ('{"dateTime":"2022-01-10 07:19:30"}', False), # missing T + ('{"dateTime":"2022-12-10T10-04-29"}', False), # incorrect separator + ( + '{"dateTime":2018-11-13T20:20:39Z}', + False, + ), # missing quotes for value + ('{"dateTime":"2023-01-01"}', False), + ], + ), + # NESTED DATE + ( + { + "title": "Foo", + "type": "object", + "properties": {"date": {"type": "string", "format": "date"}}, + }, + [ + ('{"date": "2018-11-13"}', True), + ('{"date":"2016-09-18"}', True), + ('{"date":"2008-05-11"}', True), + ('{"date":"2015-13-01"}', False), # incorrect month + ('{"date":"2022-01"}', False), # missing day + ('{"date":"2022/12/01"}', False), # incorrect separator" + ('{"date":2018-11-13}', False), # missing quotes for value + ], + ), + # NESTED TIME + ( + { + "title": "Foo", + "type": "object", + "properties": {"time": {"type": "string", "format": "time"}}, + }, + [ + ('{"time": "20:20:39Z"}', True), + ('{"time":"15:30:00Z"}', True), + ('{"time":"25:30:00"}', False), # incorrect hour + ('{"time":"15:30"}', False), # missing seconds + ('{"time":"15:30:00.000"}', False), # missing Z + ('{"time":"15-30-00"}', False), # incorrect separator + ('{"time":"15:30:00+01:00"}', False), # incorrect separator + ('{"time":20:20:39Z}', False), # missing quotes for value + ], + ), + # Unconstrained Object + ( + { + "title": "Foo", + "type": "object", + }, + [ + ("{}", True), + ('{"a": 1, "b": null}', True), + ('{"a": {"z": {"g": 4}}, "b": null}', True), + ("1234", False), # not an object + ('["a", "a"]', False), # not an array + ], + ), + # Unconstrained Array + ( + { + "type": "array", + }, + [ + ("[1, {}, false]", True), + ("[{}]", True), + ('[{"a": {"z": "q"}, "b": null}]', True), + ('[{"a": [1, 2, true], "b": null}]', True), + ('[{"a": [1, 2, true], "b": {"a": "b"}}, 1, true, [1, [2]]]', True), + # too deep, default unconstrained depth limit = 2 + ( + '[{"a": [1, 2, true], "b": {"a": "b"}}, 1, true, [1, [2, [3]]]]', + False, + ), + ('[{"a": {"z": {"g": 4}}, "b": null}]', False), + ("[[[[1]]]]", False), + # not an array + ("{}", False), + ('{"a": 1, "b": null}', False), + ('{"a": {"z": {"g": 4}}, "b": null}', False), + ("1234", False), # not an array + ('{"a": "a"}', False), # not an array + ], + ), + # No schema / unconstrained value + ( + {}, + [ + ('"aaabbuecuh"', True), # string + ("5.554", True), # number + ("true", True), # boolean + ("null", True), # null + ("5999", True), # integer + ('["a", "b"]', True), # array + ('{"key": {"k2": "value"}}', True), # nested object + ("this isnt valid json", False), + ], + ), + ], +) +def test_format_without_regex(schema, examples): + schema = json.dumps(schema) + test_regex = build_regex_from_schema(schema) + for string, does_match in examples: + match = re.fullmatch(test_regex, string) + if does_match: + assert match[0] == string + assert match.span() == (0, len(string)) + else: + assert match is None + + +@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]*", "abc"]) +def test_json_schema_custom_whitespace_pattern(whitespace_pattern): + """assert whitespace_pattern setting respected""" + + class MockModel(BaseModel): + foo: int + bar: str + + schema = json.dumps(MockModel.model_json_schema()) + + # assert any ws pattern can be used + if whitespace_pattern == "abc": + build_regex_from_schema(schema, whitespace_pattern) + return + + pattern = build_regex_from_schema(schema, whitespace_pattern) + + mock_result_mult_ws = ( + """{ "foo" : 4, \n\n\n "bar": "baz baz baz bar"\n\n}""" + ) + mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}""" + + match_default_ws = re.fullmatch(pattern, mock_result_maybe_ws) + if whitespace_pattern is None: + assert match_default_ws + else: + assert re.fullmatch(pattern, mock_result_mult_ws) + + +def test_one_of_doesnt_produce_illegal_lookaround(): + """Reproduces failure in https://github.com/outlines-dev/outlines/issues/823""" + + class Cat(BaseModel): + pet_type: Literal["cat"] + meows: int + + class Dog(BaseModel): + pet_type: Literal["dog"] + barks: float + + class Model(BaseModel): + pet: Union[Cat, Dog] = Field(..., discriminator="pet_type") + n: int + + json_schema = Model.schema_json() + + json_schema = Model.schema_json() + pattern = build_regex_from_schema(json_schema, whitespace_pattern=None) + + # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() + interegular.parse_pattern(pattern).to_fsm() diff --git a/bindings/python/tests/fsm/test_regex.py b/bindings/python/tests/fsm/test_regex.py new file mode 100644 index 00000000..d1e676cc --- /dev/null +++ b/bindings/python/tests/fsm/test_regex.py @@ -0,0 +1,717 @@ +import interegular +import numba +import numpy as np +import pytest +from outlines_core.fsm.regex import ( + _walk_fsm, + create_fsm_index_end_to_end, + create_fsm_index_tokenizer, + fsm_union, + get_sub_fsms_from_seq, + get_token_transition_keys, + get_vocabulary_transition_keys, + make_byte_level_better_fsm, + make_byte_level_fsm, + make_deterministic_fsm, + reduced_vocabulary, + walk_fsm, +) +from outlines_core.integrations.utils import adapt_tokenizer +from outlines_core.models.transformers import TransformerTokenizer +from transformers import AutoTokenizer + + +def identity(s): + return s + + +def to_bytes(s): + return [chr(b) if b < 0x80 else f"\x00{b:02X}" for b in s.encode("utf-8")] + + +def merge_symbols(byte_hexs): + return "".join(["\x00" + b if len(b) == 2 else b for b in byte_hexs]) + + +def token_str_to_trans_key(fsm, input_string): + return get_token_transition_keys( + fsm.fsm_info.alphabet_symbol_mapping, + fsm.fsm_info.alphabet_anything_value, + input_string, + ) + + +def walk_fsm_from_token_str( + fsm, + input_string: str, + start_state: int, + full_match: bool = True, +): + return walk_fsm( + fsm, + token_str_to_trans_key(fsm, input_string), + start_state, + full_match, + ) + + +def walk_fsm_from_token_str_numba( + fsm, + input_string: str, + start_state: int, + full_match: bool = True, +): + return _walk_fsm( + fsm.fsm_info.transitions, + fsm.fsm_info.initial, + fsm.fsm_info.finals, + token_str_to_trans_key(fsm, input_string), + start_state, + full_match=full_match, + ) + + +@pytest.mark.parametrize( + "function", + [ + walk_fsm_from_token_str, + walk_fsm_from_token_str_numba, + ], +) +def test_walk_fsm(function): + regex_pattern = interegular.parse_pattern("0|[1-9][2-9]*") + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + + res = tuple(function(regex_fsm, "0", regex_fsm.initial, full_match=True)) + assert res == (1,) + + res = tuple(function(regex_fsm, "00", regex_fsm.initial, full_match=False)) + assert res == (1,) + + res = tuple(function(regex_fsm, "!", regex_fsm.initial, full_match=True)) + assert res == tuple() + + res = tuple(function(regex_fsm, "00", regex_fsm.initial, full_match=True)) + assert res == tuple() + + # This should fail, because state `1` reads nothing + res = tuple(function(regex_fsm, "0", 1, full_match=True)) + assert res == tuple() + + regex_pattern = interegular.parse_pattern("0|[1-9][2-9]+") + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + + res = tuple(function(regex_fsm, "1", regex_fsm.initial, full_match=True)) + assert res == tuple() + + res = tuple(function(regex_fsm, "1", regex_fsm.initial, full_match=False)) + assert res == (2,) + + res = tuple(function(regex_fsm, "12", regex_fsm.initial, full_match=True)) + assert res == (2, 3) + + pattern = interegular.parse_pattern(r"(?:[^\W\d]\w*|[\t \x0c]+)") + fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce()) + + res = tuple(function(fsm, "x ", fsm.initial, full_match=False)) + assert res == (2,) + + start_state = list(fsm.finals)[0] + res = tuple(function(fsm, "!", start_state, full_match=False)) + assert res == tuple() + + +@pytest.mark.parametrize( + "function", + [ + walk_fsm_from_token_str, + walk_fsm_from_token_str_numba, + ], +) +@pytest.mark.parametrize( + "transform", + [ + identity, + to_bytes, + ], +) +def test_walk_fsm_multi_bytes(function, transform): + regex_pattern = interegular.parse_pattern("πŸ˜‚|[πŸ˜‡-😍][😈-😍]*") + str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True) + + res = tuple( + function( + regex_fsm, merge_symbols(transform("πŸ˜‚")), regex_fsm.initial, full_match=True + ) + ) + assert res[-1:] == (1,) + + res = tuple( + function( + regex_fsm, + merge_symbols(transform("πŸ˜‚πŸ˜‚")), + regex_fsm.initial, + full_match=False, + ) + ) + assert res[-1:] == (1,) + + res = tuple( + function( + regex_fsm, merge_symbols(transform("!")), regex_fsm.initial, full_match=True + ) + ) + assert res == tuple() + + res = tuple( + function( + regex_fsm, + merge_symbols(transform("πŸ˜‚πŸ˜‚")), + regex_fsm.initial, + full_match=True, + ) + ) + assert res == tuple() + + +def test_get_sub_fsms_from_seq(): + name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") + name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) + + def_pattern = interegular.parse_pattern("def") + def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce()) + + match_pattern = interegular.parse_pattern("match") + match_fsm, _ = make_deterministic_fsm(match_pattern.to_fsm().reduce()) + + peq_pattern = interegular.parse_pattern(r"\+=") + peq_fsm, _ = make_deterministic_fsm(peq_pattern.to_fsm().reduce()) + + plus_pattern = interegular.parse_pattern(r"\+") + plus_fsm, _ = make_deterministic_fsm(plus_pattern.to_fsm().reduce()) + + fsms = [def_fsm, match_fsm, name_fsm, peq_fsm, plus_fsm] + + fsm, fsms_to_trans_finals = fsm_union(fsms) + + assert fsms_to_trans_finals == { + 0: ({(0, 3), (3, 9), (9, 10)}, {10}, {0: {0}, 1: {3}, 2: {9}, 3: {10}}), + 1: ( + {(0, 4), (4, 5), (5, 6), (6, 7), (7, 8)}, + {8}, + {0: {0}, 1: {4}, 2: {5}, 3: {6}, 4: {7}, 5: {8}}, + ), + 2: ( + { + (0, 2), + (0, 3), + (0, 4), + (2, 2), + (3, 2), + (3, 9), + (4, 2), + (4, 5), + (5, 2), + (5, 6), + (6, 2), + (6, 7), + (7, 2), + (7, 8), + (8, 2), + (9, 2), + (9, 10), + (10, 2), + }, + {2, 3, 4, 5, 6, 7, 8, 9, 10}, + {0: {0}, 1: {2, 3, 4, 5, 6, 7, 8, 9, 10}}, + ), + 3: ({(0, 1), (1, 11)}, {11}, {0: {0}, 1: {1}, 2: {11}}), + 4: ({(0, 1)}, {1}, {0: {0}, 1: {1}}), + } + + assert not fsm.accepts("1a") + assert fsm.accepts("a1") + assert fsm.accepts("def") + assert fsm.accepts("match") + assert fsm.accepts("+=") + assert fsm.accepts("+") + + state_seq = walk_fsm_from_token_str(fsm, "def", fsm.initial) + state_seq.insert(0, fsm.fsm_info.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(0, False, True), (2, True, True)] + + # Make sure the old-to-new state map is correct + def_state_seq = walk_fsm_from_token_str(def_fsm, "def", fsm.initial) + def_state_seq.insert(0, fsm.fsm_info.initial) + + def_old_to_new_states = fsms_to_trans_finals[0][2] + assert all( + new_state in def_old_to_new_states[old_state] + for old_state, new_state in zip(def_state_seq, state_seq) + ) + + state_seq = walk_fsm_from_token_str(fsm, "ef", fsm.initial) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(2, True, True)] + + name_state_seq = walk_fsm_from_token_str(name_fsm, "ef", fsm.initial) + name_state_seq.insert(0, fsm.initial) + + name_old_to_new_states = fsms_to_trans_finals[2][2] + assert all( + new_state in name_old_to_new_states[old_state] + for old_state, new_state in zip(name_state_seq, state_seq) + ) + + state_seq = walk_fsm_from_token_str(fsm, "match", fsm.initial) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(1, False, True), (2, True, True)] + + match_state_seq = walk_fsm_from_token_str(match_fsm, "match", fsm.initial) + match_state_seq.insert(0, fsm.initial) + + match_old_to_new_states = fsms_to_trans_finals[1][2] + assert all( + new_state in match_old_to_new_states[old_state] + for old_state, new_state in zip(match_state_seq, state_seq) + ) + + state_seq = walk_fsm_from_token_str(fsm, "defa", fsm.initial) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(2, True, True)] + + state_seq = walk_fsm_from_token_str(fsm, "de", fsm.initial) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(0, True, False), (2, True, True)] + + state_seq = walk_fsm_from_token_str(fsm, "+", fsm.initial, False) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(3, True, False), (4, False, True)] + + state_seq = walk_fsm_from_token_str(fsm, "+=", fsm.initial) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(3, False, True)] + + # Test some overlapping patterns + join_fsms = [ + interegular.parse_pattern(r"JOIN").to_fsm().reduce(), + interegular.parse_pattern(r"JOIN LEFT").to_fsm().reduce(), + ] + fsm, fsms_to_trans_finals = fsm_union(join_fsms) + + # Matching "OI" + state_seq = [1, 2, 3] + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(0, True, False), (1, True, False)] + + # Matching "N" + state_seq = [3, 4] + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(0, False, True), (1, True, False)] + + # Matching " " + state_seq = [4, 5] + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(1, True, False)] + + +def test_create_fsm_index_end_to_end(): + regex_str = "0|[1-9][0-9]*" + + regex_pattern = interegular.parse_pattern(regex_str) + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + + vocabulary = { + "blah": numba.typed.List([0]), + "1a": numba.typed.List([1]), + "2": numba.typed.List([2]), + "0": numba.typed.List([3]), + "": numba.typed.List([4]), + } + + vocabulary_nb = numba.typed.List.empty_list( + numba.types.Tuple( + ( + numba.types.unicode_type, + numba.int64[:], + ) + ) + ) + for token_tuple, token_ids in vocabulary.items(): + token = merge_symbols(token_tuple) + token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) + vocabulary_nb.append((token, token_ids_np)) + + res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb) + + assert res == {0: {(2, 2), (3, 1)}, 2: {(2, 2), (3, 2)}} + + +def test_create_fsm_index_end_to_end_multi_byte(): + regex_str = "πŸ˜‡| [😈-😍][πŸ˜‡-😎]*" + + regex_pattern = interegular.parse_pattern(regex_str) + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) + + vocabulary = { + "blah": numba.typed.List([0]), + "😈a": numba.typed.List([1]), + "πŸ˜‡": numba.typed.List([2]), + "😍": numba.typed.List([3]), + merge_symbols(("F0", "9F", "98", "8D")): numba.typed.List([4]), # '😍' + " 😍": numba.typed.List([5]), + merge_symbols((" ", "F0", "9F", "98", "8D")): numba.typed.List([6]), # ' 😍' + merge_symbols((" ", "F0", "9F", "98")): numba.typed.List( + [7] + ), # ' 😍' incomplete + "": numba.typed.List([8]), + } + + vocabulary_nb = numba.typed.List.empty_list( + numba.types.Tuple( + ( + numba.types.unicode_type, + numba.int64[:], + ) + ) + ) + for token_tuple, token_ids in vocabulary.items(): + token_tuple_np = merge_symbols(token_tuple) + token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) + vocabulary_nb.append((token_tuple_np, token_ids_np)) + + res = create_fsm_index_end_to_end(byte_fsm.fsm_info, vocabulary_nb) + + assert res == {0: {(5, 3), (6, 3), (7, 7), (2, 2)}, 3: {(2, 3), (3, 3), (4, 3)}} + + +@pytest.mark.parametrize( + "hf_tokenizer_uri", + [ + "gpt2", + "microsoft/phi-2", + "Qwen/Qwen1.5-0.5B-Chat", + "NousResearch/Hermes-2-Pro-Llama-3-8B", + ], +) +def test_create_fsm_index_tokenizer(hf_tokenizer_uri): + # The combined regular expressions of a lexer state in a Python grammar + regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" + + regex_pattern = interegular.parse_pattern(regex_str) + # Not reduced, so that there are many states + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) + bytes_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) + + num_fsm_states = len(regex_fsm.states) + assert num_fsm_states == 220 + + num_bytes_fsm_states = len(bytes_fsm.states) + assert num_bytes_fsm_states == 235 + + tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri) + tokenizer = TransformerTokenizer(tokenizer) + + states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer( + bytes_fsm, tokenizer + ) + + assert not empty_token_ids + assert len(states_to_token_subsets) / num_fsm_states > 0.94 + + +@pytest.mark.parametrize( + "regex,string,should_accept", + [ + ("[a-c]+", "πŸ˜€", False), + ("[^a-c]+", "πŸ˜€", True), + ("πŸ˜€+", "πŸ˜€πŸ˜€πŸ˜€", True), + ("πŸ˜€+", "a", False), + ("[πŸ˜€-😍]{2}", "😈😈", True), + ("[πŸ˜€-😍]{2}", "aa", False), + ("[^πŸ˜€-😍]{2}", "aa", True), + ("[^πŸ˜€-😍]{2}", "😈😈", False), + ("[^πŸ˜€-😍]{2}", "😎😎", True), + ("[^πŸ˜€-😍]{2}", "πŸ˜ŽπŸ˜“", True), + ("[^πŸ˜€-😍]{2}", "😎😈", False), + ("[πŸ˜€-πŸ™Œ]{2}", "😎😈", True), + ("[^πŸ˜€-πŸ™Œ]{2}", "😎😈", False), + ("[^πŸ˜€-πŸ™Œ]{2}", "πŸ™πŸ™", True), + ("[^πŸ˜€-πŸ™Œ]{2}", "πŸ™πŸ˜Ž", False), + ], +) +def test_make_byte_level_fsm(regex, string, should_accept): + str_fsm = interegular.parse_pattern(regex).to_fsm() + str_accepts = str_fsm.accepts(string) + assert str_accepts == should_accept + + byte_fsm = make_byte_level_fsm(str_fsm) + byte_accepts = byte_fsm.accepts(to_bytes(string)) # type: ignore + assert byte_accepts == str_accepts + + mix_fsm = make_byte_level_fsm(str_fsm, keep_utf8=True) + mix_accepts = mix_fsm.accepts(to_bytes(string)) # type: ignore + assert mix_accepts == str_accepts + + mix_accepts_utf8 = mix_fsm.accepts(string) # type: ignore + assert mix_accepts_utf8 == str_accepts + + def advance(fsm, state, seq): + for symbol in seq: + if state is None: + return None + key = fsm.alphabet[symbol] + state = fsm.map[state].get(key) + return state + + # verify each state along the pattern + str_state = str_fsm.initial + byte_state = byte_fsm.initial + mix_state = byte_fsm.initial + for symbol in string: + str_state = advance(str_fsm, str_state, symbol) + byte_state = advance(byte_fsm, byte_state, to_bytes(symbol)) + mix_state_utf8 = advance(mix_fsm, mix_state, symbol) + mix_state = advance(mix_fsm, mix_state, to_bytes(symbol)) + assert byte_state == str_state + assert mix_state == str_state + assert mix_state_utf8 == str_state + + +@pytest.mark.skip(reason="Only for local profiling") +def test_regex_index_performance(): + from line_profiler import LineProfiler # type: ignore [import] + + regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" + + regex_pattern = interegular.parse_pattern(regex_str) + # Not reduced, so that there are many states + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) + + num_fsm_states = len(regex_fsm.states) + assert num_fsm_states == 220 + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = TransformerTokenizer(tokenizer) + + # Pre-compile Numba functions + res, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) + assert len(res) > 1 + + profiler = LineProfiler(create_fsm_index_end_to_end) + + profiler.runctx( + "create_fsm_index_tokenizer(regex_fsm, tokenizer)", + globals(), + locals(), + ) + profiler.dump_stats("line-profiler-create_fsm_index.pkl") + profiler.print_stats(output_unit=1e-3, summarize=True, stripzeros=True) + + +@pytest.mark.skip(reason="Only for local profiling") +def test_json_index_performance(): + import json + from enum import Enum + + import outlines_core + from line_profiler import LineProfiler # type: ignore [import] + from pydantic import BaseModel, constr + + class Weapon(str, Enum): + sword = "sword" + axe = "axe" + mace = "mace" + spear = "spear" + bow = "bow" + crossbow = "crossbow" + + class Armor(str, Enum): + leather = "leather" + chainmail = "chainmail" + plate = "plate" + + class Character(BaseModel): + name: constr(max_length=10) + # TODO: Add support for conint + age: int # conint(int, ge=18, le=100) + armor: Armor + weapon: Weapon + # TODO: Add support for conint + strength: int # conint(int, ge=0, le=100) + + model = outlines_core.models.transformers("gpt2", device="cuda") + json_schema = json.dumps(Character.model_json_schema()) + + def build_regex(): + regex_str = outlines_core.index.json_schema.build_regex_from_object(json_schema) + outlines_core.generate.regex(model, regex_str) + + profiler = LineProfiler(create_fsm_index_end_to_end) + profiler.add_function(create_fsm_index_tokenizer) + profiler.add_function(outlines_core.index.index.RegexFSM.__init__) + + profiler.runctx( + "build_regex()", + globals(), + locals(), + ) + profiler.dump_stats("line-profiler-build-json-regex.pkl") + profiler.print_stats(output_unit=1e-3, summarize=True, stripzeros=True) + + +def test_token_trans_keys_identical(): + """assert two tokens w/ identical behavior wrt FSM have same trans key seq""" + + class MockTokenizer: + vocabulary = {"a": 1, "b": 2, "z": 3, "eos": 4} + special_tokens = {"eos"} + eos_token_id = 4 + + def convert_token_to_string(self, token): + return token + + tokenizer = MockTokenizer() + + pattern = r"z[ab]z" + regex_pattern = interegular.parse_pattern(pattern) + interegular_fsm = regex_pattern.to_fsm().reduce() + regex_fsm, _ = make_deterministic_fsm(interegular_fsm) + vocabulary, _ = reduced_vocabulary(tokenizer) + token_trans_keys = get_vocabulary_transition_keys( + regex_fsm.fsm_info.alphabet_symbol_mapping, + regex_fsm.fsm_info.alphabet_anything_value, + vocabulary, + numba.typed.List.empty_list(numba.types.unicode_type), + ) + + token_str_to_tranition_keys = { + token_str: trans_key_seq + for (token_str, _), trans_key_seq in zip(vocabulary, token_trans_keys) + } + # `a` and `b` both are workable, but `z` has distinct transition rules + assert interegular_fsm.accepts("zaz") + assert interegular_fsm.accepts("zbz") + assert (token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["b"]).all() + assert not ( + token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["z"] + ).all() + + +def test_token_trans_keys_walk_fsm(): + """assert _walk_fsm works using transition keys""" + + class MockTokenizer: + vocabulary = {"ab": 1, "ac": 2, "az": 3, "eos": 4} + special_tokens = {"eos"} + eos_token_id = 4 + + def convert_token_to_string(self, token): + return token + + tokenizer = MockTokenizer() + + pattern = r"a[bc]z" + regex_pattern = interegular.parse_pattern(pattern) + interegular_fsm = regex_pattern.to_fsm().reduce() + regex_fsm, _ = make_deterministic_fsm(interegular_fsm) + vocabulary, _ = reduced_vocabulary(tokenizer) + token_trans_keys = get_vocabulary_transition_keys( + regex_fsm.fsm_info.alphabet_symbol_mapping, + regex_fsm.fsm_info.alphabet_anything_value, + vocabulary, + numba.typed.List.empty_list(numba.types.unicode_type), + ) + + token_str_trans_key_seq = { + token_str: trans_key_seq + for (token_str, _), trans_key_seq in zip(vocabulary, token_trans_keys) + } + + # verify initial state valid only for "ab" and "ac" using transition key seq + token_acceptance = {"ab": True, "ac": True, "az": False} + for token, should_accept in token_acceptance.items(): + token_trans_key_seq = token_str_trans_key_seq[token] + state_seq = _walk_fsm( + regex_fsm.fsm_info.transitions, + regex_fsm.fsm_info.initial, + regex_fsm.fsm_info.finals, + token_trans_key_seq, + regex_fsm.fsm_info.initial, + False, + ) + is_accepted = len(state_seq) >= len(token_trans_key_seq) + assert should_accept == is_accepted + + +def test_numba_leading_null_byte_UnicodeCharSeq_remains_broken(): + """Assert numba UnicodeCharSeq w/ leading \x00 is still broken""" + # EXPLANATION: + # https://github.com/outlines_core-dev/outlines/pull/930#issuecomment-2143535968 + + # from https://github.com/numba/numba/issues/9542 + d = numba.typed.typeddict.Dict.empty(numba.types.UnicodeCharSeq(1), numba.int64) + d["δΈ€"] = 10 # \xe4\xb8\x80 + with pytest.raises(KeyError): + str(d) + + # most characters are fine, but "\x00" is converted to "" + l = np.fromiter(["\x99", "\x00"], dtype=np.dtype("U2")) + assert str(l[0]) == "\x99" # fine + assert str(l[1]) == "" # 1-byte null converted to 0-bytes + + +@pytest.mark.parametrize("input_key", ["δΈ€", "\x00"]) +def test_numba_leading_null_byte_unicode_type_sane(input_key): + """Assert numba unicode_type w/ leading \x00 is working""" + # EXPLANATION: + # https://github.com/outlines_core-dev/outlines/pull/930#issuecomment-2143535968 + + # from https://github.com/numba/numba/issues/9542 + d = numba.typed.typeddict.Dict.empty(numba.types.unicode_type, numba.int64) + d["δΈ€"] = 10 # \xe4\xb8\x80 + str(d) # assert successfully interprets + + +@pytest.mark.parametrize( + "rare_token", + [ + "οΏ½", + "οΏ½οΏ½", + "οΏ½.", + "οΏ½..", + "▁�", + "▁▁�", + "▁�.", + "▁�.", + "▁▁�..", + ], +) +def test_reduced_vocabulary_with_rare_tokens(rare_token): + """Assert reduced_vocabulary works with rare tokens. + + See [1] and [2] for context. + + [1]: https://github.com/outlines-dev/outlines/pull/763 + [2]: https://github.com/outlines-dev/outlines/pull/948 + """ + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + tokenizer = adapt_tokenizer(tokenizer=tokenizer) + tokenizer.vocabulary[rare_token] = max(tokenizer.vocabulary.values()) + 1 + reduced_vocabulary(tokenizer) diff --git a/bindings/python/tests/fsm/test_types.py b/bindings/python/tests/fsm/test_types.py new file mode 100644 index 00000000..fc66bd3f --- /dev/null +++ b/bindings/python/tests/fsm/test_types.py @@ -0,0 +1,28 @@ +import datetime + +import pytest +from outlines_core.fsm.types import ( + BOOLEAN, + DATE, + DATETIME, + FLOAT, + INTEGER, + TIME, + python_types_to_regex, +) + + +@pytest.mark.parametrize( + "python_type,regex", + [ + (int, INTEGER), + (float, FLOAT), + (bool, BOOLEAN), + (datetime.date, DATE), + (datetime.time, TIME), + (datetime.datetime, DATETIME), + ], +) +def test_python_types(python_type, regex): + test_regex, _ = python_types_to_regex(python_type) + assert regex == test_regex diff --git a/bindings/python/tests/models/test_tokenizer.py b/bindings/python/tests/models/test_tokenizer.py new file mode 100644 index 00000000..9457bda5 --- /dev/null +++ b/bindings/python/tests/models/test_tokenizer.py @@ -0,0 +1,7 @@ +import pytest +from outlines_core.models.tokenizer import Tokenizer + + +def test_tokenizer(): + with pytest.raises(TypeError, match="instantiate abstract"): + Tokenizer() diff --git a/bindings/python/tests/models/test_transformers.py b/bindings/python/tests/models/test_transformers.py new file mode 100644 index 00000000..799f7a5b --- /dev/null +++ b/bindings/python/tests/models/test_transformers.py @@ -0,0 +1,116 @@ +import pytest +import torch +from outlines_core.models.transformers import TransformerTokenizer, transformers +from transformers import AutoTokenizer +from transformers.models.gpt2 import GPT2TokenizerFast + +TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM" + + +def test_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL, padding_side="left") + tokenizer = TransformerTokenizer(tokenizer) + assert tokenizer.eos_token_id == 0 + assert tokenizer.pad_token_id == 0 + assert isinstance(tokenizer.tokenizer, GPT2TokenizerFast) + + token_ids, attention_mask = tokenizer.encode("Test") + assert token_ids.ndim == 2 + assert token_ids.shape[0] == 1 + assert isinstance(token_ids, torch.LongTensor) + assert token_ids.shape == attention_mask.shape + + token_ids, attention_mask = tokenizer.encode(["Test", "Test"]) + assert token_ids.ndim == 2 + assert token_ids.shape[0] == 2 + assert isinstance(token_ids, torch.LongTensor) + assert token_ids.shape == attention_mask.shape + + token_ids, attention_mask = tokenizer.encode(["Test", "A long sentence"]) + assert token_ids.shape == attention_mask.shape + assert attention_mask[0][0] == tokenizer.pad_token_id + + text = tokenizer.decode(torch.tensor([[0, 1, 2]])) + isinstance(text, str) + + text = tokenizer.decode(torch.tensor([[0, 1, 2], [3, 4, 5]])) + isinstance(text, list) + isinstance(text[0], str) + isinstance(text[1], str) + + tokenizer = AutoTokenizer.from_pretrained( + TEST_MODEL, additional_special_tokens=["", ""] + ) + tokenizer = TransformerTokenizer(tokenizer) + assert "" in tokenizer.special_tokens + assert "" in tokenizer.special_tokens + + +def test_llama_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + tokenizer = TransformerTokenizer(tokenizer) + + # Broken + assert tokenizer.tokenizer.convert_tokens_to_string(["▁baz"]) == "baz" + assert tokenizer.tokenizer.convert_tokens_to_string(["<0x20>"]) == "" + assert tokenizer.tokenizer.convert_tokens_to_string(["▁▁▁"]) == " " + + # Not broken + assert tokenizer.convert_token_to_string("▁baz") == " baz" + assert tokenizer.convert_token_to_string("<0x20>") == " " + assert tokenizer.convert_token_to_string("▁▁▁") == " " + + +def test_model(): + model = transformers(TEST_MODEL, device="cpu") + assert isinstance(model.tokenizer, TransformerTokenizer) + assert model.model.device.type == "cpu" + + model = transformers(TEST_MODEL, model_kwargs={"device_map": "cpu"}) + assert isinstance(model.tokenizer, TransformerTokenizer) + assert model.model.device.type == "cpu" + + model = transformers(TEST_MODEL, device="cpu", model_kwargs={"device_map": "cuda"}) + assert isinstance(model.tokenizer, TransformerTokenizer) + assert model.model.device.type == "cpu" + + input_ids = torch.tensor([[0, 1, 2]]) + logits, kv_cache = model(input_ids, torch.ones_like(input_ids)) + assert logits.type() == "torch.FloatTensor" + assert logits.ndim == 2 + assert logits.shape[0] == 1 + assert len(kv_cache) == model.model.config.n_layer + assert len(kv_cache[0]) == 2 + assert kv_cache[0][0].shape[1] == model.model.config.n_head + assert kv_cache[0][0].shape[2] == 3 # number of tokens + + input_ids = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + logits, kv_cache = model(input_ids, torch.ones_like(input_ids)) + assert logits.type() == "torch.FloatTensor" + assert logits.ndim == 2 + assert logits.shape[0] == 3 + assert len(kv_cache) == model.model.config.n_layer + assert len(kv_cache[0]) == 2 + assert kv_cache[0][0].shape[1] == model.model.config.n_head + assert kv_cache[0][0].shape[2] == 3 # number of tokens + + with pytest.raises(AssertionError): + input_ids = torch.tensor([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [0, 1, 2]]]) + logits = model(input_ids, torch.ones_like(input_ids)) + + +def test_tokenizer_eq_hash(): + tokenizer_hf = AutoTokenizer.from_pretrained("gpt2") + + tokenizer = TransformerTokenizer(tokenizer_hf) + tokenizer_2 = TransformerTokenizer(tokenizer_hf) + + assert tokenizer == tokenizer_2 + assert hash(tokenizer) == hash(tokenizer_2) + + tokenizer_hf_2 = AutoTokenizer.from_pretrained("gpt2") + tokenizer_hf_2.add_tokens(["test_token"]) + + tokenizer_3 = TransformerTokenizer(tokenizer_hf_2) + assert tokenizer != tokenizer_3 + assert hash(tokenizer) != hash(tokenizer_3)