diff --git a/optillm/inference.py b/optillm/inference.py index 128d6ea..3c8a4a4 100644 --- a/optillm/inference.py +++ b/optillm/inference.py @@ -1257,6 +1257,10 @@ def create( # Handle specialized decoding approaches if decoding: logger.info(f"Using specialized decoding approach: {decoding}") + + # Ensure model is in eval mode and on correct device + pipeline.current_model.eval() + device = pipeline.current_model.device if decoding == "cot_decoding": # Use directly available parameters for CoT @@ -1284,32 +1288,46 @@ def create( completion_tokens = len(pipeline.tokenizer.encode(result)) elif decoding == "entropy_decoding": - # Configure generator for entropy decoding - generator = None - if seed is not None: - generator = torch.Generator(device=pipeline.current_model.device) - generator.manual_seed(seed) - - # Use directly available parameters for entropy decoding - entropy_params = { - "max_new_tokens": max_tokens if max_tokens is not None else 512, - "temperature": 0.666, - "top_p": 0.90, - "top_k": top_k, - "min_p": min_p, - "generator": generator - } + # Ensure model is using full precision + original_dtype = pipeline.current_model.dtype + pipeline.current_model = pipeline.current_model.to(torch.float32) + + try: + # Configure generator for entropy decoding + generator = None + if seed is not None: + generator = torch.Generator(device=device) + generator.manual_seed(seed) + else: + generator = torch.Generator(device=device) + generator.manual_seed(1337) # Default seed as in original implementation + + # Use directly available parameters for entropy decoding + entropy_params = { + "max_new_tokens": max_tokens if max_tokens is not None else 4096, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "generator": generator + } + + # Disable autocast and run in full precision + with torch.amp.autocast('cuda', enabled=False), torch.inference_mode(): + result = entropy_decode( + pipeline.current_model, + pipeline.tokenizer, + messages, + **entropy_params + ) + responses = [result] + logprobs_results = [None] + completion_tokens = len(pipeline.tokenizer.encode(result)) - result = entropy_decode( - pipeline.current_model, - pipeline.tokenizer, - messages, - **entropy_params - ) - responses = [result] - logprobs_results = [None] - completion_tokens = len(pipeline.tokenizer.encode(result)) + finally: + # Restore original dtype + pipeline.current_model = pipeline.current_model.to(original_dtype) else: raise ValueError(f"Unknown specialized decoding approach: {decoding}") diff --git a/setup.py b/setup.py index d880f03..3abe328 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="optillm", - version="0.0.12", + version="0.0.13", packages=find_packages(), py_modules=['optillm'], package_data={