diff --git a/lpips/lpips.py b/lpips/lpips.py index 93595476..44096e39 100755 --- a/lpips/lpips.py +++ b/lpips/lpips.py @@ -143,6 +143,54 @@ def forward(self, in0, in1, retPerLayer=False, normalize=False): else: return val + def forward_distmat(self, in0, in1=None, retPerLayer=False, normalize=False, batch_size=64): + """Compute distance matrix with batch processing.""" + if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] + in0 = 2 * in0 - 1 + if in1 is not None: + in1 = 2 * in1 - 1 + + if in1 is None: # save feature computation time if only one image stack + in0_input = (self.scaling_layer(in0)) if self.version=='0.1' else (in0) + outs0 = self.net.forward(in0_input) + else: + # v0.0 - original release had a bug, where input was not scaled + in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else ( + in0, in1) + outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) + feats0, feats1, diffs = {}, {}, {} + + res = [] + for kk in range(self.L): + if in1 is None: + feats0[kk] = lpips.normalize_tensor(outs0[kk]) + feats1[kk] = feats0[kk] + else: + feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk]) + res.append([]) + for imi in range(feats0[kk].shape[0]): + diffs[kk] = (feats0[kk][imi:imi+1] - feats1[kk]) ** 2 + if (self.lpips): + if (self.spatial): + res[kk].append(upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:])) + else: + res[kk].append(spatial_average(self.lins[kk](diffs[kk]), keepdim=True)) + else: + if (self.spatial): + res[kk].append(upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:])) + else: + res[kk].append(spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True)) + res[kk] = torch.stack(res[kk], dim=0) + + val = 0 + for l in range(self.L): + val += res[l] + + if (retPerLayer): + return (val, res) + else: + return val + class ScalingLayer(nn.Module): def __init__(self):