Skip to content

Commit

Permalink
outlines + transformers added
Browse files Browse the repository at this point in the history
  • Loading branch information
staru09 committed Nov 13, 2024
1 parent 99fc20b commit 6c28c10
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 185 deletions.
104 changes: 29 additions & 75 deletions docetl/operations/hf_outlines.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,60 @@
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel
from pydantic import BaseModel, create_model
from docetl.operations.base import BaseOperation
from outlines import generate, models
from transformers import AutoModelForCausalLM, AutoTokenizer
import json

class HuggingFaceMapOperation(BaseOperation):
class schema(BaseOperation.schema):
name: str
type: str = "hf_map"
model_path: str
use_local_model: bool = False
device: str = "cuda"
output_schema: Dict[str, Any]
prompt_template: str
batch_size: Optional[int] = 10
max_tokens: int = 4096

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, config: Dict[str, Any], runner=None, *args, **kwargs):
super().__init__(
config=config,
default_model=config.get('default_model', config['model_path']),
max_threads=config.get('max_threads', 1),
runner=runner
)

self.model = models.transformers(
self.config["model_path"]
)

# Create a dynamic Pydantic model from the output schema
field_definitions = {
k: (eval(v) if isinstance(v, str) else v, ...)
for k, v in self.config["output_schema"].items()
}
output_model = create_model('OutputModel', **field_definitions)

if self.config["use_local_model"]:
llm = AutoModelForCausalLM.from_pretrained(
self.config["model_path"],
device_map=self.config["device"]
)
tokenizer = AutoTokenizer.from_pretrained(self.config["model_path"])
self.model = models.Transformers(llm, tokenizer)
self.tokenizer = tokenizer
else:
self.model = models.transformers(
self.config["model_path"],
device=self.config["device"]
)
self.tokenizer = self.model.tokenizer

output_model = BaseModel.model_validate(self.config["output_schema"])
self.processor = generate.json(
self.model,
output_model,
max_tokens=self.config["max_tokens"]
output_model
)

def syntax_check(self) -> None:
"""Validate the operation configuration."""
config = self.schema(**self.config)

if not config.model_path:
raise ValueError("model_path is required")

if not config.output_schema:
raise ValueError("output_schema is required")

if not config.prompt_template:
raise ValueError("prompt_template is required")

def create_prompt(self, item: Dict[str, Any]) -> str:
"""Create a prompt from the template and input data."""
messages = [
{
'role': 'user',
'content': self.config["prompt_template"]
},
{
'role': 'assistant',
'content': "I understand and will process the input as requested."
},
{
'role': 'user',
'content': str(item)
}
]
return self.tokenizer.apply_chat_template(
messages,
tokenize=False
)
self.schema(**self.config)

def process_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
"""Process a single item through the model."""
prompt = self.create_prompt(item)
try:
result = self.processor(prompt)
result = self.processor(self.config["prompt_template"] + "\n" + str(item))
result_dict = result.model_dump()
final_dict = {**item, **result_dict}
return json.loads(json.dumps(final_dict, indent=2))
return final_dict
except Exception as e:
self.console.print(f"Error processing item: {e}")
return json.loads(json.dumps(item, indent=2))
return item

def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
@classmethod
def execute(cls, config: Dict[str, Any], input_data: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], float]:
"""Execute the operation on the input data."""
if self.status:
self.status.stop()

results = []
batch_size = self.config.get("batch_size", 10)

for i in range(0, len(input_data), batch_size):
batch = input_data[i:i + batch_size]
batch_results = [self.process_item(item) for item in batch]
results.extend(batch_results)

if self.status:
self.status.start()

instance = cls(config)
results = [instance.process_item(item) for item in input_data]
return results, 0.0
207 changes: 97 additions & 110 deletions tests/test_hf_outlines.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,82 @@
import pytest
from unittest.mock import Mock, patch
from unittest.mock import Mock, patch, MagicMock
from docetl.operations.hf_outlines import HuggingFaceMapOperation

@pytest.fixture
def mock_runner():
return Mock()

@pytest.fixture
def sample_config():
return {
"name": "test_hf_operation",
"type": "hf_map",
"model_path": "microsoft/Phi-3-mini-4k-instruct",
"use_local_model": False,
"device": "cuda",
"model_path": "meta-llama/Llama-3.2-1B-Instruct",
"output_schema": {
"first_name": "str",
"last_name": "str",
"order_number": "str",
"department": "str"
"last_name": "str"
},
"prompt_template": "Extract customer information from this text",
"batch_size": 2,
"max_tokens": 4096
}

@pytest.fixture
def mock_processor_output():
def research_config():
return {
"name": "research_analyzer",
"type": "hf_map",
"model_path": "meta-llama/Llama-3.2-1B-Instruct",
"output_schema": {
"title": "str",
"authors": "list",
"methodology": "str",
"findings": "list",
"limitations": "list",
"future_work": "list"
},
"prompt_template": "Analyze the following research paper abstract.\nExtract key components and summarize findings.",
"max_tokens": 4096
}

@pytest.fixture
def mock_research_output():
class MockOutput:
def model_dump(self):
return {
"first_name": "John",
"last_name": "Doe",
"order_number": "12345",
"department": "Sales"
"title": "Deep Learning in Natural Language Processing",
"authors": ["John Smith", "Jane Doe"],
"methodology": "Comparative analysis of transformer architectures",
"findings": [
"Improved accuracy by 15%",
"Reduced training time by 30%"
],
"limitations": [
"Limited dataset size",
"Computational constraints"
],
"future_work": [
"Extend to multilingual models",
"Optimize for edge devices"
]
}
return MockOutput()

@pytest.fixture
def sample_input_data():
return [
{"message": "Customer John Doe ordered item #12345"},
{"message": "Customer Jane Smith from Sales department"}
]

def test_initialization_remote_model(sample_config):
with patch('outlines.models.transformers') as mock_transformers:
operation = HuggingFaceMapOperation(sample_config)
assert operation.config == sample_config
assert operation.config["use_local_model"] is False
assert mock_transformers.called

def test_initialization_local_model(sample_config):
sample_config["use_local_model"] = True
with patch('transformers.AutoModelForCausalLM.from_pretrained') as mock_model, \
patch('transformers.AutoTokenizer.from_pretrained') as mock_tokenizer:
operation = HuggingFaceMapOperation(sample_config)
assert operation.config["use_local_model"] is True
assert mock_model.called
assert mock_tokenizer.called

@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_device_configuration(sample_config, device):
sample_config["device"] = device
with patch('outlines.models.transformers'):
operation = HuggingFaceMapOperation(sample_config)
assert operation.config["device"] == device

def test_syntax_check(sample_config):
with patch('outlines.models.transformers'):
operation = HuggingFaceMapOperation(sample_config)
operation.syntax_check()

@pytest.mark.parametrize("missing_field", [
"model_path",
"output_schema",
"prompt_template"
])
def test_syntax_check_missing_fields(sample_config, missing_field):
with patch('outlines.models.transformers'):
invalid_config = sample_config.copy()
invalid_config[missing_field] = ""
operation = HuggingFaceMapOperation(invalid_config)
with pytest.raises(ValueError):
operation.syntax_check()

def test_create_prompt(sample_config):
with patch('outlines.models.transformers') as mock_transformers:
mock_tokenizer = Mock()
mock_tokenizer.apply_chat_template.return_value = "mocked prompt"
mock_transformers.return_value.tokenizer = mock_tokenizer

operation = HuggingFaceMapOperation(sample_config)
test_item = {"message": "test message"}
prompt = operation.create_prompt(test_item)

assert isinstance(prompt, str)
assert mock_tokenizer.apply_chat_template.called

def test_process_item(sample_config, mock_processor_output):
with patch('outlines.models.transformers'):
operation = HuggingFaceMapOperation(sample_config)
operation.processor = Mock(return_value=mock_processor_output)
def test_process_item(sample_config, mock_runner):
mock_model = MagicMock()

class MockOutput:
def model_dump(self):
return {
"first_name": "John",
"last_name": "Doe"
}

mock_processor = Mock(return_value=MockOutput())

with patch('outlines.models.transformers', return_value=mock_model) as mock_transformers, \
patch('outlines.generate.json', return_value=mock_processor):

operation = HuggingFaceMapOperation(sample_config, runner=mock_runner)
test_item = {"message": "test message"}
result = operation.process_item(test_item)

Expand All @@ -106,42 +85,50 @@ def test_process_item(sample_config, mock_processor_output):
assert "last_name" in result
assert "message" in result

def test_process_item_error_handling(sample_config):
with patch('outlines.models.transformers'):
operation = HuggingFaceMapOperation(sample_config)
operation.processor = Mock(side_effect=Exception("Test error"))
def test_research_paper_analysis(research_config, mock_research_output, mock_runner):
mock_model = MagicMock()
mock_processor = Mock(return_value=mock_research_output)

with patch('outlines.models.transformers', return_value=mock_model) as mock_transformers, \
patch('outlines.generate.json', return_value=mock_processor):

test_item = {"message": "test message"}
operation = HuggingFaceMapOperation(research_config, runner=mock_runner)
test_item = {
"abstract": """
This paper presents a comprehensive analysis of deep learning approaches
in natural language processing. We compare various transformer architectures
and their performance on standard NLP tasks.
"""
}
result = operation.process_item(test_item)

# Verify structure and types
assert isinstance(result, dict)
assert "message" in result

def test_execute(sample_config, sample_input_data):
with patch('outlines.models.transformers'):
operation = HuggingFaceMapOperation(sample_config)
operation.process_item = Mock(return_value={"processed": True})

results, timing = operation.execute(sample_input_data)
assert "title" in result
assert isinstance(result["title"], str)
assert "authors" in result
assert isinstance(result["authors"], list)
assert "methodology" in result
assert isinstance(result["methodology"], str)
assert "findings" in result
assert isinstance(result["findings"], list)
assert len(result["findings"]) > 0
assert "limitations" in result
assert isinstance(result["limitations"], list)
assert "future_work" in result
assert isinstance(result["future_work"], list)

assert len(results) == len(sample_input_data)
assert isinstance(timing, float)
# Verify original input is preserved
assert "abstract" in result

def test_batch_processing(sample_config, sample_input_data):
with patch('outlines.models.transformers'):
operation = HuggingFaceMapOperation(sample_config)
operation.process_item = Mock(return_value={"processed": True})
def test_execute(sample_config, mock_runner):
mock_model = MagicMock()
mock_processor = Mock(return_value={"first_name": "John", "last_name": "Doe"})

with patch('outlines.models.transformers', return_value=mock_model) as mock_transformers, \
patch('outlines.generate.json', return_value=mock_processor):

# Test with different batch sizes
sample_config["batch_size"] = 1
results1, _ = operation.execute(sample_input_data)
assert len(results1) == len(sample_input_data)

sample_config["batch_size"] = 2
results2, _ = operation.execute(sample_input_data)
assert len(results2) == len(sample_input_data)

def test_max_tokens_configuration(sample_config):
with patch('outlines.models.transformers'):
operation = HuggingFaceMapOperation(sample_config)
assert operation.config["max_tokens"] == 4096
input_data = [{"message": "test message"}]
results, timing = HuggingFaceMapOperation.execute(sample_config, input_data)
assert len(results) == 1
assert isinstance(timing, float)

0 comments on commit 6c28c10

Please sign in to comment.