Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG FIX][FEATURE] Rework backend & device selection logic #568

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions genesis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def fake_print(*args, **kwargs):
from .constants import backend as gs_backend
from .logging import Logger
from .version import __version__
from .utils import set_random_seed, get_platform, get_cpu_device, get_gpu_device
from .utils import set_random_seed, get_platform, get_device

_initialized = False
backend = None
Expand All @@ -43,7 +43,6 @@ def init(
theme="dark",
logger_verbose_time=False,
):

# genesis._initialized
global _initialized
if _initialized:
Expand Down Expand Up @@ -72,19 +71,16 @@ def init(

first_init = False

# get default device and compute total device memory
# genesis.backend
global platform
global device
platform = get_platform()
if backend == gs_backend.cpu:
device, device_name, total_mem = get_cpu_device()
else:
device, device_name, total_mem = get_gpu_device()

# genesis.backend
if backend not in GS_ARCH[platform]:
raise_exception(f"backend ~~<{backend}>~~ not supported for platform ~~<{platform}>~~")
backend = GS_ARCH[platform][backend]

# get default device and compute total device memory
global device
device, device_name, total_mem, backend = get_device(backend)

_globalize_backend(backend)

logger.info(
Expand Down
48 changes: 32 additions & 16 deletions genesis/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

import genesis as gs
from genesis.constants import backend as gs_backend


def raise_exception(msg="Something went wrong."):
Expand Down Expand Up @@ -96,32 +97,47 @@ def get_cpu_name():
return platform.processor()


def get_cpu_device():
device_name = get_cpu_name()
total_mem = psutil.virtual_memory().total / 1024**3
device = torch.device("cpu")
return device, device_name, total_mem
def get_device(backend: gs_backend):
if backend == gs_backend.cuda:
if not torch.cuda.is_available():
gs.raise_exception("cuda device not available")

device = torch.device("cuda")
device_property = torch.cuda.get_device_properties(0)
device_name = device_property.name
total_mem = device_property.total_memory / 1024**3

def get_gpu_device():
if get_platform() == "macOS":
elif backend == gs_backend.metal:
if not torch.backends.mps.is_available():
gs.raise_exception("metal device not available")

# on mac, cpu and gpu are in the same device
_, device_name, total_mem = get_cpu_device()
_, device_name, total_mem, _ = get_device(gs_backend.cpu)
device = torch.device("mps")

else:
if not torch.cuda.is_available():
gs.raise_exception("cuda device not available")
elif backend == gs_backend.vulkan:
if torch.xpu.is_available(): # pytorch 2.5+ Intel XPU device
device = torch.device("xpu")
device_property = torch.xpu.get_device_properties(0)
device_name = device_property.name
total_mem = device_property.total_memory / 1024**3
else: # pytorch tensors on cpu
device, device_name, total_mem, _ = get_device(gs_backend.cpu)

elif backend == gs_backend.gpu:
if torch.cuda.is_available():
return get_device(gs_backend.cuda)
elif get_platform() == "macOS":
return get_device(gs_backend.metal)
else:
return get_device(gs_backend.vulkan)

device = torch.device("cuda")
device_property = torch.cuda.get_device_properties(0)
device_name = device_property.name
total_mem = device_property.total_memory / 1024**3
else:
device_name = get_cpu_name()
total_mem = psutil.virtual_memory().total / 1024**3
device = torch.device("cpu")

return device, device_name, total_mem
return device, device_name, total_mem, backend


def get_src_dir():
Expand Down