Skip to content

Commit

Permalink
Merge pull request #99 from Thinklab-SJTU/fix-ipfp
Browse files Browse the repository at this point in the history
Fix ipfp and other minor fixes
  • Loading branch information
rogerwwww authored Apr 27, 2024
2 parents 15f3b46 + 083f79a commit fefb7c9
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 55 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ "3.8" ]
python-version: [ "3.9" ]

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -87,7 +87,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ "3.8" ]
python-version: [ "3.9" ]

steps:
- uses: actions/checkout@v2
Expand Down
22 changes: 15 additions & 7 deletions pygmtools/jittor_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def ipfp(K: Var, n1: Var, n2: Var, n1max, n2max, x0: Var,
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
v = v0
last_v = v
best_v = v
best_obj = jt.full((batch_num, 1, 1), -1)

def comp_obj_score(v1, K, v2):
return jt.bmm(jt.bmm(v1.view(batch_num, 1, -1), K), v2)
Expand All @@ -293,19 +295,25 @@ def comp_obj_score(v1, K, v2):
binary_v = binary_sol.transpose(1, 2).view(batch_num, -1, 1)
alpha = comp_obj_score(v, K, binary_v - v)
beta = comp_obj_score(binary_v - v, K, binary_v - v)
t0 = alpha / beta
cond = jt.logical_or(beta <= 0, t0 >= 1)
t0 = - alpha / beta
cond = jt.logical_or(beta >= 0, t0 >= 1)
if cond.shape != binary_v.shape:
cond = cond.expand(binary_v.shape)
v = jt.where(cond, binary_v, v + t0 * (binary_v - v))
last_v_sol = comp_obj_score(last_v, K, last_v)
if jt.max(jt.abs(
last_v_sol - jt.bmm(cost.reshape(batch_num, 1, -1), binary_sol.reshape(batch_num, -1, 1))
) / last_v_sol) < 1e-3:
last_v_obj = comp_obj_score(last_v, K, last_v)

current_obj = comp_obj_score(binary_v, K, binary_v)
cond = current_obj > best_obj
if cond.shape != binary_v.shape:
cond = cond.expand(binary_v.shape)
best_v = jt.where(cond, binary_v, best_v) # current_obj > best_obj
best_obj = jt.where(current_obj > best_obj, current_obj, best_obj)

if jt.max(jt.abs(last_v_obj - current_obj) / last_v_obj) < 1e-3:
break
last_v = v

pred_x = binary_sol
pred_x = best_v.reshape((batch_num, int(n2max), int(n1max))).transpose(1, 2)
return pred_x


Expand Down
22 changes: 13 additions & 9 deletions pygmtools/mindspore_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ def ipfp(K: mindspore.Tensor, n1: mindspore.Tensor, n2: mindspore.Tensor, n1max,
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
v = v0
last_v = v
best_v = v
best_obj = -1

def comp_obj_score(v1, K, v2):
return mindspore.ops.BatchMatMul()(mindspore.ops.BatchMatMul()(v1.view(batch_num, 1, -1), K), v2)
Expand All @@ -288,19 +290,21 @@ def comp_obj_score(v1, K, v2):
cost = mindspore.ops.BatchMatMul()(K, v).reshape(batch_num, int(n2max), int(n1max)).swapaxes(1, 2)
binary_sol = hungarian(cost, n1, n2)
binary_v = binary_sol.swapaxes(1, 2).view(batch_num, -1, 1)
alpha = comp_obj_score(v, K, binary_v - v) # + torch.mm(k_diag.view(1, -1), (binary_sol - v).view(-1, 1))
alpha = comp_obj_score(v, K, binary_v - v)
beta = comp_obj_score(binary_v - v, K, binary_v - v)
t0 = alpha / beta
v = mindspore.numpy.where(mindspore.ops.logical_or(beta <= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_sol = comp_obj_score(last_v, K, last_v)
if (mindspore.ops.max(mindspore.ops.abs(
last_v_sol - mindspore.ops.BatchMatMul()(cost.reshape((batch_num, 1, -1)),
binary_sol.reshape((batch_num, -1, 1)))
) / last_v_sol)[1] < 1e-3).any():
t0 = - alpha / beta
v = mindspore.numpy.where(mindspore.ops.logical_or(beta >= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_obj = comp_obj_score(last_v, K, last_v)

current_obj = comp_obj_score(binary_v, K, binary_v)
best_v = mindspore.numpy.where(current_obj > best_obj, binary_v, best_v)
best_obj = mindspore.numpy.where(current_obj > best_obj, current_obj, best_obj)

if (mindspore.ops.max(mindspore.ops.abs(last_v_obj - current_obj) / last_v_obj)[1] < 1e-3).any():
break
last_v = v

pred_x = binary_sol
pred_x = best_v.reshape(batch_num, int(n2max), int(n1max)).swapaxes(1, 2)
return pred_x


Expand Down
19 changes: 12 additions & 7 deletions pygmtools/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def ipfp(K: np.ndarray, n1: np.ndarray, n2: np.ndarray, n1max, n2max, x0: np.nda
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
v = v0
last_v = v
best_v = v
best_obj = -1

def comp_obj_score(v1, K, v2):
return np.matmul(np.matmul(v1.reshape((batch_num, 1, -1)), K), v2)
Expand All @@ -303,16 +305,19 @@ def comp_obj_score(v1, K, v2):
binary_v = binary_sol.transpose((0, 2, 1)).reshape((batch_num, -1, 1))
alpha = comp_obj_score(v, K, binary_v - v)
beta = comp_obj_score(binary_v - v, K, binary_v - v)
t0 = alpha / beta
v = np.where(np.logical_or(beta <= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_sol = comp_obj_score(last_v, K, last_v)
if np.max(np.abs(
last_v_sol - np.matmul(cost.reshape((batch_num, 1, -1)), binary_sol.reshape((batch_num, -1, 1)))
) / last_v_sol) < 1e-3:
t0 = - alpha / beta
v = np.where(np.logical_or(beta >= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_obj = comp_obj_score(last_v, K, last_v)

current_obj = comp_obj_score(binary_v, K, binary_v)
best_v = np.where(current_obj > best_obj, binary_v, best_v)
best_obj = np.where(current_obj > best_obj, current_obj, best_obj)

if np.max(np.abs(last_v_obj - current_obj) / last_v_obj) < 1e-3:
break
last_v = v

pred_x = binary_sol
pred_x = best_v.reshape((batch_num, n2max, n1max)).transpose((0, 2, 1))
return pred_x


Expand Down
19 changes: 12 additions & 7 deletions pygmtools/paddle_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ def ipfp(K: paddle.Tensor, n1: paddle.Tensor, n2: paddle.Tensor, n1max, n2max, x
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
v = v0
last_v = v
best_v = v
best_obj = paddle.to_tensor(paddle.full((batch_num, 1, 1), -1.), place=K.place)

def comp_obj_score(v1, K, v2):
return paddle.bmm(paddle.bmm(paddle.reshape(v1, (batch_num, 1, -1)), K), v2)
Expand All @@ -285,16 +287,19 @@ def comp_obj_score(v1, K, v2):
binary_v = paddle.reshape(binary_sol.transpose((0, 2, 1)),(batch_num, -1, 1))
alpha = comp_obj_score(v, K, binary_v - v)
beta = comp_obj_score(binary_v - v, K, binary_v - v)
t0 = alpha / beta
v = paddle.where(paddle.logical_or(beta <= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_sol = comp_obj_score(last_v, K, last_v)
if paddle.max(paddle.abs(
last_v_sol - paddle.bmm(paddle.reshape(cost,(batch_num, 1, -1)), paddle.reshape(binary_sol, (batch_num, -1, 1)))
) / last_v_sol) < 1e-3:
t0 = - alpha / beta
v = paddle.where(paddle.logical_or(beta >= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_obj = comp_obj_score(last_v, K, last_v)

current_obj = comp_obj_score(binary_v, K, binary_v)
best_v = paddle.where(current_obj > best_obj, binary_v, best_v)
best_obj = paddle.where(current_obj > best_obj, current_obj, best_obj)

if paddle.max(paddle.abs(last_v_obj - current_obj) / last_v_obj) < 1e-3:
break
last_v = v

pred_x = binary_sol
pred_x = paddle.reshape(best_v, (batch_num, n2max, n1max)).transpose((0, 2, 1))
return pred_x


Expand Down
21 changes: 13 additions & 8 deletions pygmtools/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def ipfp(K: Tensor, n1: Tensor, n2: Tensor, n1max, n2max, x0: Tensor,
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
v = v0
last_v = v
best_v = v
best_obj = -1

def comp_obj_score(v1, K, v2):
return torch.bmm(torch.bmm(v1.view(batch_num, 1, -1), K), v2)
Expand All @@ -290,18 +292,21 @@ def comp_obj_score(v1, K, v2):
cost = torch.bmm(K, v).reshape(batch_num, n2max, n1max).transpose(1, 2)
binary_sol = hungarian(cost, n1, n2)
binary_v = binary_sol.transpose(1, 2).view(batch_num, -1, 1)
alpha = comp_obj_score(v, K, binary_v - v) # + torch.mm(k_diag.view(1, -1), (binary_sol - v).view(-1, 1))
alpha = comp_obj_score(v, K, binary_v - v)
beta = comp_obj_score(binary_v - v, K, binary_v - v)
t0 = alpha / beta
v = torch.where(torch.logical_or(beta <= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_sol = comp_obj_score(last_v, K, last_v)
if torch.max(torch.abs(
last_v_sol - torch.bmm(cost.reshape(batch_num, 1, -1), binary_sol.reshape(batch_num, -1, 1))
) / last_v_sol) < 1e-3:
t0 = - alpha / beta
v = torch.where(torch.logical_or(beta >= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_obj = comp_obj_score(last_v, K, last_v)

current_obj = comp_obj_score(binary_v, K, binary_v)
best_v = torch.where(current_obj > best_obj, binary_v, best_v)
best_obj = torch.where(current_obj > best_obj, current_obj, best_obj)

if torch.max(torch.abs(last_v_obj - current_obj) / last_v_obj) < 1e-3:
break
last_v = v

pred_x = binary_sol
pred_x = best_v.reshape(batch_num, n2max, n1max).transpose(1, 2)
return pred_x


Expand Down
25 changes: 15 additions & 10 deletions pygmtools/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def sinkhorn(s: tf.Tensor, nrows: tf.Tensor=None, ncols: tf.Tensor=None,
def rrwm(K: tf.Tensor, n1: tf.Tensor, n2: tf.Tensor, n1max, n2max, x0: tf.Tensor,
max_iter: int, sk_iter: int, alpha: float, beta: float) -> tf.Tensor:
"""
Pytorch implementation of RRWM algorithm.
Tensorflow implementation of RRWM algorithm.
"""
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
# rescale the values in K
Expand Down Expand Up @@ -283,11 +283,13 @@ def sm(K: tf.Tensor, n1: tf.Tensor, n2: tf.Tensor, n1max, n2max, x0: tf.Tensor,
def ipfp(K: tf.Tensor, n1: tf.Tensor, n2: tf.Tensor, n1max, n2max, x0: tf.Tensor,
max_iter) -> tf.Tensor:
"""
Pytorch implementation of IPFP algorithm
Tensorflow implementation of IPFP algorithm
"""
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
v = v0
last_v = v
best_v = v
best_obj = -1

def comp_obj_score(v1, K, v2):
return tf.matmul(tf.matmul(tf.reshape(v1, [batch_num, 1, -1]), K), v2)
Expand All @@ -296,18 +298,21 @@ def comp_obj_score(v1, K, v2):
cost = tf.transpose(tf.reshape(tf.matmul(K, v), [batch_num, n2max, n1max]), [0, 2, 1])
binary_sol = hungarian(cost, n1, n2)
binary_v = tf.reshape(tf.transpose(binary_sol, [0, 2, 1]), [batch_num, -1, 1])
alpha = comp_obj_score(v, K, binary_v - v) # + torch.mm(k_diag.view(1, -1), (binary_sol - v).view(-1, 1))
alpha = comp_obj_score(v, K, binary_v - v)
beta = comp_obj_score(binary_v - v, K, binary_v - v)
t0 = alpha / beta
v = tf.where(tf.math.logical_or(beta <= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_sol = comp_obj_score(last_v, K, last_v)
if tf.reduce_max(tf.abs(
last_v_sol - tf.matmul(tf.reshape(cost, [batch_num, 1, -1]), tf.reshape(binary_sol, [batch_num, -1, 1]))
) / last_v_sol) < 1e-3:
t0 = - alpha / beta
v = tf.where(tf.math.logical_or(beta >= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_obj = comp_obj_score(last_v, K, last_v)

current_obj = comp_obj_score(binary_v, K, binary_v)
best_v = tf.where(current_obj > best_obj, binary_v, best_v)
best_obj = tf.where(current_obj > best_obj, current_obj, best_obj)

if tf.reduce_max(tf.abs(last_v_obj - current_obj) / last_v_obj) < 1e-3:
break
last_v = v

pred_x = binary_sol
pred_x = tf.transpose(tf.reshape(best_v, [batch_num, n2max, n1max]), [0, 2, 1])
return pred_x


Expand Down
8 changes: 4 additions & 4 deletions pygmtools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
import pygmtools

NOT_IMPLEMENTED_MSG = \
'The backend function for {} is not implemented.\n' \
'If you are a user, this error message means the function you are calling is not available with this backend and ' \
'please use other backends as workarounds. Scroll up in the call stack and it will tell you which function is ' \
'causing this error.\n' \
'Import failed! It is likely that the backend function for {} is not implemented, OR the backend is not installed ' \
'correctly. Please Scroll up in the call stack and it will tell you who is causing this error.\n' \
'If you are a user, this error message usually means the function you are calling is not available with this backend ' \
'and please use other backends as workarounds. \n' \
'If you are a developer, it will be truly appreciated if you could develop and share your ' \
'implementation with the community! RP is welcomed via Github: https://github.com/Thinklab-SJTU/pygmtools'

Expand Down
2 changes: 1 addition & 1 deletion tests/test_classic_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _test_classic_solver_on_isomorphic_graphs(graph_num_nodes, node_feat_dim, so
_K = pygm.utils.build_aff_mat(_F1, _edge1, _conn1, _F2, _edge2, _conn2, _n1, None, _n2, None,
**aff_param_dict)
if last_K is not None:
assert np.abs(pygm.utils.to_numpy(_K) - last_K).sum() < 0.1, \
assert np.abs(pygm.utils.to_numpy(_K) - last_K).max() < 0.01, \
f"Incorrect affinity matrix for {working_backend}, " \
f"params: {';'.join([k + '=' + str(v) for k, v in aff_param_dict.items()])};" \
f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}"
Expand Down

0 comments on commit fefb7c9

Please sign in to comment.