-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_ppo_halfcheetah.py
executable file
·34 lines (28 loc) · 1.05 KB
/
test_ppo_halfcheetah.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#!/usr/bin/python3
import tensorflow as tf;
from tf_agents.environments import tf_py_environment, suite_mujoco; # environment
from tf_agents.policies import policy_saver; # policy
import cv2;
def main():
# environment
eval_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load('HalfCheetah-v2'));
# deserialize saved policy
saved_policy = tf.compat.v2.saved_model.load('checkpoints/policy_9500/');
# apply_policy and visualize
total_return = 0.0;
for _ in range(10):
episode_return = 0.0;
status = eval_env.reset();
policy_state = saved_policy.get_initial_state(eval_env.batch_size);
while not status.is_last():
action = saved_policy.action(status, policy_state);
status = eval_env.step(action.action);
policy_state = action.state;
cv2.imshow('halfcheetah', eval_env.pyenv.envs[0].render());
cv2.waitKey(25);
episode_return += status.reward;
total_return += episode_return;
avg_return = total_return / 10;
print("average return is %f" % avg_return);
if __name__ == "__main__":
main();