Skip to content

Commit

Permalink
Merge pull request #95 from codelion/fix-entropy-decoding-in-local-se…
Browse files Browse the repository at this point in the history
…rver

Fix entropy decoding in local server
  • Loading branch information
codelion authored Nov 14, 2024
2 parents c3535c4 + db31686 commit 97eb708
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
66 changes: 42 additions & 24 deletions optillm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,10 @@ def create(
# Handle specialized decoding approaches
if decoding:
logger.info(f"Using specialized decoding approach: {decoding}")

# Ensure model is in eval mode and on correct device
pipeline.current_model.eval()
device = pipeline.current_model.device

if decoding == "cot_decoding":
# Use directly available parameters for CoT
Expand Down Expand Up @@ -1284,32 +1288,46 @@ def create(
completion_tokens = len(pipeline.tokenizer.encode(result))

elif decoding == "entropy_decoding":
# Configure generator for entropy decoding
generator = None
if seed is not None:
generator = torch.Generator(device=pipeline.current_model.device)
generator.manual_seed(seed)

# Use directly available parameters for entropy decoding

entropy_params = {
"max_new_tokens": max_tokens if max_tokens is not None else 512,
"temperature": 0.666,
"top_p": 0.90,
"top_k": top_k,
"min_p": min_p,
"generator": generator
}
# Ensure model is using full precision
original_dtype = pipeline.current_model.dtype
pipeline.current_model = pipeline.current_model.to(torch.float32)

try:
# Configure generator for entropy decoding
generator = None
if seed is not None:
generator = torch.Generator(device=device)
generator.manual_seed(seed)
else:
generator = torch.Generator(device=device)
generator.manual_seed(1337) # Default seed as in original implementation

# Use directly available parameters for entropy decoding
entropy_params = {
"max_new_tokens": max_tokens if max_tokens is not None else 4096,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"min_p": min_p,
"generator": generator
}

# Disable autocast and run in full precision
with torch.amp.autocast('cuda', enabled=False), torch.inference_mode():
result = entropy_decode(
pipeline.current_model,
pipeline.tokenizer,
messages,
**entropy_params
)
responses = [result]
logprobs_results = [None]
completion_tokens = len(pipeline.tokenizer.encode(result))

result = entropy_decode(
pipeline.current_model,
pipeline.tokenizer,
messages,
**entropy_params
)
responses = [result]
logprobs_results = [None]
completion_tokens = len(pipeline.tokenizer.encode(result))
finally:
# Restore original dtype
pipeline.current_model = pipeline.current_model.to(original_dtype)

else:
raise ValueError(f"Unknown specialized decoding approach: {decoding}")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="optillm",
version="0.0.12",
version="0.0.13",
packages=find_packages(),
py_modules=['optillm'],
package_data={
Expand Down

0 comments on commit 97eb708

Please sign in to comment.