-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_rppo.py
57 lines (51 loc) · 2.94 KB
/
train_rppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""利用sb-contrib rppo进行测试
Write detailed description here
Write typical usage example here
@Modify Time @Author @Version @Description
------------ ------- -------- -----------
3/29/23 7:50 PM yinzikang 1.0 None
"""
import time
import torch as th
from sb3_contrib import RecurrentPPO as RPPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback, CallbackList
from stable_baselines3.common.env_util import make_vec_env
import gym
from gym_custom.envs.env_kwargs import env_kwargs
env_name = 'TrainEnvVariableStiffnessAndPosture-v6'
test_name = 'cabinet surface with plan'
rl_name = 'RPPO'
time_name = time.strftime("%m-%d-%H-%M")
path_name = 'train_results/' + test_name + '/' + rl_name + '/' + time_name + '/'
env_num = 4
_, _, rl_kwargs = env_kwargs(test_name, save_flag=True, save_path=path_name)
train_env = make_vec_env(env_id=env_name, n_envs=env_num, env_kwargs=rl_kwargs)
eval_env = gym.make(env_name, **rl_kwargs)
total_timesteps = 1_000_000
policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=[dict(pi=[512, 512], vf=[512, 512])],
n_lstm_layers=1, shared_lstm=True, enable_critic_lstm=True)
replay_buffer_kwargs = dict(n_sampled_goal=4, goal_selection_strategy="future")
checkpoint_callback = CheckpointCallback(save_freq=int(total_timesteps / 10 / env_num),
save_path=path_name, name_prefix="model",
save_replay_buffer=False, save_vecnormalize=False)
eval_callback = EvalCallback(eval_env, best_model_save_path=path_name, log_path=path_name,
eval_freq=int(total_timesteps / 10 / env_num))
callback = CallbackList([checkpoint_callback, eval_callback])
model = RPPO("MlpLstmPolicy", train_env, learning_rate=0.0003, policy_kwargs=policy_kwargs, verbose=1, seed=None,
device='cuda', _init_setup_model=True,
tensorboard_log='log/' + test_name + '/' + rl_name + '/' + time_name,
use_sde=True, sde_sample_freq=-1,
# on policy特有
n_steps=2048, batch_size=2048, gamma=0.99, gae_lambda=0.95, ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5,
# 算法特有参数
n_epochs=10, clip_range=0.2, clip_range_vf=None, normalize_advantage=True, target_kl=None)
model.learn(total_timesteps=total_timesteps, callback=callback, log_interval=4, tb_log_name="",
reset_num_timesteps=True, progress_bar=True)
model.save(path=path_name + 'model', exclude=None, include=None)
mean_reward, std_reward = evaluate_policy(model=model, env=eval_env, n_eval_episodes=1, deterministic=True,
render=False, callback=None, reward_threshold=None,
return_episode_rewards=False, warn=True)
print(mean_reward, std_reward)