-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
395 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import json | ||
|
||
from PIL import Image, ImageDraw, ImageFont | ||
|
||
|
||
class CocoSpatialDataset: | ||
def __init__(self, dataset_root, annotation_file, autoload=True): | ||
""" | ||
Initializes the CocoDataset with the root directory of the dataset and the annotation file. | ||
dataset_root: The root directory where the dataset is stored. | ||
annotation_file: The path to the COCO annotations file (JSON format). | ||
autoload: Whether to load the dataset automatically. Defaults to False. | ||
""" | ||
self.dataset_root = dataset_root | ||
self.annotation_file = annotation_file | ||
self.image_id_list = None | ||
self.coco_data = None | ||
self.categories = None | ||
self.annotations = None | ||
|
||
if autoload: | ||
self.load_dataset() | ||
|
||
def __len__(self): | ||
return len(self.image_id_list) | ||
|
||
def __getitem__(self, index): | ||
image_id = self.image_id_list[index] | ||
image, annotations = self.get_image_annotations(image_id) | ||
return image, annotations | ||
|
||
def load_dataset(self): | ||
""" | ||
Loads the COCO dataset annotations from the specified JSON file. | ||
""" | ||
with open(self.annotation_file, 'r') as f: | ||
self.coco_data = json.load(f) | ||
# Create a mapping from category ID to category name | ||
self.categories = {cat['id']: cat['name'] for cat in self.coco_data['categories']} | ||
self.annotations = self.coco_data['data'] | ||
self.image_id_list = list(self.annotations.keys()) | ||
|
||
def get_image_annotations(self, image_id): | ||
""" | ||
Retrieves the image and its annotations given an image ID. | ||
image_id: The ID of the image to retrieve. | ||
Return: | ||
A tuple (image, annotations) where `image` is a PIL image object, | ||
`annotations` is a list of bounding boxes and category IDs for the given image. | ||
""" | ||
# Find the image information by image ID | ||
datum = self.annotations[image_id] | ||
|
||
image_path = f'{self.dataset_root}/val2014/{datum["file_name"]}' | ||
image = Image.open(image_path) | ||
|
||
# Get the annotations for the given image ID | ||
annotations = datum['annotations'] | ||
good_pairs = datum['good_pairs'] if 'good_pairs' in datum else None | ||
data = {'annotation': annotations, 'good_pairs': good_pairs} | ||
|
||
return image, data | ||
|
||
def visualize_image(self, image, annotations, font_path=None, font_size=25): | ||
""" | ||
Visualizes the image by drawing bounding boxes and category labels on it. | ||
:param image: The PIL image object to visualize. | ||
:param annotations: A list of annotations with bounding boxes and category IDs. | ||
:param font_path: The path to the font file for rendering text. Defaults to "arial.ttf" if available. | ||
:param font_size: The font size to use for the text. | ||
""" | ||
vis_image = image.copy() | ||
draw = ImageDraw.Draw(vis_image) | ||
|
||
if isinstance(annotations, dict): | ||
annotations = annotations['annotation'] | ||
|
||
# Load the font for text labels | ||
try: | ||
font = ImageFont.truetype(font_path or "arial.ttf", font_size) | ||
except IOError: | ||
font = ImageFont.load_default() | ||
|
||
# Loop through each annotation and draw the bounding box and label | ||
for ann in annotations: | ||
bbox = ann['bbox'] # [x, y, width, height] | ||
x, y, width, height = bbox | ||
category_id = ann['category_id'] | ||
category_name = self.categories[category_id] | ||
|
||
# Draw the bounding box | ||
draw.rectangle([x, y, x + width, y + height], outline='green', width=3) | ||
|
||
# Draw the category label | ||
text_position = (x, y) | ||
draw.text(text_position, category_name, fill="red", font=font) | ||
|
||
return vis_image | ||
|
||
def generate_spatial_questions(self, image, annotation): | ||
flipped_image = image.transpose(Image.FLIP_LEFT_RIGHT) | ||
object_list = [self.categories[x['category_id']] for x in annotation['annotation']] | ||
object_pairs = annotation['good_pairs'] | ||
|
||
question_list = [] | ||
answer_list = [] | ||
for (obj_left, obj_right) in object_pairs: | ||
name_left = object_list[obj_left] | ||
name_right = object_list[obj_right] | ||
question = f"Which side of the {name_left} is the {name_right}?" | ||
# correct and wrong answers respectively | ||
answers = [ | ||
f"The {name_right} is on the right side of the {name_left}.", | ||
f"The {name_right} is on the left side of the {name_left}.", | ||
] | ||
question_list.append(question) | ||
answer_list.append(answers) | ||
|
||
return { | ||
'image': image, | ||
'image_flipped': flipped_image, | ||
'questions': question_list, | ||
'answers': answer_list | ||
} | ||
|
||
def generate_object_questions(self, annotation): | ||
object_list = [self.categories[x['category_id']] for x in annotation['annotation']] | ||
question_list = [] | ||
answer_list = [] | ||
for obj_name in object_list: | ||
if obj_name.startswith(tuple("aeiou")): | ||
question = f"Is there an {obj_name} in the image?" | ||
# correct and wrong answers respectively | ||
answer = [ | ||
f"Yes, there is a {obj_name} in the image.", | ||
f"No, there is no {obj_name} in the image.", | ||
] | ||
else: | ||
question = f"Is there a {obj_name} in the image?" | ||
# correct and wrong answers respectively | ||
answer = [ | ||
f"Yes, there is a {obj_name} in the image.", | ||
f"No, there is no {obj_name} in the image.", | ||
] | ||
|
||
question_list.append(question) | ||
answer_list.append(answer) | ||
|
||
return { | ||
'questions': question_list, | ||
'answers': answer_list | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
# Sample usage code. | ||
file_root = "/home/kanchana/data/mscoco/coco_2014" | ||
anno_file = "/home/kanchana/repo/locvlm/data/coco_spatial.json" | ||
|
||
dataset = CocoSpatialDataset(file_root, anno_file) | ||
|
||
image, annotation = dataset[5] | ||
vis_image = dataset.visualize_image(image, annotation) | ||
object_eval_data = dataset.generate_object_questions(annotation) | ||
spatial_eval_data = dataset.generate_spatial_questions(image, annotation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import json | ||
|
||
from PIL import Image, ImageDraw, ImageFont | ||
|
||
|
||
class CocoDataset: | ||
def __init__(self, dataset_root, annotation_file, autoload=False): | ||
""" | ||
Initializes the CocoDataset with the root directory of the dataset and the annotation file. | ||
dataset_root: The root directory where the dataset is stored. | ||
annotation_file: The path to the COCO annotations file (JSON format). | ||
autoload: Whether to load the dataset automatically. Defaults to False. | ||
""" | ||
self.dataset_root = dataset_root | ||
self.annotation_file = annotation_file | ||
self.coco_data = None | ||
self.categories = None | ||
self.annotations = None | ||
|
||
if autoload: | ||
self.load_dataset() | ||
|
||
def load_dataset(self): | ||
""" | ||
Loads the COCO dataset annotations from the specified JSON file. | ||
""" | ||
with open(self.annotation_file, 'r') as f: | ||
self.coco_data = json.load(f) | ||
# Create a mapping from category ID to category name | ||
self.categories = {cat['id']: cat['name'] for cat in self.coco_data['categories']} | ||
self.annotations = self.create_annotations_dict() | ||
|
||
def create_annotations_dict(self): | ||
""" | ||
Creates a dictionary that maps image_id to the list of annotations for that image. | ||
This will allow for faster lookups of annotations by image_id. | ||
Return: | ||
A dictionary where keys are image IDs and values are lists of annotations for that image. | ||
""" | ||
annotations_dict = {} | ||
|
||
for ann in self.coco_data['annotations']: | ||
image_id = ann['image_id'] | ||
if image_id not in annotations_dict: | ||
annotations_dict[image_id] = [] | ||
annotations_dict[image_id].append(ann) | ||
|
||
return annotations_dict | ||
|
||
def get_image_annotations(self, image_id): | ||
""" | ||
Retrieves the image and its annotations given an image ID. | ||
image_id: The ID of the image to retrieve. | ||
Return: | ||
A tuple (image, annotations) where `image` is a PIL image object, | ||
`annotations` is a list of bounding boxes and category IDs for the given image. | ||
""" | ||
# Find the image information by image ID | ||
image_info = next((img for img in self.coco_data['images'] if img['id'] == image_id), None) | ||
if image_info is None: | ||
raise ValueError(f"Image ID {image_id} not found in the dataset.") | ||
|
||
image_path = f'{self.dataset_root}/val2014/{image_info["file_name"]}' | ||
image = Image.open(image_path) | ||
|
||
# Get the annotations for the given image ID | ||
annotations = self.annotations[image_id] | ||
|
||
return image, annotations | ||
|
||
def visualize_image(self, image, annotations, font_path=None, font_size=25): | ||
""" | ||
Visualizes the image by drawing bounding boxes and category labels on it. | ||
:param image: The PIL image object to visualize. | ||
:param annotations: A list of annotations with bounding boxes and category IDs. | ||
:param font_path: The path to the font file for rendering text. Defaults to "arial.ttf" if available. | ||
:param font_size: The font size to use for the text. | ||
""" | ||
vis_image = image.copy() | ||
draw = ImageDraw.Draw(vis_image) | ||
|
||
# Load the font for text labels | ||
try: | ||
font = ImageFont.truetype(font_path or "arial.ttf", font_size) | ||
except IOError: | ||
font = ImageFont.load_default() | ||
|
||
# Loop through each annotation and draw the bounding box and label | ||
for ann in annotations: | ||
bbox = ann['bbox'] # [x, y, width, height] | ||
x, y, width, height = bbox | ||
category_id = ann['category_id'] | ||
category_name = self.categories[category_id] | ||
|
||
# Draw the bounding box | ||
draw.rectangle([x, y, x + width, y + height], outline='green', width=3) | ||
|
||
# Draw the category label | ||
text_position = (x, y) | ||
draw.text(text_position, category_name, fill="red", font=font) | ||
|
||
return vis_image | ||
|
||
# Example Usage: | ||
if __name__ == "__main__": | ||
# Initialize the visualizer | ||
file_root = "/home/kanchana/data/mscoco/coco_2014" | ||
anno_file = f"{file_root}/annotations/instances_val2014.json" | ||
|
||
dataset = CocoDataset(file_root, anno_file) | ||
|
||
# Load the dataset | ||
dataset.load_dataset() | ||
|
||
# Get image and annotations for a specific image ID | ||
image_id = dataset.coco_data['images'][0]['id'] # Using the first image in the dataset | ||
|
||
image, annotation = dataset.get_image_annotations(image_id) | ||
vis_image = dataset.visualize_image(image, annotation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import json | ||
|
||
import tqdm | ||
from coco_dataset import CocoDataset | ||
|
||
# Example Usage: | ||
if __name__ == "__main__": | ||
|
||
file_root = "/home/kanchana/data/mscoco/coco_2014" | ||
anno_file = f"{file_root}/annotations/instances_val2014.json" | ||
|
||
dataset = CocoDataset(file_root, anno_file) | ||
|
||
# Load the dataset | ||
dataset.load_dataset() | ||
|
||
per_image_unique_categories = {} | ||
|
||
for image_id in tqdm.tqdm(dataset.annotations.keys()): | ||
image, annotation = dataset.get_image_annotations(image_id) | ||
|
||
# Filter | ||
categories = [anno['category_id'] for anno in annotation] | ||
counter = {} | ||
for cat in categories: | ||
counter[cat] = counter.get(cat, 0) + 1 | ||
per_image_unique_categories[image_id] = [x for x,y in counter.items() if y == 1] | ||
|
||
filtered_image_category = {x:y for x,y in per_image_unique_categories.items() if len(y) >= 2} | ||
|
||
new_annotation_dict = {} | ||
new_image_info_dict = {} | ||
for image_id, category_list in tqdm.tqdm(filtered_image_category.items()): | ||
image, annotation = dataset.get_image_annotations(image_id) | ||
image_info = [img for img in dataset.coco_data['images'] if img['id'] == image_id][0] | ||
cur_annotation = [] | ||
for anno in annotation: | ||
if anno['category_id'] in category_list: | ||
if 'segmentation' in anno: | ||
anno.pop('segmentation') | ||
cur_annotation.append(anno) | ||
new_annotation_dict[image_id] = cur_annotation | ||
new_image_info_dict[image_id] = image_info | ||
|
||
combined_dict = {} | ||
for key in new_annotation_dict.keys(): | ||
combined_dict[key] = { | ||
'annotations': new_annotation_dict[key], | ||
'file_name': new_image_info_dict[key]['file_name'], | ||
'height': new_image_info_dict[key]['height'], | ||
'width': new_image_info_dict[key]['width'] | ||
} | ||
save_dict = { | ||
'categories': dataset.coco_data['categories'], | ||
'data': combined_dict | ||
} | ||
save_path = "/home/kanchana/repo/locvlm/data/coco_spatial.json" | ||
json.dump(save_dict, open(save_path, "w"), indent=2) |
Oops, something went wrong.