diff --git a/FlagEmbedding/abc/inference/AbsReranker.py b/FlagEmbedding/abc/inference/AbsReranker.py index 9a3f11a9..e9ba61b3 100644 --- a/FlagEmbedding/abc/inference/AbsReranker.py +++ b/FlagEmbedding/abc/inference/AbsReranker.py @@ -96,7 +96,7 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s elif is_torch_npu_available(): return [f"npu:{i}" for i in range(torch.npu.device_count())] elif torch.backends.mps.is_available(): - return [f"mps:{i}" for i in range(torch.mps.device_count())] + return ["mps"] else: return ["cpu"] elif isinstance(devices, str):