Skip to content

Commit

Permalink
update to latest SAM 2
Browse files Browse the repository at this point in the history
  • Loading branch information
rentainhe committed Aug 21, 2024
1 parent 35efb4a commit 6e0ddad
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 87 deletions.
22 changes: 21 additions & 1 deletion sam2/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
output_mode: str = "binary_mask",
use_m2m: bool = False,
multimask_output: bool = True,
**kwargs,
) -> None:
"""
Using a SAM 2 model, generates masks for the entire image.
Expand Down Expand Up @@ -148,6 +149,23 @@ def __init__(
self.use_m2m = use_m2m
self.multimask_output = multimask_output

@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
"""
Load a pretrained model from the Hugging Face hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2AutomaticMaskGenerator): The loaded model.
"""
from sam2.build_sam import build_sam2_hf

sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model, **kwargs)

@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""
Expand Down Expand Up @@ -284,7 +302,9 @@ def _process_batch(
orig_h, orig_w = orig_size

# Run model on this batch
points = torch.as_tensor(points, device=self.predictor.device)
points = torch.as_tensor(
points, dtype=torch.float32, device=self.predictor.device
)
in_points = self.predictor._transforms.transform_coords(
points, normalize=normalize, orig_hw=im_size
)
Expand Down
2 changes: 2 additions & 0 deletions sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def build_sam2(
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
**kwargs,
):

if apply_postprocessing:
Expand Down Expand Up @@ -47,6 +48,7 @@ def build_sam2_video_predictor(
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
**kwargs,
):
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
Expand Down
4 changes: 0 additions & 4 deletions sam2/modeling/backbones/hieradet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ def __init__(

self.dim = dim
self.dim_out = dim_out

self.num_heads = num_heads
head_dim = dim_out // num_heads
self.scale = head_dim**-0.5

self.q_pool = q_pool
self.qkv = nn.Linear(dim, dim_out * 3)
self.proj = nn.Linear(dim_out, dim_out)
Expand Down
9 changes: 7 additions & 2 deletions sam2/modeling/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
used by the Attention Is All You Need paper, generalized to work on images.
"""

def __init__(
Expand Down Expand Up @@ -211,6 +211,11 @@ def apply_rotary_enc(
# repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2]
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
if freqs_cis.is_cuda:
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
else:
# torch.repeat on complex numbers may not be supported on non-CUDA devices
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
6 changes: 3 additions & 3 deletions sam2/modeling/sam2_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,10 @@ def _prepare_memory_conditioned_features(
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].cuda(non_blocking=True)
feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
maskmem_enc = (
Expand Down Expand Up @@ -642,7 +642,7 @@ def _prepare_memory_conditioned_features(
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem

# Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
# Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]

Expand Down
11 changes: 7 additions & 4 deletions sam2/sam2_image_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
mask_threshold=0.0,
max_hole_area=0.0,
max_sprinkle_area=0.0,
**kwargs,
) -> None:
"""
Uses SAM-2 to calculate the image embedding for an image, and then
Expand All @@ -33,8 +34,10 @@ def __init__(
sam_model (Sam-2): The model to use for mask prediction.
mask_threshold (float): The threshold to use when converting mask logits
to binary masks. Masks are thresholded at 0 by default.
fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
the maximum area of fill_hole_area in low_res_masks.
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
the maximum area of max_sprinkle_area in low_res_masks.
"""
super().__init__()
self.model = sam_model
Expand Down Expand Up @@ -77,7 +80,7 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
from sam2.build_sam import build_sam2_hf

sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model)
return cls(sam_model, **kwargs)

@torch.no_grad()
def set_image(
Expand Down Expand Up @@ -180,7 +183,7 @@ def predict_batch(
normalize_coords=True,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
It returns a tupele of lists of masks, ious, and low_res_masks_logits.
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
"""
assert self._is_batch, "This function should only be used when in batched mode"
if not self._is_image_set:
Expand Down
20 changes: 12 additions & 8 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ def init_state(
offload_state_to_cpu=False,
async_loading_frames=False,
):
"""Initialize a inference state."""
"""Initialize an inference state."""
compute_device = self.device # device of the model
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
inference_state = {}
inference_state["images"] = images
Expand All @@ -65,11 +67,11 @@ def init_state(
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda")
inference_state["device"] = compute_device
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
inference_state["storage_device"] = compute_device
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
Expand Down Expand Up @@ -119,7 +121,7 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
from sam2.build_sam import build_sam2_video_predictor_hf

sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return cls(sam_model)
return sam_model

def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index."""
Expand Down Expand Up @@ -270,7 +272,8 @@ def add_new_points_or_box(
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)

if prev_out is not None and prev_out["pred_masks"] is not None:
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
device = inference_state["device"]
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
Expand Down Expand Up @@ -586,7 +589,7 @@ def propagate_in_video_preflight(self, inference_state):
# to `propagate_in_video_preflight`).
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
for is_cond in [False, True]:
# Separately consolidate conditioning and non-conditioning temp outptus
# Separately consolidate conditioning and non-conditioning temp outputs
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs
Expand All @@ -595,7 +598,7 @@ def propagate_in_video_preflight(self, inference_state):
for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
consolidated_frame_inds[storage_key].update(temp_frame_inds)
# consolidate the temprary output across all objects on this frame
# consolidate the temporary output across all objects on this frame
for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
Expand Down Expand Up @@ -793,7 +796,8 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size):
)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
Expand Down
48 changes: 36 additions & 12 deletions sam2/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def mask_to_box(masks: torch.Tensor):
compute bounding box given an input mask
Inputs:
- masks: [B, 1, H, W] boxes, dtype=torch.Tensor
- masks: [B, 1, H, W] masks, dtype=torch.Tensor
Returns:
- box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
Expand Down Expand Up @@ -106,19 +106,28 @@ class AsyncVideoFrameLoader:
A list of video frames to be load asynchronously without blocking session start.
"""

def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
def __init__(
self,
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
):
self.img_paths = img_paths
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
self.img_mean = img_mean
self.img_std = img_std
# items in `self._images` will be loaded asynchronously
# items in `self.images` will be loaded asynchronously
self.images = [None] * len(img_paths)
# catch and raise any exceptions in the async loading thread
self.exception = None
# video_height and video_width be filled when loading the first image
self.video_height = None
self.video_width = None
self.compute_device = compute_device

# load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click)
Expand Down Expand Up @@ -152,7 +161,7 @@ def __getitem__(self, index):
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
img = img.cuda(non_blocking=True)
img = img.to(self.compute_device, non_blocking=True)
self.images[index] = img
return img

Expand All @@ -167,6 +176,7 @@ def load_video_frames(
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
"""
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
Expand All @@ -179,12 +189,20 @@ def load_video_frames(
if isinstance(video_path, str) and os.path.isdir(video_path):
jpg_folder = video_path
else:
raise NotImplementedError("Only JPEG frames are supported at this moment")
raise NotImplementedError(
"Only JPEG frames are supported at this moment. For video files, you may use "
"ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
"```\n"
"ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'\n"
"```\n"
"where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
"ffmpeg to start the JPEG file from 00000.jpg."
)

frame_names = [
p
for p in os.listdir(jpg_folder)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
num_frames = len(frame_names)
Expand All @@ -196,17 +214,22 @@ def load_video_frames(

if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
)
return lazy_images, lazy_images.video_height, lazy_images.video_width

images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
img_std = img_std.to(compute_device)
# normalize by mean and std
images -= img_mean
images /= img_std
Expand All @@ -230,8 +253,9 @@ def fill_holes_in_mask_scores(mask, max_area):
except Exception as e:
# Skip the post-processing step on removing small holes if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
Expand Down
5 changes: 3 additions & 2 deletions sam2/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
except Exception as e:
# Skip the post-processing step if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
Expand Down
2 changes: 1 addition & 1 deletion sav_dataset/sav_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
parser.add_argument(
"--do_not_skip_first_and_last_frame",
help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. "
"Set this to true for evaluation on settings that doen't skip first and last frames",
"Set this to true for evaluation on settings that doesn't skip first and last frames",
action="store_true",
)

Expand Down
2 changes: 1 addition & 1 deletion sav_dataset/utils/sav_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _seg2bmap(seg, width=None, height=None):

assert not (
width > w | height > h | abs(ar1 - ar2) > 0.01
), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
), "Cannot convert %dx%d seg to %dx%d bmap." % (w, h, width, height)

e = np.zeros_like(seg)
s = np.zeros_like(seg)
Expand Down
Loading

0 comments on commit 6e0ddad

Please sign in to comment.