diff --git a/genesis/__init__.py b/genesis/__init__.py index 37a81f1e..6b21d0de 100644 --- a/genesis/__init__.py +++ b/genesis/__init__.py @@ -23,7 +23,7 @@ def fake_print(*args, **kwargs): from .constants import backend as gs_backend from .logging import Logger from .version import __version__ -from .utils import set_random_seed, get_platform, get_cpu_device, get_gpu_device +from .utils import set_random_seed, get_platform, get_device _initialized = False backend = None @@ -43,7 +43,6 @@ def init( theme="dark", logger_verbose_time=False, ): - # genesis._initialized global _initialized if _initialized: @@ -72,19 +71,16 @@ def init( first_init = False - # get default device and compute total device memory + # genesis.backend global platform - global device platform = get_platform() - if backend == gs_backend.cpu: - device, device_name, total_mem = get_cpu_device() - else: - device, device_name, total_mem = get_gpu_device() - - # genesis.backend if backend not in GS_ARCH[platform]: raise_exception(f"backend ~~<{backend}>~~ not supported for platform ~~<{platform}>~~") - backend = GS_ARCH[platform][backend] + + # get default device and compute total device memory + global device + device, device_name, total_mem, backend = get_device(backend) + _globalize_backend(backend) logger.info( diff --git a/genesis/utils/misc.py b/genesis/utils/misc.py index 7ec3843f..1c0082f3 100644 --- a/genesis/utils/misc.py +++ b/genesis/utils/misc.py @@ -11,6 +11,7 @@ import torch import genesis as gs +from genesis.constants import backend as gs_backend def raise_exception(msg="Something went wrong."): @@ -96,32 +97,47 @@ def get_cpu_name(): return platform.processor() -def get_cpu_device(): - device_name = get_cpu_name() - total_mem = psutil.virtual_memory().total / 1024**3 - device = torch.device("cpu") - return device, device_name, total_mem +def get_device(backend: gs_backend): + if backend == gs_backend.cuda: + if not torch.cuda.is_available(): + gs.raise_exception("cuda device not available") + device = torch.device("cuda") + device_property = torch.cuda.get_device_properties(0) + device_name = device_property.name + total_mem = device_property.total_memory / 1024**3 -def get_gpu_device(): - if get_platform() == "macOS": + elif backend == gs_backend.metal: if not torch.backends.mps.is_available(): gs.raise_exception("metal device not available") # on mac, cpu and gpu are in the same device - _, device_name, total_mem = get_cpu_device() + _, device_name, total_mem, _ = get_device(gs_backend.cpu) device = torch.device("mps") - else: - if not torch.cuda.is_available(): - gs.raise_exception("cuda device not available") + elif backend == gs_backend.vulkan: + if torch.xpu.is_available(): # pytorch 2.5+ Intel XPU device + device = torch.device("xpu") + device_property = torch.xpu.get_device_properties(0) + device_name = device_property.name + total_mem = device_property.total_memory / 1024**3 + else: # pytorch tensors on cpu + device, device_name, total_mem, _ = get_device(gs_backend.cpu) + + elif backend == gs_backend.gpu: + if torch.cuda.is_available(): + return get_device(gs_backend.cuda) + elif get_platform() == "macOS": + return get_device(gs_backend.metal) + else: + return get_device(gs_backend.vulkan) - device = torch.device("cuda") - device_property = torch.cuda.get_device_properties(0) - device_name = device_property.name - total_mem = device_property.total_memory / 1024**3 + else: + device_name = get_cpu_name() + total_mem = psutil.virtual_memory().total / 1024**3 + device = torch.device("cpu") - return device, device_name, total_mem + return device, device_name, total_mem, backend def get_src_dir():