Skip to content

Commit

Permalink
use octree, add color_refine
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiaomeng030 committed Apr 19, 2024
1 parent 6835dff commit a3372f1
Show file tree
Hide file tree
Showing 20 changed files with 767 additions and 677 deletions.
43 changes: 43 additions & 0 deletions slam/algorithms/nice_slam.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class NiceSLAMConfig(AlgorithmConfig):
mapping_lr_factor: float = 1.0
mapping_lr_first_factor: float = 5.0

mapping_color_refine: bool = True


class NiceSLAM(Algorithm):

Expand All @@ -67,6 +69,47 @@ def __init__(self, config: NiceSLAMConfig, camera: Camera,

self.cur_mesh = None

def do_mapping(self, cur_frame):
if not self.is_initialized():
mapping_n_iters = self.config.mapping_first_n_iters
else:
mapping_n_iters = self.config.mapping_n_iters

# here provides a color refinement postprocess
if cur_frame.is_final_frame and self.config.mapping_color_refine:
outer_joint_iters = 5
self.config.mapping_window_size *= 2
self.config.mapping_middle_iter_ratio = 0.0
self.config.mapping_fine_iter_ratio = 0.0
self.model.config.mapping_fix_color = True
self.model.config.mapping_frustum_feature_selection = False
else:
outer_joint_iters = 1

for _ in range(outer_joint_iters):
# select optimize frames
with torch.no_grad():
optimize_frames = self.select_optimize_frames(
cur_frame,
keyframe_selection_method=self.config.
keyframe_selection_method)
# optimize keyframes_pose, model_params, update model params
self.optimize_update(mapping_n_iters,
optimize_frames,
is_mapping=True,
coarse=False)

# do coarse_mapper
optimize_frames = self.select_optimize_frames(
cur_frame, keyframe_selection_method='random')
self.optimize_update(mapping_n_iters,
optimize_frames,
is_mapping=True,
coarse=True)

if not self.is_initialized():
self.set_initialized()

def optimizer_config_update(self, max_iters, coarse=False):
if len(self.keyframe_graph) > 4 and not coarse:
self.bundle_adjust = True
Expand Down
2 changes: 1 addition & 1 deletion slam/algorithms/voxfusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def get_mesh(self):
@torch.no_grad()
def extract_mesh(self, res=8, clean_mesh=False, require_color=False):
# get map states
voxels, _, features, leaf_num = self.model.octree.get_all()
voxels, _, features = self.model.svo.get_centres_and_children()
index = features.eq(-1).any(-1)
voxels = voxels[~index, :]
features = features[~index, :]
Expand Down
1 change: 1 addition & 0 deletions slam/common/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self,
self.gt_pose = gt_pose
self.separate_LR = separate_LR
self.rot_rep = rot_rep
self.is_final_frame = False

if init_pose is not None:
pose = torch.tensor(init_pose,
Expand Down
6 changes: 3 additions & 3 deletions slam/configs/input_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,12 @@
use_relative_pose=True,
save_debug_result=False,
init_pose_offset=10),
mapper=MapperConfig(keyframe_every=10, ),
mapper=MapperConfig(keyframe_every=50, ),
algorithm=VoxFusionConfig(
# keyframe_selection_algorithm='random',
keyframe_selection_method='random',
tracking_n_iters=30,
mapping_n_iters=15, # 30
mapping_first_n_iters=100,
mapping_first_n_iters=30,
mapping_window_size=5,
mapping_sample=1024,
tracking_sample=1024,
Expand Down
6 changes: 3 additions & 3 deletions slam/models/conv_onet.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
decoders_para_list += list(self.decoder.fine_decoder.parameters())
if not self.config.mapping_fix_color:
decoders_para_list += list(self.decoder.color_decoder.parameters())
param_groups['decoder'] = decoders_para_list
if len(decoders_para_list) > 0:
param_groups['decoder'] = decoders_para_list
# grid_params
for key, grid in self.grid_c.items():
grid = grid.to(self.device)
if (self.config.mapping_frustum_feature_selection
and not self.config.coarse):
if self.config.mapping_frustum_feature_selection:
mask = self.grid_opti_mask[key]
grid.set_mask(mask)
param_groups[key] = list(grid.parameters())
Expand Down
14 changes: 5 additions & 9 deletions slam/models/sparse_voxel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def find_so_files(directory):


search_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'../../third_party/sparse_octforest/build/')
'../../third_party/sparse_octree/build/')
so_files = find_so_files(search_directory)
for so_file in so_files:
torch.classes.load_library(so_file)
Expand Down Expand Up @@ -304,7 +304,8 @@ def sdf2weights(self, sdf, z_vals, valid_mask):
1e-8), z_min

def get_octree(self):
self.octree = torch.classes.forest.Octree(self.config.voxels_each_dim)
self.svo = torch.classes.svo.Octree()
self.svo.init(256, self.config.embed_dim, self.config.voxel_size)
self.embeddings = torch.nn.Parameter(torch.zeros(
(self.config.num_embeddings, self.config.embed_dim),
dtype=torch.float32,
Expand All @@ -327,18 +328,13 @@ def insert_points(self, points):
voxels = torch.div(points,
self.config.voxel_size,
rounding_mode='floor')
voxels = torch.unique(voxels.cpu().int(), sorted=False, dim=0)
# here, voxels.cpu().int() and (voxels.cpu().int()[:, None]).view(-1,3)
# has the same shape: [N_points, 3]
# i think we can remove repeated voxel ids to reduce insert time and
# torch loading time for svo.
self.octree.insert(voxels)
self.svo.insert(voxels.cpu().int())
self.update_map_states()

@torch.enable_grad()
def update_map_states(self):
"""This function is modified from voxfusion."""
voxels, children, features, leaf_num = self.octree.get_all()
voxels, children, features = self.svo.get_centres_and_children()
centres = (voxels[:, :3] + voxels[:, -1:] / 2) * self.config.voxel_size
children = torch.cat([children, voxels[:, -1:]], -1)

Expand Down
5 changes: 4 additions & 1 deletion slam/pipeline/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,10 @@ def check_mapframe(self, check_frame, map_buffer):
else:
map_every = self.config.map_every
# send to mapper
if check_frame.fid % map_every == 0:
if check_frame.fid % map_every == 0 or check_frame.fid == len(
self.dataset) - 1:
check_frame.is_final_frame = (
check_frame.fid == len(self.dataset) - 1)
map_buffer.put(check_frame, block=True)
return True
return False
Expand Down
2 changes: 1 addition & 1 deletion third_party/install.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

cd ./sparse_octforest/
cd ./sparse_octree/
python setup.py install

cd ../sparse_voxels/
Expand Down
41 changes: 0 additions & 41 deletions third_party/sparse_octforest/include/cuda_utils.h

This file was deleted.

87 changes: 0 additions & 87 deletions third_party/sparse_octforest/include/octree.h

This file was deleted.

28 changes: 0 additions & 28 deletions third_party/sparse_octforest/setup.py

This file was deleted.

24 changes: 0 additions & 24 deletions third_party/sparse_octforest/src/bindings.cpp

This file was deleted.

Loading

0 comments on commit a3372f1

Please sign in to comment.