diff --git a/infinigen/datagen/util/submitit_emulator.py b/infinigen/datagen/util/submitit_emulator.py index b79c1f4b..9b186fcb 100644 --- a/infinigen/datagen/util/submitit_emulator.py +++ b/infinigen/datagen/util/submitit_emulator.py @@ -11,6 +11,7 @@ import itertools import logging import os +import platform import re import subprocess import sys @@ -26,6 +27,14 @@ CUDA_VARNAME = "CUDA_VISIBLE_DEVICES" NVIDIA_SMI_PATH = "/bin/nvidia-smi" +NVIDIA_SMI_WSL_PATH = "/usr/lib/wsl/lib/nvidia-smi" + +def is_wsl(v: str = platform.uname().release) -> int: + # WSL v1 and v2 + if v.endswith("-Microsoft") or v.endswith("microsoft-standard-WSL2"): + return True + + return False @dataclass @@ -131,10 +140,11 @@ def instance(cls): cls._inst = cls() return cls._inst - def __init__(self, jobs_per_gpu=1, use_gpu=True): + def __init__(self, jobs_per_gpu=1, use_gpu=True, nvidia_smi_path=None): self.queue = [] self.jobs_per_gpu = jobs_per_gpu self.use_gpu = use_gpu + self.nvidia_smi_path = nvidia_smi_path def enqueue(self, func, args, kwargs, params, log_folder): job = LocalJob(job_id=get_fake_job_id(), process=None) @@ -156,13 +166,15 @@ def total_resources(self): resources = {} if self.use_gpu: - if which(NVIDIA_SMI_PATH) is None: + nvidia_smi_path = self.nvidia_smi_path or (NVIDIA_SMI_WSL_PATH if is_wsl() else NVIDIA_SMI_PATH) + + if which(nvidia_smi_path) is None: raise ValueError( - f"LocalScheduleHandler.use_gpu=True yet could not find {NVIDIA_SMI_PATH}, " - "please use --pipeline_overrides LocalScheduleHandler.use_gpu=False if your machine does not have a supported GPU" + f"LocalScheduleHandler.use_gpu=True yet could not find `nvidia-smi` by this path: {nvidia_smi_path} " + "Please use --pipeline_overrides LocalScheduleHandler.use_gpu=False if your machine does not have a supported GPU" ) - result = subprocess.check_output(f"{NVIDIA_SMI_PATH} -L".split()).decode() + result = subprocess.check_output(f"{nvidia_smi_path} -L".split()).decode() gpus_uuids = set(i for i in range(len(result.splitlines()))) if CUDA_VARNAME in os.environ: