Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

{ai}[gfbf/2024a] jax v0.4.35, ml_dtypes v0.5.0 w/ CUDA 12.6.0 WIP #21924

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
# Author: Denis Kristak
# Updated by: Alex Domingo (Vrije Universiteit Brussel)
# Updated by: Pavel Tománek (INUITS)
# Updated by: Thomas Hoffmann (EMBL Heidelberg)
easyblock = 'PythonBundle'

name = 'jax'
version = '0.4.35'
versionsuffix = '-CUDA-%(cudaver)s'

homepage = 'https://jax.readthedocs.io/'
description = """Composable transformations of Python+NumPy programs:
differentiate, vectorize, JIT to GPU/TPU, and more"""

toolchain = {'name': 'gfbf', 'version': '2024a'}
cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"]

builddependencies = [
# ('Bazel', '7.4.1'), TODO: problems with @@local_config_python//:py3_runtime:
# Error in fail: interpreter_path must be an absolute path
# Bazel 6.5.0 (download) works.
('pybind11', '2.13.6'), # 2.12.0 ? SciPy-bundle has pybind/2.12.0.
# Fix: change to builddependency in SciPy-bundle?
# tmporarily mv to dependencies (TODO: mv back)
('pytest-xdist', '3.6.1'),
('git', '2.45.1'), # bazel uses git to fetch repositories
('matplotlib', '3.9.2'), # required for tests/lobpcg_test.py
('poetry', '1.8.3'),
('Clang', '18.1.8')
]

dependencies = [
('CUDA', '12.6.0', '', SYSTEM), # 12.6.2 ?
('cuDNN', '9.5.0.50', versionsuffix, SYSTEM),
('NCCL', '2.22.3', versionsuffix),
('Python', '3.12.3'),
('SciPy-bundle', '2024.05'), # 2024.11 ?
('absl-py', '2.1.0'),
('flatbuffers-python', '24.3.25'),
('ml_dtypes', '0.5.0'),
('zlib', '1.3.1'),
# ('pybind11', '2.13.6'), # override 2.12.0. SciPy-bundle has pybind/2.12.0. Fix:
# change to builddependency in SciPy-bundle? (TODO)
]

# downloading xla and other tarballs to avoid that Bazel downloads it during the build
local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives'
# note: following commits *must* be the exact same onces used upstream
# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
local_xla_commit = '76da730179313b3bebad6dea6861768421b7358c'
# TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl
local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' # TODO: still required?
# TODO: add other downloads

# Use sources downloaded by EasyBuild
_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
# Use dependencies from EasyBuild
_jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" '
_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include:$EBROOTCUDA/extras/CUPTI/include" '
# Avoid warning (treated as error) in upb/table.c
_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" ' # TODO: still required?
# _jaxlib_buildopts += '--nouse_clang ' #TODO: avoid clang (?)
_jaxlib_buildopts += '--cuda_version=%(cudaver)s '
_jaxlib_buildopts += '--python_bin_path=$EBROOTPYTHON/bin/python3 '
# Do not use hermetic CUDA/cuDNN/NCCL: (requires action_env=CPATH=$EBROOTCUDA/extras/CUPTI/include";
# requires patch of external/xla/xla/tsl/cuda/cupti_stub.cc and jaxlib/gpu/vendor.h (#include <cupti.h>):
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDNN_PATH="$EBROOTCUDNN" """
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_NCCL_PATH="$EBROOTNCCL" """
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDA_PATH="$EBROOTCUDA" """
_jaxlib_buildopts += """--bazel_options="--copt=-Ithird_party/gpus/cuda/extras/CUPTI/include" """

# get rid of .devDate versionsuffix: TODO: find a better way
# _no_devtag = """ export JAX_RELEASE && export JAXLIB_RELEASE && """ does not work (?)
_no_devtag = """ sed -i "s/version=__version__/version='%(version)s'/g" setup.py && """
_jaxlib_buildopts += """--bazel_options="--action_env=JAXLIB_RELEASE=1" """ # required?

components = [
('jaxlib', version, {
'sources': [
{
'source_urls': ['https://github.com/google/jax/archive/'],
'filename': 'jax-v%(version)s.tar.gz',
},
{
'source_urls': ['https://github.com/openxla/xla/archive'],
'download_filename': '%s.tar.gz' % local_xla_commit,
'filename': 'xla-%s.tar.gz' % local_xla_commit[:8],
'extract_cmd': local_extract_cmd,
},
{
'source_urls': ['https://github.com/tensorflow/runtime/archive'],
'download_filename': '%s.tar.gz' % local_tfrt_commit,
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8],
'extract_cmd': local_extract_cmd,
},
],
'patches': [
'jax-0.4.35_easyblock_compat.patch',
'jax-0.4.35_fix-pybind11-systemlib_cupti.patch',
'jax-0.4.35_version.patch',
],
'checksums': [
{'jax-v0.4.35.tar.gz':
'65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
{'xla-76da7301.tar.gz':
'd67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757'},
{'tf_runtime-0aeefb16.tar.gz':
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
{'jax-0.4.35_easyblock_compat.patch':
'cbf4ad92b8438c4ce2a975efce1c47c57d4c3b117bceee071ab660f964057223'},
{'jax-0.4.35_fix-pybind11-systemlib_cupti.patch':
'78efe6b5108a5da1935258286c94dea8438fd03651533c34023eeba27f514130'},
{'jax-0.4.35_version.patch':
'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
],
'start_dir': 'jax-jax-v%(version)s',
'buildopts': _jaxlib_buildopts,
'prebuildopts': ' mkdir third_party/gpus/cuda/extras/ -p && ' +
'ln -s $EBROOTCUDA/extras/CUPTI third_party/gpus/cuda/extras --relative &&' +
_no_devtag
}),
]

# Some tests require an isolated run: TODO: still required?
local_isolated_tests = [
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp',
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc',
'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' +
'::testScipySpecialFun_gammainc_s_2x1x4_float32_float32',
]
# deliberately not testing in parallel, as that results in (additional) failing tests;
# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing,
# see https://github.com/google/jax/issues/7323 and
# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst;
# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs;
# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores;
local_test_exports = [
"NVIDIA_TF32_OVERRIDE=0",
"CUDA_VISIBLE_DEVICES=0",
"XLA_PYTHON_CLIENT_ALLOCATOR=platform",
"JAX_ENABLE_X64=true",
]
local_test = ''.join(['export %s;' % x for x in local_test_exports])
# run all tests at once except for local_isolated_tests:
local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests])
# run remaining local_isolated_tests separately:
local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests])

use_pip = True

exts_list = [
(name, version, {
'source_tmpl': '%(name)s-v%(version)s.tar.gz',
'source_urls': ['https://github.com/google/jax/archive/'],
# 'patches': ['jax-0.4.25_fix_env_test_no_log_spam.patch'], # TODO: still required? update?
'patches': ['jax-0.4.35_version.patch'],
'checksums': [
{'jax-v0.4.35.tar.gz': '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
{'jax-0.4.35_version.patch': 'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
],
# 'runtest': local_test,
'runtest': False, # tmp
'preinstallopts': _no_devtag
}),
]

sanity_pip_check = True

moduleclass = 'ai'
21 changes: 21 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.35_easyblock_compat.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Thomas Hoffmann, EMBL Heidelberg, [email protected], 2024/11
# add dummy parameters to build/build.py for cudnn_path and cuda_path, which are set by default by the jaxlib easyblock.
diff -ru jax-jax-v0.4.35/build/build.py jax-jax-v0.4.35_easyblockcompat/build/build.py
--- jax-jax-v0.4.35/build/build.py 2024-10-22 21:00:23.000000000 +0200
+++ jax-jax-v0.4.35_easyblockcompat/build/build.py 2024-11-19 12:35:46.524479324 +0100
@@ -549,6 +549,15 @@
help_str="Same as update_requirements, but will consider dev, nightly "
"and pre-release versions of packages.")

+ parser.add_argument(
+ "--cuda_path",
+ default="dummy",
+ help="compatibility with jaxlib.py easyblock")
+ parser.add_argument(
+ "--cudnn_path",
+ default="dummy",
+ help="compatibility with jaxlib.py easyblock")
+
args = parser.parse_args()

logging.basicConfig()
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
jax-0.4.25_fix-pybind11-systemlib.patch: Add missing value for System Pybind11 Bazel config
jax-0.4.25_fix-pybind11-systemlib.patch: Author: Alexander Grund (TU Dresden)

THEMBL: fix cupti include path.

diff --git a/third_party/xla/fix-pybind11-systemlib.patch b/third_party/xla/fix-pybind11-systemlib.patch
new file mode 100644
index 000000000..68bd2063d
--- /dev/null
+++ b/third_party/xla/fix-pybind11-systemlib.patch
@@ -0,0 +1,13 @@
+--- xla-orig/third_party/tsl/third_party/systemlibs/pybind11.BUILD
++++ xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4/third_party/tsl/third_party/systemlibs/pybind11.BUILD
+@@ -6,3 +6,10 @@
+ "@tsl//third_party/python_runtime:headers",
+ ],
+ )
++
++# Needed by pybind11_bazel.
++config_setting(
++ name = "osx",
++ constraint_values = ["@platforms//os:osx"],
++)
++
diff -ruN jax-jax-v0.4.35/jaxlib/gpu/vendor.h jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/jaxlib/gpu/vendor.h
--- jax-jax-v0.4.35/jaxlib/gpu/vendor.h 2024-10-22 21:00:23.000000000 +0200
+++ jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/jaxlib/gpu/vendor.h 2024-11-26 10:56:20.396087442 +0100
@@ -23,7 +23,7 @@
#if defined(JAX_GPU_CUDA)

// IWYU pragma: begin_exports
-#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
+#include <cupti.h>
#include "third_party/gpus/cuda/include/cooperative_groups.h"
#include "third_party/gpus/cuda/include/cuComplex.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
diff -ruN jax-jax-v0.4.35/third_party/xla/workspace.bzl jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/third_party/xla/workspace.bzl
--- jax-jax-v0.4.35/third_party/xla/workspace.bzl 2024-10-22 21:00:23.000000000 +0200
+++ jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/third_party/xla/workspace.bzl 2024-11-27 12:17:37.913466273 +0100
@@ -30,6 +30,11 @@
sha256 = XLA_SHA256,
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
+ patch_file = [
+ "//third_party/xla:xla-76da73_cupti.patch",
+ "//third_party/xla:fix-pybind11-systemlib.patch",
+ ],
+
)

# For development, one often wants to make changes to the TF repository as well
diff -ruN jax-jax-v0.4.35/third_party/xla/xla-76da73_cupti.patch jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/third_party/xla/xla-76da73_cupti.patch
--- jax-jax-v0.4.35/third_party/xla/xla-76da73_cupti.patch 1970-01-01 01:00:00.000000000 +0100
+++ jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/third_party/xla/xla-76da73_cupti.patch 2024-11-27 12:18:26.668582799 +0100
@@ -0,0 +1,12 @@
+diff -ru xla-76da730179313b3bebad6dea6861768421b7358c/xla/tsl/cuda/cupti_stub.cc xla-76da730179313b3bebad6dea6861768421b7358c_cupti/xla/tsl/cuda/cupti_stub.cc
+--- xla-76da730179313b3bebad6dea6861768421b7358c/xla/tsl/cuda/cupti_stub.cc 2024-10-21 20:29:31.000000000 +0200
++++ xla-76da730179313b3bebad6dea6861768421b7358c_cupti/xla/tsl/cuda/cupti_stub.cc 2024-11-26 12:04:11.695539146 +0100
+@@ -13,7 +13,7 @@
+ limitations under the License.
+ ==============================================================================*/
+
+-#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
++#include <cupti.h>
+ #include "third_party/gpus/cuda/include/cuda.h"
+ #include "tsl/platform/dso_loader.h"
+ #include "tsl/platform/load_library.h"
19 changes: 19 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.35_version.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
diff -ru jax-jax-v0.4.35/jax/version.py jax-jax-v0.4.35_version/jax/version.py
--- jax-jax-v0.4.35/jax/version.py 2024-10-22 21:00:23.000000000 +0200
+++ jax-jax-v0.4.35_version/jax/version.py 2024-11-28 13:10:52.508536023 +0100
@@ -33,6 +33,7 @@
def _get_version_string() -> str:
# The build/source distribution for jax & jaxlib overwrites _release_version.
# In this case we return it directly.
+ return _version
if _release_version is not None:
return _release_version
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
@@ -71,6 +72,7 @@
- if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906"
- if none are set: version looks like "0.4.16.dev20230906+ge58560fdc
"""
+ return _version
if _release_version is not None:
return _release_version
if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'):
56 changes: 56 additions & 0 deletions easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.5.0-gfbf-2024a.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Thomas Hoffmann, EMBL Heidelberg, [email protected], 2024/11
easyblock = 'PythonBundle'

name = 'ml_dtypes'
version = '0.5.0'

homepage = 'https://github.com/jax-ml/ml_dtypes'
description = """
ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used
in machine learning libraries, including:

bfloat16: an alternative to the standard float16 format
float8_*: several experimental 8-bit floating point representations including:
float8_e4m3b11fnuz
float8_e4m3fn
float8_e4m3fnuz
float8_e5m2
float8_e5m2fnuz
"""

toolchain = {'name': 'gfbf', 'version': '2024a'}

builddependencies = [
('poetry', '1.8.3'),
]

dependencies = [
('Python', '3.12.3'),
# ('SciPy-bundle', '2024.11'), ?
('SciPy-bundle', '2024.05'),
]


use_pip = True

default_easyblock = 'PythonPackage'

exts_list = [
('opt_einsum', '3.4.0', {
'checksums': ['96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac'],
}),
('etils', '1.10.0', {
'checksums': ['4eaa9d7248fd4eeb75e44d47ca29875a5ccea044cc14a17435794bf8ac116a05'],
}),
(name, version, {
'patches': [('ml_dtypes-0.3.2_EigenAvx512.patch', 1)],
'checksums': [
{'ml_dtypes-0.5.0.tar.gz': '3e7d3a380fe73a63c884f06136f8baa7a5249cc8e9fdec677997dd78549f8128'},
{'ml_dtypes-0.3.2_EigenAvx512.patch': '197b05b0b7f611749824369f026099f6a172f9e8eab6ebb6504a16573746c892'},
],
}),
]

sanity_pip_check = True

moduleclass = 'tools'
33 changes: 33 additions & 0 deletions easybuild/easyconfigs/p/pybind11/pybind11-2.13.6-GCC-13.3.0.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name = 'pybind11'
version = '2.13.6'

homepage = 'https://pybind11.readthedocs.io'
description = """pybind11 is a lightweight header-only library that exposes C++ types in Python and vice versa,
mainly to create Python bindings of existing C++ code."""

toolchain = {'name': 'GCC', 'version': '13.3.0'}

source_urls = ['https://github.com/pybind/pybind11/archive/']
sources = ['v%(version)s.tar.gz']
patches = [
'pybind11-2.10.3_require-catch.patch',
]
checksums = [
{'v2.13.6.tar.gz': 'e08cb87f4773da97fa7b5f035de8763abc656d87d5773e62f6da0587d1f0ec20'},
{'pybind11-2.10.3_require-catch.patch': '4a27ba3ef1d5c535d120d6178a6e876ae678e4899a07500aab37908357b0b60b'},
]

builddependencies = [
('CMake', '3.29.3'),
# Test dependencies
('Eigen', '3.4.0'),
('Catch2', '2.13.10'),
('Python-bundle-PyPI', '2024.06'), # to provide pytest
]

dependencies = [
('Boost', '1.85.0'),
('Python', '3.12.3'),
]

moduleclass = 'lib'
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ builddependencies = [
('Ninja', '1.12.1'),
('pkgconf', '2.2.0'), # required by scipy
('Cython', '3.0.10'), # required by numpy and scipy
('pybind11', '2.12.0'), # required by scipy
]

dependencies = [
('Python', '3.12.3'),
('Python-bundle-PyPI', '2024.06'),
('pybind11', '2.12.0'), # required by scipy
]

use_pip = True
Expand Down
3 changes: 3 additions & 0 deletions test/easyconfigs/easyconfigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ def check_dep_vars(self, gen, dep, dep_vars):
# OpenFOAM 5.0 requires older ParaView, CFDEMcoupling depends on OpenFOAM 5.0
(r'5\.4\.1', [r'CFDEMcoupling-3\.8\.0', r'OpenFOAM-5\.0-20180606']),
],
'pybind11': [
('2.13.6;', ['jax-0.4.35-']),
],
'pydantic': [
# GTDB-Tk v2.3.2 requires pydantic 1.x (see https://github.com/Ecogenomics/GTDBTk/pull/530)
('1.10.13;', ['GTDB-Tk-2.3.2-', 'GTDB-Tk-2.4.0-']),
Expand Down
Loading