Skip to content

Commit

Permalink
Removed incorrect handling of ignored faces
Browse files Browse the repository at this point in the history
  • Loading branch information
mike9251 committed Jan 29, 2023
1 parent c01a7e0 commit c4def24
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions src/simswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,16 +284,16 @@ def __call__(self, att_image: np.ndarray) -> np.ndarray:
)
align_att_img_batch = align_att_img_batch.to(self.device, non_blocking=True)

n, c, h, w = align_att_img_batch.shape
img_white = torch.zeros((n, 1, h, w), dtype=align_att_img_batch.dtype, device=self.device) + 255.0

inv_att_transforms: torch.Tensor = inverse_transform_batch(att_transforms)

# Get face masks for the attribute image
face_mask, ignore_mask_ids = self.bise_net.get_mask(
align_att_img_batch_for_parsing_model, self.crop_size
)

n, c, h, w = align_att_img_batch.shape
img_white = torch.zeros((n, 1, h, w), dtype=align_att_img_batch.dtype, device=self.device) + 255.0

inv_att_transforms: torch.Tensor = inverse_transform_batch(att_transforms)

soft_face_mask, _ = self.smooth_mask(face_mask)

# Only take face area from the swapped image
Expand All @@ -306,10 +306,6 @@ def __call__(self, att_image: np.ndarray) -> np.ndarray:

att_image = self.to_tensor(att_image).to(self.device, non_blocking=True)

if torch.sum(ignore_mask_ids.int()) > 0:
img_white = img_white[ignore_mask_ids, ...]
inv_att_transforms = inv_att_transforms[ignore_mask_ids, ...]

# to avoid OOM apply erosion on low res masks
img_white = F.pad(img_white, (self.erode_mask_value, self.erode_mask_value, self.erode_mask_value, self.erode_mask_value))

Expand Down

0 comments on commit c4def24

Please sign in to comment.