-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_crop_unlabeled.py
125 lines (110 loc) · 4.94 KB
/
get_crop_unlabeled.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from croptrain.modeling.meta_arch.rcnn import TwoStagePseudoLabGeneralizedRCNN
from croptrain.modeling.roi_heads.roi_heads import StandardROIHeadsPseudoLab
from croptrain.modeling.proposal_generator.rpn import PseudoLabRPN
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from croptrain import add_croptrainer_config, add_ubteacher_config
from detectron2.data import DatasetCatalog, MetadataCatalog
import os
from croptrain.data.datasets.visdrone import register_visdrone
from croptrain.engine.trainer import UBTeacherTrainer, BaselineTrainer
import numpy as np
import torch
import datetime
import time
import copy
import cv2
import json
from utils.crop_utils import get_dict_from_crops
from contextlib import ExitStack, contextmanager
from detectron2.structures.instances import Instances
from detectron2.structures.boxes import Boxes
import matplotlib.pyplot as plt
import logging
from croptrain.modeling.meta_arch.ts_ensemble import EnsembleTSModel
from detectron2.data.build import get_detection_dataset_dicts
from detectron2.utils.logger import log_every_n_seconds
from croptrain.data.datasets.visdrone import compute_crops
logging.basicConfig(level = logging.INFO)
@contextmanager
def inference_context(model):
"""
A context where the model is temporarily changed to eval mode,
and restored to previous mode afterwards.
Args:
model: a torch Module
"""
training_mode = model.training
model.eval()
yield
model.train(training_mode)
def shift_crop_boxes(data_dict, cluster_boxes):
x1, y1 = data_dict["crop_area"][0], data_dict["crop_area"][1]
ref_point = np.array([x1, y1, x1, y1])
cluster_boxes = cluster_boxes + ref_point
return cluster_boxes
def inference_crops(model, data_loader, cfg):
#dataset_dicts = get_detection_dataset_dicts(cfg.DATASETS.TEST, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS)
dataset_name = cfg.DATASETS.TRAIN[0].split("_")[0]
crop_file = os.path.join("dataseed", dataset_name + "_crops_{}.txt".format(cfg.DATALOADER.SUP_PERCENT))
crop_storage = {}
total = len(data_loader) # inference data loader must have a fixed length
cluster_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES - 1
with ExitStack() as stack:
if isinstance(model, torch.nn.Module):
stack.enter_context(inference_context(model))
stack.enter_context(torch.no_grad())
count = 0
n_crops = 0
for idx, inputs in enumerate(data_loader):
outputs = model(inputs)
cluster_class_indices = (outputs[0]["instances"].pred_classes==cluster_class)
cluster_boxes = outputs[0]["instances"][cluster_class_indices]
cluster_boxes = cluster_boxes[cluster_boxes.scores>0.35]
file_name = inputs[0]["file_name"].split('/')[-1]
if file_name not in crop_storage:
crop_storage[file_name] = []
if idx%100==0:
print("processing {}th image".format(idx))
if len(cluster_boxes)>0:
cluster_boxes = cluster_boxes.pred_boxes.tensor.cpu().numpy().astype(np.int32)
if not inputs[0]["full_image"]:
cluster_boxes = shift_crop_boxes(inputs[0], cluster_boxes)
crop_storage[file_name] += cluster_boxes.tolist()
count += 1
n_crops += len(cluster_boxes)
with open(crop_file, "w") as f:
json.dump(crop_storage, f)
print("crops present in {}/{} images".format(count, len(data_loader)))
print("number of crops is {} ".format(n_crops))
def main():
cfg = get_cfg()
add_croptrainer_config(cfg)
add_ubteacher_config(cfg)
cfg.merge_from_file(os.path.join(os.getcwd(), 'configs', 'visdrone', 'Semi-Sup-RCNN-FPN-CROP.yaml'))
if cfg.CROPTRAIN.USE_CROPS:
cfg.MODEL.ROI_HEADS.NUM_CLASSES += 1
cfg.MODEL.RETINANET.NUM_CLASSES += 1
data_dir = os.path.join(os.environ['SLURM_TMPDIR'], "VisDrone")
dataset_name = cfg.DATASETS.TRAIN[0]
cfg.OUTPUT_DIR = "/home/akhil135/scratch/DroneSSOD/FPN_CROP_SS_10_06"
cfg.MODEL.WEIGHTS = "/home/akhil135/scratch/DroneSSOD/FPN_CROP_SS_10_06/model_0069999.pth"
if not dataset_name in DatasetCatalog:
register_visdrone(dataset_name, data_dir, cfg, False)
if cfg.SEMISUPNET.USE_SEMISUP:
Trainer = UBTeacherTrainer
else:
Trainer = BaselineTrainer
model = Trainer.build_model(cfg)
if cfg.SEMISUPNET.USE_SEMISUP:
model_teacher = Trainer.build_model(cfg)
ensem_ts_model = EnsembleTSModel(model_teacher, model)
DetectionCheckpointer(
ensem_ts_model, save_dir=cfg.OUTPUT_DIR
).resume_or_load(cfg.MODEL.WEIGHTS, resume=False)
else:
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(cfg.MODEL.WEIGHTS, resume=False)
data_loader = Trainer.build_test_loader(cfg, dataset_name)
inference_crops(ensem_ts_model.modelTeacher, data_loader, cfg)
if __name__ == "__main__":
main()