Skip to content

Commit

Permalink
fix caching ground truth for MGM
Browse files Browse the repository at this point in the history
  • Loading branch information
ziao-guo committed Jun 20, 2024
1 parent 2dc7d29 commit 4737ebb
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions pygmtools/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,11 @@ def get_data(self, ids, test=False, shuffle=True):
if self.sets == 'test':
for pair in id_combination:
id_pair = (ids[pair[0]], ids[pair[1]])
gt_path = os.path.join(self.gt_cache_path, str(id_pair) + '.npy')
np.save(gt_path, perm_mat_dict[pair])
gt = perm_mat_dict[pair].toarray()
gt_path = os.path.join(self.gt_cache_path, str(id_pair) + '_' + str(gt.shape[0]) + '_'
+ str(gt.shape[1]) + '.npy')
if not os.path.exists(gt_path):
np.save(gt_path, perm_mat_dict[pair])

if not test:
return data_list, perm_mat_dict, ids
Expand Down Expand Up @@ -422,7 +425,8 @@ def eval(self, prediction, classes, verbose=False, rm_gt_cache=True):
id_cache.append(ids)
pred_cls_dict[pair_dict['cls']] += 1
perm_mat = pair_dict['perm_mat']
gt_path = os.path.join(self.gt_cache_path, str(ids) + '.npy')
gt_path = os.path.join(self.gt_cache_path, str(ids) + '_' + str(perm_mat.shape[0]) + '_'
+ str(perm_mat.shape[1]) + '.npy')
gt = np.load(gt_path, allow_pickle=True).item()
gt_array = gt.toarray()
assert type(perm_mat) == type(gt_array)
Expand Down Expand Up @@ -534,7 +538,8 @@ class function ``rm_gt_cache`` to remove groud truth cache after evaluation.
id_cache.append(ids)
pred_cls_dict += 1
perm_mat = pair_dict['perm_mat']
gt_path = os.path.join(self.gt_cache_path, str(ids) + '.npy')
gt_path = os.path.join(self.gt_cache_path, str(ids) + '_' + str(perm_mat.shape[0]) + '_'
+ str(perm_mat.shape[1]) + '.npy')
gt = np.load(gt_path, allow_pickle=True).item()
gt_array = gt.toarray()
assert type(perm_mat) == type(gt_array)
Expand Down

0 comments on commit 4737ebb

Please sign in to comment.