Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example demonstrating input-output aliasing with the FFI #25042

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions examples/ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ find_package(nanobind CONFIG REQUIRED)
set(
JAX_FFI_EXAMPLE_PROJECTS
"rms_norm"
"attrs"
"counter"
"cpu_examples"
)

foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
Expand All @@ -27,9 +26,9 @@ endforeach()

if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
enable_language(CUDA)
add_library(_cuda_e2e SHARED "src/jax_ffi_example/cuda_e2e.cu")
set_target_properties(_cuda_e2e PROPERTIES POSITION_INDEPENDENT_CODE ON
CUDA_STANDARD 17)
target_include_directories(_cuda_e2e PUBLIC ${XLA_DIR})
install(TARGETS _cuda_e2e LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu")
set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON
CUDA_STANDARD 17)
target_include_directories(_cuda_examples PUBLIC ${XLA_DIR})
install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
endif()
27 changes: 14 additions & 13 deletions examples/ffi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@ Within the example project, there are several example calls:
demonstrates the most basic use of the FFI. It also includes customization of
behavior under automatic differentiation using `jax.custom_vjp`.

2. `counter`: This example demonstrates a common pattern for how an FFI call can
use global cache to maintain state between calls. This pattern is useful when
an FFI call requires an expensive initialization step which shouldn't be
run on every execution, or if there is other shared state that could be
reused between calls. In this simple example we just count the number of
times the call was executed.
2. `cpu_examples`: This submodule includes several smaller examples:

3. `attrs`: An example demonstrating the different ways that attributes can be
passed to the FFI. For example, we can pass arrays, variadic attributes, and
user-defined types. Full support of user-defined types isn't yet supported
by XLA, so that example will be added in the future.
* `counter`: This example demonstrates a common pattern for how an FFI call
can use global cache to maintain state between calls. This pattern is
useful when an FFI call requires an expensive initialization step which
shouldn't be run on every execution, or if there is other shared state
that could be reused between calls. In this simple example we just count
the number of times the call was executed.
* `attrs`: An example demonstrating the different ways that attributes can be
passed to the FFI. For example, we can pass arrays, variadic attributes,
and user-defined types. Full support of user-defined types isn't yet
supported by XLA, so that example will be added in the future.

4. `cuda_e2e`: An end-to-end example demonstrating the use of the JAX FFI with
CUDA. The specifics of the kernels are not very important, but the general
structure, and packaging of the extension are useful for testing.
3. `cuda_examples`: An end-to-end example demonstrating the use of the JAX FFI
with CUDA. The specifics of the kernels are not very important, but the
general structure, and packaging of the extension are useful for testing.
66 changes: 0 additions & 66 deletions examples/ffi/src/jax_ffi_example/attrs.cc

This file was deleted.

53 changes: 0 additions & 53 deletions examples/ffi/src/jax_ffi_example/counter.cc

This file was deleted.

38 changes: 0 additions & 38 deletions examples/ffi/src/jax_ffi_example/counter.py

This file was deleted.

148 changes: 148 additions & 0 deletions examples/ffi/src/jax_ffi_example/cpu_examples.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/* Copyright 2024 The JAX Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>
#include <mutex>
#include <string_view>
#include <unordered_map>

#include "nanobind/nanobind.h"
#include "xla/ffi/api/ffi.h"

namespace nb = nanobind;
namespace ffi = xla::ffi;

// ----------
// Attributes
// ----------
//
// An example demonstrating the different ways that attributes can be passed to
// the FFI.
//
// For example, we can pass arrays, variadic attributes, and user-defined types.
// Full support of user-defined types isn't yet supported by XLA, so that
// example will be added in the future.

ffi::Error ArrayAttrImpl(ffi::Span<const int32_t> array,
ffi::ResultBufferR0<ffi::S32> res) {
int64_t total = 0;
for (int32_t x : array) {
total += x;
}
res->typed_data()[0] = total;
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl,
ffi::Ffi::Bind()
.Attr<ffi::Span<const int32_t>>("array")
.Ret<ffi::BufferR0<ffi::S32>>());

ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs,
ffi::ResultBufferR0<ffi::S32> secret,
ffi::ResultBufferR0<ffi::S32> count) {
auto maybe_secret = attrs.get<int64_t>("secret");
if (maybe_secret.has_error()) {
return maybe_secret.error();
}
secret->typed_data()[0] = maybe_secret.value();
count->typed_data()[0] = attrs.size();
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl,
ffi::Ffi::Bind()
.Attrs()
.Ret<ffi::BufferR0<ffi::S32>>()
.Ret<ffi::BufferR0<ffi::S32>>());

// -------
// Counter
// -------
//
// An example demonstrating how an FFI call can maintain "state" between calls
//
// In this case, the ``Counter`` call simply accumulates the number of times it
// was executed, but this pattern can also be used for more advanced use cases.
// For example, this pattern is used in jaxlib for:
//
// 1. The GPU solver linear algebra kernels which require an expensive "handler"
// initialization, and
// 2. The ``triton_call`` function which caches the compiled triton modules
// after their first use.

ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0<ffi::S32> out) {
static std::mutex mutex;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, how I'd love to see it converted to FFI state :)

static auto &cache = *new std::unordered_map<int64_t, int32_t>();
{
const std::lock_guard<std::mutex> lock(mutex);
auto it = cache.find(index);
if (it != cache.end()) {
out->typed_data()[0] = ++it->second;
} else {
cache.insert({index, 0});
out->typed_data()[0] = 0;
}
}
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
Counter, CounterImpl,
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());

// --------
// Aliasing
// --------
//
// This example demonstrates how input-output aliasing works. The handler
// doesn't do anything except to check that the input and output pointers
// address the same data.

ffi::Error AliasingImpl(ffi::AnyBuffer input,
ffi::Result<ffi::AnyBuffer> output) {
if (input.element_type() != output->element_type() ||
input.element_count() != output->element_count()) {
return ffi::Error::InvalidArgument(
"The input and output data types and sizes must match.");
}
if (input.untyped_data() != output->untyped_data()) {
return ffi::Error::InvalidArgument(
"When aliased, the input and output buffers should point to the same "
"data.");
}
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
Aliasing, AliasingImpl,
ffi::Ffi::Bind().Arg<ffi::AnyBuffer>().Ret<ffi::AnyBuffer>());

// Boilerplate for exposing handlers to Python
NB_MODULE(_cpu_examples, m) {
m.def("registrations", []() {
nb::dict registrations;
registrations["array_attr"] =
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
registrations["dictionary_attr"] =
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));

registrations["counter"] = nb::capsule(reinterpret_cast<void *>(Counter));

registrations["aliasing"] = nb::capsule(reinterpret_cast<void *>(Aliasing));

return registrations;
});
}
Loading
Loading