forked from Hauf3n/MuZero-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Env_Runner.py
87 lines (60 loc) · 2.21 KB
/
Env_Runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
device = torch.device("cuda:0")
dtype = torch.float
class Logger:
def __init__(self, filename):
self.filename = filename
f = open(f"{self.filename}.csv", "w")
f.close()
def log(self, msg):
f = open(f"{self.filename}.csv", "a+")
f.write(f"{msg}\n")
f.close()
class Env_Runner:
def __init__(self, env):
super().__init__()
self.env = env
self.num_actions = self.env.action_space.n
self.logger = Logger("episode_returns")
self.logger.log("training_step, return")
self.ob = self.env.reset()
self.total_eps = 0
def run(self, agent):
self.obs = []
self.actions = []
self.rewards = []
self.dones = []
self.pis = []
self.vs = []
self.ob = self.env.reset()
self.obs.append(torch.tensor(self.ob))
done = False
while not done:
action, pi, v = agent.mcts_inference(torch.tensor(self.ob).to(device).to(dtype))
self.ob, r, done, info = self.env.step(action)
self.obs.append(torch.tensor(self.ob))
self.actions.append(action)
self.pis.append(torch.tensor(pi))
self.vs.append(v)
self.rewards.append(torch.tensor(r))
self.dones.append(done)
if done: # environment reset
if "return" in info:
self.logger.log(f'{self.total_eps},{info["return"]}')
#self.env.render()
self.total_eps += 1
return self.make_trajectory()
def make_trajectory(self):
traj = {}
traj["obs"] = self.obs
traj["actions"] = self.actions
traj["rewards"] = self.rewards
traj["dones"] = self.dones
traj["pis"] = self.pis
traj["vs"] = self.vs
traj["length"] = len(self.obs)
return traj