Skip to content

Commit

Permalink
add code for training with opacity regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
klnavaneet committed Jul 31, 2024
1 parent 890bd1a commit adecb94
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
5 changes: 5 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ ncls_dc=4096
kmeans_iters=10
st_iter=15000
max_iters=30000
max_prune_iter=20000
lambda_reg=1e-7
output_base=output/exp_001
dset=tandt
scene=train
Expand All @@ -28,4 +30,7 @@ CUDA_VISIBLE_DEVICES=$cuda_device python train_kmeans.py \
--total_iterations "$max_iters" \
--quant_params sh dc rot scale\
--kmeans_freq 100 \
--opacity_reg \
--lambda_reg "$lambda_reg" \
--max_prune_iter "$max_prune_iter" \
--eval
29 changes: 28 additions & 1 deletion train_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,18 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
# Loss
gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))

# Optionally, use opacity regularization - from iter 15000 to max_prune_iter
if args.opacity_reg:
if iteration > args.max_prune_iter or iteration < 15000:
lambda_reg = 0.
else:
lambda_reg = args.lambda_reg
L_reg_op = gaussians.get_opacity.sum()
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) + (
lambda_reg * L_reg_op)
else:
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
loss.backward()

iter_end.record()
Expand Down Expand Up @@ -206,6 +217,14 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
gaussians.reset_opacity()

# Prune Gaussians every 1000 iterations from iter 15000 to max_prune_iter if using opacity regularization
if args.opacity_reg and iteration > 15000:
if iteration <= args.max_prune_iter and iteration % 1000 == 0:
print('Num Gaussians: ', gaussians._xyz.shape[0])
size_threshold = None
gaussians.prune(0.005, scene.cameras_extent, size_threshold)
print('Num Gaussians after prune: ', gaussians._xyz.shape[0])

# Optimizer step
if iteration < opt.iterations:
gaussians.optimizer.step()
Expand Down Expand Up @@ -361,6 +380,14 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
help='threshold on xyz gradients for densification')
parser.add_argument("--quant_params", nargs="+", type=str, default=['sh', 'dc', 'scale', 'rot'])

# Opacity regularization parameters
parser.add_argument('--max_prune_iter', type=int, default=20000,
help='Iteration till which pruning is done')
parser.add_argument('--opacity_reg', action='store_true', default=False,
help='use opacity regularization during training')
parser.add_argument('--lambda_reg', type=float, default=0.,
help='Weight for opacity regularization in loss')

args = parser.parse_args(sys.argv[1:])

args.save_iterations.append(args.iterations)
Expand Down

0 comments on commit adecb94

Please sign in to comment.