From 83bf36547ec2a069c9707b04bde483668ac33842 Mon Sep 17 00:00:00 2001 From: Justin Date: Tue, 7 Jan 2025 20:02:21 -0600 Subject: [PATCH] cv2.imshow is replaced with a pygame window --- lerobot/common/robot_devices/control_utils.py | 58 +++++++++++++++++-- pyproject.toml | 1 + 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 8cc0f3260..afa8aaeaf 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -24,7 +24,20 @@ from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed from lerobot.scripts.eval import get_pretrained_policy_path +import platform +import os +import pygame +import numpy as np +def init_pygame_display(): + """Initialize pygame display for showing camera feeds.""" + pygame.init() + pygame.font.init() # Initialize the font module + return None # Will create windows as needed + +def cleanup_pygame(): + """Cleanup pygame resources.""" + pygame.quit() def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): log_items = [] @@ -255,6 +268,9 @@ def control_loop( timestamp = 0 start_episode_t = time.perf_counter() + init_pygame_display() + main_window = None + while timestamp < control_time_s: start_loop_t = time.perf_counter() @@ -276,9 +292,43 @@ def control_loop( if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] + + # Calculate total width and maximum height for the window + total_width = sum(observation[key].shape[1] for key in image_keys) + (len(image_keys) - 1) * 10 + max_height = max(observation[key].shape[0] for key in image_keys) + + # Create or get main window + if main_window is None: + main_window = pygame.display.set_mode((total_width, max_height + 30)) # Add 30 pixels for text + pygame.display.set_caption("Camera Feeds (Press 'q' to quit)") + + # Clear the window + main_window.fill((0, 0, 0)) + + # Draw images side by side + x_offset = 0 for key in image_keys: - cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) - cv2.waitKey(1) + img = observation[key].numpy() + img_rgb = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + surface = pygame.surfarray.make_surface(np.transpose(img_rgb, (1, 0, 2))) + main_window.blit(surface, (x_offset, 0)) + # Draw camera name above each image + font = pygame.font.Font(None, 24) + text = font.render(key, True, (255, 255, 255)) + main_window.blit(text, (x_offset, max_height + 5)) + x_offset += img.shape[1] + 10 # Add 10 pixels gap + + pygame.display.flip() + + # Handle pygame events + for event in pygame.event.get(): + if event.type == pygame.QUIT: + cleanup_pygame() + events["exit_early"] = True + elif event.type == pygame.KEYDOWN: + if event.key == pygame.K_q: + cleanup_pygame() + events["exit_early"] = True if fps is not None: dt_s = time.perf_counter() - start_loop_t @@ -314,14 +364,14 @@ def reset_environment(robot, events, reset_time_s): def stop_recording(robot, listener, display_cameras): + """Stop recording and cleanup resources.""" robot.disconnect() if not is_headless(): if listener is not None: listener.stop() - if display_cameras: - cv2.destroyAllWindows() + cleanup_pygame() def sanity_check_dataset_name(repo_id, policy): diff --git a/pyproject.toml b/pyproject.toml index 59c2de8bc..44973be1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ zarr = ">=2.17.0" numba = ">=0.59.0" torch = ">=2.2.1" opencv-python = ">=4.9.0" +pygame = ">=2.5.2" diffusers = ">=0.27.2" torchvision = ">=0.17.1" h5py = ">=3.10.0"