From 3a7120e75cae4b3a439a043d9306f6a112815a48 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 21 Nov 2024 12:44:27 -0500 Subject: [PATCH 1/2] Refactor FFI examples to consolidate several examples into one submodule. --- examples/ffi/CMakeLists.txt | 13 ++--- examples/ffi/README.md | 27 ++++----- examples/ffi/src/jax_ffi_example/counter.cc | 53 ------------------ examples/ffi/src/jax_ffi_example/counter.py | 38 ------------- .../{attrs.cc => cpu_examples.cc} | 55 ++++++++++++++++++- .../{attrs.py => cpu_examples.py} | 17 +++--- .../{cuda_e2e.cu => cuda_examples.cu} | 0 .../{cuda_e2e.py => cuda_examples.py} | 2 +- examples/ffi/tests/counter_test.py | 55 ------------------- .../{attrs_test.py => cpu_examples_test.py} | 47 +++++++++++++--- ...cuda_e2e_test.py => cuda_examples_test.py} | 4 +- 11 files changed, 122 insertions(+), 189 deletions(-) delete mode 100644 examples/ffi/src/jax_ffi_example/counter.cc delete mode 100644 examples/ffi/src/jax_ffi_example/counter.py rename examples/ffi/src/jax_ffi_example/{attrs.cc => cpu_examples.cc} (57%) rename examples/ffi/src/jax_ffi_example/{attrs.py => cpu_examples.py} (73%) rename examples/ffi/src/jax_ffi_example/{cuda_e2e.cu => cuda_examples.cu} (100%) rename examples/ffi/src/jax_ffi_example/{cuda_e2e.py => cuda_examples.py} (99%) delete mode 100644 examples/ffi/tests/counter_test.py rename examples/ffi/tests/{attrs_test.py => cpu_examples_test.py} (55%) rename examples/ffi/tests/{cuda_e2e_test.py => cuda_examples_test.py} (96%) diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt index 9f9090e2b7ef..843c2cda0e3b 100644 --- a/examples/ffi/CMakeLists.txt +++ b/examples/ffi/CMakeLists.txt @@ -15,8 +15,7 @@ find_package(nanobind CONFIG REQUIRED) set( JAX_FFI_EXAMPLE_PROJECTS "rms_norm" - "attrs" - "counter" + "cpu_examples" ) foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS}) @@ -27,9 +26,9 @@ endforeach() if(JAX_FFI_EXAMPLE_ENABLE_CUDA) enable_language(CUDA) - add_library(_cuda_e2e SHARED "src/jax_ffi_example/cuda_e2e.cu") - set_target_properties(_cuda_e2e PROPERTIES POSITION_INDEPENDENT_CODE ON - CUDA_STANDARD 17) - target_include_directories(_cuda_e2e PUBLIC ${XLA_DIR}) - install(TARGETS _cuda_e2e LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) + add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu") + set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON + CUDA_STANDARD 17) + target_include_directories(_cuda_examples PUBLIC ${XLA_DIR}) + install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) endif() diff --git a/examples/ffi/README.md b/examples/ffi/README.md index eb730b483b76..bd45408e50d8 100644 --- a/examples/ffi/README.md +++ b/examples/ffi/README.md @@ -11,18 +11,19 @@ Within the example project, there are several example calls: demonstrates the most basic use of the FFI. It also includes customization of behavior under automatic differentiation using `jax.custom_vjp`. -2. `counter`: This example demonstrates a common pattern for how an FFI call can - use global cache to maintain state between calls. This pattern is useful when - an FFI call requires an expensive initialization step which shouldn't be - run on every execution, or if there is other shared state that could be - reused between calls. In this simple example we just count the number of - times the call was executed. +2. `cpu_examples`: This submodule includes several smaller examples: -3. `attrs`: An example demonstrating the different ways that attributes can be - passed to the FFI. For example, we can pass arrays, variadic attributes, and - user-defined types. Full support of user-defined types isn't yet supported - by XLA, so that example will be added in the future. + * `counter`: This example demonstrates a common pattern for how an FFI call + can use global cache to maintain state between calls. This pattern is + useful when an FFI call requires an expensive initialization step which + shouldn't be run on every execution, or if there is other shared state + that could be reused between calls. In this simple example we just count + the number of times the call was executed. + * `attrs`: An example demonstrating the different ways that attributes can be + passed to the FFI. For example, we can pass arrays, variadic attributes, + and user-defined types. Full support of user-defined types isn't yet + supported by XLA, so that example will be added in the future. -4. `cuda_e2e`: An end-to-end example demonstrating the use of the JAX FFI with - CUDA. The specifics of the kernels are not very important, but the general - structure, and packaging of the extension are useful for testing. +3. `cuda_examples`: An end-to-end example demonstrating the use of the JAX FFI + with CUDA. The specifics of the kernels are not very important, but the + general structure, and packaging of the extension are useful for testing. diff --git a/examples/ffi/src/jax_ffi_example/counter.cc b/examples/ffi/src/jax_ffi_example/counter.cc deleted file mode 100644 index d7f17e730fd6..000000000000 --- a/examples/ffi/src/jax_ffi_example/counter.cc +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2024 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "nanobind/nanobind.h" -#include "xla/ffi/api/ffi.h" - -namespace nb = nanobind; -namespace ffi = xla::ffi; - -ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0 out) { - static std::mutex mutex; - static auto& cache = *new std::unordered_map(); - { - const std::lock_guard lock(mutex); - auto it = cache.find(index); - if (it != cache.end()) { - out->typed_data()[0] = ++it->second; - } else { - cache.insert({index, 0}); - out->typed_data()[0] = 0; - } - } - return ffi::Error::Success(); -} - -XLA_FFI_DEFINE_HANDLER_SYMBOL( - Counter, CounterImpl, - ffi::Ffi::Bind().Attr("index").Ret>()); - -NB_MODULE(_counter, m) { - m.def("registrations", []() { - nb::dict registrations; - registrations["counter"] = nb::capsule(reinterpret_cast(Counter)); - return registrations; - }); -} diff --git a/examples/ffi/src/jax_ffi_example/counter.py b/examples/ffi/src/jax_ffi_example/counter.py deleted file mode 100644 index 12c7f015bf58..000000000000 --- a/examples/ffi/src/jax_ffi_example/counter.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""An example demonstrating how an FFI call can maintain "state" between calls - -In this case, the ``counter`` call simply accumulates the number of times it -was executed, but this pattern can also be used for more advanced use cases. -For example, this pattern is used in jaxlib for: - -1. The GPU solver linear algebra kernels which require an expensive "handler" - initialization, and -2. The ``triton_call`` function which caches the compiled triton modules after - their first use. -""" - -import jax -import jax.extend as jex - -from jax_ffi_example import _counter - -for name, target in _counter.registrations().items(): - jex.ffi.register_ffi_target(name, target) - - -def counter(index): - return jex.ffi.ffi_call( - "counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index)) diff --git a/examples/ffi/src/jax_ffi_example/attrs.cc b/examples/ffi/src/jax_ffi_example/cpu_examples.cc similarity index 57% rename from examples/ffi/src/jax_ffi_example/attrs.cc rename to examples/ffi/src/jax_ffi_example/cpu_examples.cc index 7ff5c98e52e1..3832c86b29b2 100644 --- a/examples/ffi/src/jax_ffi_example/attrs.cc +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include #include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" @@ -21,6 +24,17 @@ limitations under the License. namespace nb = nanobind; namespace ffi = xla::ffi; +// ---------- +// Attributes +// ---------- +// +// An example demonstrating the different ways that attributes can be passed to +// the FFI. +// +// For example, we can pass arrays, variadic attributes, and user-defined types. +// Full support of user-defined types isn't yet supported by XLA, so that +// example will be added in the future. + ffi::Error ArrayAttrImpl(ffi::Span array, ffi::ResultBufferR0 res) { int64_t total = 0; @@ -54,13 +68,52 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl, .Ret>() .Ret>()); -NB_MODULE(_attrs, m) { +// ------- +// Counter +// ------- +// +// An example demonstrating how an FFI call can maintain "state" between calls +// +// In this case, the ``Counter`` call simply accumulates the number of times it +// was executed, but this pattern can also be used for more advanced use cases. +// For example, this pattern is used in jaxlib for: +// +// 1. The GPU solver linear algebra kernels which require an expensive "handler" +// initialization, and +// 2. The ``triton_call`` function which caches the compiled triton modules +// after their first use. + +ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0 out) { + static std::mutex mutex; + static auto &cache = *new std::unordered_map(); + { + const std::lock_guard lock(mutex); + auto it = cache.find(index); + if (it != cache.end()) { + out->typed_data()[0] = ++it->second; + } else { + cache.insert({index, 0}); + out->typed_data()[0] = 0; + } + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + Counter, CounterImpl, + ffi::Ffi::Bind().Attr("index").Ret>()); + +// Boilerplate for exposing handlers to Python +NB_MODULE(_cpu_examples, m) { m.def("registrations", []() { nb::dict registrations; registrations["array_attr"] = nb::capsule(reinterpret_cast(ArrayAttr)); registrations["dictionary_attr"] = nb::capsule(reinterpret_cast(DictionaryAttr)); + + registrations["counter"] = nb::capsule(reinterpret_cast(Counter)); + return registrations; }); } diff --git a/examples/ffi/src/jax_ffi_example/attrs.py b/examples/ffi/src/jax_ffi_example/cpu_examples.py similarity index 73% rename from examples/ffi/src/jax_ffi_example/attrs.py rename to examples/ffi/src/jax_ffi_example/cpu_examples.py index 2f215e8e25b1..7771237e41d1 100644 --- a/examples/ffi/src/jax_ffi_example/attrs.py +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.py @@ -12,22 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""An example demonstrating the different ways that attributes can be passed to -the FFI. - -For example, we can pass arrays, variadic attributes, and user-defined types. -Full support of user-defined types isn't yet supported by XLA, so that example -will be added in the future. -""" - import numpy as np import jax import jax.extend as jex -from jax_ffi_example import _attrs +from jax_ffi_example import _cpu_examples -for name, target in _attrs.registrations().items(): +for name, target in _cpu_examples.registrations().items(): jex.ffi.register_ffi_target(name, target) @@ -43,3 +35,8 @@ def dictionary_attr(**kwargs): "dictionary_attr", (jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)), )(**kwargs) + + +def counter(index): + return jex.ffi.ffi_call( + "counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index)) diff --git a/examples/ffi/src/jax_ffi_example/cuda_e2e.cu b/examples/ffi/src/jax_ffi_example/cuda_examples.cu similarity index 100% rename from examples/ffi/src/jax_ffi_example/cuda_e2e.cu rename to examples/ffi/src/jax_ffi_example/cuda_examples.cu diff --git a/examples/ffi/src/jax_ffi_example/cuda_e2e.py b/examples/ffi/src/jax_ffi_example/cuda_examples.py similarity index 99% rename from examples/ffi/src/jax_ffi_example/cuda_e2e.py rename to examples/ffi/src/jax_ffi_example/cuda_examples.py index 500677050a4b..b60b12af577e 100644 --- a/examples/ffi/src/jax_ffi_example/cuda_e2e.py +++ b/examples/ffi/src/jax_ffi_example/cuda_examples.py @@ -27,7 +27,7 @@ import jax.extend as jex # Load the shared library with the FFI target definitions -SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_e2e.so") +SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_examples.so") library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY) jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd), diff --git a/examples/ffi/tests/counter_test.py b/examples/ffi/tests/counter_test.py deleted file mode 100644 index 1e2ad38a363f..000000000000 --- a/examples/ffi/tests/counter_test.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -import jax -from jax._src import test_util as jtu - -from jax_ffi_example import counter - -jax.config.parse_flags_with_absl() - - -class CounterTests(jtu.JaxTestCase): - def setUp(self): - super().setUp() - if not jtu.test_device_matches(["cpu"]): - self.skipTest("Unsupported platform") - - def test_basic(self): - self.assertEqual(counter.counter(0), 0) - self.assertEqual(counter.counter(0), 1) - self.assertEqual(counter.counter(0), 2) - self.assertEqual(counter.counter(1), 0) - self.assertEqual(counter.counter(0), 3) - - def test_jit(self): - @jax.jit - def counter_fun(x): - return x, counter.counter(2) - - self.assertEqual(counter_fun(0)[1], 0) - self.assertEqual(counter_fun(0)[1], 1) - - # Persists across different cache hits - self.assertEqual(counter_fun(1)[1], 2) - - # Persists after the cache is cleared - counter_fun.clear_cache() - self.assertEqual(counter_fun(0)[1], 3) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/examples/ffi/tests/attrs_test.py b/examples/ffi/tests/cpu_examples_test.py similarity index 55% rename from examples/ffi/tests/attrs_test.py rename to examples/ffi/tests/cpu_examples_test.py index 2eef1f627006..cb2653d2e928 100644 --- a/examples/ffi/tests/attrs_test.py +++ b/examples/ffi/tests/cpu_examples_test.py @@ -18,7 +18,7 @@ import jax.numpy as jnp from jax._src import test_util as jtu -from jax_ffi_example import attrs +from jax_ffi_example import cpu_examples jax.config.parse_flags_with_absl() @@ -30,11 +30,11 @@ def setUp(self): self.skipTest("Unsupported platform") def test_array_attr(self): - self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum()) - self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum()) + self.assertEqual(cpu_examples.array_attr(5), jnp.arange(5).sum()) + self.assertEqual(cpu_examples.array_attr(3), jnp.arange(3).sum()) def test_array_attr_jit_cache(self): - jit_array_attr = jax.jit(attrs.array_attr, static_argnums=(0,)) + jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,)) with jtu.count_jit_and_pmap_lowerings() as count: jit_array_attr(5) self.assertEqual(count[0], 1) # compiles once the first time @@ -44,22 +44,51 @@ def test_array_attr_jit_cache(self): def test_array_attr_no_jit(self): with jax.disable_jit(): - attrs.array_attr(5) # doesn't crash + cpu_examples.array_attr(5) # doesn't crash def test_dictionary_attr(self): - secret, count = attrs.dictionary_attr(secret=5) + secret, count = cpu_examples.dictionary_attr(secret=5) self.assertEqual(secret, 5) self.assertEqual(count, 1) - secret, count = attrs.dictionary_attr(secret=3, a_string="hello") + secret, count = cpu_examples.dictionary_attr(secret=3, a_string="hello") self.assertEqual(secret, 3) self.assertEqual(count, 2) with self.assertRaisesRegex(Exception, "Unexpected attribute"): - attrs.dictionary_attr() + cpu_examples.dictionary_attr() with self.assertRaisesRegex(Exception, "Wrong attribute type"): - attrs.dictionary_attr(secret="invalid") + cpu_examples.dictionary_attr(secret="invalid") + + +class CounterTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Unsupported platform") + + def test_basic(self): + self.assertEqual(cpu_examples.counter(0), 0) + self.assertEqual(cpu_examples.counter(0), 1) + self.assertEqual(cpu_examples.counter(0), 2) + self.assertEqual(cpu_examples.counter(1), 0) + self.assertEqual(cpu_examples.counter(0), 3) + + def test_jit(self): + @jax.jit + def counter_fun(x): + return x, cpu_examples.counter(2) + + self.assertEqual(counter_fun(0)[1], 0) + self.assertEqual(counter_fun(0)[1], 1) + + # Persists across different cache hits + self.assertEqual(counter_fun(1)[1], 2) + + # Persists after the cache is cleared + counter_fun.clear_cache() + self.assertEqual(counter_fun(0)[1], 3) if __name__ == "__main__": diff --git a/examples/ffi/tests/cuda_e2e_test.py b/examples/ffi/tests/cuda_examples_test.py similarity index 96% rename from examples/ffi/tests/cuda_e2e_test.py rename to examples/ffi/tests/cuda_examples_test.py index 83397f7ff5d7..f4a736599ce4 100644 --- a/examples/ffi/tests/cuda_e2e_test.py +++ b/examples/ffi/tests/cuda_examples_test.py @@ -28,8 +28,8 @@ def setUp(self): self.skipTest("Unsupported platform") # Import here to avoid trying to load the library when it's not built. - from jax_ffi_example import cuda_e2e - self.foo = cuda_e2e.foo + from jax_ffi_example import cuda_examples + self.foo = cuda_examples.foo def test_fwd_interpretable(self): shape = (2, 3) From 9b735935bf4fde2ea036800b60d3b654b063db7b Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 21 Nov 2024 13:08:50 -0500 Subject: [PATCH 2/2] Add an example demonstrating input-output aliasing with the FFI. --- .../ffi/src/jax_ffi_example/cpu_examples.cc | 29 +++++++++++++++++++ .../ffi/src/jax_ffi_example/cpu_examples.py | 6 ++++ examples/ffi/tests/cpu_examples_test.py | 13 ++++++++- 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/examples/ffi/src/jax_ffi_example/cpu_examples.cc b/examples/ffi/src/jax_ffi_example/cpu_examples.cc index 3832c86b29b2..c67eab72dd64 100644 --- a/examples/ffi/src/jax_ffi_example/cpu_examples.cc +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.cc @@ -103,6 +103,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( Counter, CounterImpl, ffi::Ffi::Bind().Attr("index").Ret>()); +// -------- +// Aliasing +// -------- +// +// This example demonstrates how input-output aliasing works. The handler +// doesn't do anything except to check that the input and output pointers +// address the same data. + +ffi::Error AliasingImpl(ffi::AnyBuffer input, + ffi::Result output) { + if (input.element_type() != output->element_type() || + input.element_count() != output->element_count()) { + return ffi::Error::InvalidArgument( + "The input and output data types and sizes must match."); + } + if (input.untyped_data() != output->untyped_data()) { + return ffi::Error::InvalidArgument( + "When aliased, the input and output buffers should point to the same " + "data."); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + Aliasing, AliasingImpl, + ffi::Ffi::Bind().Arg().Ret()); + // Boilerplate for exposing handlers to Python NB_MODULE(_cpu_examples, m) { m.def("registrations", []() { @@ -114,6 +141,8 @@ NB_MODULE(_cpu_examples, m) { registrations["counter"] = nb::capsule(reinterpret_cast(Counter)); + registrations["aliasing"] = nb::capsule(reinterpret_cast(Aliasing)); + return registrations; }); } diff --git a/examples/ffi/src/jax_ffi_example/cpu_examples.py b/examples/ffi/src/jax_ffi_example/cpu_examples.py index 7771237e41d1..a5e7ec69d25c 100644 --- a/examples/ffi/src/jax_ffi_example/cpu_examples.py +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.py @@ -40,3 +40,9 @@ def dictionary_attr(**kwargs): def counter(index): return jex.ffi.ffi_call( "counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index)) + + +def aliasing(x): + return jex.ffi.ffi_call( + "aliasing", jax.ShapeDtypeStruct(x.shape, x.dtype), + input_output_aliases={0: 0})(x) diff --git a/examples/ffi/tests/cpu_examples_test.py b/examples/ffi/tests/cpu_examples_test.py index cb2653d2e928..d9091278749f 100644 --- a/examples/ffi/tests/cpu_examples_test.py +++ b/examples/ffi/tests/cpu_examples_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax import jax.numpy as jnp @@ -91,5 +91,16 @@ def counter_fun(x): self.assertEqual(counter_fun(0)[1], 3) +class AliasingTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Unsupported platform") + + @parameterized.parameters((jnp.linspace(0, 0.5, 10),), (jnp.int32(6),)) + def test_basic(self, x): + self.assertAllClose(cpu_examples.aliasing(x), x) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())