Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GIA2 #122

Merged
merged 186 commits into from
Nov 6, 2023
Merged

GIA2 #122

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
186 commits
Select commit Hold shift + click to select a range
da38bd3
init
qgallouedec Aug 21, 2023
e24c234
liltle improvements
qgallouedec Aug 21, 2023
15be4e5
drop token simple (mujoco)
qgallouedec Aug 30, 2023
10ab8c9
all mujoco
qgallouedec Sep 3, 2023
30df306
clean
qgallouedec Sep 3, 2023
c932c32
try videowriter
qgallouedec Sep 3, 2023
e5063e6
rm eval
qgallouedec Sep 3, 2023
77d623e
try video wrtier
qgallouedec Sep 3, 2023
e58ff61
eval on cpu
qgallouedec Sep 3, 2023
41f2d8f
revert
qgallouedec Sep 3, 2023
5223825
fix all return
qgallouedec Sep 3, 2023
271811a
little improvements and custom sampler
qgallouedec Sep 4, 2023
91690e4
Merge branch 'main' into qgallouedec-idea-2
qgallouedec Sep 5, 2023
a39d568
atari
qgallouedec Sep 6, 2023
3c6d61c
eval atari
qgallouedec Sep 6, 2023
04e2f2e
improved decoder and drop observation loss
qgallouedec Sep 6, 2023
2b167b2
cleaning
qgallouedec Sep 7, 2023
9b69465
Merge branch 'main' into qgallouedec-idea-2
qgallouedec Sep 7, 2023
16c5c47
Merge branch 'main' into qgallouedec-idea-2
qgallouedec Sep 7, 2023
e07fa53
instance norm, padding, cosine lr, no episodic life
qgallouedec Sep 8, 2023
1ba1cd0
rm drop_token
qgallouedec Sep 8, 2023
8cfe3b6
rename to `new_network.py`
qgallouedec Sep 8, 2023
9905b77
friday night
qgallouedec Sep 8, 2023
66f7fe7
update new network
qgallouedec Sep 12, 2023
8e091c4
New create atari script
qgallouedec Sep 13, 2023
08039f0
drop_token_simple
qgallouedec Sep 14, 2023
6e3af07
Merge branch 'qgallouedec-idea-2' of https://github.com/huggingface/g…
qgallouedec Sep 14, 2023
6e0fce0
ant and double pendulum
qgallouedec Sep 14, 2023
d0b69ef
parquet
qgallouedec Sep 14, 2023
11b0a3e
set seed
qgallouedec Sep 14, 2023
d43fed9
hard seed
qgallouedec Sep 14, 2023
00ad90f
no parquet
qgallouedec Sep 14, 2023
dfde488
import numpy
qgallouedec Sep 14, 2023
2926eee
parquet now
qgallouedec Sep 14, 2023
da8f033
compare datasets
qgallouedec Sep 14, 2023
ec9a97c
speedup
qgallouedec Sep 14, 2023
e8419b5
remove features
qgallouedec Sep 14, 2023
94683a9
readd features
qgallouedec Sep 14, 2023
ffec54c
rm old dataset and jsut action loss
qgallouedec Sep 14, 2023
f989be9
new name
qgallouedec Sep 14, 2023
7bca363
all mujoco
qgallouedec Sep 14, 2023
c488da9
fix name
qgallouedec Sep 14, 2023
1483690
datacolator and ant-double-pendulum
qgallouedec Sep 15, 2023
ce6d386
exp name
qgallouedec Sep 15, 2023
9e85a61
fix data collator
qgallouedec Sep 15, 2023
c9e8d09
fix data collator
qgallouedec Sep 15, 2023
de55c52
max_position_embeddings=512
qgallouedec Sep 15, 2023
2004d98
bs2
qgallouedec Sep 15, 2023
a124d99
pad data collator
qgallouedec Sep 15, 2023
bd010ea
clean
qgallouedec Sep 15, 2023
967f953
print shapes
qgallouedec Sep 15, 2023
17fada7
pad 256
qgallouedec Sep 15, 2023
2c541a5
don't print
qgallouedec Sep 15, 2023
0a274e8
bs3
qgallouedec Sep 15, 2023
427c7ed
fix auto sizes
qgallouedec Sep 15, 2023
1c8daaf
fix attention
qgallouedec Sep 15, 2023
4c833f7
ant and 10k
qgallouedec Sep 15, 2023
6f5e2a6
fix return when no return loss
qgallouedec Sep 15, 2023
324ae51
allow mask = None in filter
qgallouedec Sep 15, 2023
522209a
fix pad
qgallouedec Sep 15, 2023
8776f3a
fix eval
qgallouedec Sep 15, 2023
0a22fb0
fix window size
qgallouedec Sep 15, 2023
a5f80c2
fix window
qgallouedec Sep 15, 2023
b5a51c9
all mujoco afbs
qgallouedec Sep 15, 2023
863f5dd
move everything to gia2
qgallouedec Sep 15, 2023
cc75804
cleaning
qgallouedec Sep 15, 2023
c3f7894
fix collate
qgallouedec Sep 16, 2023
1622cd0
my trainer
qgallouedec Sep 16, 2023
ff911e4
weighted sampling
qgallouedec Sep 16, 2023
26f20c3
python3.9
qgallouedec Sep 16, 2023
13977ba
same
qgallouedec Sep 16, 2023
9db0e01
grad accumulation and cyclic pad
qgallouedec Sep 17, 2023
6638510
change eval and save steps
qgallouedec Sep 17, 2023
d0ea4ab
remove weights
qgallouedec Sep 17, 2023
e4a1a90
fix eval and action pred
qgallouedec Sep 17, 2023
38846ca
compute_ce_loss and write_video
qgallouedec Sep 17, 2023
178514b
rename collator
qgallouedec Sep 17, 2023
2f813d0
proper config
qgallouedec Sep 17, 2023
f92a7b9
move collator to utils
qgallouedec Sep 17, 2023
dbb57be
preprocess_function in utils
qgallouedec Sep 17, 2023
412ba15
remove data_collator
qgallouedec Sep 17, 2023
3014919
modeling
qgallouedec Sep 17, 2023
5bcc016
train gia script
qgallouedec Sep 17, 2023
f541040
try biais=True
qgallouedec Sep 17, 2023
649f884
generate card, push to hub, and save videos
qgallouedec Sep 18, 2023
8cb35a1
update config model type
qgallouedec Sep 18, 2023
fc2b24d
loss_weight
qgallouedec Sep 18, 2023
abbfee0
loss weight mujoco
qgallouedec Sep 18, 2023
81c2964
eval script
qgallouedec Sep 18, 2023
70c50c7
model card template
qgallouedec Sep 18, 2023
d3f39ed
sept20
qgallouedec Sep 20, 2023
76289a8
new resnet
qgallouedec Sep 21, 2023
571402e
fix channels
qgallouedec Sep 21, 2023
698374d
litle cleaning
qgallouedec Sep 21, 2023
0d09b9a
new image model
qgallouedec Sep 21, 2023
f11e14a
clean
qgallouedec Sep 21, 2023
1c8f044
rm old file
qgallouedec Sep 22, 2023
7b0cfc6
delete unused
qgallouedec Sep 22, 2023
604802d
automodel and GIA2 -> Gia2
qgallouedec Sep 22, 2023
c7e6553
rm old comment
qgallouedec Sep 22, 2023
cd59b3f
style
qgallouedec Sep 22, 2023
38f5655
style
qgallouedec Sep 22, 2023
71068b2
update model and config
qgallouedec Sep 22, 2023
a0a7e3d
fix __init__
qgallouedec Sep 22, 2023
ed43f0a
rename
qgallouedec Sep 22, 2023
ca4b67d
rm old
qgallouedec Sep 22, 2023
5706ed5
update train script
qgallouedec Sep 22, 2023
187d0b7
update eval script
qgallouedec Sep 22, 2023
d65a2f7
fix max
qgallouedec Sep 23, 2023
81eb703
fix attention mask is None
qgallouedec Sep 23, 2023
6a50319
fix domain as task
qgallouedec Sep 23, 2023
76b9ffa
try register in eval
qgallouedec Sep 24, 2023
a560deb
reduce batch size
qgallouedec Sep 24, 2023
4cea860
fix tokenzier push-to-hub
qgallouedec Sep 24, 2023
0cce879
revert registration; change push_to_hub
qgallouedec Sep 24, 2023
2f48eb7
create repo
qgallouedec Sep 24, 2023
3de2ab3
pipeline tag
qgallouedec Sep 24, 2023
917d125
fix code in template
qgallouedec Sep 24, 2023
6d54476
cleanup modeling
qgallouedec Sep 27, 2023
301045c
add reward
qgallouedec Sep 27, 2023
397ca92
increase max continuous size
qgallouedec Sep 27, 2023
aeae01c
update eval
qgallouedec Sep 27, 2023
5220f91
GIA2 image captioning compatibility (#127)
qgallouedec Oct 5, 2023
3261319
Merge branch 'main' into qgallouedec-idea-2
qgallouedec Oct 5, 2023
05cba03
save to disk in download script
qgallouedec Oct 5, 2023
cb91747
check if path exists
qgallouedec Oct 5, 2023
9674884
fix download for oscar
qgallouedec Oct 5, 2023
2baa0ce
profiler
qgallouedec Oct 5, 2023
b9d7cba
remove profiler
qgallouedec Oct 5, 2023
3be0809
sampleweiht support
qgallouedec Oct 6, 2023
bb14794
sample weights
qgallouedec Oct 6, 2023
cfbe174
fix loss computation
qgallouedec Oct 7, 2023
7d214fd
fix loss weight offset
qgallouedec Oct 7, 2023
adde82e
Remove loss_function and cyclic_expand_dim from utils
qgallouedec Oct 7, 2023
f450da3
fix loss function
qgallouedec Oct 7, 2023
0c8af7b
rename test
qgallouedec Oct 7, 2023
ee942ce
default=all_exhausted
qgallouedec Oct 8, 2023
83098c9
style
qgallouedec Oct 8, 2023
3cef2d4
fix mix_iter
qgallouedec Oct 8, 2023
5e78e98
fix mix_iterable_dataset calling
qgallouedec Oct 10, 2023
994634e
Handling BabyAI's observations (to test)
ClementRomac Oct 10, 2023
164082f
split discrete encoder and add a layer
qgallouedec Oct 10, 2023
19036f2
format
qgallouedec Oct 10, 2023
c42bda8
oscar and cc sample weights
qgallouedec Oct 10, 2023
633711a
format
qgallouedec Oct 10, 2023
acb4c5c
sample weights
qgallouedec Oct 10, 2023
b06dfdd
speedup get_next_actions with kv_cache; single/multi_discrete net
qgallouedec Oct 13, 2023
5bb461d
set model max len in tokenizer
qgallouedec Oct 13, 2023
12cfcb4
truncation side and fix data alteration
qgallouedec Oct 13, 2023
36ee454
eval with kv cache
qgallouedec Oct 13, 2023
6ec8a28
fix bare except
qgallouedec Oct 13, 2023
3242877
sort imports
qgallouedec Oct 13, 2023
566cf4d
add torchvision to dep
qgallouedec Oct 13, 2023
5773e39
fix setup
qgallouedec Oct 13, 2023
2471e13
rm is_the_network_ok
qgallouedec Oct 13, 2023
f6a1f48
Merge branch 'main' into qgallouedec-idea-2
qgallouedec Oct 16, 2023
f0dd026
truncate past key values and don't predict multidescrete
qgallouedec Oct 16, 2023
baf9bc2
fix truncate for rl
qgallouedec Oct 16, 2023
9288371
update config generation script
qgallouedec Oct 16, 2023
d1c4bc9
fix train script for babyai
qgallouedec Oct 16, 2023
701f634
fix embed_textual
qgallouedec Oct 17, 2023
c451864
fix babyai names
qgallouedec Oct 19, 2023
5beffbb
fix last_continuous_action name
qgallouedec Oct 19, 2023
75d0370
fix babyai names
qgallouedec Oct 19, 2023
01f3266
task cluster in eval
qgallouedec Oct 19, 2023
02fe087
remove empty ep in train, wandb project, increse stream max retries
qgallouedec Oct 19, 2023
7b4fe51
reduce batch size to preserve memory
qgallouedec Oct 19, 2023
bf3a3ef
set the number of eval samples
qgallouedec Oct 19, 2023
8382f58
increase mix batch size to prepare for cluster training
qgallouedec Oct 19, 2023
a1ec295
print url in blue
qgallouedec Oct 19, 2023
3312a08
dispatch_batches=False
qgallouedec Oct 21, 2023
2fcebc0
ensure dispatch_batches is False
qgallouedec Oct 21, 2023
90adb99
liltle tricks to speed up: don't predict obs and relu inplace
qgallouedec Oct 20, 2023
263cc42
64 max len for atari
qgallouedec Oct 23, 2023
781fae0
fix vocab size in config
qgallouedec Oct 24, 2023
d40c38b
rm tokenizer_class entry
qgallouedec Oct 24, 2023
d0f932f
use gpt2 tokenizer instead of bert
qgallouedec Oct 24, 2023
27e4f27
save scores dict in eval
qgallouedec Oct 24, 2023
d6fb0f3
handle empty attention mask when image and text
qgallouedec Oct 24, 2023
df04d89
human score for atari
qgallouedec Oct 31, 2023
a3b2f71
Merge branch 'main' into qgallouedec-idea-2
qgallouedec Nov 2, 2023
8232951
Merge branch 'main' into qgallouedec-idea-2
qgallouedec Nov 2, 2023
91a4f53
new babyai names
qgallouedec Nov 3, 2023
751234b
tokenizer.pad_token = tokenizer.eos_token
qgallouedec Nov 3, 2023
4fc8766
Fix ambiguous variable name
qgallouedec Nov 6, 2023
255e189
update test for new tokenizer
qgallouedec Nov 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.PHONY: quality style test

# Define directories variable
DIRS = data examples gia scripts tests
DIRS = data examples gia gia2 scripts tests

# Check that source code meets quality standards
quality:
Expand Down
144 changes: 88 additions & 56 deletions data/conceptual_captions/generate_conceptual_caption.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,43 @@
import concurrent.futures
import io
import os
import urllib
import multiprocessing
from typing import Dict, List, Union
from urllib.request import Request, urlopen

import PIL.Image
from datasets import load_dataset
from datasets.utils.file_utils import get_datasets_user_agent


USER_AGENT = get_datasets_user_agent()
PATH = "data/test" # or "data/train"

MAX_WORKERS = 10 # adjust to your needs
MAX_QUEUE_SIZE = 2 * MAX_WORKERS # adjust to your needs

def fetch_image(image_url: str, timeout: float = 0.5) -> PIL.Image.Image:
"""
Fetches a single image from a given URL and returns it as a PIL Image object.

def fetch_single_image(image_url, timeout=1):
print(image_url)
try:
request = urllib.request.Request(
image_url,
data=None,
headers={"user-agent": USER_AGENT},
)
with urllib.request.urlopen(request, timeout=timeout) as req:
image = PIL.Image.open(io.BytesIO(req.read()))
except Exception:
image = None
Args:
image_url (str): The URL of the image to fetch.
timeout (float): The timeout value for the request (in seconds).

Returns:
A PIL Image object representing the fetched image, or None if the image could not be fetched.
"""
request = Request(image_url, data=None, headers={"user-agent": USER_AGENT})
with urlopen(request, timeout=timeout) as req:
image = PIL.Image.open(io.BytesIO(req.read()))
return image


def resize_single_image(image: PIL.Image):
def resize_image(image: PIL.Image) -> PIL.Image:
"""
Resize a single image to have the bigger size at most 352 pixels while maintaining aspect ratio.
Remove metadata from the image.

Args:
image (PIL.Image): The image to be resized.

Returns:
PIL.Image: The resized image without metadata.
"""
# Resize so that the bigger size is at most 352
width, height = image.size
if width > height:
Expand All @@ -40,43 +47,68 @@ def resize_single_image(image: PIL.Image):
new_height = 352
new_width = int(width * 352 / height)
image = image.resize((new_width, new_height), PIL.Image.BILINEAR)
image = image.convert("RGB")
image = image.convert("RGB") # Make sure the image is RGB
data = list(image.getdata()) # Get only the image data, and place it in a new image to remove metadata
image_without_exif = PIL.Image.new(image.mode, image.size)
image_without_exif.putdata(data)
return image_without_exif


def fetch_and_resize(img_url: str) -> Union[PIL.Image.Image, None]:
"""
Fetches an image from a given URL and resizes it.

Args:
img_url (str): The URL of the image to fetch.

Returns:
numpy.ndarray: The resized image as a NumPy array, or None if an error occurred.
"""
try:
image = fetch_image(img_url)
image = resize_image(image)
except Exception:
image = None
return image


dataset = load_dataset("conceptual_captions", split="validation") # or "train"
if not os.path.exists(f"{PATH}/metadata.csv"):
with open(f"{PATH}/metadata.csv", "w") as f:
f.write("file_name,caption,idx\n")
dataset_idx = 0
image_idx = 0
else: # get the lastest index
with open(f"{PATH}/metadata.csv", "r") as f:
lines = f.readlines()
image_idx = len(lines) - 1
dataset_idx = int(lines[-1].split(",")[-1]) + 1
print(image_idx, dataset_idx)

with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
future_to_idx = {executor.submit(fetch_single_image, dataset[dataset_idx]["image_url"]): dataset_idx}
dataset_idx += 1
while dataset_idx < len(dataset):
done, _ = concurrent.futures.wait(future_to_idx, return_when=concurrent.futures.FIRST_COMPLETED)
for future in done:
idx = future_to_idx.pop(future)
def process(example: Dict[str, List[str]]) -> Dict[str, List[Union[str, PIL.Image.Image]]]:
output = {"images": [], "text": []}

with multiprocessing.Pool() as pool:
images = pool.starmap(fetch_and_resize, [(url,) for url in example["image_url"]])

for idx, image in enumerate(images):
if image is not None:
output["images"].append(image)
output["text"].append(example["caption"][idx])

return output


if __name__ == "__main__":
from datasets import Dataset, features, load_dataset

for split in ["train", "test"]:
dataset = load_dataset("conceptual_captions", split="train" if split == "train" else "validation")
num_cpu = multiprocessing.cpu_count() // 2
dataset = dataset.map(
process,
batched=True,
batch_size=200,
remove_columns=["caption", "image_url"],
num_proc=num_cpu,
load_from_cache_file=True,
features=features.Features({"images": features.Image(decode=True), "text": features.Value("string")}),
)
dataset.save_to_disk(f"conceptual-captions-{split}")
dataset = Dataset.load_from_disk(f"conceptual-captions-{split}")

retry = 500

for i in range(retry):
try:
image = future.result()
if image is not None:
image = resize_single_image(image)
sample = dataset[idx]
caption = sample["caption"].replace(",", "").replace(";", "").replace("\n", "").replace("\t", "")
image.save(f"{PATH}/{image_idx:07d}.png", "PNG")
with open(f"{PATH}/metadata.csv", "a") as f:
f.write(f"{image_idx:07d}.png,{caption},{idx}\n")
image_idx += 1
except Exception as exc:
print(f"Generated an exception: {exc}")

while len(future_to_idx) < MAX_QUEUE_SIZE and dataset_idx < len(dataset):
future_to_idx[executor.submit(fetch_single_image, dataset[dataset_idx]["image_url"])] = dataset_idx
dataset_idx += 1
dataset.push_to_hub("gia-project/gia-dataset-parquet", "conceptual-captions", split=split)
break
except Exception:
print(f"Retry {i+1}/{retry}")
Loading
Loading