Skip to content

Commit

Permalink
fix LAMMPS wheel with CUDA wheels
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 1, 2023
1 parent cf61140 commit 8f2a7c7
Showing 1 changed file with 27 additions and 17 deletions.
44 changes: 27 additions & 17 deletions deepmd/lmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ def get_env(paths: List[Optional[str]]) -> str:
return ":".join(p for p in paths if p is not None)


def get_library_path(module: str) -> List[str]:
def get_library_path(module: str, filename: str) -> List[str]:
"""Get library path from a module.
Parameters
----------
module : str
The module name.
filename : str
The library filename pattern.
Returns
-------
Expand All @@ -53,7 +55,8 @@ def get_library_path(module: str) -> List[str]:
except ModuleNotFoundError:
return []
else:
return [str(Path(m.__file__).parent)]
libs = sorted(Path(m.__path__[0]).glob(filename))
return [str(lib) for lib in libs]


if platform.system() == "Linux":
Expand All @@ -63,6 +66,13 @@ def get_library_path(module: str) -> List[str]:
else:
raise RuntimeError("Unsupported platform")

if platform.system() == "Linux":
preload_env = "LD_PRELOAD"
elif platform.system() == "Darwin":
preload_env = "DYLD_INSERT_LIBRARIES"
else:
raise RuntimeError("Unsupported platform")

tf_dir = tf.sysconfig.get_lib()
op_dir = str((Path(__file__).parent / "lib").absolute())

Expand All @@ -71,37 +81,37 @@ def get_library_path(module: str) -> List[str]:
if platform.system() == "Linux":
cuda_library_paths.extend(
[
*get_library_path("nvidia.cuda_runtime.lib"),
*get_library_path("nvidia.cublas.lib"),
*get_library_path("nvidia.cublas.lib"),
*get_library_path("nvidia.cufft.lib"),
*get_library_path("nvidia.curand.lib"),
*get_library_path("nvidia.cusolver.lib"),
*get_library_path("nvidia.cusparse.lib"),
*get_library_path("nvidia.cudnn.lib"),
*get_library_path("nvidia.cuda_runtime.lib", "libcudart.so*"),
*get_library_path("nvidia.cublas.lib", "libcublasLt.so*"),
*get_library_path("nvidia.cublas.lib", "libcublas.so*"),
*get_library_path("nvidia.cufft.lib", "libcufft.so*"),
*get_library_path("nvidia.curand.lib", "libcurand.so*"),
*get_library_path("nvidia.cusolver.lib", "libcusolver.so*"),
*get_library_path("nvidia.cusparse.lib", "libcusparse.so*"),
*get_library_path("nvidia.cudnn.lib", "libcudnn.so*"),
]
)

os.environ[preload_env] = get_env(
[
os.environ.get(preload_env),
*cuda_library_paths,
]
)

# set LD_LIBRARY_PATH
os.environ[lib_env] = get_env(
[
os.environ.get(lib_env),
tf_dir,
os.path.join(tf_dir, "python"),
op_dir,
*cuda_library_paths,
]
)

# preload python library, only for TF<2.12
if find_libpython is not None:
libpython = find_libpython()
if platform.system() == "Linux":
preload_env = "LD_PRELOAD"
elif platform.system() == "Darwin":
preload_env = "DYLD_INSERT_LIBRARIES"
else:
raise RuntimeError("Unsupported platform")
os.environ[preload_env] = get_env(
[
os.environ.get(preload_env),
Expand Down

0 comments on commit 8f2a7c7

Please sign in to comment.