From 93791f9eacd1caa2443fce54e4be89bb4ff4e9ce Mon Sep 17 00:00:00 2001 From: chenli Date: Thu, 7 Mar 2024 02:14:15 -0800 Subject: [PATCH] Enable Apple Silicon GPU Acceleration for d2l --- d2l/tensorflow.py | 16 +++++++++++++++- d2l/torch.py | 17 +++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/d2l/tensorflow.py b/d2l/tensorflow.py index fd9ca23fda..cc4ee915cf 100644 --- a/d2l/tensorflow.py +++ b/d2l/tensorflow.py @@ -3,6 +3,7 @@ import numpy as np import tensorflow as tf +import platform nn_Module = tf.keras.Model @@ -479,11 +480,24 @@ def cpu(): Defined in :numref:`sec_use_gpu`""" return tf.device('/CPU:0') + +def is_apple_silicon(): + """Check if the code is running on Apple devices + with Apple Silicon. + + Return True only if the code is running on Apple Silicon devices""" + return 'Mac' in platform.uname().node and 'arm' in platform.uname().machine + def gpu(i=0): """Get a GPU device. Defined in :numref:`sec_use_gpu`""" - return tf.device(f'/GPU:{i}') + if is_apple_silicon(): + if num_gpus() == 0: + raise RuntimeError('Install TensorFlow-Metal!') + return tf.device(f'/physical_device:GPU:{i}') + else: + return tf.device(f'/GPU:{i}') def num_gpus(): """Get the number of available GPUs. diff --git a/d2l/torch.py b/d2l/torch.py index 84ce7da901..f953515b2e 100644 --- a/d2l/torch.py +++ b/d2l/torch.py @@ -517,17 +517,30 @@ def cpu(): Defined in :numref:`sec_use_gpu`""" return torch.device('cpu') +def is_apple_silicon(): + """Check if the code is running on Apple devices + with Apple Silicon. + + Return True only if the code is running on Apple Silicon devices""" + return torch.backends.mps.is_available() + def gpu(i=0): """Get a GPU device. Defined in :numref:`sec_use_gpu`""" - return torch.device(f'cuda:{i}') + if is_apple_silicon(): + return torch.device('mps') + else: + return torch.device(f'cuda:{i}') def num_gpus(): """Get the number of available GPUs. Defined in :numref:`sec_use_gpu`""" - return torch.cuda.device_count() + if is_apple_silicon(): + return 1 + else: + return torch.cuda.device_count() def try_gpu(i=0): """Return gpu(i) if exists, otherwise return cpu().