Skip to content

Commit

Permalink
fix backward error & v0.0.9
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerwwww committed Nov 17, 2023
1 parent 4c03d50 commit b8f9833
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions LinSATNet/linsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def init_shape(mat, vec, num_constr):
no_warning)

if vector_input:
x.squeeze_(0)
x = x.squeeze(0)
return x


Expand Down Expand Up @@ -273,13 +273,14 @@ def linsat_kernel_v2(x, A, b, tau, max_iter, dummy_val,
)

# Test with LinSAT
prev_time = time.time()
linsat_outp = linsat_layer(w, E=E, f=f, tau=0.1, max_iter=10, dummy_val=0)
print(f'LinSAT forward time: {time.time() - prev_time:.4f}')
prev_time = time.time()
loss = ((linsat_outp - x_gt) ** 2).sum()
loss.backward()
print(f'LinSAT backward time: {time.time() - prev_time:.4f}')
with torch.autograd.set_detect_anomaly(True):
prev_time = time.time()
linsat_outp = linsat_layer(w, E=E, f=f, tau=0.1, max_iter=10, dummy_val=0)
print(f'LinSAT forward time: {time.time() - prev_time:.4f}')
prev_time = time.time()
loss = ((linsat_outp - x_gt) ** 2).sum()
loss.backward()
print(f'LinSAT backward time: {time.time() - prev_time:.4f}')

# Test gradient-based optimization
niters = 10
Expand Down

0 comments on commit b8f9833

Please sign in to comment.