diff --git a/Makefile b/Makefile index 69f0a725..bb15b8e3 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY: quality style test # Define directories variable -DIRS = data examples gia scripts tests +DIRS = data examples gia gia2 scripts tests # Check that source code meets quality standards quality: diff --git a/data/conceptual_captions/generate_conceptual_caption.py b/data/conceptual_captions/generate_conceptual_caption.py index fb9faf9e..88ee2301 100644 --- a/data/conceptual_captions/generate_conceptual_caption.py +++ b/data/conceptual_captions/generate_conceptual_caption.py @@ -1,36 +1,43 @@ -import concurrent.futures import io -import os -import urllib +import multiprocessing +from typing import Dict, List, Union +from urllib.request import Request, urlopen import PIL.Image -from datasets import load_dataset from datasets.utils.file_utils import get_datasets_user_agent USER_AGENT = get_datasets_user_agent() -PATH = "data/test" # or "data/train" -MAX_WORKERS = 10 # adjust to your needs -MAX_QUEUE_SIZE = 2 * MAX_WORKERS # adjust to your needs +def fetch_image(image_url: str, timeout: float = 0.5) -> PIL.Image.Image: + """ + Fetches a single image from a given URL and returns it as a PIL Image object. -def fetch_single_image(image_url, timeout=1): - print(image_url) - try: - request = urllib.request.Request( - image_url, - data=None, - headers={"user-agent": USER_AGENT}, - ) - with urllib.request.urlopen(request, timeout=timeout) as req: - image = PIL.Image.open(io.BytesIO(req.read())) - except Exception: - image = None + Args: + image_url (str): The URL of the image to fetch. + timeout (float): The timeout value for the request (in seconds). + + Returns: + A PIL Image object representing the fetched image, or None if the image could not be fetched. + """ + request = Request(image_url, data=None, headers={"user-agent": USER_AGENT}) + with urlopen(request, timeout=timeout) as req: + image = PIL.Image.open(io.BytesIO(req.read())) return image -def resize_single_image(image: PIL.Image): +def resize_image(image: PIL.Image) -> PIL.Image: + """ + Resize a single image to have the bigger size at most 352 pixels while maintaining aspect ratio. + Remove metadata from the image. + + Args: + image (PIL.Image): The image to be resized. + + Returns: + PIL.Image: The resized image without metadata. + """ # Resize so that the bigger size is at most 352 width, height = image.size if width > height: @@ -40,43 +47,68 @@ def resize_single_image(image: PIL.Image): new_height = 352 new_width = int(width * 352 / height) image = image.resize((new_width, new_height), PIL.Image.BILINEAR) - image = image.convert("RGB") + image = image.convert("RGB") # Make sure the image is RGB + data = list(image.getdata()) # Get only the image data, and place it in a new image to remove metadata + image_without_exif = PIL.Image.new(image.mode, image.size) + image_without_exif.putdata(data) + return image_without_exif + + +def fetch_and_resize(img_url: str) -> Union[PIL.Image.Image, None]: + """ + Fetches an image from a given URL and resizes it. + + Args: + img_url (str): The URL of the image to fetch. + + Returns: + numpy.ndarray: The resized image as a NumPy array, or None if an error occurred. + """ + try: + image = fetch_image(img_url) + image = resize_image(image) + except Exception: + image = None return image -dataset = load_dataset("conceptual_captions", split="validation") # or "train" -if not os.path.exists(f"{PATH}/metadata.csv"): - with open(f"{PATH}/metadata.csv", "w") as f: - f.write("file_name,caption,idx\n") - dataset_idx = 0 - image_idx = 0 -else: # get the lastest index - with open(f"{PATH}/metadata.csv", "r") as f: - lines = f.readlines() - image_idx = len(lines) - 1 - dataset_idx = int(lines[-1].split(",")[-1]) + 1 - print(image_idx, dataset_idx) - -with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: - future_to_idx = {executor.submit(fetch_single_image, dataset[dataset_idx]["image_url"]): dataset_idx} - dataset_idx += 1 - while dataset_idx < len(dataset): - done, _ = concurrent.futures.wait(future_to_idx, return_when=concurrent.futures.FIRST_COMPLETED) - for future in done: - idx = future_to_idx.pop(future) +def process(example: Dict[str, List[str]]) -> Dict[str, List[Union[str, PIL.Image.Image]]]: + output = {"images": [], "text": []} + + with multiprocessing.Pool() as pool: + images = pool.starmap(fetch_and_resize, [(url,) for url in example["image_url"]]) + + for idx, image in enumerate(images): + if image is not None: + output["images"].append(image) + output["text"].append(example["caption"][idx]) + + return output + + +if __name__ == "__main__": + from datasets import Dataset, features, load_dataset + + for split in ["train", "test"]: + dataset = load_dataset("conceptual_captions", split="train" if split == "train" else "validation") + num_cpu = multiprocessing.cpu_count() // 2 + dataset = dataset.map( + process, + batched=True, + batch_size=200, + remove_columns=["caption", "image_url"], + num_proc=num_cpu, + load_from_cache_file=True, + features=features.Features({"images": features.Image(decode=True), "text": features.Value("string")}), + ) + dataset.save_to_disk(f"conceptual-captions-{split}") + dataset = Dataset.load_from_disk(f"conceptual-captions-{split}") + + retry = 500 + + for i in range(retry): try: - image = future.result() - if image is not None: - image = resize_single_image(image) - sample = dataset[idx] - caption = sample["caption"].replace(",", "").replace(";", "").replace("\n", "").replace("\t", "") - image.save(f"{PATH}/{image_idx:07d}.png", "PNG") - with open(f"{PATH}/metadata.csv", "a") as f: - f.write(f"{image_idx:07d}.png,{caption},{idx}\n") - image_idx += 1 - except Exception as exc: - print(f"Generated an exception: {exc}") - - while len(future_to_idx) < MAX_QUEUE_SIZE and dataset_idx < len(dataset): - future_to_idx[executor.submit(fetch_single_image, dataset[dataset_idx]["image_url"])] = dataset_idx - dataset_idx += 1 + dataset.push_to_hub("gia-project/gia-dataset-parquet", "conceptual-captions", split=split) + break + except Exception: + print(f"Retry {i+1}/{retry}") diff --git a/data/envs/atari/create_atari_dataset.py b/data/envs/atari/create_atari_dataset.py index 04122b63..372cb289 100644 --- a/data/envs/atari/create_atari_dataset.py +++ b/data/envs/atari/create_atari_dataset.py @@ -1,9 +1,9 @@ -import time -from collections import deque - +import datasets import numpy as np import torch +from datasets import Dataset, concatenate_datasets from huggingface_hub import HfApi, upload_folder +from PIL import Image from sample_factory.algo.learning.learner import Learner from sample_factory.algo.sampling.batched_sampling import preprocess_actions from sample_factory.algo.utils.action_distributions import argmax_actions @@ -12,7 +12,7 @@ from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor from sample_factory.cfg.arguments import load_from_checkpoint -from sample_factory.enjoy import render_frame, visualize_policy_inputs +from sample_factory.enjoy import visualize_policy_inputs from sample_factory.model.actor_critic import create_actor_critic from sample_factory.model.model_utils import get_rnn_size from sample_factory.utils.attr_dict import AttrDict @@ -20,8 +20,6 @@ from sample_factory.utils.utils import log from sf_examples.envpool.atari.train_envpool_atari import parse_atari_args, register_atari_components -from gia.datasets.to_hub import add_dataset_to_hub - def push_to_hf(dir_path: str, repo_name: str): _ = HfApi().create_repo(repo_id=repo_name, private=False, exist_ok=True, repo_type="dataset") @@ -33,7 +31,6 @@ def push_to_hf(dir_path: str, repo_name: str): # most of this function is redundant as it is copied from sample.enjoy.enjoy def create_atari_dataset(cfg: Config): - verbose = False cfg = load_from_checkpoint(cfg) @@ -75,37 +72,21 @@ def create_atari_dataset(cfg: Config): checkpoint_dict = Learner.load_checkpoint(checkpoints, device) actor_critic.load_state_dict(checkpoint_dict["model"]) - episode_rewards = [deque([], maxlen=100) for _ in range(env.num_agents)] - true_objectives = [deque([], maxlen=100) for _ in range(env.num_agents)] num_frames = 0 - last_render_start = time.time() - - def max_frames_reached(frames): - return cfg.max_num_frames is not None and frames > cfg.max_num_frames - - reward_list = [] - obs, infos = env.reset() rnn_states = torch.zeros([env.num_agents, get_rnn_size(cfg)], dtype=torch.float32, device=device) - episode_reward = None - finished_episode = [False for _ in range(env.num_agents)] - - video_frames = [] - num_episodes = 0 - env.action_space.n - - dataset_image_observations = [] - dataset_rewards = [] - dataset_discrete_actions = [] + image_observations = [] + rewards = [] + discrete_actions = [] ep_image_observations = [] ep_rewards = [] ep_discrete_actions = [] with torch.no_grad(): - while not max_frames_reached(num_frames): - obs = {k: v[0] for k, v in obs.items()} + while num_frames < cfg.max_num_frames: + obs["obs"] = obs["obs"][0] normalized_obs = prepare_and_normalize_obs(actor_critic, obs) if not cfg.no_render: @@ -126,103 +107,31 @@ def max_frames_reached(frames): rnn_states = policy_outputs["new_rnn_states"] - for _ in range(render_action_repeat): # this is 1 for all atari envs - last_render_start = render_frame(cfg, env, video_frames, num_episodes, last_render_start) - - # store s in buffer - if num_frames < cfg.max_num_frames: - ep_image_observations.append(obs["obs"].cpu().numpy()) - - obs, rew, terminated, truncated, infos = env.step([actions]) - - dones = make_dones(terminated, truncated) - - # store a,r, d in buffer - if num_frames < cfg.max_num_frames: - ep_rewards.append(rew.item()) - ep_discrete_actions.append(actions) - - infos = [{} for _ in range(env_info.num_agents)] if infos is None else infos - - if episode_reward is None: - episode_reward = rew.float().clone() - else: - episode_reward += rew.float() - - num_frames += 1 - - dones = dones.cpu().numpy() - for agent_i, done_flag in enumerate(dones): - if done_flag: - finished_episode[agent_i] = True - rew = episode_reward[agent_i].item() - episode_rewards[agent_i].append(rew) - - true_objective = rew - if isinstance(infos, (list, tuple)): - true_objective = infos[agent_i].get("true_objective", rew) - true_objectives[agent_i].append(true_objective) - - if verbose: - log.info( - "Episode finished for agent %d at %d frames. Reward: %.3f, true_objective: %.3f", - agent_i, - num_frames, - episode_reward[agent_i], - true_objectives[agent_i][-1], - ) - rnn_states[agent_i] = torch.zeros([get_rnn_size(cfg)], dtype=torch.float32, device=device) - episode_reward[agent_i] = 0 - - if cfg.use_record_episode_statistics: - # we want the scores from the full episode not a single agent death - # (due to EpisodicLifeEnv wrapper) - if "episode" in infos[agent_i].keys(): - num_episodes += 1 - reward_list.append(infos[agent_i]["episode"]["r"]) - else: - num_episodes += 1 - reward_list.append(true_objective) - - # if episode terminated synchronously for all agents, pause a bit before starting a new one - if all(dones): - render_frame(cfg, env, video_frames, num_episodes, last_render_start) - time.sleep(0.05) - - if all(finished_episode): - dataset_image_observations.append(np.squeeze(np.array(ep_image_observations), axis=1)) - dataset_discrete_actions.append(np.squeeze(np.array(ep_discrete_actions).astype(np.int64), axis=1)) - dataset_rewards.append(np.array(ep_rewards).astype(np.float32)) + # store s in buffer + ep_image_observations.append(Image.fromarray(np.transpose(obs["obs"][0].cpu().numpy(), (1, 2, 0)))) + + obs, rew, terminated, truncated, infos = env.step([actions]) + + done = make_dones(terminated, truncated).item() + + # store a,r, d in buffer + ep_rewards.append(rew.item()) + ep_discrete_actions.append(actions.item()) + + num_frames += 1 + + if done: # fictious done + rnn_states[0] = torch.zeros([get_rnn_size(cfg)], dtype=torch.float32, device=device) + + if infos[0]["terminated"].item(): + image_observations.append(ep_image_observations) + discrete_actions.append(np.array(ep_discrete_actions).astype(np.int64)) + rewards.append(np.array(ep_rewards).astype(np.float32)) ep_image_observations = [] ep_discrete_actions = [] ep_rewards = [] - finished_episode = [False] * env.num_agents - avg_episode_rewards_str, avg_true_objective_str = "", "" - for agent_i in range(env.num_agents): - avg_rew = np.mean(episode_rewards[agent_i]) - avg_true_obj = np.mean(true_objectives[agent_i]) - - if not np.isnan(avg_rew): - if avg_episode_rewards_str: - avg_episode_rewards_str += ", " - avg_episode_rewards_str += f"#{agent_i}: {avg_rew:.3f}" - if not np.isnan(avg_true_obj): - if avg_true_objective_str: - avg_true_objective_str += ", " - avg_true_objective_str += f"#{agent_i}: {avg_true_obj:.3f}" - - log.info( - "Avg episode rewards: %s, true rewards: %s", avg_episode_rewards_str, avg_true_objective_str - ) - log.info( - "Avg episode reward: %.3f, avg true_objective: %.3f", - np.mean([np.mean(episode_rewards[i]) for i in range(env.num_agents)]), - np.mean([np.mean(true_objectives[i]) for i in range(env.num_agents)]), - ) - - if num_episodes >= cfg.max_num_episodes: - break + log.info(f"Episode rewards: {np.sum(rewards[-1]):.3f}") env.close() @@ -232,13 +141,30 @@ def max_frames_reached(frames): task = "kungfumaster" if task == "kongfumaster" else task task = "montezumarevenge" if task == "montezuma" else task task = "privateeye" if task == "privateye" else task - add_dataset_to_hub( - "atari", - task, - image_observations=dataset_image_observations, - discrete_actions=dataset_discrete_actions, - rewards=dataset_rewards, - push_to_hub=cfg.push_to_hub, + d = { + "image_observations": image_observations, + "discrete_actions": discrete_actions, + "rewards": rewards, + } + features = datasets.Features( + { + "image_observations": datasets.Sequence(datasets.Image()), + "discrete_actions": datasets.Sequence(datasets.Value("int64")), + "rewards": datasets.Sequence(datasets.Value("float32")), + } + ) + + ds = [ + Dataset.from_dict({k: [v[idx]] for k, v in d.items()}, features=features) + for idx in range(len(d["image_observations"])) + ] + dataset = concatenate_datasets(ds) + dataset = dataset.train_test_split(test_size=0.1, writer_batch_size=1) + HfApi().create_branch("gia-project/gia-dataset-parquet", branch="new_breakout", exist_ok=True, repo_type="dataset") + dataset.push_to_hub( + "gia-project/gia-dataset-parquet", + config_name=f"atari-{task}", + branch="new_breakout", ) diff --git a/data/envs/babyai/bot_agent.py b/data/envs/babyai/bot_agent.py index 3432d2c2..392e4e2d 100644 --- a/data/envs/babyai/bot_agent.py +++ b/data/envs/babyai/bot_agent.py @@ -742,9 +742,9 @@ def _breadth_first_search(self, initial_states, accept_fn, ignore_blockers): # Location to which the bot can get without turning # are put in the queue first - for k, l in [(di, dj), (dj, di), (-dj, -di), (-di, -dj)]: - next_pos = (i + k, j + l) - next_dir_vec = (k, l) + for k, m in [(di, dj), (dj, di), (-dj, -di), (-di, -dj)]: + next_pos = (i + k, j + m) + next_dir_vec = (k, m) next_state = (*next_pos, *next_dir_vec) queue.append((next_state, (i, j))) @@ -819,8 +819,8 @@ def match_unblock(pos, cell): # We want to ensure that empty cells are connected, and that one can reach # any object cell from any other object cell. cell_class = [] - for k, l in [(-1, -1), (0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0)]: - nb_pos = (i + k, j + l) + for k, m in [(-1, -1), (0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0)]: + nb_pos = (i + k, j + m) cell = grid.get(*nb_pos) # completely blocked if self.vis_mask[nb_pos] and cell and cell.type == "wall": diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py index a166ea2c..bd3e0d2d 100644 --- a/gia/eval/rl/envs/core.py +++ b/gia/eval/rl/envs/core.py @@ -203,6 +203,7 @@ def make_atari(task_name: str, episodic_life: bool = True, clip_reward: bool = T if task_name == "atari-montezumarevenge": kwargs["max_episode_steps"] = 18_000 env = gym.make(TASK_NAME_TO_ENV_ID[task_name], **kwargs) + env.metadata["render_fps"] = 30 env = gym.wrappers.RecordEpisodeStatistics(env) env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) diff --git a/gia/eval/rl/scores_dict.json b/gia/eval/rl/scores_dict.json index de059790..b8668b10 100644 --- a/gia/eval/rl/scores_dict.json +++ b/gia/eval/rl/scores_dict.json @@ -7,6 +7,9 @@ "random": { "mean": 205.5, "std": 111.9676284790039 + }, + "human": { + "mean": 7127.7 } }, "atari-amidar": { @@ -17,6 +20,9 @@ "random": { "mean": 2.380000114440918, "std": 2.50111985206604 + }, + "human": { + "mean": 1719.5 } }, "atari-assault": { @@ -27,6 +33,9 @@ "random": { "mean": 262.5, "std": 89.6136703491211 + }, + "human": { + "mean": 742.0 } }, "atari-asterix": { @@ -37,6 +46,9 @@ "random": { "mean": 213.5, "std": 110.87267303466797 + }, + "human": { + "mean": 8503.3 } }, "atari-asteroids": { @@ -47,6 +59,9 @@ "random": { "mean": 856.4000244140625, "std": 434.3236083984375 + }, + "human": { + "mean": 47388.7 } }, "atari-atlantis": { @@ -57,6 +72,9 @@ "random": { "mean": 17764.0, "std": 6662.42529296875 + }, + "human": { + "mean": 29028.1 } }, "atari-bankheist": { @@ -67,6 +85,9 @@ "random": { "mean": 13.399999618530273, "std": 11.06525993347168 + }, + "human": { + "mean": 753.1 } }, "atari-battlezone": { @@ -77,6 +98,9 @@ "random": { "mean": 2170.0, "std": 2121.579345703125 + }, + "human": { + "mean": 37187.5 } }, "atari-beamrider": { @@ -87,6 +111,9 @@ "random": { "mean": 357.2799987792969, "std": 143.96542358398438 + }, + "human": { + "mean": 16926.5 } }, "atari-berzerk": { @@ -97,6 +124,9 @@ "random": { "mean": 160.10000610351562, "std": 118.86543273925781 + }, + "human": { + "mean": 2630.4 } }, "atari-bowling": { @@ -107,6 +137,9 @@ "random": { "mean": 23.809999465942383, "std": 6.074034690856934 + }, + "human": { + "mean": 160.7 } }, "atari-boxing": { @@ -117,6 +150,9 @@ "random": { "mean": 0.5199999809265137, "std": 4.36916446685791 + }, + "human": { + "mean": 12.1 } }, "atari-breakout": { @@ -127,6 +163,9 @@ "random": { "mean": 1.2400000095367432, "std": 1.2970736026763916 + }, + "human": { + "mean": 30.5 } }, "atari-centipede": { @@ -137,6 +176,9 @@ "random": { "mean": 2150.06005859375, "std": 1113.2806396484375 + }, + "human": { + "mean": 12017.0 } }, "atari-choppercommand": { @@ -147,6 +189,9 @@ "random": { "mean": 875.0, "std": 416.98321533203125 + }, + "human": { + "mean": 7387.8 } }, "atari-crazyclimber": { @@ -157,6 +202,9 @@ "random": { "mean": 7376.0, "std": 2253.092041015625 + }, + "human": { + "mean": 35829.4 } }, "atari-defender": { @@ -167,6 +215,9 @@ "random": { "mean": 3417.5, "std": 1443.4051513671875 + }, + "human": { + "mean": 18688.9 } }, "atari-demonattack": { @@ -177,6 +228,9 @@ "random": { "mean": 165.5500030517578, "std": 92.92710876464844 + }, + "human": { + "mean": 1971.0 } }, "atari-doubledunk": { @@ -187,6 +241,9 @@ "random": { "mean": -18.540000915527344, "std": 3.0705699920654297 + }, + "human": { + "mean": 16.4 } }, "atari-enduro": { @@ -197,6 +254,9 @@ "random": { "mean": 0.0, "std": 0.0 + }, + "human": { + "mean": 860.5 } }, "atari-fishingderby": { @@ -207,6 +267,9 @@ "random": { "mean": -93.9000015258789, "std": 3.514256477355957 + }, + "human": { + "mean": 38.7 } }, "atari-freeway": { @@ -217,6 +280,9 @@ "random": { "mean": 0.009999999776482582, "std": 0.09949874877929688 + }, + "human": { + "mean": 29.6 } }, "atari-frostbite": { @@ -227,6 +293,9 @@ "random": { "mean": 67.5999984741211, "std": 37.606380462646484 + }, + "human": { + "mean": 4334.7 } }, "atari-gopher": { @@ -237,6 +306,9 @@ "random": { "mean": 319.3999938964844, "std": 228.23594665527344 + }, + "human": { + "mean": 2412.5 } }, "atari-gravitar": { @@ -247,6 +319,9 @@ "random": { "mean": 188.5, "std": 203.32916259765625 + }, + "human": { + "mean": 3351.4 } }, "atari-hero": { @@ -257,6 +332,9 @@ "random": { "mean": 475.25, "std": 894.95263671875 + }, + "human": { + "mean": 30826.4 } }, "atari-icehockey": { @@ -267,6 +345,9 @@ "random": { "mean": -9.829999923706055, "std": 3.240540027618408 + }, + "human": { + "mean": 0.9 } }, "atari-jamesbond": { @@ -277,6 +358,9 @@ "random": { "mean": 28.5, "std": 45.41750717163086 + }, + "human": { + "mean": 302.8 } }, "atari-kangaroo": { @@ -287,6 +371,9 @@ "random": { "mean": 52.0, "std": 108.1480484008789 + }, + "human": { + "mean": 3035.0 } }, "atari-krull": { @@ -297,6 +384,9 @@ "random": { "mean": 1754.0, "std": 583.5597534179688 + }, + "human": { + "mean": 2665.5 } }, "atari-kungfumaster": { @@ -307,6 +397,9 @@ "random": { "mean": 390.0, "std": 359.0264587402344 + }, + "human": { + "mean": 22736.3 } }, "atari-montezumarevenge": { @@ -317,6 +410,9 @@ "random": { "mean": 0.0, "std": 0.0 + }, + "human": { + "mean": 4753.3 } }, "atari-mspacman": { @@ -327,6 +423,9 @@ "random": { "mean": 246.39999389648438, "std": 121.22309112548828 + }, + "human": { + "mean": 6951.6 } }, "atari-namethisgame": { @@ -337,6 +436,9 @@ "random": { "mean": 2447.39990234375, "std": 888.9675903320312 + }, + "human": { + "mean": 8049.0 } }, "atari-phoenix": { @@ -347,6 +449,9 @@ "random": { "mean": 776.7999877929688, "std": 635.8551025390625 + }, + "human": { + "mean": 7242.6 } }, "atari-pitfall": { @@ -357,6 +462,9 @@ "random": { "mean": -259.75, "std": 384.2554016113281 + }, + "human": { + "mean": 6463.7 } }, "atari-pong": { @@ -367,6 +475,9 @@ "random": { "mean": -20.219999313354492, "std": 0.9547774791717529 + }, + "human": { + "mean": 14.6 } }, "atari-privateeye": { @@ -377,6 +488,9 @@ "random": { "mean": 41.650001525878906, "std": 191.82740783691406 + }, + "human": { + "mean": 69571.3 } }, "atari-qbert": { @@ -387,6 +501,9 @@ "random": { "mean": 164.25, "std": 151.7915802001953 + }, + "human": { + "mean": 13455.0 } }, "atari-riverraid": { @@ -397,6 +514,9 @@ "random": { "mean": 1474.4000244140625, "std": 314.5864562988281 + }, + "human": { + "mean": 17118.0 } }, "atari-roadrunner": { @@ -407,6 +527,9 @@ "random": { "mean": 11.0, "std": 42.178192138671875 + }, + "human": { + "mean": 7845.0 } }, "atari-robotank": { @@ -417,6 +540,9 @@ "random": { "mean": 1.8700000047683716, "std": 1.5852760076522827 + }, + "human": { + "mean": 11.9 } }, "atari-seaquest": { @@ -427,6 +553,9 @@ "random": { "mean": 73.19999694824219, "std": 57.91166305541992 + }, + "human": { + "mean": 42054.7 } }, "atari-skiing": { @@ -437,6 +566,9 @@ "random": { "mean": -16299.51953125, "std": 1850.6959228515625 + }, + "human": { + "mean": 4336.9 } }, "atari-solaris": { @@ -447,6 +579,9 @@ "random": { "mean": 2360.39990234375, "std": 1852.0279541015625 + }, + "human": { + "mean": 12326.7 } }, "atari-spaceinvaders": { @@ -457,6 +592,9 @@ "random": { "mean": 137.1999969482422, "std": 95.818359375 + }, + "human": { + "mean": 1668.7 } }, "atari-stargunner": { @@ -467,6 +605,9 @@ "random": { "mean": 652.0, "std": 312.2434997558594 + }, + "human": { + "mean": 10250.0 } }, "atari-surround": { @@ -477,6 +618,9 @@ "random": { "mean": -9.989999771118164, "std": 0.09949874877929688 + }, + "human": { + "mean": 6.5 } }, "atari-tennis": { @@ -487,6 +631,9 @@ "random": { "mean": -23.950000762939453, "std": 0.21794496476650238 + }, + "human": { + "mean": 8.3 } }, "atari-timepilot": { @@ -497,6 +644,9 @@ "random": { "mean": 3396.0, "std": 2128.845458984375 + }, + "human": { + "mean": 5229.2 } }, "atari-tutankham": { @@ -507,6 +657,9 @@ "random": { "mean": 12.729999542236328, "std": 17.39991569519043 + }, + "human": { + "mean": 167.6 } }, "atari-upndown": { @@ -517,6 +670,9 @@ "random": { "mean": 358.8999938964844, "std": 380.1102294921875 + }, + "human": { + "mean": 11693.2 } }, "atari-venture": { @@ -527,6 +683,9 @@ "random": { "mean": 0.0, "std": 0.0 + }, + "human": { + "mean": 1187.5 } }, "atari-videopinball": { @@ -537,6 +696,9 @@ "random": { "mean": 23917.169921875, "std": 19449.591796875 + }, + "human": { + "mean": 17667.9 } }, "atari-wizardofwor": { @@ -547,6 +709,9 @@ "random": { "mean": 620.0, "std": 837.8544311523438 + }, + "human": { + "mean": 4756.5 } }, "atari-yarsrevenge": { @@ -557,6 +722,9 @@ "random": { "mean": 3503.909912109375, "std": 906.144775390625 + }, + "human": { + "mean": 54576.9 } }, "atari-zaxxon": { @@ -567,6 +735,9 @@ "random": { "mean": 21.0, "std": 102.26924896240234 + }, + "human": { + "mean": 9173.3 } }, "babyai-action-obj-door": { diff --git a/gia2/__init__.py b/gia2/__init__.py new file mode 100644 index 00000000..1fa62ac8 --- /dev/null +++ b/gia2/__init__.py @@ -0,0 +1,5 @@ +from .configuration_gia2 import Gia2Config +from .modeling_gia2 import Gia2Model + + +__all__ = ["Gia2Model", "Gia2Config"] diff --git a/gia2/configuration_gia2.py b/gia2/configuration_gia2.py new file mode 100644 index 00000000..83ef51a6 --- /dev/null +++ b/gia2/configuration_gia2.py @@ -0,0 +1,126 @@ +from transformers import GPTNeoConfig + + +class Gia2Config(GPTNeoConfig): + r""" + This is the configuration class to store the configuration of a [`Gia2Model`]. It is used to instantiate a Gia2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with + the defaults will yield a similar configuration to that of the ... (TODO) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT Neo model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPTNeoModel`]. Vocabulary size of the model. Defines the different + tokens that can be represented by the *inputs_ids* passed to the forward method of [`GPTNeoModel`]. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_size (`int`, *optional*, defaults to 2048): + Dimensionality of the encoder layers and the pooler layer. + num_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + attention_types (`List`, *optional*, defaults to `[[["global", "local"], 12]]`): + The type of attention for each layer in a `List` of the following format `[[["attention_type"], + num_layerss]]` e.g. for a 24 layer model `[[["global"], 24]]` or `[[["global", "local"], 12]]` Choose the + value of `attention_type` from `["global", "local"]` + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + window_size (`int`, *optional*, defaults to 256): + The size of the sliding window for local attention. + activation_function (`str` or `function`, *optional*, defaults to `"gelu_new"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + resid_dropout (`float`, *optional*, defaults to 0.0): + Residual dropout used in the attention pattern. + embed_dropout (`float`, *optional*, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + classifier_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing token classification, used in the model [`GPTNeoForTokenClassification`]. The + dropout ratio for the hidden layer. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + bos_token_id (`int`, *optional*, defaults to 50256): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 50256): + The id of the end of sentence token in the vocabulary. + max_continuous_size (`int`, *optional*, default to 376): + The maximum size of the continuous values. + max_discrete_value (`int`, *optional*, default to 18): + The maximum value of the discrete values. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + """ + + model_type = "gia2" + + def __init__( + self, + vocab_size=50257, + max_position_embeddings=2048, + hidden_size=2048, + num_layers=24, + attention_types=[[["global", "local"], 12]], + num_heads=16, + intermediate_size=None, + window_size=256, + activation_function="gelu_new", + resid_dropout=0.0, + embed_dropout=0.0, + attention_dropout=0.0, + classifier_dropout=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + max_continuous_size=377, + max_discrete_value=18, + image_size=224, + num_channels=3, + patch_size=16, + **kwargs, + ): + super().__init__( + vocab_size, + max_position_embeddings, + hidden_size, + num_layers, + attention_types, + num_heads, + intermediate_size, + window_size, + activation_function, + resid_dropout, + embed_dropout, + attention_dropout, + classifier_dropout, + layer_norm_epsilon, + initializer_range, + use_cache, + bos_token_id, + eos_token_id, + **kwargs, + ) + self.max_continuous_size = max_continuous_size + self.max_discrete_value = max_discrete_value + self.image_size = image_size + self.num_channels = num_channels + self.patch_size = patch_size + + +Gia2Config.register_for_auto_class() diff --git a/gia2/modeling_gia2.py b/gia2/modeling_gia2.py new file mode 100644 index 00000000..cc7a5acd --- /dev/null +++ b/gia2/modeling_gia2.py @@ -0,0 +1,822 @@ +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from gymnasium import spaces +from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn +from transformers import GPTNeoModel, GPTNeoPreTrainedModel +from transformers.modeling_outputs import ModelOutput +from transformers.models.vit.modeling_vit import ViTPatchEmbeddings + +from .configuration_gia2 import Gia2Config +from .processing_gia2 import Gia2Processor + + +def compute_mse_loss( + predicted: FloatTensor, true: FloatTensor, mask: Optional[BoolTensor], weights: Optional[FloatTensor] = None +) -> FloatTensor: + """ + Compute the Mean Squared Error (MSE) loss between predicted and true observations, considering valid timesteps. + + Args: + predicted (`FloatTensor` of shape `(batch_size, max_seq_len, ...)`): + Predicted observations at the output of the model. + true (`FloatTensor` of shape `(batch_size, max_seq_len, ...)`): + Ground truth observations. + mask (`BoolTensor` of shape `(batch_size, max_seq_len)`, *optional*): + Boolean mask indicating valid timesteps. + weights (`FloatTensor` of shape `(batch_size, max_seq_len)`, *optional*): + Weights to be applied to the loss. + + Returns: + loss (`FloatTensor` of shape `(,)`): + MSE loss between predicted and true observations. + """ + # Compute element-wise MSE loss + loss = F.mse_loss(predicted, true, reduction="none") + + # Average the loss over all dimensions after the second one + for dim in reversed(range(2, loss.dim())): + loss = loss.mean(dim=dim) + + # Use the mask to zero out invalid entries + if mask is not None: + loss = loss * mask + + # Apply weights if provided + if weights is not None: + loss = loss * weights + + # Sum the loss and normalize by the number of valid elements + loss = loss.sum() / mask.sum() if mask is not None else loss.mean() + + return loss + + +def compute_ce_loss( + predicted: FloatTensor, true: torch.LongTensor, mask: Optional[BoolTensor], weights: Optional[FloatTensor] = None +) -> FloatTensor: + """ + Compute the Cross Entropy (CE) loss between predicted logits and true class labels, considering valid timesteps. + + Args: + predicted (`FloatTensor` of shape `(batch_size, max_seq_len, num_classes)`): + Predicted logits at the output of the model. + true (`torch.LongTensor` of shape `(batch_size, max_seq_len)`): + Ground truth class labels. + mask (`BoolTensor` of shape `(batch_size, max_seq_len)`, *optional*): + Boolean mask indicating valid timesteps. + weights (`FloatTensor` of shape `(batch_size, max_seq_len)`, *optional*): + Weights to be applied to the loss. + + Returns: + loss (`FloatTensor` of shape `(,)`): + CE loss between predicted logits and true class labels. + """ + + # Compute element-wise CE loss + loss = F.cross_entropy(predicted.view(-1, predicted.size(-1)), true.view(-1), reduction="none") + loss = loss.view(true.size()) + + # Use the mask to zero out invalid entries + if mask is not None: + loss = loss * mask + + # Apply weights if provided + if weights is not None: + loss = loss * weights + + # Sum the loss and normalize by the number of valid elements + loss = loss.sum() / mask.sum() if mask is not None else loss.mean() + + return loss + + +def cyclic_expand_dim(tensor: Tensor, expanded_dim_size: int) -> Tensor: + """ + Expands the last dimension of a tensor cyclically to a specified size. + + Args: + tensor (`torch.Tensor` of shape `(batch_size, seq_len, ...)`): + Input tensor whose last dimension is to be expanded cyclically. + expanded_dim_size (`int`): + The desired size of the last dimension after expansion. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, expanded_dim_size)`: + A tensor with its last dimension expanded cyclically to the specified size. + + Examples: + >>> tensor = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + >>> cyclic_expand_dim(tensor, 5) + tensor([[[1, 2, 1, 2, 1], [3, 4, 3, 4, 3]], [[5, 6, 5, 6, 5], [7, 8, 7, 8, 7]]]) + """ + B, L, X = tensor.shape + if expanded_dim_size < X: + raise ValueError( + f"Expanded dimension size ({expanded_dim_size}) must be greater than the original dimension size ({X})." + ) + indices = torch.arange(expanded_dim_size) % X + return tensor[..., indices] + + +class ResidualBlock(nn.Module): + """ + A residual block module that consists of two convolutional layers with a residual connection. + + Args: + in_shape (`Tuple[int, int, int]`): + Shape of the input tensor. + out_channels (`int`): + Number of output channels. + + Returns: + `torch.Tensor` of shape `(batch_size, out_channels, in_shape[1], in_shape[2])`: + Output tensor. + """ + + def __init__(self, in_shape: Tuple[int, int, int], out_channels: int) -> None: + super().__init__() + out_shape = (out_channels, in_shape[1], in_shape[2]) + + self.conv1 = nn.Conv2d(in_shape[0], out_channels, kernel_size=3, stride=1, padding=1) + self.norm1 = nn.LayerNorm(out_shape) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.LayerNorm(out_shape) + + # Handling the change in dimensions with a 1x1 convolution + self.shortcut = nn.Sequential( + nn.Conv2d(in_shape[0], out_channels, kernel_size=1, stride=1), nn.LayerNorm(out_shape) + ) + + def forward(self, x: FloatTensor) -> FloatTensor: + out = F.leaky_relu(self.norm1(self.conv1(x))) + out = self.norm2(self.conv2(out)) + out += self.shortcut(x) + return F.leaky_relu(out, inplace=True) + + +class AttentionLayer(nn.Module): + """ + Attention layer that applies an attention mechanism to the input tensor. + + Args: + num_channels (`int`): + Number of channels. + + Returns: + `torch.Tensor`: + Output tensor of the same shape as the input tensor. + """ + + def __init__(self, num_channels: int) -> None: + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(num_channels, num_channels // 8, bias=False), + nn.ReLU(inplace=True), + nn.Linear(num_channels // 8, num_channels, bias=False), + nn.Sigmoid(), + ) + + def forward(self, x: FloatTensor) -> FloatTensor: + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +class ImageEncoder(nn.Module): + """ + Image encoder that encodes a batch of images. + + Args: + hidden_size (`int`): + Size of the output hidden state. + + Returns: + `torch.Tensor` of shape `(batch_size, hidden_size)`: + Output tensor. + """ + + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.conv1 = nn.Conv2d(4, 32, kernel_size=3, stride=2, padding=1) # 42x42 + self.norm1 = nn.InstanceNorm2d(32) + self.att1 = AttentionLayer(32) + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) # 21x21 + self.norm2 = nn.InstanceNorm2d(64) + self.att2 = AttentionLayer(64) + self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) # 11x11 + self.norm3 = nn.InstanceNorm2d(128) + self.att3 = AttentionLayer(128) + self.fc = nn.Linear(128 * 11 * 11, hidden_size) # Adjusted to the new spatial dimension + + def forward(self, x: FloatTensor) -> FloatTensor: + x = F.leaky_relu(self.norm1(self.conv1(x)), inplace=True) + x = self.att1(x) + x = F.leaky_relu(self.norm2(self.conv2(x)), inplace=True) + x = self.att2(x) + x = F.leaky_relu(self.norm3(self.conv3(x)), inplace=True) + x = self.att3(x) + x = x.view(x.size(0), -1) # Flatten the tensor + x = self.fc(x) + return x + + +class ImageDecoder(nn.Module): + """ + Image decoder that decodes a batch of encoded representations. + + Args: + hidden_size (`int`): + Size of the input hidden state. + + Returns: + `torch.Tensor` of shape `(batch_size, 4, 84, 84)`: + Output tensor representing the reconstructed images. + """ + + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.fc = nn.Linear(hidden_size, 128 * 11 * 11) + self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) # 21x21 + self.norm1 = nn.InstanceNorm2d(64) + self.att1 = AttentionLayer(64) + self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) # 42x42 + self.norm2 = nn.InstanceNorm2d(32) + self.att2 = AttentionLayer(32) + self.deconv3 = nn.ConvTranspose2d(32, 4, kernel_size=3, stride=2, padding=1, output_padding=1) # 84x84 + + def forward(self, x: FloatTensor) -> FloatTensor: + x = self.fc(x) + x = x.view(x.size(0), 128, 11, 11) # Reshape to the spatial dimension of encoder's last conv layer + x = F.leaky_relu(self.norm1(self.deconv1(x)), inplace=True) # 22x22 + x = F.interpolate(x, size=(21, 21)) # 21x21 + x = self.att1(x) + x = F.leaky_relu(self.norm2(self.deconv2(x)), inplace=True) + x = self.att2(x) + x = F.tanh(self.deconv3(x), inplace=True) + return x + + +class DualBatchReshapeWrapper(nn.Module): + """ + Wrapper to make a module designed for a single batch work with a dual batch. + + Args: + module (`nn.Module`): + Module to be wrapped. + """ + + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def forward(self, x: FloatTensor) -> FloatTensor: + n1, n2 = x.shape[:2] + x = x.view(n1 * n2, *x.shape[2:]) + x = self.module(x) + x = x.view(n1, n2, *x.shape[1:]) + return x + + +@dataclass +class Gia2Output(ModelOutput): + """ + Output of the Gia2 model. + + The model can be used for both RL and NLP tasks. For RL tasks, the model takes in observations and actions + (`continuous_observations`, `discrete_actions`, etc.). For textual tasks, the model takes in a sequence of tokens + and/or images (`input_ids`, `image`). The output depends on the type of input. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + For RL input, the loss is the sum of the observation loss and the action loss. + For textual input, the causal language modeling loss. + observation_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Only returned when RL input is provided. The MSE loss between predicted and true observations for + continuous observations and the cross-entropy loss for discrete observations. + action_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Only returned when RL input is provided. The MSE loss between predicted and true actions for + continuous actions and the cross-entropy loss for discrete actions. + pred_observations (`torch.FloatTensor` of shape `(batch_size, max_seq_len, ...)`): + Only returned when RL input is provided. Predicted observations from t=1 to t=max_seq_len+1. + pred_actions (`torch.FloatTensor` of shape `(batch_size, max_seq_len, ...)`): + Only returned when RL input is provided. Predicted actions from t=0 to t=max_seq_len. When input actions + are discrete, the predicted actions are logits. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or + when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[FloatTensor] = None + observation_loss: Optional[FloatTensor] = None + action_loss: Optional[FloatTensor] = None + pred_observations: Optional[FloatTensor] = None + pred_actions: Optional[FloatTensor] = None + logits: Optional[FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None + hidden_states: Optional[Tuple[FloatTensor]] = None + attentions: Optional[Tuple[FloatTensor]] = None + + +class Gia2Model(GPTNeoPreTrainedModel): + """ + Gia2 model. + """ + + config_class = Gia2Config + + def __init__(self, config: Gia2Config) -> None: + super().__init__(config) + + vocab_size = config.vocab_size + hidden_size = config.hidden_size + max_discrete_value = config.max_discrete_value + max_continuous_size = config.max_continuous_size + + # Transformer + self.transformer = GPTNeoModel(config) + + # Encoders + self.vit_encoder = ViTPatchEmbeddings(config) + self.single_discrete_encoder = self.transformer.wte + self.continuous_encoder = nn.Linear(max_continuous_size, hidden_size) + self.multi_discrete_encoder = nn.Sequential( + self.single_discrete_encoder, # (B, L, X, H) + nn.Linear(hidden_size, hidden_size // 50), # (B, L, X, H // 50) + nn.ReLU(), + nn.Flatten(start_dim=2), # (B, L, X * (H // 50)) + nn.Linear(max_discrete_value * (hidden_size // 50), hidden_size - 1), # (B, L, H) + ) # -1 to account for the reward + self.image_encoder = DualBatchReshapeWrapper(ImageEncoder(hidden_size)) + + # Decoders + self.single_discrete_decoder = nn.Linear(hidden_size, vocab_size, bias=False) + self.continuous_decoder = nn.Linear(hidden_size, max_continuous_size) + self.multi_discrete_decoder = nn.Sequential( + nn.Linear(hidden_size, max_discrete_value * (hidden_size // 50)), # (B, L, X * (H // 50)) + nn.Unflatten(dim=2, unflattened_size=(max_discrete_value, hidden_size // 50)), # (B, L, X, H // 50) + nn.ReLU(), + nn.Linear(hidden_size // 50, hidden_size), # (B, L, X, H) + nn.ReLU(), + self.single_discrete_decoder, # (B, L, X, V) + ) + self.image_decoder = DualBatchReshapeWrapper(ImageDecoder(hidden_size)) + + # Initialize weights and apply final processing + self.post_init() + + def embed_textual( + self, + input_ids: Optional[LongTensor], + pixel_values: Optional[FloatTensor] = None, + attention_mask: Optional[BoolTensor] = None, + ) -> Tensor: + text_inputs_embeds = self.single_discrete_encoder(input_ids) if input_ids is not None else None + image_inputs_embeds = self.vit_encoder(pixel_values) if pixel_values is not None else None + # Concatenate text and image inputs + if image_inputs_embeds is not None and text_inputs_embeds is not None: + inputs_embeds = torch.cat((image_inputs_embeds, text_inputs_embeds), dim=1) + # Add attention mask for image inputs + image_mask = torch.ones(image_inputs_embeds.shape[:2], dtype=torch.bool, device=self.device) + if attention_mask is None: + attention_mask = torch.ones(text_inputs_embeds.shape[:2], dtype=torch.bool, device=self.device) + attention_mask = torch.cat((image_mask, attention_mask), dim=1) + elif image_inputs_embeds is not None: + inputs_embeds = image_inputs_embeds + elif text_inputs_embeds is not None: + inputs_embeds = text_inputs_embeds + attention_mask = attention_mask + else: + raise ValueError("At least one of `input_ids` or `pixel_values` must be provided.") + return inputs_embeds, attention_mask + + def embed_rl( + self, + continuous_observations: Optional[FloatTensor] = None, + discrete_observations: Optional[LongTensor] = None, + image_observations: Optional[FloatTensor] = None, + continuous_actions: Optional[FloatTensor] = None, + discrete_actions: Optional[LongTensor] = None, + rewards: Optional[FloatTensor] = None, + attention_mask: Optional[BoolTensor] = None, + ): + # Prepare RL inputs (pad and cat rewards to observations) + assert rewards is not None + if continuous_observations is not None: + # Modify the rewards to move from [r_1, r_2, ..., r_T] to [0, r_1, r_2, ..., r_T-1] + rewards = torch.cat((torch.zeros_like(rewards[:, :1]), rewards[:, :-1]), dim=1) + continuous_observations = torch.cat((continuous_observations, rewards.unsqueeze(-1)), dim=-1) + continuous_observations = cyclic_expand_dim(continuous_observations, self.config.max_continuous_size) + if continuous_actions is not None: + continuous_actions = cyclic_expand_dim(continuous_actions, self.config.max_continuous_size) + + # Encode + if continuous_observations is not None: + batch_size, seq_len = continuous_observations.shape[:2] + inputs_embeds_observations = self.continuous_encoder(continuous_observations) + elif discrete_observations is not None: + batch_size, seq_len = discrete_observations.shape[:2] + inputs_embeds_observations = self.multi_discrete_encoder(discrete_observations) + # Modify the rewards to move from [r_1, r_2, ..., r_T] to [0, r_1, r_2, ..., r_T-1] + rewards = torch.cat((torch.zeros_like(rewards[:, :1]), rewards[:, :-1]), dim=1) + inputs_embeds_observations = torch.cat((inputs_embeds_observations, rewards.unsqueeze(-1)), dim=-1) + elif image_observations is not None: + batch_size, seq_len = image_observations.shape[:2] + inputs_embeds_observations = self.image_encoder(image_observations) + else: + raise ValueError("Missing observations.") + if continuous_actions is not None: + inputs_embeds_actions = self.continuous_encoder(continuous_actions) + elif discrete_actions is not None: + inputs_embeds_actions = self.single_discrete_encoder(discrete_actions) + else: + raise ValueError("Missing actions.") + + # Concatenate observations and actions + inputs_embeds = torch.cat((inputs_embeds_observations, inputs_embeds_actions), dim=2) + inputs_embeds = inputs_embeds.view(batch_size, 2 * seq_len, self.config.hidden_size) + if attention_mask is not None: + attention_mask = torch.repeat_interleave(attention_mask, repeats=2, dim=1) + return inputs_embeds, attention_mask + + def output_textual( + self, + transformer_outputs, + input_ids: Optional[LongTensor] = None, + attention_mask: Optional[BoolTensor] = None, + return_loss: bool = True, + return_dict: Optional[bool] = None, + ): + hidden_states = transformer_outputs[0] + loss = None + # Get only textual hidden states + lm_logits = self.single_discrete_decoder(hidden_states) + if return_loss: + if input_ids is None: + raise ValueError("Input IDs must be provided when `return_loss=True`.") + + # Shift so that tokens < n predict n + num_text_tokens = input_ids.shape[1] + shift_logits = lm_logits[:, -num_text_tokens:-1, :].contiguous() + shift_labels = input_ids[:, 1:].contiguous() + if attention_mask is not None: + shift_attention_mask = attention_mask[:, -num_text_tokens:] + shift_attention_mask = shift_attention_mask[:, 1:] + else: + shift_attention_mask = torch.ones(shift_labels.shape, dtype=bool, device=self.device) + shift_logits = shift_logits[shift_attention_mask.bool()] + shift_labels = shift_labels[shift_attention_mask.bool()] + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Gia2Output( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def output_rl( + self, + transformer_outputs, + continuous_observations: Optional[FloatTensor] = None, + discrete_observations: Optional[LongTensor] = None, + image_observations: Optional[FloatTensor] = None, + continuous_actions: Optional[FloatTensor] = None, + discrete_actions: Optional[LongTensor] = None, + rewards: Optional[FloatTensor] = None, + attention_mask: Optional[BoolTensor] = None, + return_loss: bool = True, + return_dict: Optional[bool] = None, + loss_weight: Optional[FloatTensor] = None, + ): + hidden_states = transformer_outputs.last_hidden_state + loss, observation_loss, action_loss = None, None, None + # Observations + assert rewards is not None + observations_mask = attention_mask[:, 1::2] if attention_mask is not None else None + if continuous_observations is not None: + # Modify the rewards to move from [r_1, r_2, ..., r_T] to [0, r_1, r_2, ..., r_T-1] + obs_size = continuous_observations.shape[-1] + rewards = torch.cat((torch.zeros_like(rewards[:, :1]), rewards[:, :-1]), dim=1) + continuous_observations = torch.cat((continuous_observations, rewards.unsqueeze(-1)), dim=-1) + continuous_observations = cyclic_expand_dim(continuous_observations, self.config.max_continuous_size) + pred_observations = self.continuous_decoder(hidden_states[:, 1::2]) + if return_loss: + observation_loss = compute_mse_loss( + pred_observations[:, :-1], + continuous_observations[:, 1:], + observations_mask[:, 1:] if observations_mask is not None else None, + weights=loss_weight[:, 1:] if loss_weight is not None else None, + ) + pred_observations = pred_observations[..., :obs_size] + elif discrete_observations is not None: # Note: reward is not predicted + warnings.warn("Observations aren't predicted as it is highly memory demanding.") + pred_observations = None + observation_loss = 0.0 + # pred_observations = self.multi_discrete_decoder(hidden_states[:, 1::2]) + # if return_loss: + # observation_loss = compute_ce_loss( + # pred_observations[:, :-1], + # discrete_observations[:, 1:], + # observations_mask[:, 1:] if observations_mask is not None else None, + # weights=loss_weight[:, 1:] if loss_weight is not None else None, + # ) + elif image_observations is not None: + warnings.warn("Observations aren't predicted as it is highly memory demanding.") + pred_observations = None + observation_loss = 0.0 + # pred_observations = self.image_decoder(hidden_states[:, 1::2]) + # if return_loss: + # observation_loss = compute_mse_loss( + # pred_observations[:, :-1], + # image_observations[:, 1:], + # observations_mask[:, 1:] if observations_mask is not None else None, + # weights=loss_weight[:, 1:] if loss_weight is not None else None, + # ) + + # Actions + actions_mask = attention_mask[:, ::2] if attention_mask is not None else None + if continuous_actions is not None: + act_size = continuous_actions.shape[-1] + continuous_actions = cyclic_expand_dim(continuous_actions, self.config.max_continuous_size) + pred_actions = self.continuous_decoder(hidden_states[:, ::2]) + if return_loss: + action_loss = compute_mse_loss(pred_actions, continuous_actions, actions_mask, weights=loss_weight) + pred_actions = pred_actions[..., :act_size] + elif discrete_actions is not None: + pred_actions = self.single_discrete_decoder(hidden_states[:, ::2]) + if return_loss: + action_loss = compute_ce_loss(pred_actions, discrete_actions, actions_mask, weights=loss_weight) + + # Return output + if return_loss: + loss = 0.0 * observation_loss + 1.0 * action_loss + + if not return_dict: + output = (pred_observations, pred_actions) + transformer_outputs[1:] + return ((loss, observation_loss, action_loss) + output) if loss is not None else output + + return Gia2Output( + loss=loss, + observation_loss=observation_loss, + action_loss=action_loss, + pred_observations=pred_observations, + pred_actions=pred_actions, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def forward( + self, + input_ids: Optional[LongTensor] = None, + pixel_values: Optional[FloatTensor] = None, + continuous_observations: Optional[FloatTensor] = None, + discrete_observations: Optional[LongTensor] = None, + image_observations: Optional[FloatTensor] = None, + continuous_actions: Optional[FloatTensor] = None, + discrete_actions: Optional[LongTensor] = None, + rewards: Optional[FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None, + attention_mask: Optional[BoolTensor] = None, + token_type_ids: Optional[LongTensor] = None, + position_ids: Optional[LongTensor] = None, + return_loss: bool = True, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + loss_weight: Optional[FloatTensor] = None, + ) -> Gia2Output: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Textual tasks + if input_ids is not None or pixel_values is not None: + inputs_embeds, attention_mask = self.embed_textual(input_ids, pixel_values, attention_mask) + # RL tasks + elif ( + continuous_observations is not None or discrete_observations is not None or image_observations is not None + ): + inputs_embeds, attention_mask = self.embed_rl( + continuous_observations, + discrete_observations, + image_observations, + continuous_actions, + discrete_actions, + rewards, + attention_mask, + ) + else: + raise ValueError("Input not provided.") + + # Pass through transformer + transformer_outputs = self.transformer( + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if input_ids is not None or pixel_values is not None: + return self.output_textual(transformer_outputs, input_ids, attention_mask, return_loss, return_dict) + else: + return self.output_rl( + transformer_outputs, + continuous_observations, + discrete_observations, + image_observations, + continuous_actions, + discrete_actions, + rewards, + attention_mask, + return_loss, + return_dict, + loss_weight, + ) + + def reset_rl(self): + self._last_key_values = None + self.last_discrete_observation = None + self.last_continuous_observation = None + self.last_text_observation = None + self.last_image_observation = None + self.last_discrete_action = None + self.last_continuous_action = None + self.last_reward = None + + @torch.no_grad() + def get_next_action( + self, + processor: Gia2Processor, + continuous_observations: Optional[List[List[float]]] = None, + discrete_observations: Optional[List[List[int]]] = None, + text_observations: Optional[List[str]] = None, + image_observations: Optional[List[np.ndarray]] = None, + action_space: Union[spaces.Box, spaces.Discrete] = None, + rewards: Optional[List[float]] = None, + deterministic: bool = False, + ): + # Get the maximum sequence length + max_length = self.config.max_position_embeddings // 2 + + # Convert everything to lists + def to_list(x): + return x.tolist() if isinstance(x, np.ndarray) else x + + continuous_observations = to_list(continuous_observations) + discrete_observations = to_list(discrete_observations) + + # Add a fake action and reward to the end of the sequence + if isinstance(action_space, spaces.Box): + fake_continuous_actions = [0.0 for _ in range(action_space.shape[0])] + fake_discrete_actions = None + elif isinstance(action_space, spaces.Discrete): + fake_continuous_actions = None + fake_discrete_actions = 0 + fake_rewards = 0.0 + + continuous_observations = [continuous_observations] if continuous_observations is not None else None + discrete_observations = [discrete_observations] if discrete_observations is not None else None + text_observations = [text_observations] if text_observations is not None else None + image_observations = [image_observations] if image_observations is not None else None + continuous_actions = [fake_continuous_actions] if fake_continuous_actions is not None else None + discrete_actions = [fake_discrete_actions] if fake_discrete_actions is not None else None + rewards = [fake_rewards] + if self._last_key_values is not None: + # We concatenate the last observation with the current one + continuous_observations = ( + [self.last_continuous_observation] + continuous_observations + if continuous_observations is not None + else None + ) + discrete_observations = ( + [self.last_discrete_observation] + discrete_observations if discrete_observations is not None else None + ) + text_observations = ( + [self.last_text_observation] + text_observations if text_observations is not None else None + ) + image_observations = ( + [self.last_image_observation] + image_observations if image_observations is not None else None + ) + continuous_actions = ( + [self.last_continuous_action] + continuous_actions if continuous_actions is not None else None + ) + discrete_actions = [self.last_discrete_action] + discrete_actions if discrete_actions is not None else None + rewards = [self.last_reward] + rewards + + # Store the last observation + self.last_continuous_observation = continuous_observations[-1] if continuous_observations is not None else None + self.last_discrete_observation = discrete_observations[-1] if discrete_observations is not None else None + self.last_text_observation = text_observations[-1] if text_observations is not None else None + self.last_image_observation = image_observations[-1] if image_observations is not None else None + self.last_reward = rewards[-1] + + # Add the batch dimension + continuous_observations = [continuous_observations] if continuous_observations is not None else None + discrete_observations = [discrete_observations] if discrete_observations is not None else None + text_observations = [text_observations] if text_observations is not None else None + image_observations = [image_observations] if image_observations is not None else None + continuous_actions = [continuous_actions] if continuous_actions is not None else None + discrete_actions = [discrete_actions] if discrete_actions is not None else None + rewards = [rewards] + + # Process the inputs + processed = processor( + continuous_observations=continuous_observations, + discrete_observations=discrete_observations, + text_observations=text_observations, + image_observations=image_observations, + continuous_actions=continuous_actions, + discrete_actions=discrete_actions, + rewards=rewards, + truncation=True, + truncation_side="left", + max_length=max_length, + return_tensors="pt", + ) + processed.to(self.device) + + # Forward pass + outputs = self(**processed, past_key_values=self._last_key_values, return_loss=False) + + # Truncate the past key-values + self._last_key_values = tuple( + tuple(pkv[:, :, -self.config.max_position_embeddings + 2 :] for pkv in pkvs) + for pkvs in outputs.past_key_values + ) + # Store the last key values + # We remove the last two values, as the inputs are [s_0, 0], [s_0, a_0, s_1, 0], [s_1, a_1, s_2, 0], ... + self._last_key_values = tuple(tuple(pkv[:, :, :-2] for pkv in pkvs) for pkvs in self._last_key_values) + + # Return the predicted action + if continuous_actions is not None: + self.last_continuous_action = outputs.pred_actions[0, -1].cpu().tolist() + return self.last_continuous_action + elif discrete_actions is not None: + logits = outputs.pred_actions[0, -1, : action_space.n] + if deterministic: + self.last_discrete_action = logits.argmax().cpu().item() + else: # sample + self.last_discrete_action = torch.multinomial(logits.softmax(dim=-1), num_samples=1)[0].item() + return self.last_discrete_action + + # Allows to use .generate() + def prepare_inputs_for_generation(self, input_ids, pixel_values=None, past_key_values=None, **kwargs): + # only last token for inputs_ids if past is defined in kwargs + if past_key_values is not None: + pixel_values = None + input_ids = input_ids[:, -1].unsqueeze(-1) + + model_inputs = { + "input_ids": input_ids, + "pixel_values": pixel_values, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + } + + return model_inputs + + +Gia2Model.register_for_auto_class("AutoModelForCausalLM") diff --git a/gia2/processing_gia2.py b/gia2/processing_gia2.py new file mode 100644 index 00000000..e90c40b6 --- /dev/null +++ b/gia2/processing_gia2.py @@ -0,0 +1,371 @@ +import copy +from typing import Any, Dict, List, Optional, Union + +import torch +from torchvision.transforms.functional import to_tensor +from transformers import BatchEncoding +from transformers.processing_utils import ProcessorMixin + + +def truncate( + encoding: Dict[str, List[List[Any]]], max_length: int, truncation_side: str = "right", preserve: bool = False +) -> Dict[str, List[List[Any]]]: + """ + Truncate the sequences in the encoding to the specified maximum length. + + This function is designed to process batch of sequences represented in the encoding dictionary. + Depending on the chosen strategy, sequences are either truncated with loss of residual data or with preservation + and incorporation of residual data into the batch. + + Args: + encoding (`Mapping`): + A dictionary where each key-value pair consists of a feature name and its corresponding batch of sequences. + The sequences are expected to be lists. + max_length (`int`): + The maximum allowable length for the sequences. + truncation_side (`str`, **optional**): + The strategy to use for truncation. Can be `"left"` or `"right"`. Defaults to `"right"`. + preserve (`bool`, **optional**): + Whether to preserve the residual data by adding them as new sequences in the batch. Defaults to `False`. + + Returns: + `Dict[str, List[List[Any]]]`: + A dictionary with the same keys as the input `encoding`, containing the truncated batch of sequences. + If `preserve` is set to `True`, the batch size may increase due to the addition of new sequences formed + from the residual data. + + Example: + + >>> encoding = {'feature1': [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]} + >>> truncate(encoding, 3, preserve=False) + {'feature1': [[1, 2, 3], [6, 7, 8]]} + + >>> truncate(encoding, 3, preserve=True) + {'feature1': [[1, 2, 3], [4, 5], [6, 7, 8], [9, 10]]} + """ + truncated_encoding = {} + + for key, sequences in encoding.items(): + if not all(isinstance(seq, list) for seq in sequences): + raise TypeError(f"All sequences under key {key} should be of type list.") + + truncated_sequences = [] + + for seq in sequences: + if len(seq) <= max_length: + truncated_sequences.append(seq) + continue + + if preserve: # truncate and append the residual as new sequences + if truncation_side == "right": + truncated_sequences.extend([seq[i : i + max_length] for i in range(0, len(seq), max_length)]) + elif truncation_side == "left": + n = len(seq) // max_length + int(len(seq) % max_length > 0) + low, high = len(seq) - n * max_length, len(seq) + truncated_sequences.extend( + [seq[max(0, i - max_length) : i] for i in range(high, low, -max_length)] + ) + else: + raise ValueError(f"Invalid truncation_side: {truncation_side}") + else: # simply truncate the sequence + if truncation_side == "right": + truncated_sequences.append(seq[:max_length]) + elif truncation_side == "left": + truncated_sequences.append(seq[-max_length:]) + + truncated_encoding[key] = truncated_sequences + + return truncated_encoding + + +def pad(encoding: Dict[str, List[List[Any]]], target_length: int) -> Dict[str, List[List[Any]]]: + """ + Pad the sequences in the encoding to the specified maximum length. + + This function is designed to process batch of sequences represented in the encoding dictionary. + The padding value is set to be the first element in the sequence. + + Args: + encoding (`Mapping`): + A dictionary where each key-value pair consists of a feature name and its corresponding batch of sequences. + The sequences are expected to be lists. + target_length (`int`): + The desired length for the sequences. + + Returns: + `Dict[str, List[List[Any]]]`: + A dictionary with the same keys as the input `encoding`, containing the padded batch of sequences. + An additional key `attention_mask` is added to the dictionary to indicate the positions of the non-padding + elements with 1s and the padding elements with 0s. If the input `encoding` already contains an + `attention_mask` key, the corresponding mask will be updated such that the original masking is preserved, + and the newly added padding elements will be masked with 0s. In other words, the resulting + `attention_mask` is a logical "AND" between the provided mask and the mask created due to padding, ensuring + that any element masked originally remains masked. + + Example: + + >>> encoding = {'feature1': [[1, 2], [3, 4, 5]]} + >>> pad(encoding, 4) + {'feature1': [[1, 2, 1, 1], [3, 4, 5, 3]], 'attention_mask': [[1, 1, 0, 0], [1, 1, 1, 0]]} + + >>> encoding = {'feature1': [[1, 2], [3, 4, 5]], "attention_mask": [[1, 0], [0, 1, 1]]} + >>> pad(encoding, 4) + {'feature1': [[1, 2, 1, 1], [3, 4, 5, 3]], 'attention_mask': [[1, 0, 0, 0], [0, 1, 1, 0]]} + """ + padded_encoding = {} + + for key, sequences in encoding.items(): + if not all(isinstance(seq, (list, torch.Tensor)) for seq in sequences): + raise TypeError(f"All sequences under key {key} should be of type list or tensor.") + if key == "attention_mask": # attention_mask is handled separately + continue + + padded_sequences = [] + pad_mask = [] + + for seq in sequences: + pad_len = target_length - len(seq) + padded_seq = list(seq) + [seq[0]] * max(0, pad_len) + mask = [1] * len(seq) + [0] * max(0, pad_len) + + padded_sequences.append(padded_seq) + pad_mask.append(mask) + + padded_encoding[key] = padded_sequences + + if "attention_mask" in encoding: + padded_encoding["attention_mask"] = [ + [a * (b[i] if i < len(b) else 0) for i, a in enumerate(row)] + for row, b in zip(pad_mask, encoding["attention_mask"]) + ] + else: + padded_encoding["attention_mask"] = pad_mask + + return padded_encoding + + +class Gia2Processor(ProcessorMixin): + r""" + Constructs a GIA2 processor which wraps a CLIP image processor and a BERT tokenizer into a single processor. + + [`Gia2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BertTokenizerFast`]. See the + [`~Gia2Processor.__call__`] and [`~Gia2Processor.decode`] for more information. + + Args: + image_processor ([`AutoImageProcessor`]): + The image processor is a required input. + tokenizer ([`AutoTokenizer`]): + The tokenizer is a required input. + """ + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + DONT_TRUNCATE_OR_PAD = {"pixel_values"} # Or, a better name for this would be + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + + def _truncate_and_pad( + self, + encoding: dict, + padding: Union[bool, str], + truncation: Union[bool, str], + truncation_side: str = "right", + max_length: Optional[int] = None, + ) -> dict: + # If max_length is not provided, use the maximum length accepted by the model. + if max_length is None: + max_length = self.tokenizer.model_max_length + + # Exclude keys that we don't want to truncate or pad. + excluded = {key: value for key, value in encoding.items() if key in self.DONT_TRUNCATE_OR_PAD} + encoding = {key: value for key, value in encoding.items() if key not in self.DONT_TRUNCATE_OR_PAD} + + # Apply Truncation + if truncation in [True, "lossy"]: + encoding = truncate(encoding, max_length, truncation_side, preserve=False) + elif truncation == "preserve": + encoding = truncate(encoding, max_length, truncation_side, preserve=True) + elif truncation in [False, "do_not_truncate"]: + pass + else: + raise ValueError("Invalid truncation strategy:" + str(truncation)) + + # Apply Padding + if padding in [True, "longest"]: + target_length = max(len(seq) for sequences in encoding.values() for seq in sequences) + encoding = pad(encoding, target_length) + elif padding == "max_length": + encoding = pad(encoding, max_length) + elif padding in [False, "do_not_pad"]: + pass + else: + raise ValueError("Invalid padding strategy:" + str(padding)) + + # Add back the excluded keys. + encoding.update(excluded) + + # Particular case, we handle the conversion to tensor of image_observations, as the format used + # (list of tensors) is not properly handled by the BatchEncoding class: + if "image_observations" in encoding: + encoding["image_observations"] = torch.stack([torch.stack(ep) for ep in encoding["image_observations"]]) + + return encoding + + def __call__( + self, + text=None, + images=None, + continuous_observations=None, + discrete_observations=None, + text_observations=None, + image_observations=None, + continuous_actions=None, + discrete_actions=None, + rewards=None, + return_tensors=None, + **kwargs, + ): + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, + `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + continuous_observations (`List[List[List[float]]]`): + The continuous observations or batch of continuous observations to be encoded. + discrete_observations (`List[List[List[int]]]`): + The discrete observations or batch of discrete observations to be encoded. + text_observations (`List[List[str]]`): + The text observations or batch of text observations to be encoded. + image_observations (`List[List[PIL.Image.Image]]`, `List[List[np.ndarray]]`, `List[List[torch.Tensor]]`): + The image observations or batch of image observations to be encoded. + continuous_actions (`List[List[List[float]]]`): + The continuous actions or batch of continuous actions to be encoded. + discrete_actions (``List[List[int]]`): + The discrete actions or batch of discrete actions to be encoded. + rewards (``List[List[float]]`): + The rewards or batch of rewards to be encoded. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + # we truncate and pad ourselves so we need to pass padding=False and truncation=False to the tokenizer + padding = kwargs.pop("padding", False) + truncation = kwargs.pop("truncation", False) + truncation_side = kwargs.pop("truncation_side", "right") + max_length = kwargs.pop("max_length", None) + + # Ensure that the input is batched + if text is not None and not isinstance(text, list): + text = [text] + + encoding = {} + if text is not None: + encoding["input_ids"] = self.tokenizer(text, **kwargs)["input_ids"] + if images is not None: + encoding["pixel_values"] = self.image_processor(images, **kwargs).pixel_values + if continuous_observations is not None: + encoding["continuous_observations"] = copy.deepcopy(continuous_observations) + if discrete_observations is not None: + encoding["discrete_observations"] = copy.deepcopy(discrete_observations) + if text_observations is not None: + if "discrete_observations" not in encoding: + raise ValueError("discrete_observations must be provided if text_observations is provided") + for batch_idx, sequence in enumerate(text_observations): + encoded_text = self.tokenizer(sequence, max_length=64, padding="max_length")["input_ids"] + for timestep, text_tokens in enumerate(encoded_text): + encoding["discrete_observations"][batch_idx][timestep].extend(text_tokens) + if image_observations is not None: + image_observations = [[(to_tensor(im) - 0.5) / 0.5 for im in ep] for ep in image_observations] + encoding["image_observations"] = image_observations + if continuous_actions is not None: + encoding["continuous_actions"] = copy.deepcopy(continuous_actions) + if discrete_actions is not None: + encoding["discrete_actions"] = copy.deepcopy(discrete_actions) + + if rewards is not None: + encoding["rewards"] = copy.deepcopy(rewards) + + # Handle image+text case, need to reduce the max_len as the image and text will be concatenated + if text is not None and images is not None: + if max_length is None: + max_length = self.tokenizer.model_max_length + max_length -= (224 // 16) ** 2 # substract the number of image tokens + elif ( + continuous_observations is not None + or discrete_observations is not None + or text_observations is not None + or image_observations is not None + ): + if max_length is None: + max_length = self.tokenizer.model_max_length + max_length //= 2 # observations and actions are interleaved + + encoding = self._truncate_and_pad(encoding, padding, truncation, truncation_side, max_length) + + return BatchEncoding(encoding, tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def pad(self, *args, **kwargs): + inputs = {key: [arg[key] for arg in args[0]] for key in args[0][0].keys()} + encoding = self._truncate_and_pad( + inputs, padding=kwargs.get("padding", False), truncation=False, max_length=kwargs.get("max_length") + ) + return BatchEncoding(encoding, tensor_type=kwargs.get("return_tensors")) + + @property + def model_input_names(self): + return [ + "input_ids", + "attention_mask", + "pixel_values", + "continuous_observations", + "discrete_observations", + "image_observations", + "continuous_actions", + "discrete_actions", + "rewards", + ] + + +Gia2Processor.register_for_auto_class("AutoProcessor") diff --git a/gia2/utils.py b/gia2/utils.py new file mode 100644 index 00000000..be114a8a --- /dev/null +++ b/gia2/utils.py @@ -0,0 +1,470 @@ +import json +import os +import random +import sys +import tempfile +from contextlib import contextmanager +from typing import Dict, List, Optional + +import cv2 +import numpy as np +from datasets import IterableDataset +from huggingface_hub import EvalResult, HfApi, ModelCard, ModelCardData +from transformers import PreTrainedModel, ProcessorMixin + + +PRETTY_TASK_NAMES = { + "atari-alien": "ALE/Alien-v5", + "atari-amidar": "ALE/Amidar-v5", + "atari-assault": "ALE/Assault-v5", + "atari-asterix": "ALE/Asterix-v5", + "atari-asteroids": "ALE/Asteroids-v5", + "atari-atlantis": "ALE/Atlantis-v5", + "atari-bankheist": "ALE/BankHeist-v5", + "atari-battlezone": "ALE/BattleZone-v5", + "atari-beamrider": "ALE/BeamRider-v5", + "atari-berzerk": "ALE/Berzerk-v5", + "atari-bowling": "ALE/Bowling-v5", + "atari-boxing": "ALE/Boxing-v5", + "atari-breakout": "ALE/Breakout-v5", + "atari-centipede": "ALE/Centipede-v5", + "atari-choppercommand": "ALE/ChopperCommand-v5", + "atari-crazyclimber": "ALE/CrazyClimber-v5", + "atari-defender": "ALE/Defender-v5", + "atari-demonattack": "ALE/DemonAttack-v5", + "atari-doubledunk": "ALE/DoubleDunk-v5", + "atari-enduro": "ALE/Enduro-v5", + "atari-fishingderby": "ALE/FishingDerby-v5", + "atari-freeway": "ALE/Freeway-v5", + "atari-frostbite": "ALE/Frostbite-v5", + "atari-gopher": "ALE/Gopher-v5", + "atari-gravitar": "ALE/Gravitar-v5", + "atari-hero": "ALE/Hero-v5", + "atari-icehockey": "ALE/IceHockey-v5", + "atari-jamesbond": "ALE/Jamesbond-v5", + "atari-kangaroo": "ALE/Kangaroo-v5", + "atari-krull": "ALE/Krull-v5", + "atari-kungfumaster": "ALE/KungFuMaster-v5", + "atari-montezumarevenge": "ALE/MontezumaRevenge-v5", + "atari-mspacman": "ALE/MsPacman-v5", + "atari-namethisgame": "ALE/NameThisGame-v5", + "atari-phoenix": "ALE/Phoenix-v5", + "atari-pitfall": "ALE/Pitfall-v5", + "atari-pong": "ALE/Pong-v5", + "atari-privateeye": "ALE/PrivateEye-v5", + "atari-qbert": "ALE/Qbert-v5", + "atari-riverraid": "ALE/Riverraid-v5", + "atari-roadrunner": "ALE/RoadRunner-v5", + "atari-robotank": "ALE/Robotank-v5", + "atari-seaquest": "ALE/Seaquest-v5", + "atari-skiing": "ALE/Skiing-v5", + "atari-solaris": "ALE/Solaris-v5", + "atari-spaceinvaders": "ALE/SpaceInvaders-v5", + "atari-stargunner": "ALE/StarGunner-v5", + "atari-surround": "ALE/Surround-v5", + "atari-tennis": "ALE/Tennis-v5", + "atari-timepilot": "ALE/TimePilot-v5", + "atari-tutankham": "ALE/Tutankham-v5", + "atari-upndown": "ALE/UpNDown-v5", + "atari-venture": "ALE/Venture-v5", + "atari-videopinball": "ALE/VideoPinball-v5", + "atari-wizardofwor": "ALE/WizardOfWor-v5", + "atari-yarsrevenge": "ALE/YarsRevenge-v5", + "atari-zaxxon": "ALE/Zaxxon-v5", + "babyai-action-obj-door": "BabyAI-ActionObjDoor-v0", + "babyai-blocked-unlock-pickup": "BabyAI-BlockedUnlockPickup-v0", + "babyai-boss-level-no-unlock": "BabyAI-BossLevelNoUnlock-v0", + "babyai-boss-level": "BabyAI-BossLevel-v0", + "babyai-find-obj-s5": "BabyAI-FindObjS5-v0", + "babyai-go-to-door": "BabyAI-GoToDoor-v0", + "babyai-go-to-imp-unlock": "BabyAI-GoToImpUnlock-v0", + "babyai-go-to-local": "BabyAI-GoToLocal-v0", + "babyai-go-to-obj-door": "BabyAI-GoToObjDoor-v0", + "babyai-go-to-obj": "BabyAI-GoToObj-v0", + "babyai-go-to-red-ball-grey": "BabyAI-GoToRedBallGrey-v0", + "babyai-go-to-red-ball-no-dists": "BabyAI-GoToRedBallNoDists-v0", + "babyai-go-to-red-ball": "BabyAI-GoToRedBall-v0", + "babyai-go-to-red-blue-ball": "BabyAI-GoToRedBlueBall-v0", + "babyai-go-to-seq": "BabyAI-GoToSeq-v0", + "babyai-go-to": "BabyAI-GoTo-v0", + "babyai-key-corridor": "BabyAI-KeyCorridor-v0", + "babyai-mini-boss-level": "BabyAI-MiniBossLevel-v0", + "babyai-move-two-across-s8n9": "BabyAI-MoveTwoAcrossS8N9-v0", + "babyai-one-room-s8": "BabyAI-OneRoomS8-v0", + "babyai-open-door": "BabyAI-OpenDoor-v0", + "babyai-open-doors-order-n4": "BabyAI-OpenDoorsOrderN4-v0", + "babyai-open-red-door": "BabyAI-OpenRedDoor-v0", + "babyai-open-two-doors": "BabyAI-OpenTwoDoors-v0", + "babyai-open": "BabyAI-Open-v0", + "babyai-pickup-above": "BabyAI-PickupAbove-v0", + "babyai-pickup-dist": "BabyAI-PickupDist-v0", + "babyai-pickup-loc": "BabyAI-PickupLoc-v0", + "babyai-pickup": "BabyAI-Pickup-v0", + "babyai-put-next-local": "BabyAI-PutNextLocal-v0", + "babyai-put-next": "BabyAI-PutNextS7N4-v0", + "babyai-synth-loc": "BabyAI-SynthLoc-v0", + "babyai-synth-seq": "BabyAI-SynthSeq-v0", + "babyai-synth": "BabyAI-Synth-v0", + "babyai-unblock-pickup": "BabyAI-UnblockPickup-v0", + "babyai-unlock-local": "BabyAI-UnlockLocal-v0", + "babyai-unlock-pickup": "BabyAI-UnlockPickup-v0", + "babyai-unlock-to-unlock": "BabyAI-UnlockToUnlock-v0", + "babyai-unlock": "BabyAI-Unlock-v0", + "conceptual-captions": "Conceptual Captions", + "metaworld-assembly": "assembly-v2", + "metaworld-basketball": "basketball-v2", + "metaworld-bin-picking": "bin-picking-v2", + "metaworld-box-close": "box-close-v2", + "metaworld-button-press-topdown-wall": "button-press-topdown-wall-v2", + "metaworld-button-press-topdown": "button-press-topdown-v2", + "metaworld-button-press-wall": "button-press-wall-v2", + "metaworld-button-press": "button-press-v2", + "metaworld-coffee-button": "coffee-button-v2", + "metaworld-coffee-pull": "coffee-pull-v2", + "metaworld-coffee-push": "coffee-push-v2", + "metaworld-dial-turn": "dial-turn-v2", + "metaworld-disassemble": "disassemble-v2", + "metaworld-door-close": "door-close-v2", + "metaworld-door-lock": "door-lock-v2", + "metaworld-door-open": "door-open-v2", + "metaworld-door-unlock": "door-unlock-v2", + "metaworld-drawer-close": "drawer-close-v2", + "metaworld-drawer-open": "drawer-open-v2", + "metaworld-faucet-close": "faucet-close-v2", + "metaworld-faucet-open": "faucet-open-v2", + "metaworld-hammer": "hammer-v2", + "metaworld-hand-insert": "hand-insert-v2", + "metaworld-handle-press-side": "handle-press-side-v2", + "metaworld-handle-press": "handle-press-v2", + "metaworld-handle-pull-side": "handle-pull-side-v2", + "metaworld-handle-pull": "handle-pull-v2", + "metaworld-lever-pull": "lever-pull-v2", + "metaworld-peg-insert-side": "peg-insert-side-v2", + "metaworld-peg-unplug-side": "peg-unplug-side-v2", + "metaworld-pick-out-of-hole": "pick-out-of-hole-v2", + "metaworld-pick-place-wall": "pick-place-wall-v2", + "metaworld-pick-place": "pick-place-v2", + "metaworld-plate-slide-back-side": "plate-slide-back-side-v2", + "metaworld-plate-slide-back": "plate-slide-back-v2", + "metaworld-plate-slide-side": "plate-slide-side-v2", + "metaworld-plate-slide": "plate-slide-v2", + "metaworld-push-back": "push-back-v2", + "metaworld-push-wall": "push-wall-v2", + "metaworld-push": "push-v2", + "metaworld-reach-wall": "reach-wall-v2", + "metaworld-reach": "reach-v2", + "metaworld-shelf-place": "shelf-place-v2", + "metaworld-soccer": "soccer-v2", + "metaworld-stick-pull": "stick-pull-v2", + "metaworld-stick-push": "stick-push-v2", + "metaworld-sweep-into": "sweep-into-v2", + "metaworld-sweep": "sweep-v2", + "metaworld-window-close": "window-close-v2", + "metaworld-window-open": "window-open-v2", + "mujoco-ant": "Ant-v4", + "mujoco-doublependulum": "InvertedDoublePendulum-v4", + "mujoco-halfcheetah": "HalfCheetah-v4", + "mujoco-hopper": "Hopper-v4", + "mujoco-humanoid": "Humanoid-v4", + "mujoco-pendulum": "InvertedPendulum-v4", + "mujoco-pusher": "Pusher-v4", + "mujoco-reacher": "Reacher-v4", + "mujoco-standup": "HumanoidStandup-v4", + "mujoco-swimmer": "Swimmer-v4", + "mujoco-walker": "Walker2d-v4", + "ok-vqa": "OK-VQA", + "oscar": "OSCAR", +} + + +@contextmanager +def suppress_stdout(): + class DummyFile(object): + def write(self, x): + pass + + # Save the current stdout + original_stdout = sys.stdout + sys.stdout = DummyFile() + try: + yield + finally: + sys.stdout = original_stdout + + +def no_print_decorator(func): + def wrapper(*args, **kwargs): + with suppress_stdout(): + return func(*args, **kwargs) + + return wrapper + + +def generate_rl_eval_results(scores_dict: Dict[str, List[float]]) -> List[EvalResult]: + """ + Generate a list of EvalResult objects. + + Args: + scores_dict (`Dict[str, List[float]]`): + Dictionary containing the scores for each task. + + Returns: + `List[EvalResult]`: + A list of EvalResult objects. + """ + eval_results = [] + for task_name, scores in scores_dict.items(): + mean_reward = np.mean(scores) + std_reward = np.std(scores) + + eval_results.append( + EvalResult( + task_type="reinforcement-learning", + task_name="Reinforcement Learning", + dataset_type=task_name, + dataset_name=PRETTY_TASK_NAMES[task_name], + metric_type="total_reward", + metric_name="Total reward", + metric_value=f"{mean_reward:.2f} +/- {std_reward:.2f}", + ) + ) + + for task_name, scores in scores_dict.items(): + mean_reward = np.mean(scores) + std_reward = np.std(scores) + with open("gia/eval/rl/scores_dict.json", "r") as file: + scores_dict = json.load(file) + + expert_score = scores_dict[task_name]["expert"]["mean"] + random_score = scores_dict[task_name]["random"]["mean"] + norm_mean_reward = (mean_reward - random_score) / (expert_score - random_score) + norm_std_reward = std_reward / (expert_score - random_score) + + eval_results.append( + EvalResult( + task_type="reinforcement-learning", + task_name="Reinforcement Learning", + dataset_type=task_name, + dataset_name=PRETTY_TASK_NAMES[task_name], + metric_type="expert_normalized_total_reward", + metric_name="Expert normalized total reward", + metric_value=f"{norm_mean_reward:.2f} +/- {norm_std_reward:.2f}", + ) + ) + return eval_results + + +def generate_model_card(model_name: str, scores_dict: Optional[Dict[str, List[float]]] = None) -> ModelCard: + """ + Generate a ModelCard from a template. + + Args: + model_name (`str`): + Model name. + scores_dict (`Dict[str, List[float]]`): + Dictionary containing the scores for each task. + + Returns: + `ModelCard`: + A ModelCard object. + """ + tags = ["reinforcement-learning"] + if scores_dict is not None: + tags.extend(scores_dict.keys()) + card_data = ModelCardData( + tags=tags, + eval_results=generate_rl_eval_results(scores_dict) if scores_dict is not None else None, + model_name=model_name, + datasets="gia-project/gia-dataset-parquet", + pipeline_tag="reinforcement-learning", + ) + card = ModelCard.from_template( + card_data, + template_path="templates/model_card.md", + model_name=model_name, + model_id="Gia2", + tasks=[PRETTY_TASK_NAMES[task_name] for task_name in scores_dict.keys()] if scores_dict is not None else [], + ) + return card + + +def push_to_hub( + model: PreTrainedModel, + processor: ProcessorMixin, + repo_id: str, + scores_dict: Optional[Dict[str, List[float]]] = None, + replay_path: Optional[str] = None, +) -> None: + """ + Push a model to the Hugging Face Hub. + + Args: + model (`PreTrainedModel`): + Model to push. + processor (`ProcessorMixin`): + Processor to push. + repo_id (`str`): + Repository ID to push to. + scores_dict (`Dict[str, List[float]]` or `None`, **optional**): + Dictionary containing the scores for each task. + replay_path (`str` or `None`, **optional**): + Path to the replay video. + """ + api = HfApi() + + # Create the repo + api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) + + # Create a README.md using a template + model_card = generate_model_card(repo_id, scores_dict) + model_card.push_to_hub(repo_id, commit_message="Upload model card") + + # Push the model + model.push_to_hub(repo_id, commit_message="Upload model") + + # Push the processor + processor.push_to_hub(repo_id, commit_message="Upload processor") + + # Push the replay + if replay_path is not None: + api.upload_file( + path_or_fileobj=replay_path, + path_in_repo="replay.mp4", + repo_id=repo_id, + commit_message="Upload replay", + repo_type="model", + ) + + print(f"Pushed model to \033[34mhttps://huggingface.co/{repo_id}\033[0m") + + +def save_video_grid( + videos: List[List[np.ndarray]], + input_fps: List[int], + output_filename: str = "output.mp4", + width: int = 1920, + output_fps: int = 30, + max_length_seconds: Optional[int] = None, +) -> None: + """ + Save a grid video from a list of videos. + + Args: + videos (`List[List[np.ndarray]]`): + List of videos, where each video is a list of frames in RGB format. + input_fps (`List[int]`): + List of FPS values for each video. + output_filename (`str`, **optional**): + Output video filename including the extension. + output_fps (`int`, **optional**): + Frames per second for the output video. + max_length_seconds (`Optional[int]`, **optional**): + Maximum length of the output video in seconds. If None, the length of the longest video is used. + """ + # Check if there are any videos + if not videos: + raise ValueError("No videos provided") + + if len(videos) != len(input_fps): + raise ValueError("The number of videos must match the number of FPS values") + + # Determine grid size based on the number of videos + num_cols = int(np.ceil(np.sqrt(len(videos)))) + num_rows = int(np.ceil(len(videos) / num_cols)) + height = width * num_rows // num_cols + + # Define the codec and create a VideoWriter object + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + temp_filename = tempfile.mktemp(suffix=".mp4") + out = cv2.VideoWriter(temp_filename, fourcc, output_fps, (width, height)) + + # Number of frames in the longest video, if max_length_seconds is specified, adjust max_frames + max_frames = max(len(video) for video in videos) + if max_length_seconds is not None: + max_frames = min(max_frames, output_fps * max_length_seconds) + + for frame_idx in range(max_frames): + # Create an empty grid + grid = np.zeros((height, width, 3), dtype=np.uint8) + + for video_idx, video in enumerate(videos): + # Adjust for different FPS values + adjusted_frame_idx = int((frame_idx * input_fps[video_idx]) / output_fps) + looped_frame_idx = adjusted_frame_idx % len(video) + frame = video[looped_frame_idx] + row = video_idx // num_cols + col = video_idx % num_cols + # resize the frame to the grid size + w = width // num_cols + h = height // num_rows + frame = cv2.resize(frame, (w, h)) + grid[row * h : (row + 1) * h, col * w : (col + 1) * w] = frame + + grid = grid[..., [2, 1, 0]] # RGB to BGR + out.write(grid) + + out.release() + os.system(f"ffmpeg -y -i {temp_filename} -vcodec h264 {output_filename}") + + +def mix_iterable_datasets( + datasets: List[IterableDataset], + batch_size: int, + stopping_strategy: str = "all_exhausted", + weights: List[float] = None, +): + """ + Mixes multiple IterableDataset objects into a single IterableDataset. + + Args: + datasets (`List[IterableDataset]`): + List of IterableDataset objects. + batch_size (`int`): + Batch size. + stopping_strategy (`str`, **optional**): + Stopping strategy. Can be either "first_exhausted" or "all_exhausted". + weights (`List[float]`, **optional**): + List of weights for each dataset. If None, uniform weights are used. + + Returns: + `IterableDataset`: + A mixed IterableDataset object. + """ + + def generator( + datasets: List[IterableDataset], + batch_size: int, + stopping_strategy: str = "first_exhausted", + weights: List[float] = None, + ): + assert stopping_strategy in ["first_exhausted", "all_exhausted"] + iterators = [iter(dataset) for dataset in datasets] + exhausted = [False] * len(datasets) # A list to keep track of which iterators are exhausted + weights = weights if weights is not None else [1.0] * len(datasets) + + while True: + dataset_idx = random.choices(range(len(datasets)), weights=weights, k=1)[0] # Choose a dataset randomly + iterator = iterators[dataset_idx] + for _ in range(batch_size): + try: + yield next(iterator) + except StopIteration: + if stopping_strategy == "first_exhausted": + return + else: + # Mark the iterator as exhausted + exhausted[dataset_idx] = True + # Check if all iterators are exhausted + if all(exhausted): + return + # Reinitialize the exhausted iterator + iterator = iterators[dataset_idx] = iter(datasets[dataset_idx]) + yield next(iterators[dataset_idx]) + + gen_kwargs = { + "datasets": datasets, + "batch_size": batch_size, + "stopping_strategy": stopping_strategy, + "weights": weights, + } + return IterableDataset.from_generator(generator=generator, gen_kwargs=gen_kwargs) diff --git a/scripts/download_all_datasets.py b/scripts/download_all_datasets.py index c09bdf45..295e1e5e 100755 --- a/scripts/download_all_datasets.py +++ b/scripts/download_all_datasets.py @@ -1,10 +1,26 @@ #!/usr/bin/env python3 """Load and generate batch for all datasets from the GIA dataset""" +import argparse +import os + from datasets import get_dataset_config_names, load_dataset +from datasets.config import HF_DATASETS_CACHE + + +parser = argparse.ArgumentParser() +parser.add_argument("--tasks", nargs="+", default=[]) +tasks = parser.parse_args().tasks +if tasks == ["all"]: + tasks = get_dataset_config_names("gia-project/gia-dataset-parquet") # get all task names from gia dataset -task_names = get_dataset_config_names("gia-project/gia-dataset-parquet") # get all task names from gia dataset -for task_name in task_names: - print(f"Loading {task_name}...") - load_dataset("gia-project/gia-dataset-parquet", task_name) +for task in tasks: + print(f"Loading {task}...") + cache_path = f"{HF_DATASETS_CACHE}/gia-project/gia-dataset-parquet/{task}" + if not os.path.exists(cache_path): + if task == "oscar": + dataset = load_dataset("ClementRomac/cleaned_deduplicated_oscar") + else: + dataset = load_dataset("gia-project/gia-dataset-parquet", task) + dataset.save_to_disk(cache_path) diff --git a/scripts/eval_gia2.py b/scripts/eval_gia2.py new file mode 100755 index 00000000..bd31c631 --- /dev/null +++ b/scripts/eval_gia2.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +"""Eval a GIA model on the GIA dataset""" +import json +import logging +import os +import sys +import warnings +from dataclasses import dataclass, field +from typing import List, Optional + +import numpy as np +import torch +from tqdm import tqdm +from transformers import HfArgumentParser + +from gia.eval.rl import make +from gia.eval.rl.envs.core import TASK_NAME_TO_ENV_ID +from gia2.modeling_gia2 import Gia2Model +from gia2.processing_gia2 import Gia2Processor +from gia2.utils import push_to_hub, save_video_grid, suppress_stdout + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config we are going to train from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + + +@dataclass +class EvaluationArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + tasks: List[str] = field(default_factory=list, metadata={"help": "Tasks to train on."}) + use_cpu: bool = field(default=False, metadata={"help": "Use CPU instead of GPU."}) + save_video: bool = field(default=False, metadata={"help": "Save video of the evaluation."}) + num_episodes: int = field(default=2, metadata={"help": "Number of episodes to evaluate on."}) + push_to_hub: bool = field(default=False, metadata={"help": "Push the model to the hub."}) + repo_id: Optional[str] = field(default=None, metadata={"help": "Repository ID to push to."}) + + +def get_default_device() -> torch.device: + if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + return torch.device("mps") + elif torch.cuda.is_available(): + return torch.device("cuda") + else: + return torch.device("cpu") + + +def eval_rl(model, processor, task, eval_args): + # Create the environment + env_kwargs = {} + if task.startswith("atari"): + env_kwargs["clip_reward"] = False + if eval_args.save_video: + env_kwargs["render_mode"] = "rgb_array" + with suppress_stdout(): # avoid printing the env info + env = make(task, **env_kwargs) + + scores = [] + frames = [] + for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False): + observation, _ = env.reset() + reward = None + rewards = [] + done = False + model.reset_rl() # remove KV Cache + while not done: + action = model.get_next_action(processor, **observation, rewards=reward, action_space=env.action_space) + observation, reward, termined, truncated, info = env.step(action) + done = termined or truncated + + # Handle "fake done" for atari + if done and task.startswith("atari"): + if "episode" not in info: + observation, info = env.reset() + done = False + else: + print("Episode done, score:", info["episode"]["r"], sum(rewards)) + + # Update the return + rewards.append(reward) + + # Render the environment + if eval_args.save_video: + frames.append(np.array(env.render(), dtype=np.uint8)) + + scores.append(sum(rewards)) + env.close() + + # Get the mean and std of the expert and random scores + with open("gia/eval/rl/scores_dict.json", "r") as file: + scores_dict = json.load(file) + + expert_mean = scores_dict[task]["expert"]["mean"] + random_mean = scores_dict[task]["random"]["mean"] + + # Normalize the scores + raw_mean = np.mean(scores) + raw_std = np.std(scores) + norm_mean = (raw_mean - random_mean) / (expert_mean - random_mean) + norm_std = raw_std / (expert_mean - random_mean) + + # Print the results + tqdm.write( + f"Task {task} Raw score: {raw_mean:.2f} ± {raw_std:.2f} " f"Normalized score: {norm_mean:.2f} ± {norm_std:.2f}" + ) + + return scores, frames, env.metadata["render_fps"] + + +def main(): + parser = HfArgumentParser((ModelArguments, EvaluationArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, eval_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, eval_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + # Set the tasks + tasks = eval_args.tasks + for domain in ["atari", "babyai", "metaworld", "mujoco"]: + if domain in tasks: + tasks.remove(domain) + tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)]) + + device = torch.device("cpu") if eval_args.use_cpu else get_default_device() + model = Gia2Model.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir).to(device) + processor = Gia2Processor.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) + + scores_dict = {} + video_list = [] + input_fps = [] + + for task in tqdm(tasks, desc="Evaluation", unit="task", leave=True): + if task in TASK_NAME_TO_ENV_ID.keys(): + scores, frames, fps = eval_rl(model, processor, task, eval_args) + scores_dict[task] = scores + # Save the video + if eval_args.save_video: + video_list.append(frames) + input_fps.append(fps) + else: + warnings.warn(f"Task {task} is not supported.") + + # Extract mean and std, and save scores dict + to_save = {task: {"mean": np.mean(scores), "std": np.std(scores)} for task, scores in scores_dict.items()} + with open(f"{model_args.model_name_or_path}/scores_dict.json", "w") as file: + json.dump(to_save, file) + + # Save the video + if eval_args.save_video: + replay_path = f"{model_args.model_name_or_path}/replay.mp4" + save_video_grid(video_list, input_fps, replay_path, output_fps=30, max_length_seconds=180) + else: + replay_path = None + + # Push the model to the hub + if eval_args.push_to_hub: + assert eval_args.repo_id is not None, "You need to specify a repo_id to push to." + push_to_hub(model, processor, eval_args.repo_id, scores_dict=scores_dict, replay_path=replay_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_config_gia2.py b/scripts/generate_config_gia2.py new file mode 100644 index 00000000..ea8447cb --- /dev/null +++ b/scripts/generate_config_gia2.py @@ -0,0 +1,68 @@ +from transformers import AutoTokenizer, CLIPImageProcessor + +from gia2.configuration_gia2 import Gia2Config +from gia2.processing_gia2 import Gia2Processor + + +# Small model +tokenizer = AutoTokenizer.from_pretrained("gpt2", model_input_names=["input_ids", "attention_mask"]) +config = Gia2Config( + vocab_size=tokenizer.vocab_size, + max_position_embeddings=512, + hidden_size=768, + num_layers=12, + attention_types=[[["global", "local"], 6]], + num_heads=12, + max_discrete_value=148 + 64, # 148 (discrete obs from BabyAI) + 64 (max size of BabyAI's text observation) + tokenizer_class=tokenizer.__class__.__name__, +) +image_processor = CLIPImageProcessor( + size={"shortest_edge": config.image_size}, crop_size={"height": config.image_size, "width": config.image_size} +) +tokenizer.model_max_length = config.max_position_embeddings +tokenizer.pad_token = tokenizer.eos_token +processor = Gia2Processor(tokenizer=tokenizer, image_processor=image_processor) +config.push_to_hub("gia-project/gia2-small") +processor.push_to_hub("gia-project/gia2-small") + +# Medium model +tokenizer = AutoTokenizer.from_pretrained("gpt2", model_input_names=["input_ids", "attention_mask"]) +config = Gia2Config( + vocab_size=tokenizer.vocab_size, + max_position_embeddings=1024, + hidden_size=2048, + num_layers=24, + attention_types=[[["global", "local"], 12]], + num_heads=16, + max_discrete_value=148 + 64, # 148 (discrete obs from BabyAI) + 64 (max size of BabyAI's text observation) + tokenizer_class=tokenizer.__class__.__name__, +) +image_processor = CLIPImageProcessor( + size={"shortest_edge": config.image_size}, crop_size={"height": config.image_size, "width": config.image_size} +) +tokenizer.model_max_length = config.max_position_embeddings +tokenizer.pad_token = tokenizer.eos_token +processor = Gia2Processor(tokenizer=tokenizer, image_processor=image_processor) +config.push_to_hub("gia-project/gia2-medium") +processor.push_to_hub("gia-project/gia2-medium") + +# Large model +tokenizer = AutoTokenizer.from_pretrained("gpt2", model_input_names=["input_ids", "attention_mask"]) +config = Gia2Config( + vocab_size=tokenizer.vocab_size, + max_position_embeddings=2048, + hidden_size=2560, + num_layers=32, + attention_types=[[["global", "local"], 16]], + num_heads=20, + max_discrete_value=148 + 64, # 148 (discrete obs from BabyAI) + 64 (max size of BabyAI's text observation) + tokenizer_class=tokenizer.__class__.__name__, +) +image_processor = CLIPImageProcessor( + size={"shortest_edge": config.image_size}, crop_size={"height": config.image_size, "width": config.image_size} +) +tokenizer.model_max_length = config.max_position_embeddings +tokenizer.pad_token = tokenizer.eos_token +processor = Gia2Processor(tokenizer=tokenizer, image_processor=image_processor) +config.push_to_hub("gia-project/gia2-large") +processor.push_to_hub("gia-project/gia2-large") diff --git a/scripts/train_gia2.py b/scripts/train_gia2.py new file mode 100755 index 00000000..dd99a411 --- /dev/null +++ b/scripts/train_gia2.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +"""Train a GIA model on the GIA dataset""" + + +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import List, Optional + +import datasets.config +from datasets import load_dataset, load_from_disk +from datasets.config import HF_DATASETS_CACHE, HF_DATASETS_OFFLINE +from transformers import AutoConfig, AutoProcessor, HfArgumentParser, Trainer, TrainingArguments + +from gia.eval.rl.envs.core import TASK_NAME_TO_ENV_ID +from gia2.modeling_gia2 import Gia2Model +from gia2.utils import mix_iterable_datasets + + +# Sometimes, the server is down; increasing the number of +# retries allows to wait more instead of making the training crash +datasets.config.STREAMING_READ_MAX_RETRIES = 10000 + + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config we are going to train from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": ( + "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" + "should only be set to `True` for repositories you trust and in which you have read the code, as it " + "will execute code present on the Hub on your local machine." + ) + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + tasks: List[str] = field(default_factory=list, metadata={"help": "Tasks to train on."}) + preprocess_num_proc: int = field( + default=1, metadata={"help": "Number of processes to use for preprocessing the data."} + ) + eval_num_samples: int = field(default=1000, metadata={"help": "Number of samples to use for evaluation."}) + + +LOSS_WEIGHTS = { + "mujoco-pendulum": 20.0, + "mujoco-doublependulum": 10.0, +} +SAMPLE_WEIGHTS = { + # "oscar": 10.0, + # "conceptual_caption": 10.0, +} + +os.environ["WANDB_ENTITY"] = "gia-project" +os.environ["WANDB_PROJECT"] = "gia2" + + +def main(): + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + ) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + ) + model = Gia2Model(config) + processor = AutoProcessor.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + ) + + # Set the tasks + tasks = data_args.tasks + for domain in ["atari", "babyai", "metaworld", "mujoco"]: + if domain in tasks: + tasks.remove(domain) + tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)]) + + # Load the dataset + # Automatic cache is broken for parquet datasets + # The following is a fix from https://github.com/huggingface/datasets/issues/3547#issuecomment-1252503988 + dataset_dict = {} + if HF_DATASETS_OFFLINE: + for task in tasks: + if not os.path.exists(f"{HF_DATASETS_CACHE}/gia-project/gia-dataset-parquet/{task}"): + raise ValueError( + f"""Dataset {task} not found in {HF_DATASETS_CACHE}/gia-project/gia-dataset-parquet/ +Make sure to download and save it first with +``` +from datasets import load_dataset +dataset = load_dataset('gia-project/gia-dataset-parquet', '{task}') +dataset.save_to_disk('{HF_DATASETS_CACHE}/gia-project/gia-dataset-parquet/{task}') +```""" + ) + dataset = load_from_disk(f"{HF_DATASETS_CACHE}/gia-project/gia-dataset-parquet/{task}") + dataset_dict[task] = {s: d.to_iterable_dataset() for s, d in dataset.items()} + else: + for task in tasks: + if task == "oscar": + dataset_dict[task] = load_dataset("ClementRomac/cleaned_deduplicated_oscar", streaming=True) + else: + dataset_dict[task] = load_dataset("gia-project/gia-dataset-parquet", task, streaming=True) + + # Preprocess the dataset + for task in dataset_dict.keys(): + for split in dataset_dict[task].keys(): + dataset = dataset_dict[task][split] + column_names = set(dataset.column_names) # need to be done here because this info is lost after the map + dataset = dataset.filter(lambda example: example.get("rewards") != []) + # We've shown that reducing the sequence length for atari doesn't impact performance but allows for a + # larger global batch size + max_length = 64 if task.startswith("atari") else None + + def preprocess(example_batch, max_length): + return processor(**example_batch, padding="max_length", truncation="preserve", max_length=max_length) + + dataset = dataset.map( + preprocess, + batched=True, + batch_size=1, # small to avoid OOM + remove_columns={"text", "images", "text_observations"}.intersection(column_names), + fn_kwargs={"max_length": max_length}, + ) + dataset = dataset.map( + lambda x: {"loss_weight": [LOSS_WEIGHTS.get(task, 1.0)] * len(next(iter(x.values())))} + ) + dataset_dict[task][split] = dataset + + train_dataset = {t: d["train"] for t, d in dataset_dict.items()} + eval_dataset = {t: d["test"] for t, d in dataset_dict.items()} + + for key in tasks: # Reduce the number of eval samples + eval_dataset[key] = eval_dataset[key].take(data_args.eval_num_samples) + + weights = [SAMPLE_WEIGHTS.get(t, 1.0) for t in train_dataset.keys()] + train_dataset = mix_iterable_datasets( + list(train_dataset.values()), batch_size=training_args.per_device_train_batch_size, weights=weights + ) + # Due to the train dataset's structure, where every 'n' consecutive samples share the same modalities, we can't + # load all samples at once. Different sets of 'n' samples have different modalities. Therefore, we must load and + # process each set of 'n' samples separately. + if training_args.dispatch_batches: + raise ValueError("Make sure to pass `--dispatch_batches False`.") + + # Why the training continue after exauhsting the dataset? https://github.com/huggingface/transformers/issues/26635 + trainer = Trainer( + model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=processor + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 6291ce27..986b282a 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ "numpy", "opencv-python", "torch==2.0.1", + "torchvision", "transformers==4.32.1", "wandb", ] diff --git a/templates/model_card.md b/templates/model_card.md new file mode 100644 index 00000000..49544eaa --- /dev/null +++ b/templates/model_card.md @@ -0,0 +1,38 @@ +--- +{{ card_data }} +--- + +# Model Card for {{ model_id | default("Model name", true) }} + +This is a multi-modal and multi-task model. + +## Model Details + +### Model Description + +- **Developed by:** The GIA Team +- **License:** Apache 2.0 + +### Model Sources + +- **Repository:** +- **Paper:** Coming soon +- **Demo:** Coming soon + +## Training + +The model was trained on the following tasks: + +{% for task in tasks -%} +- {{ task }} +{% endfor %} +## How to Get Started with the Model + +Use the code below to get started with the model. + +```python +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained("{{ model_name | default("[More Information Needed]", true)}}") +``` + diff --git a/tests/gia2/test_modeling_gia2.py b/tests/gia2/test_modeling_gia2.py new file mode 100644 index 00000000..20df7c00 --- /dev/null +++ b/tests/gia2/test_modeling_gia2.py @@ -0,0 +1,38 @@ +import torch + +from gia2.modeling_gia2 import compute_mse_loss + + +def test_basic(): + predicted = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]]) + true = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]]) + mask = torch.tensor([[True, True]]) + + loss = compute_mse_loss(predicted, true, mask) + expected_loss = torch.tensor(0.0) + + assert torch.isclose(loss, expected_loss, atol=1e-8) + + +def test_masking(): + predicted = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]]) + true = torch.tensor([[[1.0, 2.0], [10.0, 10.0]]]) # second time step is different + mask = torch.tensor([[True, False]]) # mask out the second time step + + loss = compute_mse_loss(predicted, true, mask) + expected_loss = torch.tensor(0.0) # masked entries should be ignored + + assert torch.isclose(loss, expected_loss, atol=1e-8) + + +def test_weighted(): + # batch size = 1, time steps = 3, features = 2 + predicted = torch.tensor([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]]) + true = torch.tensor([[[1.0, 2.0], [10.0, 10.0], [5.0, 6.0]]]) # second time step is different + mask = torch.tensor([[True, True, True]]) + weights = torch.tensor([[1.0, 0.0, 1.0]]) # mask out the second time step + + loss = compute_mse_loss(predicted, true, mask, weights=weights) + expected_loss = torch.tensor(0.0) # second time step should be ignored due to zero weight + print(loss) + assert torch.isclose(loss, expected_loss, atol=1e-8) diff --git a/tests/gia2/test_processing_gia2.py b/tests/gia2/test_processing_gia2.py new file mode 100644 index 00000000..1596fcfa --- /dev/null +++ b/tests/gia2/test_processing_gia2.py @@ -0,0 +1,128 @@ +import pytest +import torch +from PIL import Image +from transformers import AutoTokenizer, CLIPImageProcessor + +from gia2.processing_gia2 import Gia2Processor + + +@pytest.fixture +def processor(): + image_size = 224 + image_processor = CLIPImageProcessor( + size={"shortest_edge": image_size}, crop_size={"height": image_size, "width": image_size} + ) + tokenizer = AutoTokenizer.from_pretrained("gpt2", model_input_names=["input_ids", "attention_mask"]) + processor = Gia2Processor(tokenizer=tokenizer, image_processor=image_processor) + return processor + + +def test_unbatched_text_encoding(processor): + text = "The quick brown fox jumps over the lazy dog" + encoding = processor(text=text, return_tensors="pt") + assert "input_ids" in encoding + assert encoding["input_ids"].shape == torch.Size([1, 9]) + + +def test_unbatched_text_encoding_pad(processor): + text = "The quick brown fox jumps over the lazy dog" + encoding = processor(text=text, return_tensors="pt", padding="max_length", max_length=16) + assert "input_ids" in encoding + assert "attention_mask" in encoding + assert encoding["input_ids"].shape == torch.Size([1, 16]) + assert encoding["attention_mask"].shape == torch.Size([1, 16]) + assert torch.all(encoding["attention_mask"][:, :9] == 1) + assert torch.all(encoding["attention_mask"][:, 9:] == 0) + + +def test_unbatched_text_encoding_truncate(processor): + text = "The quick brown fox jumps over the lazy dog" + encoding = processor(text=text, return_tensors="pt", truncation=True, max_length=8) + assert "input_ids" in encoding + assert encoding["input_ids"].shape == torch.Size([1, 8]) + + +def test_unbatched_text_encoding_truncate_preserve(processor): + text = "The quick brown fox jumps over the lazy dog" + encoding = processor(text=text, return_tensors="pt", truncation="preserve", max_length=6, padding=True) + assert "input_ids" in encoding + assert "attention_mask" in encoding + assert encoding["input_ids"].shape == torch.Size([2, 6]) + assert encoding["attention_mask"].shape == torch.Size([2, 6]) + assert torch.all(encoding["attention_mask"] == torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]])) + + +def test_image_encoding(processor): + image = Image.new("RGB", (224, 224)) + encoding = processor(images=image, return_tensors="pt") + assert "pixel_values" in encoding + assert encoding["pixel_values"].shape == torch.Size([1, 3, 224, 224]) + + +def test_image_encoding_batched(processor): + images = [Image.new("RGB", (224, 224))] + encoding = processor(images=images, return_tensors="pt") + assert "pixel_values" in encoding + assert encoding["pixel_values"].shape == torch.Size([1, 3, 224, 224]) + + +def test_text_and_image_encoding(processor): + text = "The quick brown fox jumps over the lazy dog" + image = Image.new("RGB", (224, 224)) + encoding = processor(text=text, images=image, return_tensors="pt") + assert "input_ids" in encoding + assert "pixel_values" in encoding + assert encoding["input_ids"].shape == torch.Size([1, 9]) + assert encoding["pixel_values"].shape == torch.Size([1, 3, 224, 224]) + + +def test_batch_decode(processor): + texts = ["The quick brown fox", "jumps over the lazy dog"] + encoding = processor(text=texts) + decoded_texts = processor.batch_decode(encoding["input_ids"]) + assert isinstance(decoded_texts, list) + assert len(decoded_texts) == 2 + assert decoded_texts[0] == "The quick brown fox" + assert decoded_texts[1] == "jumps over the lazy dog" + + +def test_decode(processor): + text = "The quick brown fox jumps over the lazy dog" + encoding = processor.tokenizer(text) + decoded_text = processor.decode(encoding["input_ids"]) + assert decoded_text == text + + +def test_continuous_observations_encoding(processor): + continuous_observations = [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]] # 1 episode, 4 steps, 2 features + encoding = processor(continuous_observations=continuous_observations, return_tensors="pt") + assert "continuous_observations" in encoding + assert encoding["continuous_observations"].shape == torch.Size([1, 4, 2]) + + +def test_discrete_observations_encoding(processor): + discrete_observations = [[[1, 2], [3, 4], [5, 6], [7, 8]]] # 1 episode, 4 steps, 2 features + encoding = processor(discrete_observations=discrete_observations, return_tensors="pt") + assert "discrete_observations" in encoding + assert encoding["discrete_observations"].shape == torch.Size([1, 4, 2]) + + +def test_image_observations_encoding(processor): + image_observations = [[Image.new("RGBA", (84, 84))]] + encoding = processor(image_observations=image_observations, return_tensors="pt") + assert "image_observations" in encoding + assert encoding["image_observations"].shape == torch.Size([1, 1, 4, 84, 84]) + + +def test_continuous_actions_encoding(processor): + continuous_actions = [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]] # 1 episode, 4 steps, 2 features + encoding = processor(continuous_actions=continuous_actions, return_tensors="pt") + assert "continuous_actions" in encoding + assert encoding["continuous_actions"].shape == torch.Size([1, 4, 2]) + + +def test_discrete_actions_encoding(processor): + discrete_actions = [[1, 2, 3, 4]] # 1 episode, 4 steps + encoding = processor(discrete_actions=discrete_actions, return_tensors="pt") + assert "discrete_actions" in encoding + assert encoding["discrete_actions"].shape == torch.Size([1, 4])