Intel® Extension for OpenXLA* 0.2.0 Release
Major Features
Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official OpenXLA on Intel GPU. It is based on PJRT plugin mechanism, which can seamlessly run JAX models on Intel® Data Center GPU Max Series. This release contains following major features:
-
Upgrade JAX version to v0.4.20.
-
Experimental support JAX native distributed scale-up collectives based on JAX pmap.
-
Continuous optimize common kernels, and optimize GEMM kernels by Intel® Xe Templates for Linear Algebra. 3 inference models (Stable Diffusion, GPT-J, FLAN-T5) are verified on Intel® Data Center GPU Max Series single device, and added to examples.
Known Caveats
-
Device number is restricted as 2/4/6/8/10/12 in the experimental supported collectives in single node.
-
XLA_ENABLE_MULTIPLE_STREAM=1
should be set when use JAX parallelization on multiply devices without collectives. It will add synchronization between different devices to avoid possible accuracy issue. -
MHA=0
should be set to disable MHA fusion in training. MHA fusion is not supported in training yet and will cause runtime error as below:
ir_emission_[utils.cc:109](http://utils.cc:109/)] Check failed: lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)) == rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))
Breaking changes
-
Previous JAX v0.4.13 is no longer supported. Please follow JAX change log to update application if meet version errors.
-
GCC 10.0.0 or newer is required if build from source. Please refer installation guide for more details.