Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"RuntimeError: Unknown backend iree" #77

Open
WoongQ opened this issue Aug 8, 2023 · 2 comments
Open

"RuntimeError: Unknown backend iree" #77

WoongQ opened this issue Aug 8, 2023 · 2 comments

Comments

@WoongQ
Copy link

WoongQ commented Aug 8, 2023

Hello, I'm trying to install iree-jax to test GPT-2 on IREE. After running python -m pip install -e '.[test,xla,cpu]' -f https://openxla.github.io/iree/pip-release-links.html, I built jaxlib from source. However, when I run lit -v tests/, I get a RuntimeError with the message "Unknown backend iree". This also happens when running models/gpt2/test_jax.py. Did I miss something during the setup process? Your help would be greatly appreciated. I have attached the error log below.

Using pure python filecheck: /home/woongq/jax/bin/filecheck
-- Testing: 5 tests, 5 workers --
FAIL: IREE_JAX :: program/trivial_kernel.py (1 of 5)
******************** TEST 'IREE_JAX :: program/trivial_kernel.py' FAILED ********************
Script:
--
: 'RUN: at line 15';   /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/trivial_kernel.py | /home/woongq/jax/bin/filecheck /home/woongq/iree-jax/tests/program/trivial_kernel.py
--
Exit Code: 2

Command Output (stdout):
--
$ ":" "RUN: at line 15"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/trivial_kernel.py"
# command stderr:
WARNING:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002429485321044922 sec
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:jax._src.dispatch:Finished tracing + transforming jit(broadcast_in_dim) in 0.0002300739288330078 sec
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
WARNING:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.001964092254638672 sec
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
WARNING:jax._src.dispatch:Finished XLA compilation of jit(broadcast_in_dim) in 0.012798309326171875 sec
WARNING:jax._src.dispatch:Finished tracing + transforming fn for pjit in 0.0004911422729492188 sec
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[3,4]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
WARNING:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(fn) in 0.0016129016876220703 sec
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
WARNING:jax._src.dispatch:Finished XLA compilation of jit(fn) in 0.0081329345703125 sec
DEBUG:iree_jax:Create new Program subclass: trivial_kernel
DEBUG:root:DEFINE PY_ONLY: _linear = <Exportable Pure Func: <function TrivialKernel._linear at 0x7f91ee93ce50>>
DEBUG:iree_jax:def_global_tree: array _params$0=(3, 4):dtype('float32')
DEBUG:iree_jax:def_global_tree: array _params$1=(3, 4):dtype('float32')
DEBUG:iree_jax:def_global_tree: new tree=Params(x=ConcreteArray(ExportedGlobalArray(@_params$0 : tensor<3x4xf32>), dtype=float32), b=ConcreteArray(ExportedGlobalArray(@_params$1 : tensor<3x4xf32>), dtype=float32))
DEBUG:iree_jax:def_global_tree: array _x$0=(3, 4):dtype('float32')
DEBUG:iree_jax:def_global_tree: new tree=ExportedGlobalArray(@_params$0 : tensor<3x4xf32>)
Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 61, in <module>
    m = TrivialKernel()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 48, in run
    result = self._linear(multiplier, self._params.x, self._params.b)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
    donate_argnums) = infer_params_fn(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
    in_shardings = out_shardings = _create_sharding_with_device_backend(
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
    xb.get_backend(backend).get_default_device_assignment(1)[0])
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
    return _get_backend_uncached(platform)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
    raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 61, in <module>
    m = TrivialKernel()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 48, in run
    result = self._linear(multiplier, self._params.x, self._params.b)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree
nanobind: leaked 66 instances!
nanobind: leaked 16 types!
 - leaked type "iree._runtime.VmVariantList"
 - leaked type "iree._runtime.HalBufferView"
 - leaked type "iree._runtime.BufferUsage"
 - leaked type "iree._runtime.VmContext"
 - leaked type "iree._runtime.MappedMemory"
 - leaked type "iree._runtime.ArgumentPacker"
 - leaked type "iree._runtime.HalElementType"
 - leaked type "iree._runtime.VmRef"
 - leaked type "iree._runtime.VmModule"
 - leaked type "iree._runtime.HalDevice"
 - leaked type "iree._runtime._InvokeStatics"
 - ... skipped remainder
nanobind: leaked 78 functions!
 - leaked function ""
 - leaked function "lookup_function"
 - leaked function "__eq__"
 - leaked function ""
 - leaked function "__iree_vm_type__"
 - leaked function "__or__"
 - leaked function "__init__"
 - leaked function "create_device_by_uri"
 - leaked function ""
 - leaked function "invoke"
 - leaked function "__init__"
 - ... skipped remainder
nanobind: this is likely caused by a reference counting issue in the binding code.

error: command failed with exit status: 1
$ "/home/woongq/jax/bin/filecheck" "/home/woongq/iree-jax/tests/program/trivial_kernel.py"
# command output:
CHECK: FileCheck error: '-' is empty.
FileCheck command line: /home/woongq/iree-jax/tests/program/trivial_kernel.py

error: command failed with exit status: 2

--

********************
FAIL: IREE_JAX :: program/fft.py (2 of 5)
******************** TEST 'IREE_JAX :: program/fft.py' FAILED ********************
Script:
--
: 'RUN: at line 15';   /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/fft.py | /home/woongq/jax/bin/filecheck /home/woongq/iree-jax/tests/program/fft.py
--
Exit Code: 2

Command Output (stdout):
--
$ ":" "RUN: at line 15"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/fft.py"
# command stderr:
DEBUG:iree_jax:Create new Program subclass: f_f_t
DEBUG:root:DEFINE PY_ONLY: _fft = <Exportable Pure Func: <function FFT._fft at 0x7f92544a2290>>
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/fft.py", line 41, in <module>
    m = FFT()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/fft.py", line 33, in fft
    return self._fft(x)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
    donate_argnums) = infer_params_fn(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
    in_shardings = out_shardings = _create_sharding_with_device_backend(
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
    xb.get_backend(backend).get_default_device_assignment(1)[0])
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
    return _get_backend_uncached(platform)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
    raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/fft.py", line 41, in <module>
    m = FFT()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/fft.py", line 33, in fft
    return self._fft(x)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree

error: command failed with exit status: 1
$ "/home/woongq/jax/bin/filecheck" "/home/woongq/iree-jax/tests/program/fft.py"
# command output:
CHECK: FileCheck error: '-' is empty.
FileCheck command line: /home/woongq/iree-jax/tests/program/fft.py

error: command failed with exit status: 2

--

********************
PASS: IREE_JAX :: program/trivial_globals.py (3 of 5)
FAIL: IREE_JAX :: program/duplicate_helper.py (4 of 5)
******************** TEST 'IREE_JAX :: program/duplicate_helper.py' FAILED ********************
Script:
--
: 'RUN: at line 1';   /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/duplicate_helper.py
--
Exit Code: 1

Command Output (stdout):
--
$ ":" "RUN: at line 1"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/duplicate_helper.py"
# command stderr:
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 67, in <module>
    print(str(Program.get_mlir_module(module)))
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 377, in get_mlir_module
    info = Program.get_info(Program._get_instance(m))
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 372, in _get_instance
    m = m()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 50, in encode
    return mdl._encode(x, y)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
    donate_argnums) = infer_params_fn(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
    in_shardings = out_shardings = _create_sharding_with_device_backend(
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
    xb.get_backend(backend).get_default_device_assignment(1)[0])
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
    return _get_backend_uncached(platform)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
    raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 67, in <module>
    print(str(Program.get_mlir_module(module)))
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 377, in get_mlir_module
    info = Program.get_info(Program._get_instance(m))
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 372, in _get_instance
    m = m()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 50, in encode
    return mdl._encode(x, y)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree

error: command failed with exit status: 1

--

********************
FAIL: IREE_JAX :: program/program_api_test.py (5 of 5)
******************** TEST 'IREE_JAX :: program/program_api_test.py' FAILED ********************
Script:
--
: 'RUN: at line 1';   /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/program_api_test.py
--
Exit Code: 1

Command Output (stdout):
--
$ ":" "RUN: at line 1"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/program_api_test.py"
# command stderr:
.DEBUG:iree_jax:Create new Program subclass: hidden
.DEBUG:iree_jax:Create new Program subclass: nullary
DEBUG:iree_jax:Create new Program subclass: unary
.DEBUG:iree_jax:Create new Program subclass: Foobar
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: global
.DEBUG:iree_jax:Create new Program subclass: my_subclass
./home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py:288: DeprecationWarning: backend and device argument on jit is deprecated. You can use a `jax.sharding.Mesh` context manager or device_put the arguments before passing them to `jit`. Please see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html for more information.
  warnings.warn(
DEBUG:iree_jax:Create new Program subclass: iree_jax
DEBUG:root:DEFINE PY_ONLY: _f = <Exportable Pure Func: <function ProgramApiTest.test_value_tracing_with_flax_frozen_dict.<locals>.IreeJaxProgram._f at 0x7f673b4e7760>>
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
EDEBUG:iree_jax:Create new Program subclass: iree_jax
DEBUG:root:DEFINE PY_ONLY: _f = <Exportable Pure Func: <function ProgramApiTest.test_value_tracing_with_list.<locals>.IreeJaxProgram._f at 0x7f673b5384c0>>
E
======================================================================
ERROR: test_value_tracing_with_flax_frozen_dict (__main__.ProgramApiTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 163, in <module>
    unittest.main()
  File "/usr/lib/python3.10/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/usr/lib/python3.10/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/usr/lib/python3.10/unittest/runner.py", line 184, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/case.py", line 650, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 145, in test_value_tracing_with_flax_frozen_dict
    IreeJaxProgram()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 139, in f
    return self._f(x)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
    donate_argnums) = infer_params_fn(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
    in_shardings = out_shardings = _create_sharding_with_device_backend(
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
    xb.get_backend(backend).get_default_device_assignment(1)[0])
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
    return _get_backend_uncached(platform)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
    raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 145, in test_value_tracing_with_flax_frozen_dict
    IreeJaxProgram()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 139, in f
    return self._f(x)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree

======================================================================
ERROR: test_value_tracing_with_list (__main__.ProgramApiTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 163, in <module>
    unittest.main()
  File "/usr/lib/python3.10/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/usr/lib/python3.10/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/usr/lib/python3.10/unittest/runner.py", line 184, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/case.py", line 650, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 159, in test_value_tracing_with_list
    IreeJaxProgram()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 153, in f
    return self._f(x)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
    donate_argnums) = infer_params_fn(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
    in_shardings = out_shardings = _create_sharding_with_device_backend(
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
    xb.get_backend(backend).get_default_device_assignment(1)[0])
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
    return _get_backend_uncached(platform)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
    raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 159, in test_value_tracing_with_list
    IreeJaxProgram()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 153, in f
    return self._f(x)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree

----------------------------------------------------------------------
Ran 12 tests in 0.035s

FAILED (errors=2)

error: command failed with exit status: 1

--

********************
********************
Failed Tests (4):
  IREE_JAX :: program/duplicate_helper.py
  IREE_JAX :: program/fft.py
  IREE_JAX :: program/program_api_test.py
  IREE_JAX :: program/trivial_kernel.py


Testing Time: 0.73s
  Passed: 1
  Failed: 4
@okkwon
Copy link
Member

okkwon commented Aug 8, 2023

https://github.com/openxla/openxla-pjrt-plugin is the right way to use JAX+IREE.

@ScottTodd
Copy link
Member

https://github.com/openxla/openxla-pjrt-plugin is the right way to use JAX+IREE.

The PJRT plugin is one way to use JAX+IREE, mostly for JIT scenarios from Python. This repository is another way, with a focus on AOT scenarios outside of Python. See https://openxla.github.io/iree/guides/ml-frameworks/jax/

Did I miss something during the setup process?

Possibly. You can see what https://github.com/iree-org/iree-jax/blob/main/.github/workflows/test_gpt2_model.yaml is doing... that runs nightly at https://github.com/iree-org/iree-jax/actions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants