From 78ad9c97dab600a2a378a66bde43b6440182a6ab Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 26 Nov 2024 13:27:23 -0500 Subject: [PATCH] docs: update exporting_to_jax.md (#1107) --- docs/src/manual/exporting_to_jax.md | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/docs/src/manual/exporting_to_jax.md b/docs/src/manual/exporting_to_jax.md index 50014b99e..dc23905a1 100644 --- a/docs/src/manual/exporting_to_jax.md +++ b/docs/src/manual/exporting_to_jax.md @@ -59,7 +59,7 @@ end Now we define a python script to run the model using EnzymeJAX. ```python -from enzyme_ad.jax import primitives +from enzyme_ad.jax import hlo_call import jax import jax.numpy as jnp @@ -81,7 +81,7 @@ def run_lux_model( weight6_3, bias6_3, ): - return primitives.ffi_call( + return hlo_call( x, weight1, bias1, @@ -93,13 +93,7 @@ def run_lux_model( bias6_2, weight6_3, bias6_3, - out_shapes=[ - jax.core.ShapedArray([4, 10], jnp.float32), - ], - fn="main", source=code, - lang=primitives.LANG_MHLO, - pipeline_options=primitives.JaXPipeline(""), )