Skip to content

Commit

Permalink
Add MKL_RT env to bypass library search
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Dec 15, 2023
1 parent f3a7e6b commit 2c02890
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions sparse_dot_mkl/_mkl_interface/_load_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,27 @@
import ctypes.util as _ctypes_util
import os

IMPORT_ERRORS = (OSError, ImportError)


def _try_load_mkl_rt(path=None):
# Check each of these library names
# Also include derivatives because windows find_library implementation
# won't match partials
for so_file in ["libmkl_rt.so", "libmkl_rt.dylib", "mkl_rt.dll"] + [
for so_file in [
"libmkl_rt.so",
"libmkl_rt.dylib",
"mkl_rt.dll"
] + [
f"mkl_rt.{i}.dll" for i in range(5, 0, -1)
] + [
f"libmkl_rt.so.{i}" for i in range(5, 0, -1)
]:
try:
# If this finds anything, break out of the loop
return _ctypes.cdll.LoadLibrary(os.path.join(path, so_file))

except (OSError, ImportError):
except IMPORT_ERRORS:
pass

return None
Expand All @@ -24,6 +32,15 @@ def mkl_library():

# Load mkl_spblas through the libmkl_rt common interface
_libmkl = None

# Use MKL_RT env (useful if there are multiple MKL binaries in path)
if 'MKL_RT' in os.environ:
try:
_libmkl = _ctypes.cdll.LoadLibrary(os.environ['MKL_RT'])
return _libmkl
except IMPORT_ERRORS:
pass

try:
_so_file = _ctypes_util.find_library("mkl_rt")

Expand Down Expand Up @@ -68,7 +85,7 @@ def mkl_library():

# Couldn't find anything to import
# Raise the ImportError
except (OSError, ImportError) as err:
except IMPORT_ERRORS as err:
raise ImportError(
"Unable to load the MKL libraries through "
"libmkl_rt. Try setting $LD_LIBRARY_PATH. " + str(err)
Expand Down

0 comments on commit 2c02890

Please sign in to comment.