Skip to content

Commit

Permalink
Merge pull request #533 from KafuuChikai/main
Browse files Browse the repository at this point in the history
[FEATURE] update HoverEnv, change hyper params, and visualization while training
  • Loading branch information
zhouxian authored Jan 12, 2025
2 parents 70fbb72 + 77f893b commit 10a1078
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 4 deletions.
3 changes: 3 additions & 0 deletions examples/drone/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ Run with:
python hover_eval.py -e drone-hovering --ckpt 500 --record
```

**Note**: If you experience slow performance or encounter other issues
during evaluation, try removing the `--record` option.

## Technical Details

- The drone model used is the Crazyflie 2.X (`urdf/drones/cf2x.urdf`)
Expand Down
2 changes: 1 addition & 1 deletion examples/drone/hover_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_vie
camera_lookat=(0.0, 0.0, 1.0),
camera_fov=40,
),
vis_options=gs.options.VisOptions(n_rendered_envs=1),
vis_options=gs.options.VisOptions(n_rendered_envs=10),
rigid_options=gs.options.RigidOptions(
dt=self.dt,
constraint_solver=gs.constraint_solver.Newton,
Expand Down
4 changes: 4 additions & 0 deletions examples/drone/hover_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,8 @@ def main():
"""
# evaluation
python examples/drone/hover_eval.py
# Note
If you experience slow performance or encounter other issues
during evaluation, try removing the --record option.
"""
15 changes: 12 additions & 3 deletions examples/drone/hover_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_train_cfg(exp_name, max_iterations):
"algorithm": {
"clip_param": 0.2,
"desired_kl": 0.01,
"entropy_coef": 0.002,
"entropy_coef": 0.004,
"gamma": 0.99,
"lam": 0.95,
"learning_rate": 0.0003,
Expand Down Expand Up @@ -109,11 +109,12 @@ def get_cfgs():
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--exp_name", type=str, default="drone-hovering")
parser.add_argument("-v", "--vis", action="store_true", default=False)
parser.add_argument("-B", "--num_envs", type=int, default=8192)
parser.add_argument("--max_iterations", type=int, default=500)
args = parser.parse_args()

gs.init(logging_level="warning")
gs.init(logging_level="error")

log_dir = f"logs/{args.exp_name}"
env_cfg, obs_cfg, reward_cfg, command_cfg = get_cfgs()
Expand All @@ -123,8 +124,16 @@ def main():
shutil.rmtree(log_dir)
os.makedirs(log_dir, exist_ok=True)

if args.vis:
env_cfg["visualize_target"] = True

env = HoverEnv(
num_envs=args.num_envs, env_cfg=env_cfg, obs_cfg=obs_cfg, reward_cfg=reward_cfg, command_cfg=command_cfg
num_envs=args.num_envs,
env_cfg=env_cfg,
obs_cfg=obs_cfg,
reward_cfg=reward_cfg,
command_cfg=command_cfg,
show_viewer=args.vis,
)

runner = OnPolicyRunner(env, train_cfg, log_dir, device="cuda:0")
Expand Down

0 comments on commit 10a1078

Please sign in to comment.