Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MPS on MacOS with ARM chips #29

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions image_demo.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions spiga/demo/analyze/extract/spiga_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ class SPIGAProcessor(pr.Processor):
def __init__(self,
dataset='wflw',
features=('lnd', 'pose'),
gpus=[0]):
gpus=[0],
device='cuda'):

super().__init__()

# Configure and load processor
self.processor_cfg = model_cfg.ModelConfig(dataset)
self.processor = SPIGAFramework(self.processor_cfg, gpus=gpus)
self.processor = SPIGAFramework(self.processor_cfg, gpus=gpus, device=device)

# Define attributes
if 'lnd' in features:
Expand Down
4 changes: 2 additions & 2 deletions spiga/demo/analyze/track/get_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
zoos = [zoo_rs]


def get_tracker(model_name):
def get_tracker(model_name, device='cuda'):
for zoo in zoos:
model = zoo.get_tracker(model_name)
model = zoo.get_tracker(model_name, device=device)
if model is not None:
return model

Expand Down
6 changes: 4 additions & 2 deletions spiga/demo/analyze/track/retinasort/face_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@

class RetinaSortTracker(tracker.Tracker):

def __init__(self, config=cfg.cfg_retinasort):
def __init__(self, config=cfg.cfg_retinasort, device='cuda'):
super().__init__()

self.detector = retinaface.RetinaFaceDetector(model=config['retina']['model_name'],
device=device,
extra_features=config['retina']['extra_features'],
cfg_postreat=config['retina']['postreat'])
cfg_postreat=config['retina']['postreat']
)

self.associator = sort_tracker.Sort(max_age=config['sort']['max_age'],
min_hits=config['sort']['min_hits'],
Expand Down
10 changes: 5 additions & 5 deletions spiga/demo/analyze/track/retinasort/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
import spiga.demo.analyze.track.retinasort.config as cfg_tr


def get_tracker(model_name):
def get_tracker(model_name, device='cuda'):

# MobileNet Backbone
if model_name == 'RetinaSort':
return tr.RetinaSortTracker()
return tr.RetinaSortTracker(device=device)
# ResNet50 Backbone
if model_name == 'RetinaSort_Res50':
return tr.RetinaSortTracker(cfg_tr.cfg_retinasort_res50)
return tr.RetinaSortTracker(cfg_tr.cfg_retinasort_res50, device=device)
# Config CAV3D: https://ict.fbk.eu/units/speechtek/cav3d/
if model_name == 'RetinaSort_cav3d':
return tr.RetinaSortTracker(cfg_tr.cfg_retinasort_cav3d)
return tr.RetinaSortTracker(cfg_tr.cfg_retinasort_cav3d, device=device)
# Config AV16: https://ict.fbk.eu/units/speechtek/cav3d/
if model_name == 'RetinaSort_av16':
return tr.RetinaSortTracker(cfg_tr.cfg_retinasort_av16)
return tr.RetinaSortTracker(cfg_tr.cfg_retinasort_av16, device=device)

return None

Expand Down
38 changes: 29 additions & 9 deletions spiga/demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def main():
pars.add_argument('--outpath', type=str, default=video_out_path_dft, help='Video output directory')
pars.add_argument('--fps', type=int, default=30, help='Frames per second')
pars.add_argument('--shape', nargs='+', type=int, help='Visualizer shape (W,H)')
pars.add_argument('--device', type=str, default='cuda', help='torch device to use, eg: cpu, cuda, cuda:0, mps, mps:0')
args = pars.parse_args()

if args.shape:
Expand All @@ -44,13 +45,32 @@ def main():
if not args.noview and not args.save:
raise ValueError('No results will be saved neither shown')

video_app(args.input, spiga_dataset=args.dataset, tracker=args.tracker, fps=args.fps,
save=args.save, output_path=args.outpath, video_shape=video_shape, visualize=args.noview, plot=args.show)


def video_app(input_name, spiga_dataset=None, tracker=None, fps=30, save=False,
output_path=video_out_path_dft, video_shape=None, visualize=True, plot=()):

video_app(
args.input,
spiga_dataset=args.dataset,
tracker=args.tracker,
fps=args.fps,
save=args.save,
output_path=args.outpath,
video_shape=video_shape,
visualize=args.noview,
plot=args.show,
device=args.device
)


def video_app(
input_name,
spiga_dataset=None,
tracker=None,
fps=30,
save=False,
output_path=video_out_path_dft,
video_shape=None,
visualize=True,
plot=(),
device='cuda'
):
# Load video
try:
capture = cv2.VideoCapture(int(input_name))
Expand All @@ -77,10 +97,10 @@ def video_app(input_name, spiga_dataset=None, tracker=None, fps=30, save=False,
viewer.record_video(output_path, video_name)

# Initialize face tracker
faces_tracker = tr.get_tracker(tracker)
faces_tracker = tr.get_tracker(tracker, device=device)
faces_tracker.detector.set_input_shape(capture.get(4), capture.get(3))
# Initialize processors
processor = pr_spiga.SPIGAProcessor(dataset=spiga_dataset)
processor = pr_spiga.SPIGAProcessor(dataset=spiga_dataset, device=device)
# Initialize Analyzer
faces_analyzer = VideoAnalyzer(faces_tracker, processor=processor)

Expand Down
30 changes: 27 additions & 3 deletions spiga/inference/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,29 @@

class SPIGAFramework:

def __init__(self, model_cfg: ModelConfig(), gpus=[0], load3DM=True):
def __init__(self, model_cfg: ModelConfig(), gpus=[0], load3DM=True, device='cuda'):

# Parameters
self.model_cfg = model_cfg
self.gpus = gpus

self.device = torch.device('cpu')

if 'cuda' in device or device=='gpu':
if torch.cuda.is_available():
self.device = torch.device('cuda:{}'.format(gpus[0]))
print('Using CUDA')
else:
print('CUDA is not available, will use CPU')
elif device=='mps':
if torch.mps.is_available():
self.device = torch.device('mps')
print('Using MPS')
else:
print('MPS is not available, will use CPU')
else:
print('Using CPU')

# Pretreatment initialization
self.transforms = pretreat.get_transformers(self.model_cfg)

Expand All @@ -42,7 +59,12 @@ def __init__(self, model_cfg: ModelConfig(), gpus=[0], load3DM=True):
model_state_dict = torch.load(weights_file)

self.model.load_state_dict(model_state_dict)
self.model = self.model.cuda(gpus[0])

# self.model = self.model.cuda(gpus[0])

if self.device != torch.device('cpu'):
self.model = self.model.to(self.device)

self.model.eval()
print('SPIGA model loaded!')

Expand Down Expand Up @@ -133,5 +155,7 @@ def _data2device(self, data):
data[k] = self._data2device(v)
else:
with torch.no_grad():
data_var = data.cuda(device=self.gpus[0], non_blocking=True)
# data_var = data.cuda(device=self.gpus[0], non_blocking=True)
data_var = data.to(self.device, non_blocking=True)

return data_var
15 changes: 9 additions & 6 deletions spiga/models/gnn/pose_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ def euler_to_rotation_matrix(euler):
sr = torch.sin(rad[:, 2])

# Init R matrix tensors
working_device = None
if euler.is_cuda:
working_device = euler.device
working_device = euler.device
# working_device = None
# if euler.is_cuda:
# working_device = euler.device
Ry = torch.zeros((euler.shape[0], 3, 3), device=working_device)
Rp = torch.zeros((euler.shape[0], 3, 3), device=working_device)
Rr = torch.zeros((euler.shape[0], 3, 3), device=working_device)
Expand Down Expand Up @@ -54,9 +55,11 @@ def euler_to_rotation_matrix(euler):
def projectPoints(pts, rot, trl, cam_matrix):

# Get working device
working_device = None
if pts.is_cuda:
working_device = pts.device
working_device = pts.device

# working_device = None
# if pts.is_cuda:
# working_device = pts.device

# Perspective projection model
trl = trl.unsqueeze(2)
Expand Down
3 changes: 2 additions & 1 deletion spiga/models/spiga.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def extract_visual_embedded(self, pts_proj, receptive_field, step):
grid = grid.reshape(B, L, self.kwindow * self.kwindow, 2)

# Crop windows
crops = torch.nn.functional.grid_sample(receptive_field, grid, padding_mode="border") # BxCxLxK*K
padding_mode = "reflection" if torch.mps.is_available() else "border"
crops = torch.nn.functional.grid_sample(receptive_field, grid, padding_mode=padding_mode) # BxCxLxK*K
crops = crops.transpose(1, 2) # BxLxCxK*K
crops = crops.reshape(B * L, C, self.kwindow, self.kwindow)

Expand Down