Skip to content

Commit

Permalink
Specify ASCEND NPU for inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
as12138 committed Nov 29, 2024
1 parent 1cd4b74 commit 83d36e2
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 0 deletions.
7 changes: 7 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,13 @@ def add_model_args(parser):
help="A single GPU like 1 or multiple GPUs like 0,2",
)
parser.add_argument("--num-gpus", type=int, default=1)
parser.add_argument(
"--npus",
type=str,
default=None,
help="A single NPU like 1 or multiple NPUs like 0,2",
)
parser.add_argument("--num-npus", type=int, default=1)
parser.add_argument(
"--max-gpu-memory",
type=str,
Expand Down
9 changes: 9 additions & 0 deletions fastchat/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Type "!!save <filename>" to save the conversation history to a json file.
- Type "!!load <filename>" to load a conversation history from a json file.
"""

import argparse
import os
import re
Expand Down Expand Up @@ -197,6 +198,14 @@ def main(args):
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus
if args.npus:
if len(args.npus.split(",")) < args.num_npus:
raise ValueError(
f"Larger --num_npus ({args.num_npus}) than --npus {args.npus}!"
)
if len(args.npus.split(",")) == 1:
import torch_npu
torch.npu.set_device(int(args.npus))
if args.enable_exllama:
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
Expand Down
9 changes: 9 additions & 0 deletions fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
A model worker that executes the model.
"""

import argparse
import base64
import gc
Expand Down Expand Up @@ -351,6 +352,14 @@ def create_model_worker():
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
if args.npus:
if len(args.npus.split(",")) < args.num_npus:
raise ValueError(
f"Larger --num_npus ({args.num_npus}) than --npus {args.npus}!"
)
if len(args.npus.split(",")) == 1:
import torch_npu
torch.npu.set_device(int(args.npus))

gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
Expand Down
9 changes: 9 additions & 0 deletions fastchat/serve/multi_model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
We recommend using this with multiple Peft models (with `peft` in the name)
where all Peft models are trained on the exact same base model.
"""

import argparse
import asyncio
import dataclasses
Expand Down Expand Up @@ -206,6 +207,14 @@ def create_multi_model_worker():
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
if args.npus:
if len(args.npus.split(",")) < args.num_npus:
raise ValueError(
f"Larger --num_npus ({args.num_npus}) than --npus {args.npus}!"
)
if len(args.npus.split(",")) == 1:
import torch_npu
torch.npu.set_device(int(args.npus))

gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
Expand Down

0 comments on commit 83d36e2

Please sign in to comment.