Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide wheels for macOS ARM #5501

Closed
ericmjl opened this issue Jan 24, 2021 · 132 comments
Closed

Provide wheels for macOS ARM #5501

ericmjl opened this issue Jan 24, 2021 · 132 comments
Assignees
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request open Issues intentionally left open, with no schedule for next steps. P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)

Comments

@ericmjl
Copy link

ericmjl commented Jan 24, 2021

Hi all,

I was digging around to see what might need to happen to allow JAX to work on Apple Silicon. Knowing that JAX gets compiled to XLA, my guess here is that XLA would need to be made Apple Silicon-compatible first before JAX could run on it. May I ask, do you all know if there are plans on the XLA team to make that happen, or is it being ignored completely? (Knowing the answer can help me make some decisions on how I should set up my development environment mostly.)

Cheers,
Eric

@8bitmp3
Copy link
Contributor

8bitmp3 commented Jan 24, 2021

Check out this post #5084 by @hawkinsp (cc @rxwei)

@hawkinsp
Copy link
Collaborator

Targeting the M1's ARM CPU shouldn't be difficult.

XLA already supports AArch64 and has done for a long time. I suspect that tensorflow/tensorflow#45404 already did most of the work to adapt XLA to build on the M1 and all that is left is a few small changes to the .bazelrc file that JAX's build.py script generates, analogous to the changes in that TF PR.

That said, I don't have access to any M1 hardware, so this is in the "contributions welcome" category.

Targeting the GPU or the Neural Engine is likely a lot more difficult. For GPU, one would probably need to target Metal (probably doable, but not trivial), and I'm unsure how we could target the Neural Engine at this time.

@hawkinsp hawkinsp added contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request open Issues intentionally left open, with no schedule for next steps. labels Jan 24, 2021
@jotsif
Copy link

jotsif commented Feb 28, 2021

Hi! I have tried to get jaxlib working now on my apple m1, and have managed to build it with some minor changes according to tensorflow/tensorflow#45404 (branch on https://github.com/jotsif/jax/tree/jax_for_darwin_arm64).

However when loading jaxlib in python3 (3.9.2) I get the import error ImportError: dlopen(/opt/homebrew/lib/python3.9/site-packages/jaxlib/xla_extension.so, 2): Symbol not found: _LLVMInitializeAArch64AsmPrinter

Seems to be a known issue from comment here tensorflow/tensorflow#45404 (comment) and here elixir-nx/nx#217 (comment).

I am investigating further but if anyone has any tip that would be welcome.

@jotsif
Copy link

jotsif commented Mar 6, 2021

Update: this TF PR will probably fix this issue tensorflow/tensorflow#47594

@mattjj
Copy link
Collaborator

mattjj commented Mar 6, 2021

@jotsif thanks for that information, that's really helpful!

@jotsif
Copy link

jotsif commented Mar 7, 2021

Confirmed that it works 🎉 . Built with Bazel master, and https://github.com/freedomtan/tensorflow/tree/bazel_native_build_on_m1

>>> import platform
>>> platform.uname()
uname_result(system='Darwin', node='Josefs-MBP-2.lan', release='20.3.0', version='Darwin Kernel Version 20.3.0: Thu Jan 21 00:06:51 PST 2021; root:xnu-7195.81.3~1/RELEASE_ARM64_T8101', machine='arm64')
>>> from jax.lib import xla_client as xc
>>> xops = xc.ops
>>> c = xc.XlaBuilder("simple_scalar")
>>> param_shape = xc.Shape.array_shape(np.dtype(np.float32), ())
>>> x = xops.Parameter(c, 0, param_shape)
>>> y = xops.Sin(x)
>>> computation = c.Build()
>>> cpu_backend = xc.get_local_backend("cpu")
2021-03-07 09:26:16.684549: W external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
>>> compiled_computation = cpu_backend.compile(computation)
>>> host_input = np.array(3.0, dtype=np.float32)
>>> device_input = cpu_backend.buffer_from_pyval(host_input)
>>> device_out = compiled_computation.execute([device_input ,])
>>> device_out[0].to_py()
array(0.14112, dtype=float32)

@mattjj @hawkinsp If you want a PR I would be happy to create one, but maybe it makes more sense to wait until bazel has released a working native arm64 build and tensorflow have the necessary code in master.

@erwincoumans
Copy link
Contributor

https://github.com/freedomtan/tensorflow/tree/bazel_native_build_on_m1

Congrats! Do you mind sharing a jax/jaxlib wheel for M1?

@akbir
Copy link
Contributor

akbir commented May 10, 2021

Hi @jotsif -

Bazel 4.1 works natively with arm64 - bazelbuild/bazel#13099 and TF has necessary code in master.

If you have an example branch, i'm happy to help towards a PR now!

@hawkinsp
Copy link
Collaborator

@akbir I think we still need to wait for Bazel to actually release 4.1. But that should be soon I think! At that point we can probably just bump the Bazel dependency to 4.1 and hopefully everything should work on Mac ARM.

We can look into releasing Mac ARM wheels as well, although we don't yet have a way to test them (we personally do not have Mac ARM hardware yet), which gives me some pause.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue May 19, 2021
…res on Mac M1.

jax-ml/jax#5501

PiperOrigin-RevId: 374601482
Change-Id: If7b36f6963c74770522d06865d0800811ce63f81
@hawkinsp hawkinsp mentioned this issue May 19, 2021
@hawkinsp
Copy link
Collaborator

I believe now that at head jaxlib will build from source on a Mac M1 and pretty much everything works (*). We still don't have a great way to provide pre-built wheels yet, but hopefully this is enough to unblock everyone!

(*) It's a bit annoying to install jax still, mostly because there aren't any prebuilt scipy wheels, so you'll have to build scipy yourself to build jaxlib or use jax. I followed these instructions scipy/scipy#13409 (comment) which worked for me.

@akbir
Copy link
Contributor

akbir commented May 22, 2021

Hi @hawkinsp - tried following this but hit the following error when running build.py (included full logs at the bottom).

ERROR: Analysis of target '//build:build_wheel' failed; build aborted: Analysis of target '@local_config_cc//:toolchain' failed

Unsure why toolchain isn't working an error is discussed here: bazelbuild/bazel#13099 (comment).

I've also run build against bazel 4.1.0.rc5 and still get the same error.

Full error logs:

Bazel binary path: ./bazel-4.1.0rc4-darwin-arm64
Python binary path: /Users/akbirkhan/jax/venv/bin/python
Python version: 3.9
MKL-DNN enabled: yes
Target CPU features: release
CUDA enabled: no
TPU enabled: no
ROCm enabled: no

Building XLA and installing it in the jaxlib source tree...
./bazel-4.1.0rc4-darwin-arm64 run --verbose_failures=true --config=short_logs --config=mkl_open_source_only :build_wheel -- --output_path=/Users/akbirkhan/jax/dist
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'run' from /Users/akbirkhan/jax/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'run' from /Users/akbirkhan/jax/.bazelrc:
  Inherited 'build' options: --repo_env PYTHON_BIN_PATH=/Users/akbirkhan/jax/venv/bin/python --action_env=PYENV_ROOT --python_path=/Users/akbirkhan/jax/venv/bin/python --repo_env TF_NEED_CUDA=0 --action_env TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0 --repo_env TF_NEED_ROCM=0 --action_env TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 --distinct_host_configuration=false -c opt --apple_platform_type=macos --macos_minimum_os=10.9 --announce_rc --define open_source_build=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --spawn_strategy=standalone --strategy=Genrule=standalone --enable_platform_specific_config
INFO: Found applicable config definition build:short_logs in file /Users/akbirkhan/jax/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:mkl_open_source_only in file /Users/akbirkhan/jax/.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1
INFO: Found applicable config definition build:macos in file /Users/akbirkhan/jax/.bazelrc: --config=posix
INFO: Found applicable config definition build:posix in file /Users/akbirkhan/jax/.bazelrc: --copt=-Wno-sign-compare --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --cxxopt=-std=c++14 --host_cxxopt=-std=c++14
Loading: 
Loading: 0 packages loaded
WARNING: Download from http://mirror.tensorflow.org/github.com/tensorflow/runtime/archive/3f4cd5e8a34eb2179537b8f71b1484bb0d26701f.tar.gz failed: class com.google.devtools.build.lib.bazel.repository.downloader.UnrecoverableHttpException GET returned 404 Not Found
DEBUG: /private/var/tmp/_bazel_akbirkhan/30d03974cb2cfe2f871935b707b3981a/external/tf_runtime/third_party/cuda/dependencies.bzl:51:10: The following command will download NVIDIA proprietary software. By using the software you agree to comply with the terms of the license agreement that accompanies the software. If you do not agree to the terms of the license agreement, do not use the software.
Analyzing: target //build:build_wheel (0 packages loaded, 0 targets configured)
ERROR: /private/var/tmp/_bazel_akbirkhan/30d03974cb2cfe2f871935b707b3981a/external/local_config_cc/BUILD:48:19: in cc_toolchain_suite rule @local_config_cc//:toolchain: cc_toolchain_suite '@local_config_cc//:toolchain' does not contain a toolchain for cpu 'darwin_arm64'
DEBUG: Rule 'io_bazel_rules_docker' indicated that a canonical reproducible form can be obtained by modifying arguments shallow_since = "1556410077 -0400"
DEBUG: Repository io_bazel_rules_docker instantiated at:
  /Users/akbirkhan/jax/WORKSPACE:34:10: in <toplevel>
  /private/var/tmp/_bazel_akbirkhan/30d03974cb2cfe2f871935b707b3981a/external/org_tensorflow/tensorflow/workspace0.bzl:108:34: in workspace
  /private/var/tmp/_bazel_akbirkhan/30d03974cb2cfe2f871935b707b3981a/external/bazel_toolchains/repositories/repositories.bzl:37:23: in repositories
Repository rule git_repository defined at:
  /private/var/tmp/_bazel_akbirkhan/30d03974cb2cfe2f871935b707b3981a/external/bazel_tools/tools/build_defs/repo/git.bzl:199:33: in <toplevel>
ERROR: Analysis of target '//build:build_wheel' failed; build aborted: Analysis of target '@local_config_cc//:toolchain' failed
INFO: Elapsed time: 0.155s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (0 packages loaded, 0 targets configured)
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully (0 packages loaded, 0 targets configured),
subprocess.CalledProcessError: Command '['./bazel-4.1.0rc4-darwin-arm64', 'run', '--verbose_failures=true', '--config=short_logs', '--config=mkl_open_source_only', ':build_wheel', '--', '--output_path=/Users/akbirkhan/jax/dist']' returned non-zero exit status 1.

@Noahyt
Copy link

Noahyt commented May 22, 2021

Getting same error as @akbir .

ERROR: /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/local_config_cc/BUILD:48:19: in cc_toolchain_suite rule @local_config_cc//:toolchain: cc_toolchain_suite '@local_config_cc//:toolchain' does not contain a toolchain for cpu 'darwin_arm64'
INFO: Repository com_google_absl instantiated at:
  /Users/noah/harvard/2021/network_repair/jax/WORKSPACE:30:10: in <toplevel>
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/tensorflow/workspace2.bzl:1090:28: in workspace
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/tensorflow/workspace2.bzl:56:9: in _initialize_third_party
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/third_party/absl/workspace.bzl:12:20: in repo
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/third_party/repo.bzl:112:21: in tf_http_archive
Repository rule _tf_http_archive defined at:
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/third_party/repo.bzl:65:35: in <toplevel>
Analyzing: target //build:build_wheel (35 packages loaded, 264 targets configured)
INFO: Repository cython instantiated at:
  /Users/noah/harvard/2021/network_repair/jax/WORKSPACE:30:10: in <toplevel>
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/tensorflow/workspace2.bzl:1097:21: in workspace
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/tensorflow/workspace2.bzl:845:20: in _tf_repositories
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/third_party/repo.bzl:112:21: in tf_http_archive
Repository rule _tf_http_archive defined at:
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/third_party/repo.bzl:65:35: in <toplevel>
INFO: Repository pocketfft instantiated at:
  /Users/noah/harvard/2021/network_repair/jax/WORKSPACE:24:10: in <toplevel>
  /Users/noah/harvard/2021/network_repair/jax/third_party/pocketfft/workspace.bzl:20:17: in repo
Repository rule http_archive defined at:
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/bazel_tools/tools/build_defs/repo/http.bzl:336:31: in <toplevel>
ERROR: Analysis of target '//build:build_wheel' failed; build aborted: Analysis of target '@local_config_cc//:toolchain' failed
INFO: Elapsed time: 10.071s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (35 packages loaded, 264 targets configured)
ERROR: Build failed. Not running target

@Noahyt
Copy link

Noahyt commented May 22, 2021

FYI for those looking for a quick and dirty workaround, you can install jax and jaxlib using pip a miniconda environment running in Rosetta 2. The most recent versions of jax and jaxlib don't work (giving an error like "zsh: illegal hardware instruction"), So I ended up using jax == 0.2.10 and jaxlib==0.1.60.

@hawkinsp
Copy link
Collaborator

@akbir Do you have a working XCode installation including the command line tools?
https://jax.readthedocs.io/en/latest/developer.html#building-jaxlib-from-source

I don't think the error you are seeing is related to the Bazel version. We're currently pinning Bazel 4.1.0rc4 because that was the newest version last week. If you like you can try a different Bazel version, but 4.1.0rc4 worked for me. The easiest way to do that is to install Bazel yourself and pass --bazel_path=/somewhere/bazel to the build.py command line.

@Noahyt jaxlib 0.1.62 and newer on x86 use AVX, which Rosetta does not support (https://github.com/google/jax/blob/master/CHANGELOG.md#jaxlib-0162-march-9-2021). All recent x86 CPUs support AVX and have for a long time. We don't intend to ship wheels without AVX, although if you like you can build a jaxlib from source that does not require AVX. But I don't think we should worry about that too much, since we want a native ARM version, anyway.

@akbir
Copy link
Contributor

akbir commented May 24, 2021

Got this to finally build!! Thank you @hawkinsp

For others, I installed Xcode (not just command-line interface) and Bazel.

Also updated .bazelversion to 4.1.0 (should this be updated in the repo?)

sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license
bazel sync --configure

python build.py --bazel_path=/somewhere/bazel

@hawkinsp
Copy link
Collaborator

@akbir If you want to send a PR that updates the Bazel version for Mac ARM, that sounds great! The version of bazel is chosen here: https://cs.opensource.google/jax/jax/+/master:build/build.py;drc=dacf31f2020175181014745cdabc240a10031227;l=119

@akbir
Copy link
Contributor

akbir commented May 24, 2021

would love to!

quick question - why does jax also specify the version here: https://github.com/google/jax/blob/master/.bazelversion ?

@hawkinsp
Copy link
Collaborator

@akbir If I remember correctly, that's for folks using Bazelisk (https://github.com/bazelbuild/bazelisk). I don't know if it's possible to specify a separate version for Mac arm via Bazelisk.

We can probably upgrade to 4.1.0 for all platforms, but let's not do that right away. So fixing build.py is probably enough.

@yiyaz
Copy link

yiyaz commented Jan 27, 2022

I have tried the solution proposed by dfm and yashk2810 with no success.
dfm's approach yields ERROR: Could not find a version that satisfies the requirement jaxlib==0.1.74 (from versions: 0.1.75) ERROR: No matching distribution found for jaxlib==0.1.74

while yashk2810's approach gives me
ERROR: jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl is not a supported wheel on this platform.
updating pip to latest version and python to 3.10.2 did not help either

@phinate
Copy link

phinate commented Jan 27, 2022

You should be able to just pip install jax jaxlib now -- as the first error states, jaxlib==0.1.75 is now provided on pypi for ARM.

From yashk:

I am going to close this issue as it has been fixed and on the next release, we will have official jaxlib maxos arm64 wheels on pypi.

@nlp4whp
Copy link

nlp4whp commented Jan 29, 2022

You should be able to just pip install jax jaxlib now -- as the first error states, jaxlib==0.1.75 is now provided on pypi for ARM.

From yashk:

I am going to close this issue as it has been fixed and on the next release, we will have official jaxlib maxos arm64 wheels on pypi.

THANKS, and would it be supported for py3.8 on maxos arm64?

@hawkinsp
Copy link
Collaborator

We don't provide 3.8 wheels for mac arm64, only 3.9 and 3.10. (I think originally we were under the impression that 3.8 was never released for Mac ARM, although I guess that's not true.) I guess we could, though.

@dwyatte
Copy link

dwyatte commented Jan 30, 2022

@hawkinsp please see #9065 for a specific request for Python 3.8 Mac OS arm64 wheels if it's not a ton of effort for maintainers.

@nickreich
Copy link

I had the same problem. Since I already had Anaconda installed and didn't want to clutter up my space with Anaconda + miniconda + homebrew and whatever, what worked for my was installing jax and jaxlib via conda-forge directly:

conda install -c conda-forge jaxlib
conda install -c conda-forge jax

@gabrieldernbach
Copy link

gabrieldernbach commented Feb 18, 2022

The conda packages from -c conda-forge did not work for me. Yet pip will fail on installation of the dependency scipy. In the end this was the shortest way I could find

conda create --name venv python=3.10
conda activate venv
conda install -y scipy
pip install jax jaxlib

@xhochy
Copy link

xhochy commented Feb 18, 2022

@gabrieldernbach Can you enlighten me what didn't work with the conda packages? Since yesterday new versions are up that should work.

@gabrielhuang
Copy link

gabrielhuang commented Nov 4, 2022

By far the easiest solution for me was to re-install Python using the latest arm64 version of Miniforge (not Miniconda) then pip install jaxlib jax

@ddrous
Copy link

ddrous commented Nov 30, 2022

By far the easiest solution for me was to re-install Python using the latest arm64 version of Miniforge (not Miniconda) then pip install jaxlib jax

I've tried this, and it works!! This video might be useful for people already using miniconda, that want miniforge on the side or as default.

@timhdesilva
Copy link

Hi all,

I just upgraded from an Intel Mac to an M2 Mac and read this thread. What is the best way for my to proceed in terms of installing JAX on the M2? Is it possible to build JAX to work with Apple GPU in addition to CPUs?

Thanks and apologies in advance if I missed this in the discussion above!

@kechan
Copy link

kechan commented Mar 26, 2023

I had the same problem. Since I already had Anaconda installed and didn't want to clutter up my space with Anaconda + miniconda + homebrew and whatever, what worked for my was installing jax and jaxlib via conda-forge directly:

conda install -c conda-forge jaxlib
conda install -c conda-forge jax

Is this CPU only? How about Apple GPU cores, or maybe even the Neural Engines?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request open Issues intentionally left open, with no schedule for next steps. P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)
Projects
None yet
Development

No branches or pull requests