From 7b89e05cd9292c41cca490c306192f41d60236af Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Fri, 27 Dec 2024 16:31:04 +0800 Subject: [PATCH] finish whole router process --- flagscale/serve/core/dag.py | 89 +++++++++---------- flagscale/serve/run_serve.py | 9 +- tests/unit_tests/serve/build_dag/main.py | 48 +++------- .../serve/build_dag/serve/config.yaml | 2 +- 4 files changed, 58 insertions(+), 90 deletions(-) diff --git a/flagscale/serve/core/dag.py b/flagscale/serve/core/dag.py index c4d32c5c..03b4ee87 100644 --- a/flagscale/serve/core/dag.py +++ b/flagscale/serve/core/dag.py @@ -1,47 +1,24 @@ import importlib +import uvicorn import networkx as nx import matplotlib.pyplot as plt import ray from ray import workflow import omegaconf +import logging as logger -from pydantic import BaseModel + +from pydantic import create_model from typing import Callable, Any from fastapi import FastAPI, HTTPException, Request -class RequestData(BaseModel): - prompt: str - -def create_route(path: str, func: Callable, method="post"): - app = FastAPI() - - if method.lower() == 'post': - @app.post(path) - async def route_handler(request_data: RequestData): - try: - response = func(request_data.prompt) - return response - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - - -#final_result = build_and_run_dag(config["deploy"], tasks, input_data) -#print(f"res: {final_result}") - -#ray.shutdown() -create_route('/process', 'post', build_and_run_dag) - -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="127.0.0.1", port=8000) class Builder: def __init__(self, config): self.config = config self.check_config(config) + self.tasks = {} def check_config(self, config): if not config.get("deploy", None): @@ -106,21 +83,20 @@ def check_dag(self, visibilization=False): ) def build_task(self): - tasks = {} for model_alias, model_config in self.config["deploy"]["models"].items(): module_name = model_config["module"] model_name = model_config["entrypoint"] - print(module_name, model_name) module = importlib.import_module(module_name) model = getattr(module, model_name) num_gpus = model_config.get("num_gpus", 0) - tasks[model_alias] = ray.remote(model).options(num_gpus=num_gpus) + self.tasks[model_alias] = ray.remote(model).options(num_gpus=num_gpus) # tasks[model_alias] = ray.remote(num_gpus=num_gpus)(model) # models[model_alias] = model self.check_dag() - return tasks + return - def run_task(self, tasks, input_data): + def run_task(self, input_data): + assert len(self.tasks) > 0 ray.init( num_gpus=6, storage="/tmp/ray_workflow", @@ -152,32 +128,49 @@ def run_task(self, tasks, input_data): if dependencies: if len(dependencies) > 1: inputs = [model_nodes[dep] for dep in dependencies] - model_nodes[model_alias] = tasks[model_alias].bind(*inputs) + model_nodes[model_alias] = self.tasks[model_alias].bind(*inputs) else: - model_nodes[model_alias] = tasks[model_alias].bind( + model_nodes[model_alias] = self.tasks[model_alias].bind( model_nodes[dependencies[0]] ) else: - model_nodes[model_alias] = tasks[model_alias].bind(input_data) + model_nodes[model_alias] = self.tasks[model_alias].bind(input_data) models_to_process.remove(model_alias) progress = True if not progress: raise ValueError("Circular dependency detected in model configuration") - print("model_nodes ************** ", model_nodes, flush=True) - print( - " config['deploy']['models'] ************ ", - self.config["deploy"]["models"], - flush=True, - ) + logger.info(f" =========== deploy model_nodes ============= ", model_nodes) + final_node = model_nodes[self.config["deploy"]["exit"]] + final_result = workflow.run(final_node) + return final_result + + def run_router_task(self, method="post"): + router_config = self.config["deploy"].get("router") - if router_config and len(router_config) > 0: - name = router_config["name"] - port = router_config["port"] - create_route(name, port, workflow.run) + assert router_config and len(router_config) > 0 + name = router_config["name"] + port = router_config["port"] + request_config = router_config["request"] + + RequestData = create_model( + "Request", **{field: (type_, ...) for field, type_ in request_config.items()} + ) + app = FastAPI() + + if method.lower() == "post": + + @app.post(name) + async def route_handler(request_data: RequestData): + try: + response = self.run_task(request_data.prompt) + return response + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + else: - final_result = workflow.run(final_node) - return final_result + raise ValueError(f"Unsupported HTTP method: {method}") + uvicorn.run(app, host="127.0.0.1", port=port) diff --git a/flagscale/serve/run_serve.py b/flagscale/serve/run_serve.py index 8f943267..06068755 100644 --- a/flagscale/serve/run_serve.py +++ b/flagscale/serve/run_serve.py @@ -9,10 +9,11 @@ def main(): project_path = config["root_path"] sys.path.append(project_path) builder = Builder(config) - tasks = builder.build_task() - res = builder.run_task(tasks, input_data="Introduce Bruce Lee") - print("**************** res ****************", res) - + builder.build_task() + if config["deploy"].get("router"): + builder.run_router_task() + else: + result = builder.run_task(input_data="Introduce Bruce Lee") if __name__ == "__main__": main() diff --git a/tests/unit_tests/serve/build_dag/main.py b/tests/unit_tests/serve/build_dag/main.py index 62cc8a47..b86af993 100644 --- a/tests/unit_tests/serve/build_dag/main.py +++ b/tests/unit_tests/serve/build_dag/main.py @@ -1,50 +1,24 @@ -import os -from pydantic import BaseModel from vllm import LLM, SamplingParams - from custom.models import fn -#os.environ['CUDA_VISIBLE_DEVICES']="0,1,2,3,4,5" - -class GenerateRequest(BaseModel): - prompt: str - - -# class LLMActor: -# def __init__(self): -# # Initialize the LLM inside the actor to avoid serialization -# self.llm = LLM( -# model="/models/Qwen2.5-0.5B-Instruct", -# tensor_parallel_size=1, -# gpu_memory_utilization=0.5 -# ) - -# def generate(self, prompt: str) -> str: -# sampling_params = SamplingParams( -# temperature=0.7, -# top_p=0.95, -# max_tokens=1000 -# ) -# result = self.llm.generate([prompt], sampling_params=sampling_params) -# return result[0].outputs[0].text - -#llm = LLM(model="/models/Qwen2.5-0.5B-Instruct", tensor_parallel_size=1, gpu_memory_utilization=0.5) - -#actor = LLMActor() - -#prompt="introduce Bruce Lee" def model_A(prompt): - #prompt="introduce Bruce Lee" - llm = LLM(model="/models/Qwen2.5-0.5B-Instruct", tensor_parallel_size=1, gpu_memory_utilization=0.5) + llm = LLM( + model="/models/Qwen2.5-0.5B-Instruct", + tensor_parallel_size=1, + gpu_memory_utilization=0.5, + ) sampling_params = SamplingParams(temperature=0.7, top_p=0.95, max_tokens=1000) - + result = llm.generate([prompt], sampling_params=sampling_params) - return result[0].outputs[0].text + return fn(result[0].outputs[0].text) + def model_B(input_data): res = input_data + "__add_model_B" return res + if __name__ == "__main__": - print(model_A()) + prompt="introduce Bruce Lee" + print(model_A(prompt)) diff --git a/tests/unit_tests/serve/build_dag/serve/config.yaml b/tests/unit_tests/serve/build_dag/serve/config.yaml index 43ae48ad..bcac9481 100644 --- a/tests/unit_tests/serve/build_dag/serve/config.yaml +++ b/tests/unit_tests/serve/build_dag/serve/config.yaml @@ -30,4 +30,4 @@ deploy: name: generate port: 8000 request: - key: prompt + prompt: str