From b8f9833c352b8b6c5f9edc7adb237f3d2901de3b Mon Sep 17 00:00:00 2001 From: roger <18309862+rogerwwww@users.noreply.github.com> Date: Fri, 17 Nov 2023 09:37:29 -0500 Subject: [PATCH] fix backward error & v0.0.9 --- LinSATNet/linsat.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/LinSATNet/linsat.py b/LinSATNet/linsat.py index 09edb68..4693051 100644 --- a/LinSATNet/linsat.py +++ b/LinSATNet/linsat.py @@ -100,7 +100,7 @@ def init_shape(mat, vec, num_constr): no_warning) if vector_input: - x.squeeze_(0) + x = x.squeeze(0) return x @@ -273,13 +273,14 @@ def linsat_kernel_v2(x, A, b, tau, max_iter, dummy_val, ) # Test with LinSAT - prev_time = time.time() - linsat_outp = linsat_layer(w, E=E, f=f, tau=0.1, max_iter=10, dummy_val=0) - print(f'LinSAT forward time: {time.time() - prev_time:.4f}') - prev_time = time.time() - loss = ((linsat_outp - x_gt) ** 2).sum() - loss.backward() - print(f'LinSAT backward time: {time.time() - prev_time:.4f}') + with torch.autograd.set_detect_anomaly(True): + prev_time = time.time() + linsat_outp = linsat_layer(w, E=E, f=f, tau=0.1, max_iter=10, dummy_val=0) + print(f'LinSAT forward time: {time.time() - prev_time:.4f}') + prev_time = time.time() + loss = ((linsat_outp - x_gt) ** 2).sum() + loss.backward() + print(f'LinSAT backward time: {time.time() - prev_time:.4f}') # Test gradient-based optimization niters = 10