Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update xla to use mlir rather than backend-specific-translations #314

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

pseudo-rnd-thoughts
Copy link

Description

Fixes #313

Motivation and Context

EnvPool XLA doesn't work with Jax 0.4.29+

Types of changes

What types of changes does your code introduce? Put an x in all the boxes that apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • New environment (non-breaking change which adds 3rd-party environment)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of example)

@pseudo-rnd-thoughts
Copy link
Author

@Trinkle23897 The linting error is unrelated to the PR

@pseudo-rnd-thoughts
Copy link
Author

Testing this throws errors however there appears to be limited documentation on what the suggested change should be

@JagerHoHo
Copy link

JagerHoHo commented Oct 13, 2024

After implementing the suggested changes, I encountered an error when running the XLA example with JAX 0.4.34(latest). The error reads:

TypeError: CustomCallWithLayout(): incompatible function arguments. The following argument types are supported:
    1. CustomCallWithLayout(builder: jaxlib.xla_extension.XlaBuilder, call_target_name: bytes, operands: Span[jaxlib.xla_extension.XlaOp], shape_with_layout: jaxlib.xla_extension.Shape, operand_shapes_with_layout: Span[jaxlib.xla_extension.Shape], opaque: bytes = b'', has_side_effect: bool = False, schedule: jaxlib.xla_extension.ops.CustomCallSchedule = CustomCallSchedule.SCHEDULE_NONE, api_version: jaxlib.xla_extension.ops.CustomCallApiVersion = CustomCallApiVersion.API_VERSION_ORIGINAL) -> jaxlib.xla_extension.XlaOp
Invoked with types: jax._src.interpreters.mlir.LoweringRuleContext, bytes, kwargs = { operands: tuple, operand_shapes_with_layout: tuple, shape_with_layout: jaxlib.xla_extension.Shape, opaque: bytes, has_side_effect: bool }

I installed envpool using pip install envpool and manually applied the changes as instructed. It seems there may be an issue with how the CustomCallWithLayout is invoked in the current context with the latest JAX.

def translation(c: Any, *args: Any, platform: str = "cpu") -> Any:
output_shape_with_layout = _shape_with_layout(out_specs)
if len(out_specs) == 1:
output_shape = output_shape_with_layout[0]
else:
output_shape = xla_client.Shape.tuple_shape(output_shape_with_layout)
return xla_client.ops.CustomCallWithLayout(
c,
f"{type(obj).__name__}_{id(obj)}_{name}_{platform}".encode(),
operands=args,
operand_shapes_with_layout=_shape_with_layout(in_specs),
shape_with_layout=output_shape,
opaque=handle,
has_side_effect=True,
)

Would appreciate any guidance on resolving this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] XLA is incompatible with jax 0.4.29
2 participants