Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prototype of kornia.set_backend #6

Open
wants to merge 1 commit into
base: set-backend
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kornia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

# Multi-framework support using ivy
from .transpiler import to_jax, to_numpy, to_tensorflow
from .transpiler import set_backend, to_jax, to_numpy, to_tensorflow

# NOTE: we are going to expose to top level very few things
from kornia.constants import pi
Expand Down
4 changes: 2 additions & 2 deletions kornia/transpiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .transpiler import to_jax, to_numpy, to_tensorflow
from .transpiler import set_backend, to_jax, to_numpy, to_tensorflow

__all__ = ["to_jax", "to_numpy", "to_tensorflow"]
__all__ = ["set_backend", "to_jax", "to_numpy", "to_tensorflow"]
39 changes: 39 additions & 0 deletions kornia/transpiler/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,42 @@ def to_tensorflow():
source="torch",
target="tensorflow",
)


def set_backend(backend: str = "torch") -> None:
"""Converts Kornia to the chosen backend framework inplace.

Transpiles the Kornia library to the chosen backend framework using [ivy](https://github.com/ivy-llc/ivy).
The transpilation process occurs lazily, so the transpilation on a given kornia function/class will only
occur when it's called or instantiated for the first time. This will make any functions/classes slow when
being used for the first time, but any subsequent uses should be as fast as expected.

Args:
backend (str, optional): The backend framework to transpile Kornia to.
Must be one of ["jax", "numpy", "tensorflow", "torch"].
Defaults to "torch".

Example:
>>> import kornia
>>> kornia.set_backend("tensorflow")
>>> import tensorflow as tf
>>> input = tf.random.normal((2, 3, 4, 5))
>>> gray = kornia.color.gray.rgb_to_grayscale(input)
"""
import sys

kornia_module = sys.modules["kornia"]
backend = backend.lower()

assert backend in ["jax", "numpy", "tensorflow", "torch"], 'Backend framework must be one of "jax", "numpy", "tensorflow", or "torch"'

ivy.transpile(
kornia_module,
source="torch",
target=backend,
inplace=True, # TODO: add this functionality to ivy
)

# TODO: unwrap and re-wrap the kornia module if, say, it's already converted to jax and the user wants to convert it to tensorflow
# TODO: ensure that torch -> torch works fine by returning the existing module
# TODO: ensure that framework -> torch works fine by unwrapping the existing module

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this point refer to exactly? for kornia, isn't source=torch always going to be the case?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is referring to the case that a user does something like:

import  kornia
kornia.set_backend("tensorflow")
kornia.set_backend("torch")

so in this case we'd need to just unwrap the module