From b8c47bca8418d863a5a9edcce35c962b764933cc Mon Sep 17 00:00:00 2001 From: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com> Date: Tue, 29 Oct 2024 21:10:31 +0000 Subject: [PATCH] create prototype of kornia.set_backend --- kornia/__init__.py | 2 +- kornia/transpiler/__init__.py | 4 ++-- kornia/transpiler/transpiler.py | 39 +++++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/kornia/__init__.py b/kornia/__init__.py index 15c6c8a78c..001e49a9a4 100644 --- a/kornia/__init__.py +++ b/kornia/__init__.py @@ -25,7 +25,7 @@ ) # Multi-framework support using ivy -from .transpiler import to_jax, to_numpy, to_tensorflow +from .transpiler import set_backend, to_jax, to_numpy, to_tensorflow # NOTE: we are going to expose to top level very few things from kornia.constants import pi diff --git a/kornia/transpiler/__init__.py b/kornia/transpiler/__init__.py index 5344ce6480..55b0f06553 100644 --- a/kornia/transpiler/__init__.py +++ b/kornia/transpiler/__init__.py @@ -1,3 +1,3 @@ -from .transpiler import to_jax, to_numpy, to_tensorflow +from .transpiler import set_backend, to_jax, to_numpy, to_tensorflow -__all__ = ["to_jax", "to_numpy", "to_tensorflow"] +__all__ = ["set_backend", "to_jax", "to_numpy", "to_tensorflow"] diff --git a/kornia/transpiler/transpiler.py b/kornia/transpiler/transpiler.py index 4d002140a2..27897ca197 100644 --- a/kornia/transpiler/transpiler.py +++ b/kornia/transpiler/transpiler.py @@ -80,3 +80,42 @@ def to_tensorflow(): source="torch", target="tensorflow", ) + + +def set_backend(backend: str = "torch") -> None: + """Converts Kornia to the chosen backend framework inplace. + + Transpiles the Kornia library to the chosen backend framework using [ivy](https://github.com/ivy-llc/ivy). + The transpilation process occurs lazily, so the transpilation on a given kornia function/class will only + occur when it's called or instantiated for the first time. This will make any functions/classes slow when + being used for the first time, but any subsequent uses should be as fast as expected. + + Args: + backend (str, optional): The backend framework to transpile Kornia to. + Must be one of ["jax", "numpy", "tensorflow", "torch"]. + Defaults to "torch". + + Example: + >>> import kornia + >>> kornia.set_backend("tensorflow") + >>> import tensorflow as tf + >>> input = tf.random.normal((2, 3, 4, 5)) + >>> gray = kornia.color.gray.rgb_to_grayscale(input) + """ + import sys + + kornia_module = sys.modules["kornia"] + backend = backend.lower() + + assert backend in ["jax", "numpy", "tensorflow", "torch"], 'Backend framework must be one of "jax", "numpy", "tensorflow", or "torch"' + + ivy.transpile( + kornia_module, + source="torch", + target=backend, + inplace=True, # TODO: add this functionality to ivy + ) + + # TODO: unwrap and re-wrap the kornia module if, say, it's already converted to jax and the user wants to convert it to tensorflow + # TODO: ensure that torch -> torch works fine by returning the existing module + # TODO: ensure that framework -> torch works fine by unwrapping the existing module