From 9e7d938b31be8daa76f9c349a0b872b4836f2edd Mon Sep 17 00:00:00 2001 From: "binxu.wang" Date: Fri, 22 Apr 2022 21:33:08 -0400 Subject: [PATCH] Add another forward function to compute distance matrix with one or two stack of images. Rationale: It reuses the features computed for images, so it largely accelerates the speed of computing a large distance matrix among a set of images. Memory usage. We choose to compute the matrix row by row, such that the memory footprint shall be still linear in the number of images instead of quadratic. TODO: it could be extended to compute a few rows in a batch. Batch size need to be choose w.r.t. memory --- lpips/lpips.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/lpips/lpips.py b/lpips/lpips.py index 93595476..44096e39 100755 --- a/lpips/lpips.py +++ b/lpips/lpips.py @@ -143,6 +143,54 @@ def forward(self, in0, in1, retPerLayer=False, normalize=False): else: return val + def forward_distmat(self, in0, in1=None, retPerLayer=False, normalize=False, batch_size=64): + """Compute distance matrix with batch processing.""" + if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] + in0 = 2 * in0 - 1 + if in1 is not None: + in1 = 2 * in1 - 1 + + if in1 is None: # save feature computation time if only one image stack + in0_input = (self.scaling_layer(in0)) if self.version=='0.1' else (in0) + outs0 = self.net.forward(in0_input) + else: + # v0.0 - original release had a bug, where input was not scaled + in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else ( + in0, in1) + outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) + feats0, feats1, diffs = {}, {}, {} + + res = [] + for kk in range(self.L): + if in1 is None: + feats0[kk] = lpips.normalize_tensor(outs0[kk]) + feats1[kk] = feats0[kk] + else: + feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk]) + res.append([]) + for imi in range(feats0[kk].shape[0]): + diffs[kk] = (feats0[kk][imi:imi+1] - feats1[kk]) ** 2 + if (self.lpips): + if (self.spatial): + res[kk].append(upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:])) + else: + res[kk].append(spatial_average(self.lins[kk](diffs[kk]), keepdim=True)) + else: + if (self.spatial): + res[kk].append(upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:])) + else: + res[kk].append(spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True)) + res[kk] = torch.stack(res[kk], dim=0) + + val = 0 + for l in range(self.L): + val += res[l] + + if (retPerLayer): + return (val, res) + else: + return val + class ScalingLayer(nn.Module): def __init__(self):