Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FFI example demonstrating the use of XLA's FFI state. #26235

Merged
merged 1 commit into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions examples/ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,29 @@ message(STATUS "XLA include directory: ${XLA_DIR}")
find_package(nanobind CONFIG REQUIRED)

set(
JAX_FFI_EXAMPLE_PROJECTS
JAX_FFI_EXAMPLE_CPU_PROJECTS
"rms_norm"
"cpu_examples"
)

foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
foreach(PROJECT ${JAX_FFI_EXAMPLE_CPU_PROJECTS})
nanobind_add_module("_${PROJECT}" NB_STATIC "src/jax_ffi_example/${PROJECT}.cc")
target_include_directories("_${PROJECT}" PUBLIC ${XLA_DIR})
install(TARGETS "_${PROJECT}" LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
endforeach()

if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)

add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu")
set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON
CUDA_STANDARD 17)
target_include_directories(_cuda_examples PUBLIC ${XLA_DIR})
install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})

nanobind_add_module(_gpu_examples NB_STATIC "src/jax_ffi_example/gpu_examples.cc")
target_include_directories(_gpu_examples PUBLIC ${XLA_DIR})
target_link_libraries(_gpu_examples PRIVATE CUDA::cudart)
install(TARGETS _gpu_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
endif()
62 changes: 62 additions & 0 deletions examples/ffi/src/jax_ffi_example/gpu_examples.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/* Copyright 2025 The JAX Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>
#include <memory>

#include "nanobind/nanobind.h"
#include "cuda_runtime_api.h"
#include "xla/ffi/api/ffi.h"

namespace nb = nanobind;
namespace ffi = xla::ffi;

struct State {
static xla::ffi::TypeId id;
explicit State(int32_t value) : value(value) {}
int32_t value;
};
ffi::TypeId State::id = {};

static ffi::ErrorOr<std::unique_ptr<State>> StateInstantiate() {
return std::make_unique<State>(42);
}

static ffi::Error StateExecute(cudaStream_t stream, State* state,
ffi::ResultBufferR0<ffi::S32> out) {
cudaMemcpyAsync(out->typed_data(), &state->value, sizeof(int32_t),
cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate,
ffi::Ffi::BindInstantiate());
XLA_FFI_DEFINE_HANDLER(kStateExecute, StateExecute,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>()
.Ctx<ffi::State<State>>()
.Ret<ffi::BufferR0<ffi::S32>>());

NB_MODULE(_gpu_examples, m) {
m.def("type_id",
[]() { return nb::capsule(reinterpret_cast<void*>(&State::id)); });
m.def("handler", []() {
nb::dict d;
d["instantiate"] = nb::capsule(reinterpret_cast<void*>(kStateInstantiate));
d["execute"] = nb::capsule(reinterpret_cast<void*>(kStateExecute));
return d;
});
}
24 changes: 24 additions & 0 deletions examples/ffi/src/jax_ffi_example/gpu_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
from jax_ffi_example import _gpu_examples
import jax.numpy as jnp

jax.ffi.register_ffi_target("state", _gpu_examples.handler(), platform="CUDA")
jax.ffi.register_ffi_type_id("state", _gpu_examples.type_id(), platform="CUDA")


def read_state():
return jax.ffi.ffi_call("state", jax.ShapeDtypeStruct((), jnp.int32))()
41 changes: 41 additions & 0 deletions examples/ffi/tests/gpu_examples_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
import jax
from jax._src import test_util as jtu

jax.config.parse_flags_with_absl()


class GpuExamplesTest(jtu.JaxTestCase):


def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cuda"]):
self.skipTest("Unsupported platform")

# Import here to avoid trying to load the library when it's not built.
from jax_ffi_example import gpu_examples # pylint: disable=g-import-not-at-top

self.read_state = gpu_examples.read_state

def test_basic(self):
self.assertEqual(self.read_state(), 42)
self.assertEqual(jax.jit(self.read_state)(), 42)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
Loading