From b4e9528f9569d6eb8c29624771a4058fe794cb5a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 9 Aug 2024 00:06:36 -0700 Subject: [PATCH] [Core] Streamline stream termination in `AsyncLLMEngine` (#7336) --- tests/async_engine/test_request_tracker.py | 6 ++-- vllm/engine/async_llm_engine.py | 41 ++++++++++++---------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/tests/async_engine/test_request_tracker.py b/tests/async_engine/test_request_tracker.py index c66bdd5f9003d..5668cc30d32c3 100644 --- a/tests/async_engine/test_request_tracker.py +++ b/tests/async_engine/test_request_tracker.py @@ -47,8 +47,10 @@ async def test_request_tracker(): assert tracker.new_requests_event.is_set() await tracker.wait_for_new_requests() new, aborted = tracker.get_new_and_aborted_requests() - assert len(aborted) == 1 - assert "4" in aborted + # aborted new requests will cancel each other out - + # there's no need for them to propagate into the + # engine + assert not aborted assert not new assert stream_4.finished diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6af347def475e..809eb6de9f173 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -85,11 +85,14 @@ def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, return self._queue.put_nowait(item) - def finish(self, cancelled: bool = False) -> None: + def finish( + self, + exception: Optional[Union[BaseException, Type[BaseException]]] = None, + ) -> None: if not self._finished: self._finished = True self._queue.put_nowait( - asyncio.CancelledError if cancelled else STOP_ITERATION) + exception if exception is not None else STOP_ITERATION) @property def finished(self) -> bool: @@ -133,14 +136,12 @@ def propagate_exception(self, """Propagate an exception to request streams (all if request_id is None).""" if request_id is not None: - self._request_streams[request_id].put(exc) - self.abort_request(request_id) + self.abort_request(request_id, exception=exc) else: - # NB: list() used here because self.abort_request pops the stream + # NB: tuple() used here because self.abort_request pops the stream # out of self._request_streams, so we can't iterate on it directly - for rid, stream in list(self._request_streams.items()): - stream.put(exc) - self.abort_request(rid) + for rid in tuple(self._request_streams.keys()): + self.abort_request(rid, exception=exc) def process_request_output(self, request_output: Union[RequestOutput, @@ -167,14 +168,13 @@ def process_request_output(self, def process_exception(self, request_id: str, - exception: Exception, + exception: BaseException, *, verbose: bool = False) -> None: """Propagate an exception from the engine.""" - self._request_streams[request_id].put(exception) if verbose: logger.info("Finished request %s.", request_id) - self.abort_request(request_id) + self.abort_request(request_id, exception=exception) def add_request(self, request_id: str, @@ -203,7 +203,8 @@ def add_request(self, def abort_request(self, request_id: str, *, - cancelled: bool = False, + exception: Optional[Union[BaseException, + Type[BaseException]]] = None, verbose: bool = False) -> None: """Abort a request during next background loop iteration.""" if verbose: @@ -213,7 +214,7 @@ def abort_request(self, stream = self._request_streams.pop(request_id, None) if stream is not None: - stream.finish(cancelled=cancelled) + stream.finish(exception=exception) def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: """Get the new requests and finished requests to be @@ -227,12 +228,14 @@ def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: while not self._new_requests.empty(): stream, new_request = self._new_requests.get_nowait() - if stream.request_id in finished_requests: + request_id = stream.request_id + if request_id in finished_requests: # The request has already been aborted. - stream.finish(cancelled=True) - continue - self._request_streams[stream.request_id] = stream - new_requests.append(new_request) + stream.finish(asyncio.CancelledError) + finished_requests.discard(request_id) + else: + self._request_streams[request_id] = stream + new_requests.append(new_request) return new_requests, finished_requests @@ -1015,7 +1018,7 @@ def _abort(self, request_id: str) -> None: request_id: The unique id of the request. """ self._request_tracker.abort_request(request_id, - cancelled=True, + exception=asyncio.CancelledError, verbose=self.log_requests) async def get_model_config(self) -> ModelConfig: