Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make LitAPI.predict optional and validate API implementation #394

Merged
merged 4 commits into from
Dec 12, 2024

Conversation

aniketmaurya
Copy link
Collaborator

@aniketmaurya aniketmaurya commented Dec 11, 2024

What does this PR do?

This PR enables Continuous batching loop to use LitAPI.step method and validate methods using the pre_setup method.

Continuous batching with vLLM engine in 40 lines of code.

from vllm import LLMEngine, EngineArgs, SamplingParams
from litserve.loops import ContinuousBatchingLoop, Output
from litserve.utils import LitAPIStatus
import litserve as ls

class LitvLLM(ls.LitAPI):
    def setup(self, device):
        self.uids = {}  # Maps vLLM request_id (str) to original uid
        engine_args = EngineArgs(model="meta-llama/Llama-3.2-1B", device=device, max_model_len=2048)
        self.engine = LLMEngine.from_engine_args(engine_args)

    def decode_request(self, request: dict) -> dict:
        return request
    
    def step(self, prev_outputs,):
        request_outputs = self.engine.step()
        outputs = []
        for request_output in request_outputs:
            if request_output.request_id not in self.uids:
                continue
            original_uid = self.uids[request_output.request_id]
            token_id = request_output.outputs[0].token_ids[-1]
            outputs.append(Output(uid=original_uid, output=token_id, status=LitAPIStatus.OK))
            if request_output.finished:
                outputs.append(Output(uid=original_uid, output="", status=LitAPIStatus.FINISH_STREAMING))
                del self.uids[request_output.request_id]
        return outputs

    def encode_response(self, x):
        return self.engine.get_tokenizer().decode([x])

class vLLMLoop(ContinuousBatchingLoop):
    def add_request(self, uid: str, request, lit_api: LitvLLM, lit_spec) -> None:
        super().add_request(uid, request, lit_api, lit_spec)
        request_id = str(uid)
        sampling_params = SamplingParams(temperature=request.get("temperature", 0.0), max_tokens=request.get("max_tokens", 10))
        lit_api.engine.add_request(request_id=request_id, prompt=request["prompt"], params=sampling_params)
        lit_api.uids[request_id] = uid
if __name__ == "__main__":
    loop = vLLMLoop()
    api = LitvLLM()
    server = ls.LitServer(api, loop=loop, max_batch_size=4, stream=True)
    server.run()
Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

Copy link

codecov bot commented Dec 11, 2024

Codecov Report

Attention: Patch coverage is 23.33333% with 23 lines in your changes missing coverage. Please review.

Project coverage is 88%. Comparing base (35129f7) to head (1fa71ab).
Report is 1 commits behind head on main.

Additional details and impacted files
@@         Coverage Diff         @@
##           main   #394   +/-   ##
===================================
- Coverage    88%    88%   -0%     
===================================
  Files        25     25           
  Lines      1765   1774    +9     
===================================
+ Hits       1558   1560    +2     
- Misses      207    214    +7     

@aniketmaurya aniketmaurya merged commit ca9d8c9 into main Dec 12, 2024
21 checks passed
@aniketmaurya aniketmaurya deleted the aniket/optional-predict branch December 12, 2024 18:08

responses = []
responses: List[Output] = []

# Process each sequence's new token
for uid, token in zip(self.active_sequences.keys(), new_tokens):
seq = self.active_sequences[uid]
Copy link
Member

@Borda Borda Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit offtopic to this PR but how about using Dataclass for seq instend of dict?
would make it easier to work with IDE as dataclass can be tracked compare to dict where is anything

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I agree @Borda. I think dataclass here would be good for IDE and also less chances to make mistake.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants