Skip to content

Commit

Permalink
fix merge err
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiaomeng030 committed Apr 20, 2024
1 parent ada03ef commit 2d066a1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
2 changes: 2 additions & 0 deletions slam/algorithms/nice_slam.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def get_loss(self,
n_iters,
coarse=False):
self.set_stage(is_mapping, step, n_iters, coarse=coarse)
if is_mapping:
self.model.grid_processing(coarse=coarse)
model_input = self.get_model_input(optimize_frames, is_mapping)
model_outputs = self.model(model_input)
loss_dict = self.model.get_loss_dict(model_outputs, model_input,
Expand Down
18 changes: 2 additions & 16 deletions slam/model_components/feature_grid_nice.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import numpy as np
import torch
import torch.nn as nn


class FeatureGrid(nn.Module):
class FeatureGrid():
def __init__(self, xyz_len, grid_len, c_dim, std=0.01):
super(FeatureGrid, self).__init__()

Expand All @@ -12,16 +10,4 @@ def __init__(self, xyz_len, grid_len, c_dim, std=0.01):
val_shape = [1, c_dim, *val_shape]
val = torch.zeros(val_shape).normal_(mean=0, std=std)

mask = np.ones((val_shape[2:])[::-1]).astype(bool)
mask = torch.from_numpy(mask).permute(
2, 1, 0).unsqueeze(0).unsqueeze(0).repeat(1, val_shape[1], 1, 1, 1)

self.val = nn.Parameter(val, requires_grad=True)
self.mask = nn.Parameter(mask, requires_grad=False)

def set_mask(self, new_mask):
self.mask.data.copy_(new_mask)

def val_mask(self):
masked_val = self.val * self.mask
return masked_val
self.val = val

0 comments on commit 2d066a1

Please sign in to comment.