Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input mask format #35

Open
TritiumR opened this issue Nov 17, 2024 · 1 comment
Open

Input mask format #35

TritiumR opened this issue Nov 17, 2024 · 1 comment

Comments

@TritiumR
Copy link

Thanks for the great work.

I am trying to use custom data for testing and I am not sure if it's right since the results looks odd.

What's the correct format for the "support_masks"? Currently I have segmentation masks with value 0-n representing n+1 parts. How to convert them into the input "support_masks"?

Thank you for any help.

@TritiumR
Copy link
Author

part of my code

# prepare support images and masks
for idx, image_name in enumerate(ref_images):
    support_img = cv2.imread(os.path.join(image_path, 'ref_images', image_name))
    support_img = cv2.resize(support_img, (518, 518), interpolation=cv2.INTER_NEAREST)
    support_img = torch.tensor(np.array(support_img)).float()
    support_img = support_img.transpose(0, 2).transpose(1, 2)

    support_anno = cv2.imread(os.path.join(image_path, 'ref_annotations', image_name.replace('.JPEG', '.png')), cv2.IMREAD_GRAYSCALE)
    support_anno = cv2.resize(support_anno, (518, 518), interpolation=cv2.INTER_NEAREST)
    support_anno = torch.tensor(np.array(support_anno))

    for part_id in range(args.part_num):
        support_mask = (support_anno == part_id).float()
        # print('support_mask', support_mask.min(), support_mask.max(), support_mask.sum())
        support_annos[part_id].append(support_mask)

    support_imgs.append(support_img)

support_imgs = torch.stack(support_imgs, dim=0)

# Testing
for idx, image_name in enumerate(images):
    query_img = cv2.imread(os.path.join(image_path, 'images', image_name))
    query_img = cv2.resize(query_img, (518, 518), interpolation=cv2.INTER_NEAREST)
    query_img = torch.tensor(np.array(query_img)).float()
    query_img = query_img.transpose(0, 2).transpose(1, 2).unsqueeze(0)

    query_anno = cv2.imread(os.path.join(image_path, 'annotations', image_name.replace('.JPEG', '.png')), cv2.IMREAD_GRAYSCALE)
    query_anno = cv2.resize(query_anno, (518, 518), interpolation=cv2.INTER_NEAREST)
    query_anno = torch.tensor(np.array(query_anno))

    for part_id in range(args.part_num):
        # 1. Matcher prepare references
        support_masks = torch.stack(support_annos[part_id], dim=0)
        print('support_masks', support_masks.size(), support_masks.min(), support_masks.max(), support_masks.sum())
        matcher.set_reference(support_imgs[None].to(args.device), support_masks[None].to(args.device))

        # 2. Matcher prepare target
        matcher.set_target(query_img.to(args.device))

        # 3. Predict mask of target
        pred_mask = matcher.predict()
        matcher.clear()`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant