Skip to content

Commit

Permalink
Run yapf and ruff
Browse files Browse the repository at this point in the history
Signed-off-by: wchen61 <[email protected]>
  • Loading branch information
wchen61 committed Nov 16, 2024
1 parent f33f01d commit 3c27e97
Showing 1 changed file with 40 additions and 23 deletions.
63 changes: 40 additions & 23 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from vllm.utils import FlexibleArgumentParser
from transformers import AutoTokenizer


def get_prompts(args):
# The default sample prompts.
prompts = [
Expand Down Expand Up @@ -32,22 +33,22 @@ def get_prompts(args):
prompts = [prompt for _ in range(args.batch_size)]

if args.batch_size != len(prompts):
prompts = (prompts * ((args.batch_size // len(prompts)) + 1))[:args.batch_size]
prompts = (prompts *
((args.batch_size // len(prompts)) + 1))[:args.batch_size]

return prompts


def main(args):
# Create prompts
prompts = get_prompts(args)

# Create a sampling params object.
sampling_params = SamplingParams(
n=args.n,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
max_tokens=args.output_len
)
sampling_params = SamplingParams(n=args.n,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
max_tokens=args.output_len)

# Create an LLM.
# The default model is 'facebook/opt-125m', ensured by the default parameters of EngineArgs

Check failure on line 54 in examples/offline_inference.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

examples/offline_inference.py:54:81: E501 Line too long (95 > 80)
Expand All @@ -63,25 +64,41 @@ def main(args):
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == '__main__':
parser = FlexibleArgumentParser()
parser.add_argument("--batch-size", type=int, default=4,
help="Batch size for inference, default is lenght of sample prompts")
parser.add_argument("--input-len", type=int, default=None,
help="Use fake fixed-length prompt as input if set")
parser.add_argument("--output-len", type=int, default=16,
help="Output length for sampling")
parser.add_argument('--n', type=int, default=1,
help='Number of generated sequences per prompt')
parser.add_argument('--temperature', type=float, default=0.8,
help='Temperature for text generation')
parser.add_argument('--top-p', type=float, default=0.95,
help='top_p for text generation')
parser.add_argument('--top-k', type=int, default=-1,
help='top_k for text generation')
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Batch size for inference, default is lenght of sample prompts")
parser.add_argument("--input-len",
type=int,
default=None,
help="Use fake fixed-length prompt as input if set")
parser.add_argument("--output-len",
type=int,
default=16,
help="Output length for sampling")
parser.add_argument('--n',
type=int,
default=1,
help='Number of generated sequences per prompt')
parser.add_argument('--temperature',
type=float,
default=0.8,
help='Temperature for text generation')
parser.add_argument('--top-p',
type=float,
default=0.95,
help='top_p for text generation')
parser.add_argument('--top-k',
type=int,
default=-1,
help='top_k for text generation')

EngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
main(args)
main(args)

0 comments on commit 3c27e97

Please sign in to comment.