-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add FFI example demonstrating the use of XLA's FFI state.
Support for this was added in JAX v0.5.0. PiperOrigin-RevId: 721785553
- Loading branch information
1 parent
c1e1360
commit fbb02e3
Showing
4 changed files
with
138 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* Copyright 2024 The JAX Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include <cstdint> | ||
#include <memory> | ||
|
||
#include "nanobind/nanobind.h" | ||
#include "third_party/gpus/cuda/include/driver_types.h" | ||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h" | ||
#include "xla/ffi/api/ffi.h" | ||
|
||
namespace nb = nanobind; | ||
namespace ffi = xla::ffi; | ||
|
||
struct State { | ||
static xla::ffi::TypeId id; | ||
explicit State(int32_t value) : value(value) {} | ||
int32_t value; | ||
}; | ||
ffi::TypeId State::id = {}; | ||
|
||
static ffi::ErrorOr<std::unique_ptr<State>> StateInstantiate() { | ||
return std::make_unique<State>(42); | ||
} | ||
|
||
static ffi::Error StateExecute(cudaStream_t stream, State* state, | ||
ffi::ResultBufferR0<ffi::S32> out) { | ||
cudaMemcpyAsync(out->typed_data(), &state->value, sizeof(int32_t), | ||
cudaMemcpyHostToDevice, stream); | ||
cudaStreamSynchronize(stream); | ||
return ffi::Error::Success(); | ||
} | ||
|
||
XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate, | ||
ffi::Ffi::BindInstantiate()); | ||
XLA_FFI_DEFINE_HANDLER(kStateExecute, StateExecute, | ||
ffi::Ffi::Bind() | ||
.Ctx<ffi::PlatformStream<cudaStream_t>>() | ||
.Ctx<ffi::State<State>>() | ||
.Ret<ffi::BufferR0<ffi::S32>>()); | ||
|
||
NB_MODULE(_gpu_state, m) { | ||
m.def("type_id", | ||
[]() { return nb::capsule(reinterpret_cast<void*>(&State::id)); }); | ||
m.def("handler", []() { | ||
nb::dict d; | ||
d["instantiate"] = nb::capsule(reinterpret_cast<void*>(kStateInstantiate)); | ||
d["execute"] = nb::capsule(reinterpret_cast<void*>(kStateExecute)); | ||
return d; | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
from jax_ffi_example import _gpu_examples | ||
|
||
jax.ffi.register_ffi_target("state", _gpu_examples.handler(), platform="CUDA") | ||
jax.ffi.register_ffi_type_id("state", _gpu_examples.type_id(), platform="CUDA") | ||
|
||
def read_state(): | ||
return jax.ffi.ffi_call( | ||
"state", jax.ShapeDtypeStruct((), jnp.int32), has_side_effect=True | ||
)() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright 2025 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from absl.testing import absltest | ||
import jax | ||
from jax._src import test_util as jtu | ||
|
||
jax.config.parse_flags_with_absl() | ||
|
||
|
||
class GpuExamplesTest(jtu.JaxTestCase): | ||
|
||
|
||
def setUp(self): | ||
super().setUp() | ||
if not jtu.test_device_matches(["cuda"]): | ||
self.skipTest("Unsupported platform") | ||
|
||
# Import here to avoid trying to load the library when it's not built. | ||
from jax_ffi_example import gpu_examples # pylint: disable=g-import-not-at-top | ||
|
||
self.read_state = gpu_examples.read_state | ||
|
||
def test_basic(self): | ||
self.assertEqual(self.read_state(), 42) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main(testLoader=jtu.JaxTestLoader()) |