diff --git a/plugin/sycl/device_manager.cc b/plugin/sycl/device_manager.cc index 0ddbf144083b..021ced67ecaf 100644 --- a/plugin/sycl/device_manager.cc +++ b/plugin/sycl/device_manager.cc @@ -20,18 +20,25 @@ ::sycl::device DeviceManager::GetDevice(const DeviceOrd& device_spec) const { (collective::IsDistributed()); if (not_use_default_selector) { DeviceRegister& device_register = GetDevicesRegister(); - const int device_idx = - collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal; if (device_spec.IsSyclDefault()) { auto& devices = device_register.devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, devices.size()); return devices[device_idx]; } else if (device_spec.IsSyclCPU()) { auto& cpu_devices = device_register.cpu_devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % cpu_devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, cpu_devices.size()); return cpu_devices[device_idx]; } else { auto& gpu_devices = device_register.gpu_devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % gpu_devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, gpu_devices.size()); return gpu_devices[device_idx]; } @@ -63,18 +70,25 @@ ::sycl::queue DeviceManager::GetQueue(const DeviceOrd& device_spec) const { std::lock_guard guard(queue_registering_mutex); if (not_use_default_selector) { DeviceRegister& device_register = GetDevicesRegister(); - const int device_idx = - collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal; if (device_spec.IsSyclDefault()) { auto& devices = device_register.devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, devices.size()); queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]); } else if (device_spec.IsSyclCPU()) { auto& cpu_devices = device_register.cpu_devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % cpu_devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, cpu_devices.size()); queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]); } else if (device_spec.IsSyclGPU()) { auto& gpu_devices = device_register.gpu_devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % gpu_devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, gpu_devices.size()); queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]); }