Skip to content

Commit

Permalink
Merge pull request #3 from rhymes-ai/inference
Browse files Browse the repository at this point in the history
add autocast to bf16 context during inference
  • Loading branch information
aria-hacker authored Oct 1, 2024
2 parents efd4c54 + 9205db4 commit bb28cda
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inputs = processor(text=text, images=image, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

with torch.inference_mode():
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=500,
Expand Down
2 changes: 1 addition & 1 deletion aria/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def inference(
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

with torch.inference_mode():
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=500,
Expand Down
2 changes: 1 addition & 1 deletion docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ inputs = processor(text=text, images=image, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

with torch.inference_mode():
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=500,
Expand Down
2 changes: 1 addition & 1 deletion examples/nextqa/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def load_model_and_tokenizer(args):

def process_batch(model, tokenizer, inputs, original_batch, prompts):
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.inference_mode():
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=20,
Expand Down
2 changes: 1 addition & 1 deletion examples/nlvr2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def load_model_and_tokenizer(args):

def process_batch(model, tokenizer, inputs, original_batch, prompts):
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.inference_mode():
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=50,
Expand Down
2 changes: 1 addition & 1 deletion examples/refcoco/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def load_model_and_tokenizer(args):

def process_batch(model, tokenizer, inputs, original_batch, prompts):
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.inference_mode():
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=50,
Expand Down
2 changes: 1 addition & 1 deletion examples/refcoco/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def inference(
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

with torch.inference_mode():
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=500,
Expand Down

0 comments on commit bb28cda

Please sign in to comment.