From fb2d8ff858b41b96defcef417e02aa1907ec734e Mon Sep 17 00:00:00 2001 From: Ye Cao Date: Mon, 18 Nov 2024 17:33:46 +0800 Subject: [PATCH] Support cuda tensor in vineyard. Signed-off-by: Ye Cao --- python/vineyard/contrib/ml/torch.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/vineyard/contrib/ml/torch.py b/python/vineyard/contrib/ml/torch.py index 9694504fc..af3aa5348 100644 --- a/python/vineyard/contrib/ml/torch.py +++ b/python/vineyard/contrib/ml/torch.py @@ -145,7 +145,12 @@ def torch_tensor_builder(client, value, **kw): meta['typename'] = 'vineyard::Tensor<%s>' % str(value.dtype) meta['value_type_'] = str(value.dtype) - meta.add_member('buffer_', build_torch_buffer(client, value)) + if value.is_cuda: + meta['device_'] = str(value.device) + value_in_cpu = value.to('cpu') + meta.add_member('buffer_', build_torch_buffer(client, value_in_cpu)) + else: + meta.add_member('buffer_', build_torch_buffer(client, value)) return client.create_metadata(meta) @@ -157,6 +162,7 @@ def torch_tensor_resolver(obj): value_type = normalize_tensor_dtype(value_type_name) shape = from_json(meta['shape_']) order = from_json(meta.get('order_', 'C')) + device = meta.get('device_', 'cpu') if np.prod(shape) == 0: return torch.zeros(shape, dtype=value_type) @@ -167,6 +173,9 @@ def torch_tensor_resolver(obj): c_tensor = torch.frombuffer(buffer, dtype=value_type).reshape(shape) tensor = c_tensor if order == 'C' else c_tensor.contiguous() + if "cuda" in device: + cuda_device = torch.device(device) + tensor = c_tensor.to(cuda_device) return tensor