diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 86ca1948ef94a..4072616fd30e2 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with graph_capture() as graph_capture_context: + with graph_capture(device=device) as graph_capture_context: # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 3e9b0e10a11d8..36cfe42251384 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -107,7 +107,7 @@ def multiple_allreduce_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with graph_capture(): + with graph_capture(device=device): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e6768467f4c27..a0d4235460f3b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -920,7 +920,7 @@ def get_kv_transfer_group() -> kv_transfer.KVTransferAgent: @contextmanager -def graph_capture(): +def graph_capture(device: torch.device): """ `graph_capture` is a context manager which should surround the code that is capturing the CUDA graph. Its main purpose is to ensure that the @@ -934,8 +934,9 @@ def graph_capture(): in order to explicitly distinguish the kernels to capture from other kernels possibly launched on background in the default stream. """ - with get_tp_group().graph_capture() as context, get_pp_group( - ).graph_capture(context): + context = GraphCaptureContext(torch.cuda.Stream(device=device)) + with get_tp_group().graph_capture(context), get_pp_group().graph_capture( + context): yield context diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a08a86d4007dc..6bd951838043b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -741,7 +741,7 @@ def capture_model(self) -> None: # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. - with graph_capture(): + with graph_capture(device=self.device): for num_tokens in reversed(self.cudagraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2b545d1b28bd2..d5a08123619e4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1425,10 +1425,15 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = self.max_batchsize_to_capture - input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() - input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() + input_tokens = torch.zeros(max_batch_size, + dtype=torch.long, + device=self.device) + input_positions = torch.zeros(max_batch_size, + dtype=torch.long, + device=self.device) if self.model_config.uses_mrope: - input_positions = torch.tile(input_positions, (3, 1)) + input_positions = torch.tile(input_positions, + (3, 1)).cuda(device=self.device) # Prepare dummy previous_hidden_states only if needed by the model. # This is used by draft models such as EAGLE. previous_hidden_states = None @@ -1447,8 +1452,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: dtype=self.model_config.dtype, device=self.device) - with self.attn_state.graph_capture( - max_batch_size), graph_capture() as graph_capture_context: + with self.attn_state.graph_capture(max_batch_size), graph_capture( + self.device) as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for virtual_engine in range( @@ -1548,10 +1553,12 @@ def _update_inputs_to_capture_for_enc_dec_model(self, """ # During the decode phase encoder_input_ids and encoder_positions are # unset. Do the same thing for graph capture. - capture_inputs["encoder_input_ids"] = torch.tensor( - [], dtype=torch.long).cuda() - capture_inputs["encoder_positions"] = torch.tensor( - [], dtype=torch.long).cuda() + capture_inputs["encoder_input_ids"] = torch.tensor([], + dtype=torch.long, + device=self.device) + capture_inputs["encoder_positions"] = torch.tensor([], + dtype=torch.long, + device=self.device) @property def vocab_size(self) -> int: