diff --git a/lerobot/common/robot_devices/control_context.py b/lerobot/common/robot_devices/control_context.py new file mode 100644 index 000000000..acc6062e5 --- /dev/null +++ b/lerobot/common/robot_devices/control_context.py @@ -0,0 +1,299 @@ +import base64 +import time +from dataclasses import dataclass +from typing import Dict, Optional + +import cv2 +import numpy as np +import torch +import zmq + +from lerobot.common.robot_devices.control_utils import log_control_info, serialize_log_items +from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.robot_devices.utils import busy_wait + + +class ControlPhase: + TELEOPERATE = "Teleoperate" + WARMUP = "Warmup" + RECORD = "Record" + RESET = "Reset" + SAVING = "Saving" + PROCESSING_DATASET = "Processing Dataset" + UPLOADING_DATASET_TO_HUB = "Uploading Dataset to Hub" + RECORDING_COMPLETE = "Recording Complete" + + +@dataclass +class ControlContextConfig: + assign_rewards: bool = False + control_phase: str = ControlPhase.TELEOPERATE + num_episodes: int = 0 + robot: Robot = None + fps: Optional[int] = None + + +class ControlContext: + def __init__(self, config: ControlContextConfig): + self.config = config + self.modes_with_no_observation = [ + ControlPhase.RESET, + ControlPhase.SAVING, + ControlPhase.PROCESSING_DATASET, + ControlPhase.UPLOADING_DATASET_TO_HUB, + ControlPhase.RECORDING_COMPLETE, + ] + self.last_observation = None + self._initialize_communication() + self._initialize_state() + + def _initialize_state(self): + self.events = { + "exit_early": False, + "rerecord_episode": False, + "stop_recording": False, + "next_reward": 0, + } + + if self.config.assign_rewards: + self.events["next_reward"] = 0 + + self.current_episode_index = 0 + + # Define the control instructions + self.controls = [ + ("Right Arrow", "Exit Early"), + ("Left Arrow", "Rerecord"), + ("Escape", "Stop"), + ("Space", "Toggle Reward"), + ] + + def _initialize_communication(self): + self.zmq_context = zmq.Context() + self.publisher_socket = self.zmq_context.socket(zmq.PUB) + self.publisher_socket.bind("tcp://127.0.0.1:5555") + + self.command_sub_socket = self.zmq_context.socket(zmq.SUB) + self.command_sub_socket.connect("tcp://127.0.0.1:5556") + self.command_sub_socket.setsockopt_string(zmq.SUBSCRIBE, "") + + def _handle_browser_events(self): + try: + # Set a non-blocking polls + if self.command_sub_socket.poll(timeout=0): # Check if there's a message + msg = self.command_sub_socket.recv_json() + + if msg.get("type") == "command" and msg.get("command") == "keydown": + key_pressed = msg.get("key_pressed") + + if key_pressed == "ArrowRight": + print("Received 'ArrowRight' from browser -> Exit Early") + self.events["exit_early"] = True + elif key_pressed == "ArrowLeft": + print("Received 'ArrowLeft' from browser -> Rerecord Episode") + self.events["rerecord_episode"] = True + self.events["exit_early"] = True + elif key_pressed == "Escape": + print("Received 'Escape' from browser -> Stop") + self.events["stop_recording"] = True + self.events["exit_early"] = True + elif key_pressed == "Space": + # Toggle "next_reward" + self.events["next_reward"] = 1 if self.events["next_reward"] == 0 else 0 + print(f"Space toggled reward to {self.events['next_reward']}") + else: + # No message available, continue + pass + + except zmq.Again: + # No message received within timeout + pass + except Exception as e: + print(f"Error while polling for commands: {e}") + + def update_config(self, config: ControlContextConfig): + """Update configuration and reinitialize UI components as needed""" + self.config = config + + # Update ZMQ message with new config + self._publish_config_update() + + return self + + def _publish_config_update(self): + """Publish configuration update to ZMQ subscribers""" + config_data = { + "assign_rewards": self.config.assign_rewards, + "control_phase": self.config.control_phase, + "num_episodes": self.config.num_episodes, + "current_episode": self.current_episode_index, + } + + message = { + "type": "config_update", + "timestamp": time.time(), + "config": config_data, + } + + self.publisher_socket.send_json(message) + + def update_with_observations( + self, observation: Dict[str, np.ndarray], start_loop_t: int, countdown_time: int + ): + if observation is not None: + self.last_observation = observation + + if self.config.control_phase in self.modes_with_no_observation: + observation = self.last_observation + + log_items = self.log_control_info(start_loop_t) + self._publish_observations(observation, log_items, countdown_time) + self._handle_browser_events() + return self + + def _publish_observations(self, observation: Dict[str, np.ndarray], log_items: list, countdown_time: int): + """Encode and publish observation data with current configuration""" + processed_data = {} + for key, value in observation.items(): + if "image" in key: + image = value.numpy() if torch.is_tensor(value) else value + bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + success, buffer = cv2.imencode(".jpg", bgr_image) + if success: + b64_jpeg = base64.b64encode(buffer).decode("utf-8") + processed_data[key] = { + "type": "image", + "encoding": "jpeg_base64", + "data": b64_jpeg, + "shape": image.shape, + } + else: + tensor_data = value.detach().cpu().numpy() if torch.is_tensor(value) else value + processed_data[key] = { + "type": "tensor", + "data": tensor_data.tolist(), + "shape": tensor_data.shape, + } + + # Include current configuration in observation update + config_data = { + "assign_rewards": self.config.assign_rewards, + "control_phase": self.config.control_phase, + "num_episodes": self.config.num_episodes, + "current_episode": self.current_episode_index, + } + + # Sanitize countdown time. if inf set to max 32-bit int + countdown_time = int(countdown_time) if countdown_time != float("inf") else 2 ** 31 - 1 + if self.config.control_phase == ControlPhase.TELEOPERATE: + countdown_time = 0 + + message = { + "type": "observation_update", + "timestamp": time.time(), + "data": processed_data, + "events": self.get_events(), + "config": config_data, + "log_items": serialize_log_items(log_items), + "countdown_time": countdown_time, + } + + self.publisher_socket.send_json(message) + + def update_current_episode(self, episode_index): + self.current_episode_index = episode_index + return self + + def get_events(self): + return self.events.copy() + + def log_control_info(self, start_loop_t): + log_items = [] + fps = self.config.fps + if fps is not None: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / fps - dt_s) + + dt_s = time.perf_counter() - start_loop_t + log_items = log_control_info(self.config.robot, dt_s, fps=fps) + + return log_items + + def log_say(self, message): + self._publish_log_say(message) + + def _publish_log_say(self, message): + message = { + "type": "log_say", + "timestamp": time.time(), + "message": message, + } + + self.publisher_socket.send_json(message) + + def cleanup(self, robot=None): + """Clean up resources and connections""" + if robot: + robot.disconnect() + + self.publisher_socket.close() + self.command_sub_socket.close() + self.zmq_context.term() + + +if __name__ == "__main__": + import time + + import cv2 + import numpy as np + import torch + + def read_image_from_camera(cap): + ret, frame = cap.read() + if not ret: + print("Failed to grab frame") + return None + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + return torch.tensor(frame_rgb).float() + + config = ControlContextConfig( + assign_rewards=True, + control_phase=ControlPhase.RECORD, + num_episodes=200, + fps=30, + ) + context = ControlContext(config) + context.update_current_episode(199) + + cameras = {"main": cv2.VideoCapture(0), "top": cv2.VideoCapture(4)} + + for name, cap in cameras.items(): + if not cap.isOpened(): + raise Exception(f"Error: Could not open {name} camera") + + while True: + images = {} + camera_logs = {} + for name, cap in cameras.items(): + before_camread_t = time.perf_counter() + images[name] = read_image_from_camera(cap) + camera_logs[f"read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + + # Create state tensor (simulating follower positions) + state = torch.tensor([10.0195, 128.9355, 173.0566, -13.2715, -7.2070, 34.4531]) + + obs_dict = {"observation.state": state} + + for name in cameras: + obs_dict[f"observation.images.{name}"] = images[name] + + # Update context with observations + context.update_with_observations(obs_dict, time.perf_counter(), countdown_time=10) + events = context.get_events() + + if events["exit_early"]: + break + + for cap in cameras.values(): + cap.release() + cv2.destroyAllWindows() diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 8cc0f3260..de44265a7 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -8,9 +8,10 @@ import traceback from contextlib import nullcontext from copy import copy +from dataclasses import asdict, dataclass from functools import cache +from typing import Any, Dict, List, Optional -import cv2 import torch import tqdm from deepdiff import DeepDiff @@ -21,53 +22,87 @@ from lerobot.common.datasets.utils import get_features_from_robot from lerobot.common.policies.factory import make_policy from lerobot.common.robot_devices.robots.utils import Robot -from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed from lerobot.scripts.eval import get_pretrained_policy_path -def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): - log_items = [] - if episode_index is not None: - log_items.append(f"ep:{episode_index}") - if frame_index is not None: - log_items.append(f"frame:{frame_index}") +@dataclass +class LogItem: + name: str + value: float + unit: str + color: str = "white" - def log_dt(shortname, dt_val_s): - nonlocal log_items, fps - info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)" - if fps is not None: - actual_fps = 1 / dt_val_s - if actual_fps < fps - 1: - info_str = colored(info_str, "yellow") - log_items.append(info_str) + def to_dict(self): + return asdict(self) - # total step time displayed in milliseconds and its frequency - log_dt("dt", dt_s) +def stringify_and_log(log_items: List[LogItem]): + parts = [] + for item in log_items: + if item.unit: + info_str = f"{item.name}:{item.value:.2f} {item.unit}" + else: + info_str = f"{item.name}:{int(item.value)}" + + if item.color != "white": + info_str = colored(info_str, item.color) + + parts.append(info_str) + + info_str = " ".join(parts) + logging.info(info_str) - # TODO(aliberts): move robot-specific logs logic in robot.print_logs() +def serialize_log_items(log_items: List[LogItem]) -> List[Dict[str, Any]]: + return [item.to_dict() for item in log_items] + +def log_control_info(robot: Robot, dt_s: float, fps: Optional[float] = None, + episode_index: Optional[int] = None, + frame_index: Optional[int] = None) -> List[LogItem]: + log_items: List[LogItem] = [] + + # Add episode and frame information if provided + if episode_index is not None: + log_items.append(LogItem(name="ep", value=float(episode_index), unit="")) + if frame_index is not None: + log_items.append(LogItem(name="frame", value=float(frame_index), unit="")) + + # Helper function to create LogItem instances + def create_log_item(shortname: str, dt_val_s: float, base_fps: Optional[float]) -> LogItem: + value_ms = dt_val_s * 1000 + frequency = 1 / dt_val_s if dt_val_s > 0 else 0.0 + unit = f"ms ({frequency:.1f}Hz)" + color = "white" + if base_fps is not None and frequency < (base_fps - 1): + color = "yellow" + return LogItem(name=shortname, value=value_ms, unit=unit, color=color) + + # Log total step time + log_items.append(create_log_item("dt", dt_s, fps)) + + # Robot-specific logs if not robot.robot_type.startswith("stretch"): for name in robot.leader_arms: key = f"read_leader_{name}_pos_dt_s" if key in robot.logs: - log_dt("dtRlead", robot.logs[key]) + log_items.append(create_log_item("dtRlead", robot.logs[key], fps)) for name in robot.follower_arms: - key = f"write_follower_{name}_goal_pos_dt_s" - if key in robot.logs: - log_dt("dtWfoll", robot.logs[key]) + key_write = f"write_follower_{name}_goal_pos_dt_s" + if key_write in robot.logs: + log_items.append(create_log_item("dtWfoll", robot.logs[key_write], fps)) - key = f"read_follower_{name}_pos_dt_s" - if key in robot.logs: - log_dt("dtRfoll", robot.logs[key]) + key_read = f"read_follower_{name}_pos_dt_s" + if key_read in robot.logs: + log_items.append(create_log_item("dtRfoll", robot.logs[key_read], fps)) for name in robot.cameras: key = f"read_camera_{name}_dt_s" if key in robot.logs: - log_dt(f"dtR{name}", robot.logs[key]) + log_items.append(create_log_item(f"dtR{name}", robot.logs[key], fps)) + + stringify_and_log(log_items) + return log_items - info_str = " ".join(log_items) - logging.info(info_str) @cache @@ -183,44 +218,40 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides): def warmup_record( robot, - events, enable_teleoperation, warmup_time_s, - display_cameras, fps, + control_context ): control_loop( robot=robot, control_time_s=warmup_time_s, - display_cameras=display_cameras, - events=events, fps=fps, teleoperate=enable_teleoperation, + control_context=control_context, ) def record_episode( robot, dataset, - events, episode_time_s, - display_cameras, policy, device, use_amp, fps, + control_context ): control_loop( robot=robot, control_time_s=episode_time_s, - display_cameras=display_cameras, dataset=dataset, - events=events, policy=policy, device=device, use_amp=use_amp, fps=fps, teleoperate=policy is None, + control_context=control_context, ) @@ -229,14 +260,15 @@ def control_loop( robot, control_time_s=None, teleoperate=False, - display_cameras=False, dataset: LeRobotDataset | None = None, - events=None, policy=None, device=None, use_amp=None, fps=None, + control_context=None, ): + events = control_context.get_events() if control_context is not None else None + # TODO(rcadene): Add option to record logs if not robot.is_connected: robot.connect() @@ -255,50 +287,49 @@ def control_loop( timestamp = 0 start_episode_t = time.perf_counter() - while timestamp < control_time_s: - start_loop_t = time.perf_counter() + total_time = 0 + try: + while timestamp < control_time_s: + start_loop_t = time.perf_counter() - if teleoperate: - observation, action = robot.teleop_step(record_data=True) - else: - observation = robot.capture_observation() + if teleoperate: + observation, action = robot.teleop_step(record_data=True) + else: + observation = robot.capture_observation() - if policy is not None: - pred_action = predict_action(observation, policy, device, use_amp) - # Action can eventually be clipped using `max_relative_target`, - # so action actually sent is saved in the dataset. - action = robot.send_action(pred_action) - action = {"action": action} + if policy is not None: + pred_action = predict_action(observation, policy, device, use_amp) + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. + action = robot.send_action(pred_action) + action = {"action": action} - if dataset is not None: - frame = {**observation, **action} - dataset.add_frame(frame) + if dataset is not None: + frame = {**observation, **action} + dataset.add_frame(frame) - if display_cameras and not is_headless(): - image_keys = [key for key in observation if "image" in key] - for key in image_keys: - cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) - cv2.waitKey(1) + timestamp = time.perf_counter() - start_episode_t + total_time += timestamp + countdown_time = max(0, control_time_s - timestamp) - if fps is not None: - dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / fps - dt_s) + control_context.update_with_observations(observation, start_loop_t, countdown_time) - dt_s = time.perf_counter() - start_loop_t - log_control_info(robot, dt_s, fps=fps) + if events["exit_early"]: + events["exit_early"] = False + break - timestamp = time.perf_counter() - start_episode_t - if events["exit_early"]: - events["exit_early"] = False - break + except Exception as e: + print(f"Error in control loop: {e}") -def reset_environment(robot, events, reset_time_s): +def reset_environment(robot, control_context, reset_time_s): # TODO(rcadene): refactor warmup_record and reset_environment # TODO(alibets): allow for teleop during reset if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() + events = control_context.get_events() + timestamp = 0 start_vencod_t = time.perf_counter() @@ -307,23 +338,14 @@ def reset_environment(robot, events, reset_time_s): while timestamp < reset_time_s: time.sleep(1) timestamp = time.perf_counter() - start_vencod_t + countdown_time = max(0, reset_time_s - timestamp) + control_context.update_with_observations(None, 0, countdown_time) pbar.update(1) + if events["exit_early"]: events["exit_early"] = False break - -def stop_recording(robot, listener, display_cameras): - robot.disconnect() - - if not is_headless(): - if listener is not None: - listener.stop() - - if display_cameras: - cv2.destroyAllWindows() - - def sanity_check_dataset_name(repo_id, policy): _, dataset_name = repo_id.split("/") # either repo_id doesnt start with "eval_" and there is no policy diff --git a/lerobot/scripts/browser_ui_server.py b/lerobot/scripts/browser_ui_server.py new file mode 100644 index 000000000..c2074f292 --- /dev/null +++ b/lerobot/scripts/browser_ui_server.py @@ -0,0 +1,148 @@ +import threading +import time +from pathlib import Path + +import zmq +from flask import Flask, render_template +from flask_socketio import SocketIO + +# Get template directory path +template_dir = Path(__file__).resolve().parent.parent / "templates" + +# Initialize Flask with custom template directory +app = Flask(__name__, template_folder=str(template_dir)) +socketio = SocketIO(app, cors_allowed_origins="*") + +# Global dictionary to hold the latest data from ZeroMQ +latest_data = { + "observation": {}, + "config": {} +} + +zmq_context = zmq.Context() + +# For receiving updates from ControlContext +subscriber_socket = zmq_context.socket(zmq.SUB) +subscriber_socket.connect("tcp://127.0.0.1:5555") +subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "") + +# For sending keydown events to ControlContext +command_publisher = zmq_context.socket(zmq.PUB) +command_publisher.bind("tcp://127.0.0.1:5556") + +def zmq_consumer(): + while True: + try: + message = subscriber_socket.recv_json() + + if message.get("type") == "observation_update": + processed_data = { + "timestamp": message.get("timestamp"), + "images": {}, + "state": {}, + "events": message.get("events", {}), + "config": message.get("config", {}), + "log_items": message.get("log_items", []), + "countdown_time": message.get("countdown_time") + } + + # Process observation data + observation_data = message.get("data", {}) + for key, value in observation_data.items(): + if "image" in key: + if value["type"] == "image": + processed_data["images"][key.split(".")[-1]] = value["data"] + else: + if value["type"] == "tensor": + processed_data["state"][key] = { + "data": value["data"], + "shape": value["shape"] + } + + # Update latest observation and config + latest_data["observation"].update(processed_data) + latest_data["config"].update(processed_data.get("config", {})) + + # # Emit the observation data to the browser + socketio.emit("observation_update", processed_data) + + + elif message.get("type") == "config_update": + # Handle dedicated config updates + config_data = message.get("config", {}) + latest_data["config"].update(config_data) + + # Emit configuration update to browser + socketio.emit("config_update", { + "timestamp": message.get("timestamp"), + "config": config_data + }) + elif message.get("type") == "log_say": + data = message.get("message") + timestamp = message.get("timestamp") + socketio.emit("log_say", { + "timestamp": timestamp, + "message": data + }) + + except Exception as e: + print(f"ZMQ consumer error: {e}") + time.sleep(1) + + +@socketio.on("keydown_event") +def handle_keydown_event(data): + """ + When the browser sends a keydown_event, we publish it over ZeroMQ. + """ + key_pressed = data.get("key") + + # Publish over ZeroMQ + message = { + "type": "command", + "command": "keydown", + "key_pressed": key_pressed + } + command_publisher.send_json(message) + +@app.route("/") +def index(): + """Render the main page.""" + return render_template("browser_ui.html") + +@socketio.on("connect") +def handle_connect(): + """Handle client connection.""" + print("Client connected") + # Send current state if available + if latest_data["observation"]: + socketio.emit("observation_update", latest_data["observation"]) + if latest_data["config"]: + socketio.emit("config_update", { + "timestamp": time.time(), + "config": latest_data["config"] + }) + +@socketio.on("disconnect") +def handle_disconnect(): + """Handle client disconnection.""" + print("Client disconnected") + +def run_server(host="0.0.0.0", port=8000): + """Run the Flask-SocketIO server.""" + # Start ZMQ consumer in a background thread + zmq_thread = threading.Thread(target=zmq_consumer, daemon=True) + zmq_thread.start() + + # Run Flask-SocketIO app + socketio.run(app, host=host, port=port) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0", help="Host IP address") + parser.add_argument("--port", type=int, default=8000, help="Port number") + args = parser.parse_args() + + print(f"Starting server at {args.host}:{args.port}") + run_server(host=args.host, port=args.port) \ No newline at end of file diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 12eaf146f..aebf60760 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -101,17 +101,20 @@ # from safetensors.torch import load_file, save_file from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.robot_devices.control_context import ( + ControlContext, + ControlContextConfig, + ControlPhase, +) from lerobot.common.robot_devices.control_utils import ( control_loop, has_method, - init_keyboard_listener, init_policy, log_control_info, record_episode, reset_environment, sanity_check_dataset_name, sanity_check_dataset_robot_compatibility, - stop_recording, warmup_record, ) from lerobot.common.robot_devices.robots.factory import make_robot @@ -172,14 +175,21 @@ def calibrate(robot: Robot, arms: list[str] | None): @safe_disconnect def teleoperate( - robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False + robot: Robot, fps: int | None = None, teleop_time_s: float | None = None ): + control_context = ControlContext( + config=ControlContextConfig( + control_phase=ControlPhase.TELEOPERATE, + robot=robot, + fps=fps, + ) + ) control_loop( robot, control_time_s=teleop_time_s, fps=fps, teleoperate=True, - display_cameras=display_cameras, + control_context=control_context, ) @@ -209,8 +219,6 @@ def record( local_files_only: bool = False, ) -> LeRobotDataset: # TODO(rcadene): Add option to record logs - listener = None - events = None policy = None device = None use_amp = None @@ -259,15 +267,24 @@ def record( if not robot.is_connected: robot.connect() - listener, events = init_keyboard_listener() - # Execute a few seconds without recording to: # 1. teleoperate the robot to move it in starting position if no policy provided, # 2. give times to the robot devices to connect and start synchronizing, # 3. place the cameras windows on screen enable_teleoperation = policy is None - log_say("Warmup record", play_sounds) - warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps) + + control_context = ControlContext( + config=ControlContextConfig( + robot=robot, + control_phase=ControlPhase.WARMUP, + assign_rewards=False, + num_episodes=num_episodes, + fps=fps, + ) + ) + control_context.log_say("Warmup record") + + warmup_record(robot, enable_teleoperation, warmup_time_s, fps, control_context) if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() @@ -282,19 +299,31 @@ def record( # if multi_task: # task = input("Enter your task description: ") - log_say(f"Recording episode {dataset.num_episodes}", play_sounds) + control_context = control_context.update_config( + ControlContextConfig( + robot=robot, + control_phase=ControlPhase.RECORD, + assign_rewards=False, + num_episodes=num_episodes, + fps=fps, + ) + ) + + control_context.log_say(f"Recording episode {dataset.num_episodes + 1}") record_episode( dataset=dataset, robot=robot, - events=events, episode_time_s=episode_time_s, - display_cameras=display_cameras, policy=policy, device=device, use_amp=use_amp, fps=fps, + control_context=control_context, ) + # Events will be updated by control loop + events = control_context.get_events() + # Execute a few seconds without recording to give time to manually reset the environment # Current code logic doesn't allow to teleoperate during this time. # TODO(rcadene): add an option to enable teleoperation during reset @@ -302,24 +331,53 @@ def record( if not events["stop_recording"] and ( (dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"] ): - log_say("Reset the environment", play_sounds) - reset_environment(robot, events, reset_time_s) + control_context = control_context.update_config( + ControlContextConfig( + robot=robot, + control_phase=ControlPhase.RESET, + assign_rewards=False, + num_episodes=num_episodes, + fps=fps, + ) + ) + control_context.log_say("Reset the environment") + reset_environment(robot, control_context=control_context, reset_time_s=reset_time_s) if events["rerecord_episode"]: - log_say("Re-record episode", play_sounds) + control_context.log_say("Re-record episode") events["rerecord_episode"] = False events["exit_early"] = False dataset.clear_episode_buffer() continue + control_context = control_context.update_config( + ControlContextConfig( + robot=robot, + control_phase=ControlPhase.SAVING, + assign_rewards=False, + num_episodes=num_episodes, + fps=fps, + ) + ) dataset.save_episode(task) recorded_episodes += 1 + control_context.update_current_episode(recorded_episodes) if events["stop_recording"]: break - log_say("Stop recording", play_sounds, blocking=True) - stop_recording(robot, listener, display_cameras) + control_context.log_say("Stop recording") + control_context.cleanup(robot) + + control_context = control_context.update_config( + ControlContextConfig( + robot=robot, + control_phase=ControlPhase.PROCESSING_DATASET, + assign_rewards=False, + num_episodes=num_episodes, + fps=fps, + ) + ) if run_compute_stats: logging.info("Computing dataset statistics") @@ -327,9 +385,27 @@ def record( dataset.consolidate(run_compute_stats) if push_to_hub: + control_context = control_context.update_config( + ControlContextConfig( + robot=robot, + control_phase=ControlPhase.UPLOADING_DATASET_TO_HUB, + assign_rewards=False, + num_episodes=num_episodes, + fps=fps, + ) + ) dataset.push_to_hub(tags=tags) - log_say("Exiting", play_sounds) + control_context.log_say("Exiting") + control_context = control_context.update_config( + ControlContextConfig( + robot=robot, + control_phase=ControlPhase.RECORDING_COMPLETE, + assign_rewards=False, + num_episodes=num_episodes, + fps=fps, + ) + ) return dataset @@ -397,12 +473,6 @@ def replay( parser_teleop.add_argument( "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" ) - parser_teleop.add_argument( - "--display-cameras", - type=int, - default=1, - help="Display all cameras on screen (set to 1 to display or 0).", - ) parser_record = subparsers.add_parser("record", parents=[base_parser]) task_args = parser_record.add_mutually_exclusive_group(required=True) diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py index 4fffa8c75..ae839d3e1 100644 --- a/lerobot/scripts/control_sim_robot.py +++ b/lerobot/scripts/control_sim_robot.py @@ -299,6 +299,7 @@ def record( dataset.add_frame(frame) + # @TODO(jackvial): Update to use ControlContext if display_cameras and not is_headless(): for key in image_keys: cv2.imshow(key, cv2.cvtColor(observation[key], cv2.COLOR_RGB2BGR)) diff --git a/lerobot/templates/browser_ui.html b/lerobot/templates/browser_ui.html new file mode 100644 index 000000000..aaa1fd3fc --- /dev/null +++ b/lerobot/templates/browser_ui.html @@ -0,0 +1,388 @@ + + +
+ +