Skip to content

Commit

Permalink
init data files
Browse files Browse the repository at this point in the history
  • Loading branch information
kahnchana committed Oct 3, 2024
1 parent 071e463 commit 26e868c
Show file tree
Hide file tree
Showing 4 changed files with 395 additions and 0 deletions.
169 changes: 169 additions & 0 deletions data/coco_spatial_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import json

from PIL import Image, ImageDraw, ImageFont


class CocoSpatialDataset:
def __init__(self, dataset_root, annotation_file, autoload=True):
"""
Initializes the CocoDataset with the root directory of the dataset and the annotation file.
dataset_root: The root directory where the dataset is stored.
annotation_file: The path to the COCO annotations file (JSON format).
autoload: Whether to load the dataset automatically. Defaults to False.
"""
self.dataset_root = dataset_root
self.annotation_file = annotation_file
self.image_id_list = None
self.coco_data = None
self.categories = None
self.annotations = None

if autoload:
self.load_dataset()

def __len__(self):
return len(self.image_id_list)

def __getitem__(self, index):
image_id = self.image_id_list[index]
image, annotations = self.get_image_annotations(image_id)
return image, annotations

def load_dataset(self):
"""
Loads the COCO dataset annotations from the specified JSON file.
"""
with open(self.annotation_file, 'r') as f:
self.coco_data = json.load(f)
# Create a mapping from category ID to category name
self.categories = {cat['id']: cat['name'] for cat in self.coco_data['categories']}
self.annotations = self.coco_data['data']
self.image_id_list = list(self.annotations.keys())

def get_image_annotations(self, image_id):
"""
Retrieves the image and its annotations given an image ID.
image_id: The ID of the image to retrieve.
Return:
A tuple (image, annotations) where `image` is a PIL image object,
`annotations` is a list of bounding boxes and category IDs for the given image.
"""
# Find the image information by image ID
datum = self.annotations[image_id]

image_path = f'{self.dataset_root}/val2014/{datum["file_name"]}'
image = Image.open(image_path)

# Get the annotations for the given image ID
annotations = datum['annotations']
good_pairs = datum['good_pairs'] if 'good_pairs' in datum else None
data = {'annotation': annotations, 'good_pairs': good_pairs}

return image, data

def visualize_image(self, image, annotations, font_path=None, font_size=25):
"""
Visualizes the image by drawing bounding boxes and category labels on it.
:param image: The PIL image object to visualize.
:param annotations: A list of annotations with bounding boxes and category IDs.
:param font_path: The path to the font file for rendering text. Defaults to "arial.ttf" if available.
:param font_size: The font size to use for the text.
"""
vis_image = image.copy()
draw = ImageDraw.Draw(vis_image)

if isinstance(annotations, dict):
annotations = annotations['annotation']

# Load the font for text labels
try:
font = ImageFont.truetype(font_path or "arial.ttf", font_size)
except IOError:
font = ImageFont.load_default()

# Loop through each annotation and draw the bounding box and label
for ann in annotations:
bbox = ann['bbox'] # [x, y, width, height]
x, y, width, height = bbox
category_id = ann['category_id']
category_name = self.categories[category_id]

# Draw the bounding box
draw.rectangle([x, y, x + width, y + height], outline='green', width=3)

# Draw the category label
text_position = (x, y)
draw.text(text_position, category_name, fill="red", font=font)

return vis_image

def generate_spatial_questions(self, image, annotation):
flipped_image = image.transpose(Image.FLIP_LEFT_RIGHT)
object_list = [self.categories[x['category_id']] for x in annotation['annotation']]
object_pairs = annotation['good_pairs']

question_list = []
answer_list = []
for (obj_left, obj_right) in object_pairs:
name_left = object_list[obj_left]
name_right = object_list[obj_right]
question = f"Which side of the {name_left} is the {name_right}?"
# correct and wrong answers respectively
answers = [
f"The {name_right} is on the right side of the {name_left}.",
f"The {name_right} is on the left side of the {name_left}.",
]
question_list.append(question)
answer_list.append(answers)

return {
'image': image,
'image_flipped': flipped_image,
'questions': question_list,
'answers': answer_list
}

def generate_object_questions(self, annotation):
object_list = [self.categories[x['category_id']] for x in annotation['annotation']]
question_list = []
answer_list = []
for obj_name in object_list:
if obj_name.startswith(tuple("aeiou")):
question = f"Is there an {obj_name} in the image?"
# correct and wrong answers respectively
answer = [
f"Yes, there is a {obj_name} in the image.",
f"No, there is no {obj_name} in the image.",
]
else:
question = f"Is there a {obj_name} in the image?"
# correct and wrong answers respectively
answer = [
f"Yes, there is a {obj_name} in the image.",
f"No, there is no {obj_name} in the image.",
]

question_list.append(question)
answer_list.append(answer)

return {
'questions': question_list,
'answers': answer_list
}


if __name__ == "__main__":
# Sample usage code.
file_root = "/home/kanchana/data/mscoco/coco_2014"
anno_file = "/home/kanchana/repo/locvlm/data/coco_spatial.json"

dataset = CocoSpatialDataset(file_root, anno_file)

image, annotation = dataset[5]
vis_image = dataset.visualize_image(image, annotation)
object_eval_data = dataset.generate_object_questions(annotation)
spatial_eval_data = dataset.generate_spatial_questions(image, annotation)
124 changes: 124 additions & 0 deletions data/preprocessing/coco_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import json

from PIL import Image, ImageDraw, ImageFont


class CocoDataset:
def __init__(self, dataset_root, annotation_file, autoload=False):
"""
Initializes the CocoDataset with the root directory of the dataset and the annotation file.
dataset_root: The root directory where the dataset is stored.
annotation_file: The path to the COCO annotations file (JSON format).
autoload: Whether to load the dataset automatically. Defaults to False.
"""
self.dataset_root = dataset_root
self.annotation_file = annotation_file
self.coco_data = None
self.categories = None
self.annotations = None

if autoload:
self.load_dataset()

def load_dataset(self):
"""
Loads the COCO dataset annotations from the specified JSON file.
"""
with open(self.annotation_file, 'r') as f:
self.coco_data = json.load(f)
# Create a mapping from category ID to category name
self.categories = {cat['id']: cat['name'] for cat in self.coco_data['categories']}
self.annotations = self.create_annotations_dict()

def create_annotations_dict(self):
"""
Creates a dictionary that maps image_id to the list of annotations for that image.
This will allow for faster lookups of annotations by image_id.
Return:
A dictionary where keys are image IDs and values are lists of annotations for that image.
"""
annotations_dict = {}

for ann in self.coco_data['annotations']:
image_id = ann['image_id']
if image_id not in annotations_dict:
annotations_dict[image_id] = []
annotations_dict[image_id].append(ann)

return annotations_dict

def get_image_annotations(self, image_id):
"""
Retrieves the image and its annotations given an image ID.
image_id: The ID of the image to retrieve.
Return:
A tuple (image, annotations) where `image` is a PIL image object,
`annotations` is a list of bounding boxes and category IDs for the given image.
"""
# Find the image information by image ID
image_info = next((img for img in self.coco_data['images'] if img['id'] == image_id), None)
if image_info is None:
raise ValueError(f"Image ID {image_id} not found in the dataset.")

image_path = f'{self.dataset_root}/val2014/{image_info["file_name"]}'
image = Image.open(image_path)

# Get the annotations for the given image ID
annotations = self.annotations[image_id]

return image, annotations

def visualize_image(self, image, annotations, font_path=None, font_size=25):
"""
Visualizes the image by drawing bounding boxes and category labels on it.
:param image: The PIL image object to visualize.
:param annotations: A list of annotations with bounding boxes and category IDs.
:param font_path: The path to the font file for rendering text. Defaults to "arial.ttf" if available.
:param font_size: The font size to use for the text.
"""
vis_image = image.copy()
draw = ImageDraw.Draw(vis_image)

# Load the font for text labels
try:
font = ImageFont.truetype(font_path or "arial.ttf", font_size)
except IOError:
font = ImageFont.load_default()

# Loop through each annotation and draw the bounding box and label
for ann in annotations:
bbox = ann['bbox'] # [x, y, width, height]
x, y, width, height = bbox
category_id = ann['category_id']
category_name = self.categories[category_id]

# Draw the bounding box
draw.rectangle([x, y, x + width, y + height], outline='green', width=3)

# Draw the category label
text_position = (x, y)
draw.text(text_position, category_name, fill="red", font=font)

return vis_image

# Example Usage:
if __name__ == "__main__":
# Initialize the visualizer
file_root = "/home/kanchana/data/mscoco/coco_2014"
anno_file = f"{file_root}/annotations/instances_val2014.json"

dataset = CocoDataset(file_root, anno_file)

# Load the dataset
dataset.load_dataset()

# Get image and annotations for a specific image ID
image_id = dataset.coco_data['images'][0]['id'] # Using the first image in the dataset

image, annotation = dataset.get_image_annotations(image_id)
vis_image = dataset.visualize_image(image, annotation)
58 changes: 58 additions & 0 deletions data/preprocessing/filter_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import json

import tqdm
from coco_dataset import CocoDataset

# Example Usage:
if __name__ == "__main__":

file_root = "/home/kanchana/data/mscoco/coco_2014"
anno_file = f"{file_root}/annotations/instances_val2014.json"

dataset = CocoDataset(file_root, anno_file)

# Load the dataset
dataset.load_dataset()

per_image_unique_categories = {}

for image_id in tqdm.tqdm(dataset.annotations.keys()):
image, annotation = dataset.get_image_annotations(image_id)

# Filter
categories = [anno['category_id'] for anno in annotation]
counter = {}
for cat in categories:
counter[cat] = counter.get(cat, 0) + 1
per_image_unique_categories[image_id] = [x for x,y in counter.items() if y == 1]

filtered_image_category = {x:y for x,y in per_image_unique_categories.items() if len(y) >= 2}

new_annotation_dict = {}
new_image_info_dict = {}
for image_id, category_list in tqdm.tqdm(filtered_image_category.items()):
image, annotation = dataset.get_image_annotations(image_id)
image_info = [img for img in dataset.coco_data['images'] if img['id'] == image_id][0]
cur_annotation = []
for anno in annotation:
if anno['category_id'] in category_list:
if 'segmentation' in anno:
anno.pop('segmentation')
cur_annotation.append(anno)
new_annotation_dict[image_id] = cur_annotation
new_image_info_dict[image_id] = image_info

combined_dict = {}
for key in new_annotation_dict.keys():
combined_dict[key] = {
'annotations': new_annotation_dict[key],
'file_name': new_image_info_dict[key]['file_name'],
'height': new_image_info_dict[key]['height'],
'width': new_image_info_dict[key]['width']
}
save_dict = {
'categories': dataset.coco_data['categories'],
'data': combined_dict
}
save_path = "/home/kanchana/repo/locvlm/data/coco_spatial.json"
json.dump(save_dict, open(save_path, "w"), indent=2)
Loading

0 comments on commit 26e868c

Please sign in to comment.