From 5f7e8ef3707e245dbd948b347453861f7d503ab3 Mon Sep 17 00:00:00 2001 From: Victor Wang Date: Fri, 2 Sep 2022 13:00:25 +1200 Subject: [PATCH] Fix incorrect calculation of TP and FP counts for given '-thresh' value --- src/detector.c | 103 +++++++++++++++++++++++++------------------------ 1 file changed, 53 insertions(+), 50 deletions(-) diff --git a/src/detector.c b/src/detector.c index 0b947b69089..89d3140ac14 100644 --- a/src/detector.c +++ b/src/detector.c @@ -1106,13 +1106,16 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa const int checkpoint_detections_count = detections_count; - int i; - for (i = 0; i < nboxes; ++i) { - - int class_id; - for (class_id = 0; class_id < classes; ++class_id) { + // For each class, match detections in decreasing order of confidence + int class_id; + for (class_id = 0; class_id < classes; ++class_id) { + // Create detections for this class then sort them by confidence + int num_class_dets = 0; + int i; + for (i = 0; i < nboxes; ++i) { float prob = dets[i].prob[class_id]; if (prob > 0) { + num_class_dets++; detections_count++; detections = (box_prob*)xrealloc(detections, detections_count * sizeof(box_prob)); detections[detections_count - 1].b = dets[i].bbox; @@ -1121,59 +1124,59 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa detections[detections_count - 1].class_id = class_id; detections[detections_count - 1].truth_flag = 0; detections[detections_count - 1].unique_truth_index = -1; - - int truth_index = -1; - float max_iou = 0; - for (j = 0; j < num_labels; ++j) - { - box t = { truth[j].x, truth[j].y, truth[j].w, truth[j].h }; - //printf(" IoU = %f, prob = %f, class_id = %d, truth[j].id = %d \n", - // box_iou(dets[i].bbox, t), prob, class_id, truth[j].id); - float current_iou = box_iou(dets[i].bbox, t); - if (current_iou > iou_thresh && class_id == truth[j].id) { - if (current_iou > max_iou) { - max_iou = current_iou; - truth_index = unique_truth_count + j; - } + } + } + qsort(detections + detections_count - num_class_dets, num_class_dets, sizeof(box_prob), detections_comparator); + + for (i = detections_count - num_class_dets; i < detections_count; ++i) { + int truth_index = -1; + float max_iou = 0; + for (j = 0; j < num_labels; ++j) { + box t = { truth[j].x, truth[j].y, truth[j].w, truth[j].h }; + float current_iou = box_iou(detections[i].b, t); + if (current_iou > iou_thresh && class_id == truth[j].id) { + if (current_iou > max_iou) { + max_iou = current_iou; + truth_index = unique_truth_count + j; } } + } - // best IoU - if (truth_index > -1) { - detections[detections_count - 1].truth_flag = 1; - detections[detections_count - 1].unique_truth_index = truth_index; - } - else { - // if object is difficult then remove detection - for (j = 0; j < num_labels_dif; ++j) { - box t = { truth_dif[j].x, truth_dif[j].y, truth_dif[j].w, truth_dif[j].h }; - float current_iou = box_iou(dets[i].bbox, t); - if (current_iou > iou_thresh && class_id == truth_dif[j].id) { - --detections_count; - break; - } + // best IoU + if (truth_index > -1) { + detections[i].truth_flag = 1; + detections[i].unique_truth_index = truth_index; + } + else { + // if object is difficult then remove detection + for (j = 0; j < num_labels_dif; ++j) { + box t = { truth_dif[j].x, truth_dif[j].y, truth_dif[j].w, truth_dif[j].h }; + float current_iou = box_iou(detections[i].b, t); + if (current_iou > iou_thresh && class_id == truth_dif[j].id) { + --detections_count; + break; } } + } - // calc avg IoU, true-positives, false-positives for required Threshold - if (prob > thresh_calc_avg_iou) { - int z, found = 0; - for (z = checkpoint_detections_count; z < detections_count - 1; ++z) { - if (detections[z].unique_truth_index == truth_index) { - found = 1; break; - } + // calc avg IoU, true-positives, false-positives at the required threshold + if (detections[i].p > thresh_calc_avg_iou) { + int z, found = 0; + for (z = detections_count - num_class_dets; z < i - 1; ++z) { + if (detections[z].unique_truth_index == truth_index) { + found = 1; break; } + } - if (truth_index > -1 && found == 0) { - avg_iou += max_iou; - ++tp_for_thresh; - avg_iou_per_class[class_id] += max_iou; - tp_for_thresh_per_class[class_id]++; - } - else{ - fp_for_thresh++; - fp_for_thresh_per_class[class_id]++; - } + if (truth_index > -1 && found == 0) { + avg_iou += max_iou; + ++tp_for_thresh; + avg_iou_per_class[class_id] += max_iou; + tp_for_thresh_per_class[class_id]++; + } + else { + fp_for_thresh++; + fp_for_thresh_per_class[class_id]++; } } }