diff --git a/examples/drone/README.md b/examples/drone/README.md index 39eb37c1..44eb7cee 100644 --- a/examples/drone/README.md +++ b/examples/drone/README.md @@ -71,6 +71,9 @@ Run with: python hover_eval.py -e drone-hovering --ckpt 500 --record ``` +**Note**: If you experience slow performance or encounter other issues +during evaluation, try removing the `--record` option. + ## Technical Details - The drone model used is the Crazyflie 2.X (`urdf/drones/cf2x.urdf`) diff --git a/examples/drone/hover_env.py b/examples/drone/hover_env.py index 0f9c7553..1c7ca886 100644 --- a/examples/drone/hover_env.py +++ b/examples/drone/hover_env.py @@ -39,7 +39,7 @@ def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_vie camera_lookat=(0.0, 0.0, 1.0), camera_fov=40, ), - vis_options=gs.options.VisOptions(n_rendered_envs=1), + vis_options=gs.options.VisOptions(n_rendered_envs=10), rigid_options=gs.options.RigidOptions( dt=self.dt, constraint_solver=gs.constraint_solver.Newton, diff --git a/examples/drone/hover_eval.py b/examples/drone/hover_eval.py index 51a8fe31..095b840c 100644 --- a/examples/drone/hover_eval.py +++ b/examples/drone/hover_eval.py @@ -66,4 +66,8 @@ def main(): """ # evaluation python examples/drone/hover_eval.py + +# Note +If you experience slow performance or encounter other issues +during evaluation, try removing the --record option. """ diff --git a/examples/drone/hover_train.py b/examples/drone/hover_train.py index 49e3c891..f84c3eb9 100644 --- a/examples/drone/hover_train.py +++ b/examples/drone/hover_train.py @@ -14,7 +14,7 @@ def get_train_cfg(exp_name, max_iterations): "algorithm": { "clip_param": 0.2, "desired_kl": 0.01, - "entropy_coef": 0.002, + "entropy_coef": 0.004, "gamma": 0.99, "lam": 0.95, "learning_rate": 0.0003, @@ -109,11 +109,12 @@ def get_cfgs(): def main(): parser = argparse.ArgumentParser() parser.add_argument("-e", "--exp_name", type=str, default="drone-hovering") + parser.add_argument("-v", "--vis", action="store_true", default=False) parser.add_argument("-B", "--num_envs", type=int, default=8192) parser.add_argument("--max_iterations", type=int, default=500) args = parser.parse_args() - gs.init(logging_level="warning") + gs.init(logging_level="error") log_dir = f"logs/{args.exp_name}" env_cfg, obs_cfg, reward_cfg, command_cfg = get_cfgs() @@ -123,8 +124,16 @@ def main(): shutil.rmtree(log_dir) os.makedirs(log_dir, exist_ok=True) + if args.vis: + env_cfg["visualize_target"] = True + env = HoverEnv( - num_envs=args.num_envs, env_cfg=env_cfg, obs_cfg=obs_cfg, reward_cfg=reward_cfg, command_cfg=command_cfg + num_envs=args.num_envs, + env_cfg=env_cfg, + obs_cfg=obs_cfg, + reward_cfg=reward_cfg, + command_cfg=command_cfg, + show_viewer=args.vis, ) runner = OnPolicyRunner(env, train_cfg, log_dir, device="cuda:0")