Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added same_focals option to PointCloudOptimizer and demo #94

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,11 @@ def get_args_parser():

def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
cam_color=None, as_pointcloud=False,
transparent_cams=False, silent=False):
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
transparent_cams=False, silent=False, same_focals=False):

assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
if not same_focals:
assert(len(cams2world) == len(focals))
pts3d = to_numpy(pts3d)
imgs = to_numpy(imgs)
focals = to_numpy(focals)
Expand All @@ -85,8 +88,12 @@ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world,
camera_edge_color = cam_color[i]
else:
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
if same_focals:
focal = focals[0]
else:
focal = focals[i]
add_scene_cam(scene, pose_c2w, camera_edge_color,
None if transparent_cams else imgs[i], focals[i],
None if transparent_cams else imgs[i], focal,
imsize=imgs[i].shape[1::-1], screen_width=cam_size)

rot = np.eye(4)
Expand All @@ -100,7 +107,7 @@ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world,


def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
clean_depth=False, transparent_cams=False, cam_size=0.05):
clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
"""
extract 3D_model (glb file) from a reconstructed scene
"""
Expand All @@ -121,12 +128,12 @@ def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
msk = to_numpy(scene.get_masks())
return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent, same_focals=same_focals)


def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr,
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
scenegraph_type, winsize, refid):
scenegraph_type, winsize, refid, same_focals):
"""
from a list of images, run dust3r inference, global aligner.
then run get_3D_model_from_scene
Expand All @@ -144,14 +151,14 @@ def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist,
output = inference(pairs, model, device, batch_size=batch_size, verbose=not silent)

mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
scene = global_aligner(output, device=device, mode=mode, verbose=not silent)
scene = global_aligner(output, device=device, mode=mode, verbose=not silent, same_focals=same_focals)
lr = 0.01

if mode == GlobalAlignerMode.PointCloudOptimizer:
loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)

outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size)
clean_depth, transparent_cams, cam_size, same_focals=same_focals)

# also return rgb, depth and confidence imgs
# depth is normalized with the max value for all images
Expand Down Expand Up @@ -213,6 +220,7 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
value='complete', label="Scenegraph",
info="Define how to make pairs",
interactive=True)
same_focals = gradio.Checkbox(value=False, label="Focal", info="Use the same focal for all cameras")
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
minimum=1, maximum=1, step=1, visible=False)
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
Expand Down Expand Up @@ -244,33 +252,34 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
run_btn.click(fn=recon_fun,
inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
mask_sky, clean_depth, transparent_cams, cam_size,
scenegraph_type, winsize, refid],
scenegraph_type, winsize, refid, same_focals],
outputs=[scene, outmodel, outgallery])
min_conf_thr.release(fn=model_from_scene_fun,
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size],
clean_depth, transparent_cams, cam_size, same_focals],
outputs=outmodel)
cam_size.change(fn=model_from_scene_fun,
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size],
clean_depth, transparent_cams, cam_size, same_focals],
outputs=outmodel)
as_pointcloud.change(fn=model_from_scene_fun,
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size],
clean_depth, transparent_cams, cam_size, same_focals],
outputs=outmodel)
mask_sky.change(fn=model_from_scene_fun,
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size],
clean_depth, transparent_cams, cam_size, same_focals],
outputs=outmodel)
clean_depth.change(fn=model_from_scene_fun,
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size],
clean_depth, transparent_cams, cam_size, same_focals],
outputs=outmodel)
transparent_cams.change(model_from_scene_fun,
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size],
clean_depth, transparent_cams, cam_size, same_focals],
outputs=outmodel)
demo.launch(share=False, server_name=server_name, server_port=server_port)

demo.launch(share=True, server_name=server_name, server_port=server_port)


if __name__ == '__main__':
Expand Down
12 changes: 10 additions & 2 deletions dust3r/cloud_opt/base_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _init_from_views(self, view1, view2, pred1, pred2,
pw_break=20,
rand_pose=torch.randn,
iterationsCount=None,
same_focals=False,
verbose=True):
super().__init__()
if not isinstance(view1['idx'], list):
Expand All @@ -60,6 +61,7 @@ def _init_from_views(self, view1, view2, pred1, pred2,
self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
self.dist = ALL_DISTS[dist]
self.verbose = verbose
self.same_focals = same_focals

self.n_imgs = self._check_edges()

Expand Down Expand Up @@ -237,12 +239,18 @@ def clean_pointcloud(self, tol=0.001, max_bad_conf=0):
"""
assert 0 <= tol < 1
cams = inv(self.get_im_poses())
K = self.get_intrinsics()
Ks = self.get_intrinsics()
depthmaps = self.get_depthmaps()
res = deepcopy(self)

for i, pts3d in enumerate(self.depth_to_pts3d()):
for j in range(self.n_imgs):

if self.same_focals:
K = Ks[0]
else:
K = Ks[j]

if i == j:
continue

Expand All @@ -251,7 +259,7 @@ def clean_pointcloud(self, tol=0.001, max_bad_conf=0):
Hj, Wj = self.imshapes[j]
proj = geotrf(cams[j], pts3d[:Hi*Wi]).reshape(Hi, Wi, 3)
proj_depth = proj[:, :, 2]
u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
u, v = geotrf(K, proj, norm=1, ncol=2).round().long().unbind(-1)

# check which points are actually in the visible cone
msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj)
Expand Down
4 changes: 3 additions & 1 deletion dust3r/cloud_opt/init_im_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ def init_from_pts3d(self, pts3d, im_focals, im_poses):
depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
self._set_depthmap(i, depth)
self._set_pose(self.im_poses, i, cam2world)
if im_focals[i] is not None:
if im_focals[i] is not None and not self.same_focals:
self._set_focal(i, im_focals[i])
if self.same_focals:
self._set_focal(0, torch.tensor(im_focals).mean()) # initialize with mean focal

if self.verbose:
print(' init loss =', float(self()))
Expand Down
14 changes: 9 additions & 5 deletions dust3r/cloud_opt/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
# adding thing to optimize
self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
self.im_focals = nn.ParameterList(torch.FloatTensor(
[self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics
if self.same_focals:
self.im_focals = nn.Parameter(torch.FloatTensor([[torch.tensor(self.focal_break)*np.log(max(self.imshapes[0]))]])) # initialize with H x W of first image
else:
self.im_focals = nn.ParameterList(torch.FloatTensor(
[self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics
self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
self.im_pp.requires_grad_(optimize_pp)

Expand Down Expand Up @@ -175,7 +178,7 @@ def depth_to_pts3d(self):
depth = self.get_depthmaps(raw=True)

# get pointmaps in camera frame
rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp)
rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp, same_focals=self.same_focals)
# project to world frame
return geotrf(im_poses, rel_ptmaps)

Expand All @@ -201,10 +204,11 @@ def forward(self):
return li + lj


def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp, same_focals=False):
pp = pp.unsqueeze(1)
focal = focal.unsqueeze(1)
assert focal.shape == (len(depth), 1, 1)
if not same_focals:
assert focal.shape == (len(depth), 1, 1)
assert pp.shape == (len(depth), 1, 2)
assert pixel_grid.shape == depth.shape + (2,)
depth = depth.unsqueeze(-1)
Expand Down
15 changes: 12 additions & 3 deletions dust3r/cloud_opt/pair_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def __init__(self, *args, **kwargs):
self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]]

self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False)
self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)
if self.same_focals:
self.focals = nn.Parameter(torch.tensor([torch.tensor(self.focals).mean()]), requires_grad = False)
else:
self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)
self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False)
self.depth = nn.ParameterList(self.depth)
for p in self.parameters():
Expand Down Expand Up @@ -116,9 +119,15 @@ def get_im_poses(self):

def depth_to_pts3d(self):
pts3d = []
for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()):

for i, (d, im_pose) in enumerate(zip(self.depth, self.get_im_poses())):

if self.same_focals:
intrinsic = self.get_intrinsics()[0]
else:
intrinsic = self.get_intrinsics()[i]
pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(),
intrinsics.cpu().numpy(),
intrinsic.cpu().numpy(),
im_pose.cpu().numpy())
pts3d.append(torch.from_numpy(pts).to(device=self.device))
return pts3d
Expand Down