diff --git a/qtensor/contraction_backends/torch.py b/qtensor/contraction_backends/torch.py index c90c3165..0074c1e9 100644 --- a/qtensor/contraction_backends/torch.py +++ b/qtensor/contraction_backends/torch.py @@ -228,18 +228,20 @@ def get_result_data(self, result): class TorchBackendMatm(TorchBackend): - def _get_index_sizes(self, *ixs): + def _get_index_sizes(self, *ixs, size_dict = None): + if size_dict is not None: + return [size_dict[i] for i in ixs] try: sizes = [ i.size for i in ixs ] except AttributeError: sizes = [2] * len(ixs) return sizes - def _get_index_space_size(self, *ixs): - sizes = self._get_index_sizes(*ixs) + def _get_index_space_size(self, *ixs, size_dict = None): + sizes = self._get_index_sizes(*ixs, size_dict = size_dict) return reduce(np.multiply, sizes, 1) - def pairwise_sum_contract(self, ixa, a, ixb, b, ixout): + def pairwise_sum_contract(self, ixa, a, ixb, b, ixout, size_dict = None): out = ixout common = set(ixa).intersection(set(ixb)) # -- sum indices that are in one tensor only @@ -267,16 +269,18 @@ def pairwise_sum_contract(self, ixa, a, ixb, b, ixout): kix = common - set(out) fix = common - kix common = list(kix) + list(fix) - #print(f'{ixa=} {ixb=} {ixout=}; {common=} {mix=} {nix=}') - a = tensors[0].permute(*[ - list(ixs[0]).index(x) for x in common + list(mix) - ]) - - b = tensors[1].permute(*[ - list(ixs[1]).index(x) for x in common + list(nix) - ]) - - k, f, m, n = [self._get_index_space_size(*ix) + #print(f'{ixa=} {ixb=} {ixout=}; {common=} {mix=} {nix=}, {size_dict=}') + if tensors[0].numel() > 1: + a = tensors[0].permute(*[ + list(ixs[0]).index(x) for x in common + list(mix) + ]) + + if tensors[1].numel() > 1: + b = tensors[1].permute(*[ + list(ixs[1]).index(x) for x in common + list(nix) + ]) + + k, f, m, n = [self._get_index_space_size(*ix, size_dict=size_dict) for ix in (kix, fix, mix, nix) ] a = a.reshape(k, f, m) @@ -284,12 +288,15 @@ def pairwise_sum_contract(self, ixa, a, ixb, b, ixout): c = torch.einsum('kfm, kfn -> fmn', a, b) if len(out): #print('out ix', out, 'kfmnix', kix, fix, mix, nix) - c = c.reshape(*self._get_index_sizes(*out)) + c = c.reshape(*self._get_index_sizes(*out, size_dict=size_dict)) #print('outix', out, 'res', c.shape, 'kfmn',kix, fix, mix, nix) current_ord_ = list(fix) + list(mix) + list(nix) if len(out): c = c.permute(*[current_ord_.index(i) for i in out]) + else: + c = c.flatten() + #print(f'c shape {c.shape}') return c def process_bucket(self, bucket, no_sum=False): @@ -303,17 +310,24 @@ def process_bucket(self, bucket, no_sum=False): ixr = list(map(int, result_indices)) ixt = list(map(int, tensor.indices)) - result_indices = tuple(sorted( + out_indices = tuple(sorted( set(result_indices + tensor.indices), key=int, reverse=True ) ) - ixout = list(map(int, result_indices)) - - logger.trace('Before contract. expr: {}, {} ->', ixr, ixt, ixout) - result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout) + ixout = list(map(int, out_indices)) + + logger.trace('Before contract. expr: {}, {} -> {}', ixr, ixt, ixout) + size_dict = {} + for i in result_indices: + size_dict[int(i)] = i.size + for i in tensor.indices: + size_dict[int(i)] = i.size + logger.debug("result_indices: {}", result_indices) + result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout, size_dict = size_dict) + result_indices = out_indices #result_data = torch.einsum(expr, result_data, tensor.data) - logger.trace("Data: {}, -> {}", result_data, tensor.data, result_data_new) + logger.trace("Data: {}, {} -> {}", result_data.shape, tensor.data.shape, result_data_new.shape) result_data = result_data_new # Merge and sort indices and shapes @@ -334,10 +348,24 @@ def process_bucket(self, bucket, no_sum=False): ))[:-1] ixout = list(map(int, result_indices)) - logger.trace('Before contract. expr: {}, {} ->', ixr, ixt, ixout) - result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout) + logger.trace('Before contract. expr: {}, {} -> {}', ixr, ixt, ixout) + size_dict = {} + for i in result_indices: + size_dict[int(i)] = i.size + for i in tensor.indices: + size_dict[int(i)] = i.size + #logger.debug("result_indices: {}", result_indices) + result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout, size_dict = size_dict) #result_data = torch.einsum(expr, result_data, tensor.data) - logger.trace("Data: {}, -> {}", result_data, tensor.data, result_data_new) + logger.trace("Data: {}, {} -> {}", result_data.mean(), tensor.data.mean(), result_data_new.mean()) + #if result_data_new.mean() == 0: + # logger.warning("Result is zero") + # logger.debug("result_indices: {}", result_indices) + # logger.debug("result_data: {}", result_data) + # logger.debug("tensor: {}", tensor) + # logger.debug("tensor_data: {}", tensor.data) + # logger.debug("result_data_new: {}", result_data_new) + # raise ValueError("Result is zero") result_data = result_data_new else: result_data = result_data.sum(axis=-1)