Skip to content

Commit

Permalink
fix global var bug
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed Dec 16, 2024
1 parent c601d9d commit 85c5662
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
8 changes: 3 additions & 5 deletions examples/qwen/conf/serve/serve_qwen2.5_multiple_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@ model_args:
gpu_memory_utilization: 0.9
max_model_len: 32768
max_num_seqs: 256
port: 4567
trust_remote_code: true
enable_chunked_prefill: true
model2:
model: /models/Qwen2.5-3B-Instruct
tensor_parallel_size: 1
model: /models/Qwen2.5-Coder-32B-Instruct
tensor_parallel_size: 2
gpu_memory_utilization: 0.9
max_model_len: 32768
max_num_seqs: 256
port: 4567
trust_remote_code: true
enable_chunked_prefill: true

Expand All @@ -23,4 +21,4 @@ deploy:
model1:
num_gpus: 1
model2:
num_gpus: 1
num_gpus: 2
13 changes: 7 additions & 6 deletions flagscale/serve/serve_multiple_qwens.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

@serve.remote(name="model1")
class ModelWorker1:
def __init__(self):
model_config = serve.task_config["serve"]["model_args"]["model1"]
def __init__(self, args):
model_config = args["serve"]["model_args"]["model1"]
self.llm = LLM(**model_config)
self.sampling_params = SamplingParams(temperature=0.7, top_p=0.95)

Expand All @@ -21,8 +21,8 @@ def generate(self, prompt):

@serve.remote(name="model2")
class ModelWorker2:
def __init__(self):
model_config = serve.task_config["serve"]["model_args"]["model2"]
def __init__(self, args):
model_config = args["serve"]["model_args"]["model2"]
self.llm = LLM(**model_config)
self.sampling_params = SamplingParams(temperature=0.7, top_p=0.95)

Expand All @@ -31,8 +31,9 @@ def generate(self, prompt):
return [output.text for output in outputs]


model_worker1 = ModelWorker1.remote()
model_worker2 = ModelWorker2.remote()
ray.init()
model_worker1 = ModelWorker1.remote(serve.task_config)
model_worker2 = ModelWorker2.remote(serve.task_config)


app = FastAPI()
Expand Down

0 comments on commit 85c5662

Please sign in to comment.