Major Features
Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official OpenXLA on Intel GPU. It is based on PJRT plugin mechanism, which can seamlessly run JAX models on Intel® Data Center GPU Max Series and Intel® Data Center GPU Flex Series. This release contains following major features:
- Jax Upgrade: Upgrade version to v0.4.30 and support the Compatibility of
jax
andjaxlib
, it allows the Extension to support multiple different versions ofjax
. Please refer to <How are jax and jaxlib versioned?> for more version details betweenjax
andjaxlib
.intel-extension-for-openxla jaxlib jax 0.5.0 0.4.30 >= 0.4.30, <= 0.4.31 - Feature Support:
- Support for Python 3.9,3.10,3.11,3.12 versions.
- Continue to improve
jax
native distributed scale-up collectives. - Support for accuracy for GPT-J with different layer number.
- Continue to improve support of the FMHA backward fusion
- Bug Fix:
- Fix Forward MHA accuracy error.
- Fix known caveat fix-in-place error that occurred on Stable Diffusion model.
- Fix known caveat hang related to deadlock when working with Toolkit 2025.0.
- Fix known caveat some unit test failures with the latest graphics driver.
- Fix known caveat OOM related to deprecated API clear_backends.
- Toolkit Support: Support Intel® oneAPI Base Toolkit 2025.0.
- Driver Support: Support upgraded Driver LTS release 2350.125
- OneDNN support: Support oneDNN v3.6.1.
Known Caveats
- Flan T5 and Gemma models have a dependency on Tensorflow-Text, which doesn't support Python 3.12.
- Multi-process API support is still experimental and may cause hang issues with collectives.
Breaking changes
- Previous JAX v0.4.26 is no longer supported by this release. Please follow JAX change log to update the application if meets version errors. Please roll back the Extension version if want to use it with old JAX version.
Documents
- Introduce of Intel® Extension for OpenXLA*
- Accelerate JAX models on Intel GPUs via PJRT
- How JAX and OpenXLA Enabled an Argonne Workload and Quality Assurance on the Aurora Supercompute
- JAX and OpenXLA* Part 1: Run Process and Underlying Logic
- JAX and OpenXLA Part 2: Run Process and Underlying Logic
- How are jax and jaxlib versioned?