Skip to content

Commit

Permalink
Context window (#154)
Browse files Browse the repository at this point in the history
* gia -> jat

* gia -> jat

* Gia -> Jat

* GIA -> JAT

* rename files and folder

* additional file renaming

* Generally Intelligent Agents -> Jack of All Trades

* fix scores_dict path

* Drop relative path for normalization (#150)

* train_jat_tokenized

* New pretty names and random scores (#151)

* take litterature random scores

* New pretty names

* download tokenized

* fix os exists train

* tokenize stream

* remove save to dict

* train_jat_tokenized

* batch encoding

* handle pad for alread tokenized

* wandb entity qgallouedec

* try allow multi worker loader

* revert

* allow multiple load worker

* new interleave datasets

* fix none values

* fix

* seed

* remove features

* don't use iterable

* support n_contiguous for map_style

* fix take/select

* try cast reward

* simplfy

* even simpler

* fix train-dataset

* hf offline

* lil fix processing

* fix last sample

* try sequential sampling

* tqdm loading

* to tensor

* try fix processing or image in eval

* handle normalization failure case

* handle none in eval

* use rliable for metric computations

* fix to_tensor

* fix human normalize atari calculation

* allow not all domain score for eval

* style

* tmp readme

* readme

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Context window
  • Loading branch information
qgallouedec authored Mar 25, 2024
1 parent 09a9a21 commit 28154e7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
5 changes: 5 additions & 0 deletions jat/modeling_jat.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,7 @@ def get_next_action(
action_space: Union[spaces.Box, spaces.Discrete] = None,
reward: Optional[float] = None,
deterministic: bool = False,
context_window: Optional[int] = None,
):
# Get the maximum sequence length
max_length = self.config.max_position_embeddings // 2
Expand Down Expand Up @@ -804,6 +805,10 @@ def to_list(x):
# We remove the last two values, as the inputs are [s_0, 0], [s_0, a_0, s_1, 0], [s_1, a_1, s_2, 0], ...
self._last_key_values = tuple(tuple(pkv[:, :, :-2] for pkv in pkvs) for pkvs in self._last_key_values)

# Context window
if context_window is not None:
self._last_key_values = tuple(tuple(pkv[:, :, -context_window:] for pkv in pkvs) for pkvs in self._last_key_values)

# Return the predicted action
if continuous_actions is not None:
self.last_continuous_action = outputs.pred_actions[0, -1].cpu().tolist()
Expand Down
4 changes: 3 additions & 1 deletion scripts/eval_jat.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def eval_rl(model, processor, task, eval_args):

env = make(task, **env_kwargs)

context_window = 32 if task.startswith("atari") else 256

scores = []
frames = []
for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False):
Expand All @@ -84,7 +86,7 @@ def eval_rl(model, processor, task, eval_args):
done = False
model.reset_rl() # remove KV Cache
while not done:
action = model.get_next_action(processor, **observation, reward=reward, action_space=env.action_space)
action = model.get_next_action(processor, **observation, reward=reward, action_space=env.action_space, context_window=context_window)
observation, reward, termined, truncated, info = env.step(action)
done = termined or truncated

Expand Down

0 comments on commit 28154e7

Please sign in to comment.