-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathreplay_buffer.py
124 lines (102 loc) · 3.86 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import datetime
import io
import random
import traceback
import copy
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset
def episode_len(episode):
# subtract -1 because the dummy first transition
return next(iter(episode.values())).shape[0] - 1
def save_episode(episode, fn):
with io.BytesIO() as bs:
np.savez_compressed(bs, **episode)
bs.seek(0)
with fn.open('wb') as f:
f.write(bs.read())
def load_episode(fn):
with fn.open('rb') as f:
episode = np.load(f)
episode = {k: episode[k] for k in episode.keys()}
return episode
def relable_episode(env, episode):
rewards = []
reward_spec = env.reward_spec()
states = episode['physics']
for i in range(states.shape[0]):
with env.physics.reset_context():
env.physics.set_state(states[i])
reward = env.task.get_reward(env.physics)
reward = np.full(reward_spec.shape, reward, reward_spec.dtype)
rewards.append(reward)
episode['reward'] = np.array(rewards, dtype=reward_spec.dtype)
return episode
class OfflineReplayBuffer(IterableDataset):
def __init__(self, env, replay_dir, max_size, num_workers, discount):
self._env = env
self._replay_dir = replay_dir
self._size = 0
self._max_size = max_size
self._num_workers = max(1, num_workers)
self._episode_fns = []
self._episodes = dict()
self._discount = discount
self._loaded = False
def _load(self, relable=True):
print('Labeling data...')
try:
worker_id = torch.utils.data.get_worker_info().id
except:
worker_id = 0
eps_fns = sorted(self._replay_dir.glob('*.npz'))
for eps_fn in eps_fns:
if self._size > self._max_size:
break
eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
if eps_idx % self._num_workers != worker_id:
continue
episode = load_episode(eps_fn)
if relable:
episode = self._relable_reward(episode)
self._episode_fns.append(eps_fn)
self._episodes[eps_fn] = episode
self._size += episode_len(episode)
def _sample_episode(self):
if not self._loaded:
self._load()
self._loaded = True
eps_fn = random.choice(self._episode_fns)
return self._episodes[eps_fn]
def _relable_reward(self, episode):
return relable_episode(self._env, episode)
def _sample(self):
episode = self._sample_episode()
# add +1 for the first dummy transition
idx = np.random.randint(0, episode_len(episode)) + 1
obs = episode['observation'][idx - 1]
action = episode['action'][idx]
next_obs = episode['observation'][idx]
reward = episode['reward'][idx]
discount = episode['discount'][idx] * self._discount
return (obs, action, reward, discount, next_obs)
def __iter__(self):
while True:
yield self._sample()
def _worker_init_fn(worker_id):
seed = np.random.get_state()[1][0] + worker_id
np.random.seed(seed)
random.seed(seed)
def make_replay_loader(env, replay_dir, max_size, batch_size, num_workers,
discount):
max_size_per_worker = max_size // max(1, num_workers)
iterable = OfflineReplayBuffer(env, replay_dir, max_size_per_worker,
num_workers, discount)
loader = torch.utils.data.DataLoader(iterable,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
worker_init_fn=_worker_init_fn)
return loader