Skip to content

Commit

Permalink
Update inference.py
Browse files Browse the repository at this point in the history
- Do entropy decoding in full precision
  • Loading branch information
codelion committed Nov 14, 2024
1 parent c3535c4 commit 51f09af
Showing 1 changed file with 42 additions and 24 deletions.
66 changes: 42 additions & 24 deletions optillm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,10 @@ def create(
# Handle specialized decoding approaches
if decoding:
logger.info(f"Using specialized decoding approach: {decoding}")

# Ensure model is in eval mode and on correct device
pipeline.current_model.eval()
device = pipeline.current_model.device

if decoding == "cot_decoding":
# Use directly available parameters for CoT
Expand Down Expand Up @@ -1284,32 +1288,46 @@ def create(
completion_tokens = len(pipeline.tokenizer.encode(result))

elif decoding == "entropy_decoding":
# Configure generator for entropy decoding
generator = None
if seed is not None:
generator = torch.Generator(device=pipeline.current_model.device)
generator.manual_seed(seed)

# Use directly available parameters for entropy decoding

entropy_params = {
"max_new_tokens": max_tokens if max_tokens is not None else 512,
"temperature": 0.666,
"top_p": 0.90,
"top_k": top_k,
"min_p": min_p,
"generator": generator
}
# Ensure model is using full precision
original_dtype = pipeline.current_model.dtype
pipeline.current_model = pipeline.current_model.to(torch.float32)

try:
# Configure generator for entropy decoding
generator = None
if seed is not None:
generator = torch.Generator(device=device)
generator.manual_seed(seed)
else:
generator = torch.Generator(device=device)
generator.manual_seed(1337) # Default seed as in original implementation

# Use directly available parameters for entropy decoding
entropy_params = {
"max_new_tokens": max_tokens if max_tokens is not None else 4096,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"min_p": min_p,
"generator": generator
}

# Disable autocast and run in full precision
with torch.amp.autocast('cuda', enabled=False), torch.inference_mode():
result = entropy_decode(
pipeline.current_model,
pipeline.tokenizer,
messages,
**entropy_params
)
responses = [result]
logprobs_results = [None]
completion_tokens = len(pipeline.tokenizer.encode(result))

result = entropy_decode(
pipeline.current_model,
pipeline.tokenizer,
messages,
**entropy_params
)
responses = [result]
logprobs_results = [None]
completion_tokens = len(pipeline.tokenizer.encode(result))
finally:
# Restore original dtype
pipeline.current_model = pipeline.current_model.to(original_dtype)

else:
raise ValueError(f"Unknown specialized decoding approach: {decoding}")
Expand Down

0 comments on commit 51f09af

Please sign in to comment.