From de30241d09d0cc265af8f2a5c8c9392f4194c242 Mon Sep 17 00:00:00 2001 From: Yuri Alexeev Date: Fri, 15 Mar 2024 13:25:37 +0000 Subject: [PATCH] add torch matm backend --- qtensor/contraction_backends/torch.py | 160 ++++++++++++++++++++++++-- qtree | 2 +- 2 files changed, 154 insertions(+), 8 deletions(-) diff --git a/qtensor/contraction_backends/torch.py b/qtensor/contraction_backends/torch.py index 5398f19e..ff7c8d25 100644 --- a/qtensor/contraction_backends/torch.py +++ b/qtensor/contraction_backends/torch.py @@ -1,10 +1,12 @@ from qtensor.tools.lazy_import import torch import qtree import numpy as np +from functools import reduce from qtree import np_framework from qtensor.contraction_backends import ContractionBackend from .common import get_slice_bounds, get_einsum_expr, slice_numpy_tensor import string +from loguru import logger CHARS = string.ascii_lowercase + string.ascii_uppercase def qtree2torch_tensor(tensor, data_dict): @@ -63,7 +65,7 @@ def slice_torch_tensor(data:np.ndarray, indices_in, indices_out, slice_dict): indices_sliced = [ i for sl, i in zip(slice_bounds, indices_in) if not isinstance(sl, int) ] - print(f'indicies_in {indices_in}, slice_dict {slice_dict}, bounds {slice_bounds}, slicedix {indices_sliced}, sshape {s_data.shape}') + #print(f'indicies_in {indices_in}, slice_dict {slice_dict}, bounds {slice_bounds}, slicedix {indices_sliced}, sshape {s_data.shape}') indices_sized = [v.copy(size=size) for v, size in zip(indices_sliced, s_data.shape)] indices_out = [v for v in indices_out if not isinstance(slice_dict.get(v, None), int)] assert len(indices_sized) == len(s_data.shape) @@ -73,8 +75,22 @@ def slice_torch_tensor(data:np.ndarray, indices_in, indices_out, slice_dict): class TorchBackend(ContractionBackend): + def __init__(self, device='cpu'): - self.device = device + # alias of gpu -> cuda + if device=='gpu': + device='cuda' + # Check that CUDA is available if specified + if device=='cuda': + if not torch.cuda.is_available(): + logger.warning("Cuda is not available. Falling back to CPU") + device = 'cpu' + if device=='xpu': + import intel_extension_for_pytorch as ipex + + + self.device = torch.device(device) + logger.debug("Torch backend using device {}", self.device) self.dtype = ['float', 'double', 'complex64', 'complex128'] self.width_dict = [set() for i in range(30)] self.width_bc = [[0,0] for i in range(30)] #(#distinct_bc, #bc) @@ -91,7 +107,9 @@ def process_bucket(self, bucket, no_sum=False): list(map(int, result_indices)), list(map(int, tensor.indices)) ) + logger.trace('Before contract. Expr: {}, inputs: {}, {}', expr, result_data, tensor) result_data = torch.einsum(expr, result_data, tensor.data) + logger.trace("expression {}. Data: {}, -> {}", expr, tensor.data, result_data) # Merge and sort indices and shapes result_indices = tuple(sorted( @@ -114,7 +132,9 @@ def process_bucket(self, bucket, no_sum=False): list(map(int, result_indices)), list(map(int, tensor.indices)) , contract = 1 ) + logger.trace('Before contract. Expr: {}, inputs: {}, {}', expr, result_data, tensor) result_data = torch.einsum(expr, result_data, tensor.data) + logger.trace("expression {}. Data: {}, -> {}", expr, tensor.data, result_data) result_indices = tuple(sorted( set(result_indices + tensor.indices), key=int, reverse=True @@ -188,11 +208,7 @@ def get_sliced_buckets(self, buckets, data_dict, slice_dict): data = tensor.data # Works for torch tensors just fine if not isinstance(data, torch.Tensor): - if self.device == 'gpu' and torch.cuda.is_available(): - cuda = torch.device('cuda') - data = torch.from_numpy(data.astype(np.complex128)).to(cuda) - else: - data = torch.from_numpy(data.astype(np.complex128)) + data = torch.from_numpy(data.astype(np.complex128)).to(self.device) else: data = data.type(torch.complex128) # slice data @@ -206,3 +222,133 @@ def get_sliced_buckets(self, buckets, data_dict, slice_dict): def get_result_data(self, result): return torch.permute(result.data, tuple(reversed(range(result.data.ndim)))) + +class TorchBackendMatm(TorchBackend): + + def _get_index_sizes(self, *ixs): + try: + sizes = [ i.size for i in ixs ] + except AttributeError: + sizes = [2] * len(ixs) + return sizes + + def _get_index_space_size(self, *ixs): + sizes = self._get_index_sizes(*ixs) + return reduce(np.multiply, sizes, 1) + + def pairwise_sum_contract(self, ixa, a, ixb, b, ixout): + out = ixout + common = set(ixa).intersection(set(ixb)) + # -- sum indices that are in one tensor only + all_ix = set(ixa+ixb) + sum_ix = all_ix - set(out) + a_sum = sum_ix.intersection(set(ixa) - common) + b_sum = sum_ix.intersection(set(ixb) - common) + #print('ab', ixa, ixb) + #print('all sum', sum_ix, 'a/b_sum', a_sum, b_sum) + if len(a_sum): + a = a.sum(axis=tuple(ixa.index(x) for x in a_sum)) + ixa = [x for x in ixa if x not in a_sum] + if len(b_sum): + b = b.sum(axis=tuple(ixb.index(x) for x in b_sum)) + ixb = [x for x in ixb if x not in b_sum] + tensors = a, b + # -- + + ixs = ixa, ixb + common = set(ixs[0]).intersection(set(ixs[1])) + + # \sum_k A_{kfm} * B_{kfn} = C_{fmn} + mix = set(ixs[0]) - common + nix = set(ixs[1]) - common + kix = common - set(out) + fix = common - kix + common = list(kix) + list(fix) + a = tensors[0].transpose(*[ + list(ixs[0]).index(x) for x in common + list(mix) + ]) + + b = tensors[1].transpose(*[ + list(ixs[1]).index(x) for x in common + list(nix) + ]) + + k, f, m, n = [self._get_index_space_size(*ix) + for ix in (kix, fix, mix, nix) + ] + a = a.reshape(k, f, m) + b = b.reshape(k, f, n) + c = torch.einsum('kfm, kfn -> fmn', a, b) + if len(out): + #print('out ix', out, 'kfmnix', kix, fix, mix, nix) + c = c.reshape(*self._get_index_sizes(*out)) + #print('outix', out, 'res', c.shape, 'kfmn',kix, fix, mix, nix) + + current_ord_ = list(fix) + list(mix) + list(nix) + c = c.transpose(*[current_ord_.index(i) for i in out]) + return c + + def process_bucket(self, bucket, no_sum=False): + bucket.sort(key = lambda x: len(x.indices)) + result_indices = bucket[0].indices + result_data = bucket[0].data + width = len(set(bucket[0].indices)) + + for tensor in bucket[1:-1]: + + ixr = list(map(int, result_indices)) + ixt = list(map(int, tensor.indices)) + result_indices = tuple(sorted( + set(result_indices + tensor.indices), + key=int, reverse=True + ) + ) + ixout = list(map(int, result_indices)) + + logger.trace('Before contract. expr: {}, {} ->', ixr, ixt, ixout) + result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout) + #result_data = torch.einsum(expr, result_data, tensor.data) + logger.trace("Data: {}, -> {}", result_data, tensor.data, result_data_new) + result_data = result_data_new + + # Merge and sort indices and shapes + + size = len(set(tensor.indices)) + if size > width: + width = size + + + if len(bucket)>1: + tensor = bucket[-1] + + ixr = list(map(int, result_indices)) + ixt = list(map(int, tensor.indices)) + result_indices = tuple(sorted( + set(result_indices + tensor.indices), + key=int, reverse=True + ))[:-1] + ixout = list(map(int, result_indices)) + + logger.trace('Before contract. expr: {}, {} ->', ixr, ixt, ixout) + result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout) + #result_data = torch.einsum(expr, result_data, tensor.data) + logger.trace("Data: {}, -> {}", result_data, tensor.data, result_data_new) + result_data = result_data_new + else: + result_data = result_data.sum(axis=-1) + + + + if len(result_indices) > 0: + first_index = result_indices[-1] + result_indices = result_indices[:-1] + tag = first_index.identity + else: + tag = 'f' + result_indices = [] + + # reduce + result = qtree.optimizer.Tensor(f'E{tag}', result_indices, + data=result_data) + return result + + diff --git a/qtree b/qtree index 7b038d5a..16efbba2 160000 --- a/qtree +++ b/qtree @@ -1 +1 @@ -Subproject commit 7b038d5a4cc1f9b5e0ede4b0e5740bff4b22153e +Subproject commit 16efbba2566e65a37bb7927f06a80c9f88ac57ff