From 477a23a1e41257d8aa3e87331c043b8838708c52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 25 Mar 2024 20:50:47 +0100 Subject: [PATCH] some fixes (#156) --- Makefile | 4 ++-- README.md | 23 +++++++++++++---------- jat/modeling_jat.py | 4 +++- pyproject.toml | 8 +++++--- scripts/eval_jat.py | 4 +++- 5 files changed, 26 insertions(+), 17 deletions(-) diff --git a/Makefile b/Makefile index 3d934b55..a48c4c6a 100644 --- a/Makefile +++ b/Makefile @@ -6,12 +6,12 @@ DIRS = data examples jat scripts tests # Check that source code meets quality standards quality: black --check $(DIRS) setup.py - ruff $(DIRS) setup.py + ruff check $(DIRS) setup.py # Format source code automatically style: black $(DIRS) setup.py - ruff $(DIRS) setup.py --fix + ruff check $(DIRS) setup.py --fix # Run tests for the library test: diff --git a/README.md b/README.md index 8baae665..73414da3 100644 --- a/README.md +++ b/README.md @@ -19,18 +19,20 @@ To get started with JAT, follow these steps: 1. Clone this repository onto your local machine. + + ```shell + git clone https://github.com/huggingface/jat.git + cd jat ``` - git clone https://github.com/huggingface/jat.git - cd jat - ``` + 2. Create a new virtual environment and activate it, and install required dependencies via pip. - ``` + + ```shell python3 -m venv env source env/bin/activate pip install . ``` - ## Demonstration of the trained agent The trained JAT agent is available [here](https://huggingface.co/jat-project/jat). The following script gives an example of the use of this agent on the Pong environment @@ -65,27 +67,28 @@ env.close() % GIF of trained agent here - ## Usage Examples + Here are some examples of how you might use JAT in both evaluation and fine-tuning modes. More detailed information about each example is provided within the corresponding script files. * **Evaluation Mode**: Evaluate pretrained JAT models on specific downstream tasks - ``` + + ```shell python scripts/eval_jat.py --model_name_or_path jat-project/jat --tasks atari-pong --trust_remote_code ``` + * **Training Mode**: Train your own JAT model from scratch - ``` + + ```shell python scripts/train_jat.py %TODO ``` For further details regarding usage, consult the documentation included with individual script files. - ## Dataset % TODO - ## Citation Please ensure proper citations when incorporating this work into your projects. diff --git a/jat/modeling_jat.py b/jat/modeling_jat.py index d67aa78e..6d131d75 100644 --- a/jat/modeling_jat.py +++ b/jat/modeling_jat.py @@ -807,7 +807,9 @@ def to_list(x): # Context window if context_window is not None: - self._last_key_values = tuple(tuple(pkv[:, :, -context_window:] for pkv in pkvs) for pkvs in self._last_key_values) + self._last_key_values = tuple( + tuple(pkv[:, :, -context_window:] for pkv in pkvs) for pkvs in self._last_key_values + ) # Return the predicted action if continuous_actions is not None: diff --git a/pyproject.toml b/pyproject.toml index 73e12b64..4e4b7de5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,10 +3,12 @@ line-length = 119 target-version = ['py38'] [tool.ruff] +line-length = 119 + +[tool.ruff.lint] ignore = ["C901"] select = ["C", "E", "F", "I", "W"] -line-length = 119 -[tool.ruff.isort] +[tool.ruff.lint.isort] lines-after-imports = 2 -known-first-party = ["jat"] \ No newline at end of file +known-first-party = ["jat"] diff --git a/scripts/eval_jat.py b/scripts/eval_jat.py index 034a9882..bc373e82 100755 --- a/scripts/eval_jat.py +++ b/scripts/eval_jat.py @@ -86,7 +86,9 @@ def eval_rl(model, processor, task, eval_args): done = False model.reset_rl() # remove KV Cache while not done: - action = model.get_next_action(processor, **observation, reward=reward, action_space=env.action_space, context_window=context_window) + action = model.get_next_action( + processor, **observation, reward=reward, action_space=env.action_space, context_window=context_window + ) observation, reward, termined, truncated, info = env.step(action) done = termined or truncated