-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplay.py
62 lines (52 loc) · 1.62 KB
/
play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import pybullet_envs # noqa
import gym
import numpy as np
import os
from gerel.util.datastore import DataStore
from gerel.model.model import Model
from gym.wrappers import Monitor
import click
import time
# https://www.etedal.net/2020/04/pybullet-panda_2.html
ENV_NAME = 'AntBulletEnv-v0'
STEPS = 1000
DIR = './assets/example/'
def play(genome, record=False, steps=1000):
done = False
model = Model(genome)
env = gym.make(ENV_NAME)
if record:
env = Monitor(env, './video', force=True)
env.render()
state = env.reset()
action_map = lambda x: np.tanh(np.array(x)) # noqa
rewards = 0
i = 0
while not done and i < STEPS:
i += 1
time.sleep(0.007)
action = model(state)
action = action_map(action)
next_state, reward, done, _ = env.step(action)
rewards += reward
state = next_state
env.render()
return rewards
@click.command()
@click.option('--record', '-r', is_flag=True,
help='Record roleout')
@click.option('--steps', '-s', default=1000, type=int,
help='Max number of steps per episode')
@click.option('--generation', '-g', default=None, type=int,
help='Generation to play')
@click.option('--dir', '-d', default=DIR,
help='working folder')
def cli(record, steps, generation, dir):
if not generation:
generation = max([int(i) for i in os.listdir(dir)])
ds = DataStore(name=dir)
data = ds.load(generation)
rewards = play(data['best_genome'], record, steps)
print(f'generation: {generation}, rewards: {rewards}')
if __name__ == '__main__':
cli()