diff --git a/qtensor/contraction_backends/tests/test_torch.py b/qtensor/contraction_backends/tests/test_torch.py index 32a4757d..36a1eacc 100644 --- a/qtensor/contraction_backends/tests/test_torch.py +++ b/qtensor/contraction_backends/tests/test_torch.py @@ -1,7 +1,8 @@ import qtensor import pytest import numpy as np -from qtensor.contraction_backends import TorchBackend, NumpyBackend +from qtensor.contraction_backends import NumpyBackend +from qtensor.contraction_backends.torch import TorchBackend, TorchBackendMatm, permute_flattened from qtensor import QtreeSimulator from qtensor.tests import get_test_qaoa_ansatz_circ torch = pytest.importorskip('torch') @@ -61,6 +62,22 @@ def contract_tn(backend, search_len=1, test_problem_kwargs={}): assert restr.shape == resnp.shape assert np.allclose(restr, resnp) +# -- Testing low-level functions for torch matm backend + +def test_torch_matm_permute(): + K = 5 + d = 2 + shape = [5] + [d]*(K-1) + x = torch.randn(shape) + for i in range(20): + perm = list(np.random.permutation(K)) + y = permute_flattened(x.flatten(), perm, shape) + assert y.ndim == 1 + assert y.numel() == x.numel() + print('perm', perm) + assert torch.allclose(y, x.permute(perm).flatten()) + +# -- Testing get_sliced_buckets def test_torch_get_sliced__slice(): backend = TorchBackend() diff --git a/qtensor/contraction_backends/torch.py b/qtensor/contraction_backends/torch.py index 0074c1e9..0f314834 100644 --- a/qtensor/contraction_backends/torch.py +++ b/qtensor/contraction_backends/torch.py @@ -7,10 +7,12 @@ from .common import get_slice_bounds, get_einsum_expr, slice_numpy_tensor import string from loguru import logger + CHARS = string.ascii_lowercase + string.ascii_uppercase + def qtree2torch_tensor(tensor, data_dict): - """ Converts qtree tensor to pytorch tensor using data dict""" + """Converts qtree tensor to pytorch tensor using data dict""" if isinstance(tensor.data, torch.Tensor): return tensor if tensor.data is not None: @@ -21,25 +23,25 @@ def qtree2torch_tensor(tensor, data_dict): data_dict[tensor.data_key] = torch_t return tensor.copy(data=torch_t) + def get_einsum_expr_bucket(bucket, all_indices_list, result_indices): - # converting elements to int will make stuff faster, + # converting elements to int will make stuff faster, # but will drop support for char indices # all_indices_list = [int(x) for x in all_indices] # to_small_int = lambda x: all_indices_list.index(int(x)) to_small_int = lambda x: all_indices_list.index(x) - expr = ','.join( - ''.join(CHARS[to_small_int(i)] for i in t.indices) - for t in bucket) +\ - '->'+''.join(CHARS[to_small_int(i)] for i in result_indices) + expr = ( + ",".join("".join(CHARS[to_small_int(i)] for i in t.indices) for t in bucket) + + "->" + + "".join(CHARS[to_small_int(i)] for i in result_indices) + ) return expr - - -def permute_torch_tensor_data(data:np.ndarray, indices_in, indices_out): +def permute_torch_tensor_data(data: np.ndarray, indices_in, indices_out): """ Permute the data of a numpy tensor to the given indices_out. - + Returns: permuted data """ @@ -49,7 +51,8 @@ def permute_torch_tensor_data(data:np.ndarray, indices_in, indices_out): # permute tensor return torch.permute(data, perm) -def slice_torch_tensor(data:np.ndarray, indices_in, indices_out, slice_dict): + +def slice_torch_tensor(data: np.ndarray, indices_in, indices_out, slice_dict): """ Args: data : np.ndarray @@ -65,9 +68,11 @@ def slice_torch_tensor(data:np.ndarray, indices_in, indices_out, slice_dict): indices_sliced = [ i for sl, i in zip(slice_bounds, indices_in) if not isinstance(sl, int) ] - #print(f'{indices_in=}, {indices_sliced=} {slice_dict=}, {slice_bounds=}, slicedix {indices_sliced}, sshape {s_data.shape}') + # print(f'{indices_in=}, {indices_sliced=} {slice_dict=}, {slice_bounds=}, slicedix {indices_sliced}, sshape {s_data.shape}') indices_sized = [v.copy(size=size) for v, size in zip(indices_sliced, s_data.shape)] - indices_out = [v for v in indices_out if not isinstance(slice_dict.get(v, None), int)] + indices_out = [ + v for v in indices_out if not isinstance(slice_dict.get(v, None), int) + ] assert len(indices_sized) == len(s_data.shape) assert len(indices_sliced) == len(s_data.shape) st_data = permute_torch_tensor_data(s_data, indices_sliced, indices_out) @@ -75,49 +80,48 @@ def slice_torch_tensor(data:np.ndarray, indices_in, indices_out, slice_dict): class TorchBackend(ContractionBackend): - - def __init__(self, device='cpu'): + def __init__(self, device="cpu"): # alias of gpu -> cuda - if device=='gpu': - device='cuda' + if device == "gpu": + device = "cuda" # Check that CUDA is available if specified - if device=='cuda': + if device == "cuda": if not torch.cuda.is_available(): logger.warning("Cuda is not available. Falling back to CPU") - device = 'cpu' - if device=='xpu': + device = "cpu" + if device == "xpu": import intel_extension_for_pytorch as ipex - self.device = torch.device(device) logger.debug("Torch backend using device {}", self.device) - self.dtype = ['float', 'double', 'complex64', 'complex128'] + self.dtype = ["float", "double", "complex64", "complex128"] self.width_dict = [set() for i in range(30)] - self.width_bc = [[0,0] for i in range(30)] #(#distinct_bc, #bc) + self.width_bc = [[0, 0] for i in range(30)] # (#distinct_bc, #bc) def process_bucket(self, bucket, no_sum=False): - bucket.sort(key = lambda x: len(x.indices)) + bucket.sort(key=lambda x: len(x.indices)) result_indices = bucket[0].indices result_data = bucket[0].data width = len(set(bucket[0].indices)) for tensor in bucket[1:-1]: - expr = get_einsum_expr( list(map(int, result_indices)), list(map(int, tensor.indices)) ) - logger.trace('Before contract. Expr: {}, inputs: {}, {}', expr, result_data, tensor) + logger.trace( + "Before contract. Expr: {}, inputs: {}, {}", expr, result_data, tensor + ) result_data = torch.einsum(expr, result_data, tensor.data) - logger.trace("expression {}. Data: {}, -> {}", expr, tensor.data, result_data) + logger.trace( + "expression {}. Data: {}, -> {}", expr, tensor.data, result_data + ) # Merge and sort indices and shapes - result_indices = tuple(sorted( - set(result_indices + tensor.indices), - key=int, reverse=True - ) + result_indices = tuple( + sorted(set(result_indices + tensor.indices), key=int, reverse=True) ) - + size = len(set(tensor.indices)) if size > width: width = size @@ -126,38 +130,40 @@ def process_bucket(self, bucket, no_sum=False): self.width_bc[width][0] = len(self.width_dict[width]) self.width_bc[width][1] += 1 - if len(bucket)>1: + if len(bucket) > 1: tensor = bucket[-1] expr = get_einsum_expr( - list(map(int, result_indices)), list(map(int, tensor.indices)) - , contract = 0 if no_sum else 1 + list(map(int, result_indices)), + list(map(int, tensor.indices)), + contract=0 if no_sum else 1, + ) + logger.trace( + "Before contract. Expr: {}, inputs: {}, {}", expr, result_data, tensor ) - logger.trace('Before contract. Expr: {}, inputs: {}, {}', expr, result_data, tensor) result_data = torch.einsum(expr, result_data, tensor.data) - logger.trace("expression {}. Data: {}, -> {}", expr, tensor.data, result_data) - result_indices = tuple(sorted( - set(result_indices + tensor.indices), - key=int, reverse=True - )) + logger.trace( + "expression {}. Data: {}, -> {}", expr, tensor.data, result_data + ) + result_indices = tuple( + sorted(set(result_indices + tensor.indices), key=int, reverse=True) + ) else: if not no_sum: result_data = result_data.sum(axis=-1) else: result_data = result_data - if len(result_indices) > 0: first_index = result_indices[-1] if not no_sum: result_indices = result_indices[:-1] tag = first_index.identity else: - tag = 'f' + tag = "f" result_indices = [] # reduce - result = qtree.optimizer.Tensor(f'E{tag}', result_indices, - data=result_data) + result = qtree.optimizer.Tensor(f"E{tag}", result_indices, data=result_data) return result def process_bucket_merged(self, ixs, bucket, no_sum=False): @@ -177,11 +183,11 @@ def process_bucket_merged(self, ixs, bucket, no_sum=False): tensors.append(tensor.data) if tensor.data.dtype == torch.complex128: is128 = True - + if is128: for i in range(len(tensors)): tensors[i] = tensors[i].type(torch.complex128) - + expr = get_einsum_expr_bucket(bucket, all_indices_list, result_indices) expect = len(result_indices) result_data = torch.einsum(expr, *tensors) @@ -190,11 +196,10 @@ def process_bucket_merged(self, ixs, bucket, no_sum=False): first_index, *_ = result_indices tag = str(first_index) else: - tag = 'f' + tag = "f" + + result = qtree.optimizer.Tensor(f"E{tag}", result_indices, data=result_data) - result = qtree.optimizer.Tensor(f'E{tag}', result_indices, - data=result_data) - return result def get_sliced_buckets(self, buckets, data_dict, slice_dict): @@ -210,15 +215,16 @@ def get_sliced_buckets(self, buckets, data_dict, slice_dict): else: data = tensor.data # Works for torch tensors just fine - if not isinstance(data, torch.Tensor): + if not isinstance(data, torch.Tensor): data = torch.from_numpy(data.astype(np.complex128)).to(self.device) else: data = data.type(torch.complex128) # slice data - data, new_indices = slice_torch_tensor(data, tensor.indices, out_indices, slice_dict) + data, new_indices = slice_torch_tensor( + data, tensor.indices, out_indices, slice_dict + ) - sliced_bucket.append( - tensor.copy(indices=new_indices, data=data)) + sliced_bucket.append(tensor.copy(indices=new_indices, data=data)) sliced_buckets.append(sliced_bucket) return sliced_buckets @@ -226,36 +232,113 @@ def get_sliced_buckets(self, buckets, data_dict, slice_dict): def get_result_data(self, result): return torch.permute(result.data, tuple(reversed(range(result.data.ndim)))) -class TorchBackendMatm(TorchBackend): - def _get_index_sizes(self, *ixs, size_dict = None): +def _swap_flattened(data, a: int, b: int, sprod, different_dims=False): + """ + Swap two dimensions in a flattened tensor. + + Args: + data: flattened tensor + a, b: dimensions to swap + sprod (iterable of ints): ith element is the product of dimensions 0 to i Last element should be 1 + """ + if a == b: + return data + ndim = len(sprod) - 1 + assert ndim >= max(a, b) + a, b = min(a, b), max(a, b) + d5 = data.reshape( + ( + sprod[a - 1], + sprod[a] // sprod[a - 1], + sprod[b - 1] // sprod[a], + sprod[b] // sprod[b - 1], + sprod[ndim - 1] // sprod[b], + ) + ) + # -- modify sprod accordingly + if different_dims: + adim = sprod[a] // sprod[a - 1] + bdim = sprod[b] // sprod[b - 1] + for i in range(a, b): + sprod[i] *= bdim + sprod[i] //= adim + return d5.transpose(1, 3).flatten() + + +def permute_flattened(data, perm, shape): + """ + Permute the data of a many-dimensional tensor stored as a flattened array. + This is a workaround for the limitation of 12 dimensions in intel extension + for pytorch. + + While permuting, tensor is reshaped to maximum of 5 dimensions: + + for each dimension swap a-b: + 1. Reshape to 5-dimensional tensor ... a ... b ... + 2. Swap a and b. + 3. Flatten to 1-dimensional tensor. + + Args: + data: flattened data + perm (iterable of ints): permutation, as in torch.permute + shape (iterable of ints): shape of the original tensor + + Returns: + permuted data, equivalent to torch.permute(data.reshape(shape), perm).flatten() + """ + sprod = [] + k = 1 + different_dims = False + for i in shape: + if i != shape[0]: + different_dims = True + k *= i + sprod.append(k) + sprod.append(1) + # print(f'different_dims {different_dims}') + # Is there a way to use only one dict? + d2l = {i: i for i in range(len(shape))} + + l2d = {i: i for i in range(len(shape))} + for t, s in enumerate(perm): + s = d2l[s] + data = _swap_flattened(data, s, t, sprod, different_dims) + l2d[s], l2d[t] = l2d[t], l2d[s] + d2l[l2d[s]], d2l[l2d[t]] = s, t + # print(f'{s=}, {t=}, {d2l=}, {l2d=}') + return data + + +class TorchBackendMatm(TorchBackend): + def _get_index_sizes(self, *ixs, size_dict=None): if size_dict is not None: return [size_dict[i] for i in ixs] try: - sizes = [ i.size for i in ixs ] + sizes = [i.size for i in ixs] except AttributeError: sizes = [2] * len(ixs) return sizes - def _get_index_space_size(self, *ixs, size_dict = None): - sizes = self._get_index_sizes(*ixs, size_dict = size_dict) + def _get_index_space_size(self, *ixs, size_dict=None): + sizes = self._get_index_sizes(*ixs, size_dict=size_dict) return reduce(np.multiply, sizes, 1) - def pairwise_sum_contract(self, ixa, a, ixb, b, ixout, size_dict = None): + def pairwise_sum_contract(self, ixa, a, ixb, b, ixout, size_dict=None): out = ixout common = set(ixa).intersection(set(ixb)) # -- sum indices that are in one tensor only - all_ix = set(ixa+ixb) + all_ix = set(ixa + ixb) sum_ix = all_ix - set(out) a_sum = sum_ix.intersection(set(ixa) - common) b_sum = sum_ix.intersection(set(ixb) - common) - #print('ab', ixa, ixb) - #print('all sum', sum_ix, 'a/b_sum', a_sum, b_sum) + # print('ab', ixa, ixb) + # print('all sum', sum_ix, 'a/b_sum', a_sum, b_sum) if len(a_sum): - a = a.sum(axis=tuple(ixa.index(x) for x in a_sum)) + #a = a.sum(axis=tuple(ixa.index(x) for x in a_sum)) ixa = [x for x in ixa if x not in a_sum] if len(b_sum): - b = b.sum(axis=tuple(ixb.index(x) for x in b_sum)) + #b = b.sum(axis=tuple(ixb.index(x) for x in b_sum)) ixb = [x for x in ixb if x not in b_sum] tensors = a, b # -- @@ -269,96 +352,120 @@ def pairwise_sum_contract(self, ixa, a, ixb, b, ixout, size_dict = None): kix = common - set(out) fix = common - kix common = list(kix) + list(fix) - #print(f'{ixa=} {ixb=} {ixout=}; {common=} {mix=} {nix=}, {size_dict=}') + # print(f'{ixa=} {ixb=} {ixout=}; {common=} {mix=} {nix=}, {size_dict=}') if tensors[0].numel() > 1: - a = tensors[0].permute(*[ - list(ixs[0]).index(x) for x in common + list(mix) - ]) + # a = tensors[0].permute(*[ + # list(ixs[0]).index(x) for x in common + list(mix) + # ]) + a = permute_flattened( + tensors[0], + [list(ixs[0]).index(x) for x in common + list(mix)], + self._get_index_sizes(*ixa, size_dict=size_dict), + ) if tensors[1].numel() > 1: - b = tensors[1].permute(*[ - list(ixs[1]).index(x) for x in common + list(nix) - ]) + # b = tensors[1].permute(*[ + # list(ixs[1]).index(x) for x in common + list(nix) + # ]) + b = permute_flattened( + tensors[1], + [list(ixs[1]).index(x) for x in common + list(nix)], + self._get_index_sizes(*ixb, size_dict=size_dict), + ) - k, f, m, n = [self._get_index_space_size(*ix, size_dict=size_dict) - for ix in (kix, fix, mix, nix) - ] + k, f, m, n = [ + self._get_index_space_size(*ix, size_dict=size_dict) + for ix in (kix, fix, mix, nix) + ] a = a.reshape(k, f, m) b = b.reshape(k, f, n) - c = torch.einsum('kfm, kfn -> fmn', a, b) - if len(out): - #print('out ix', out, 'kfmnix', kix, fix, mix, nix) - c = c.reshape(*self._get_index_sizes(*out, size_dict=size_dict)) - #print('outix', out, 'res', c.shape, 'kfmn',kix, fix, mix, nix) + c = torch.einsum("kfm, kfn -> fmn", a, b) + #if len(out): + # print('out ix', out, 'kfmnix', kix, fix, mix, nix) + #c = c.reshape(*self._get_index_sizes(*out, size_dict=size_dict)) + # print('outix', out, 'res', c.shape, 'kfmn',kix, fix, mix, nix) current_ord_ = list(fix) + list(mix) + list(nix) + c = c.flatten() if len(out): - c = c.permute(*[current_ord_.index(i) for i in out]) - else: - c = c.flatten() - #print(f'c shape {c.shape}') + #c = c.permute(*[current_ord_.index(i) for i in out]) + c = permute_flattened( + c, + [current_ord_.index(i) for i in out], + self._get_index_sizes(*out, size_dict=size_dict), + ) + # print(f'c shape {c.shape}') return c def process_bucket(self, bucket, no_sum=False): - bucket.sort(key = lambda x: len(x.indices)) + bucket.sort(key=lambda x: len(x.indices)) result_indices = bucket[0].indices result_data = bucket[0].data width = len(set(bucket[0].indices)) - + print("bucket", bucket) for tensor in bucket[1:-1]: - ixr = list(map(int, result_indices)) ixt = list(map(int, tensor.indices)) - out_indices = tuple(sorted( - set(result_indices + tensor.indices), - key=int, reverse=True - ) + out_indices = tuple( + sorted(set(result_indices + tensor.indices), key=int, reverse=True) ) ixout = list(map(int, out_indices)) - logger.trace('Before contract. expr: {}, {} -> {}', ixr, ixt, ixout) + logger.trace("Before contract. expr: {}, {} -> {}", ixr, ixt, ixout) size_dict = {} for i in result_indices: size_dict[int(i)] = i.size for i in tensor.indices: size_dict[int(i)] = i.size logger.debug("result_indices: {}", result_indices) - result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout, size_dict = size_dict) + result_data_new = self.pairwise_sum_contract( + ixr, result_data, ixt, tensor.data, ixout, size_dict=size_dict + ) result_indices = out_indices - #result_data = torch.einsum(expr, result_data, tensor.data) - logger.trace("Data: {}, {} -> {}", result_data.shape, tensor.data.shape, result_data_new.shape) + # result_data = torch.einsum(expr, result_data, tensor.data) + logger.trace( + "Data: {}, {} -> {}", + result_data.shape, + tensor.data.shape, + result_data_new.shape, + ) result_data = result_data_new # Merge and sort indices and shapes - + size = len(set(tensor.indices)) if size > width: width = size - - if len(bucket)>1: + if len(bucket) > 1: tensor = bucket[-1] ixr = list(map(int, result_indices)) ixt = list(map(int, tensor.indices)) - result_indices = tuple(sorted( - set(result_indices + tensor.indices), - key=int, reverse=True - ))[:-1] + result_indices = tuple( + sorted(set(result_indices + tensor.indices), key=int, reverse=True) + )[:-1] ixout = list(map(int, result_indices)) - logger.trace('Before contract. expr: {}, {} -> {}', ixr, ixt, ixout) + logger.trace("Before contract. expr: {}, {} -> {}", ixr, ixt, ixout) size_dict = {} for i in result_indices: size_dict[int(i)] = i.size for i in tensor.indices: size_dict[int(i)] = i.size - #logger.debug("result_indices: {}", result_indices) - result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout, size_dict = size_dict) - #result_data = torch.einsum(expr, result_data, tensor.data) - logger.trace("Data: {}, {} -> {}", result_data.mean(), tensor.data.mean(), result_data_new.mean()) - #if result_data_new.mean() == 0: + # logger.debug("result_indices: {}", result_indices) + result_data_new = self.pairwise_sum_contract( + ixr, result_data, ixt, tensor.data, ixout, size_dict=size_dict + ) + # result_data = torch.einsum(expr, result_data, tensor.data) + logger.trace( + "Data: {}, {} -> {}", + result_data.mean(), + tensor.data.mean(), + result_data_new.mean(), + ) + # if result_data_new.mean() == 0: # logger.warning("Result is zero") # logger.debug("result_indices: {}", result_indices) # logger.debug("result_data: {}", result_data) @@ -368,19 +475,33 @@ def process_bucket(self, bucket, no_sum=False): # raise ValueError("Result is zero") result_data = result_data_new else: - result_data = result_data.sum(axis=-1) + # Sum the last index + print("result_data", result_data.shape) + #shape = self._get_index_sizes(*result_indices) + if result_data.numel() > 2: + result_data = result_data.reshape(-1, 2).sum(axis=-1) + else: + result_data = result_data.reshape(2, 1).sum(axis=-1) + #result_data = result_data.sum(axis=-1) result_indices = result_indices[:-1] if len(result_indices) > 0: first_index = result_indices[-1] tag = first_index.identity else: - tag = 'f' + tag = "f" result_indices = [] # reduce - result = qtree.optimizer.Tensor(f'E{tag}', result_indices, - data=result_data) + result = qtree.optimizer.Tensor(f"E{tag}", result_indices, data=result_data) + print("result", result) + print("result_data", result_data.shape) return result + def get_result_data(self, result): + if len(result.indices): + d = result.data.reshape(self._get_index_sizes(*result.indices)) + else: + d = result.data + return torch.permute(d, tuple(reversed(range(d.ndim))))