Skip to content

Commit

Permalink
add torch matm backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuri Alexeev committed Mar 15, 2024
1 parent a3d228b commit de30241
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 8 deletions.
160 changes: 153 additions & 7 deletions qtensor/contraction_backends/torch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from qtensor.tools.lazy_import import torch
import qtree
import numpy as np
from functools import reduce
from qtree import np_framework
from qtensor.contraction_backends import ContractionBackend
from .common import get_slice_bounds, get_einsum_expr, slice_numpy_tensor
import string
from loguru import logger
CHARS = string.ascii_lowercase + string.ascii_uppercase

def qtree2torch_tensor(tensor, data_dict):
Expand Down Expand Up @@ -63,7 +65,7 @@ def slice_torch_tensor(data:np.ndarray, indices_in, indices_out, slice_dict):
indices_sliced = [
i for sl, i in zip(slice_bounds, indices_in) if not isinstance(sl, int)
]
print(f'indicies_in {indices_in}, slice_dict {slice_dict}, bounds {slice_bounds}, slicedix {indices_sliced}, sshape {s_data.shape}')
#print(f'indicies_in {indices_in}, slice_dict {slice_dict}, bounds {slice_bounds}, slicedix {indices_sliced}, sshape {s_data.shape}')
indices_sized = [v.copy(size=size) for v, size in zip(indices_sliced, s_data.shape)]
indices_out = [v for v in indices_out if not isinstance(slice_dict.get(v, None), int)]
assert len(indices_sized) == len(s_data.shape)
Expand All @@ -73,8 +75,22 @@ def slice_torch_tensor(data:np.ndarray, indices_in, indices_out, slice_dict):


class TorchBackend(ContractionBackend):

def __init__(self, device='cpu'):
self.device = device
# alias of gpu -> cuda
if device=='gpu':
device='cuda'
# Check that CUDA is available if specified
if device=='cuda':
if not torch.cuda.is_available():
logger.warning("Cuda is not available. Falling back to CPU")
device = 'cpu'
if device=='xpu':
import intel_extension_for_pytorch as ipex


self.device = torch.device(device)
logger.debug("Torch backend using device {}", self.device)
self.dtype = ['float', 'double', 'complex64', 'complex128']
self.width_dict = [set() for i in range(30)]
self.width_bc = [[0,0] for i in range(30)] #(#distinct_bc, #bc)
Expand All @@ -91,7 +107,9 @@ def process_bucket(self, bucket, no_sum=False):
list(map(int, result_indices)), list(map(int, tensor.indices))
)

logger.trace('Before contract. Expr: {}, inputs: {}, {}', expr, result_data, tensor)
result_data = torch.einsum(expr, result_data, tensor.data)
logger.trace("expression {}. Data: {}, -> {}", expr, tensor.data, result_data)

# Merge and sort indices and shapes
result_indices = tuple(sorted(
Expand All @@ -114,7 +132,9 @@ def process_bucket(self, bucket, no_sum=False):
list(map(int, result_indices)), list(map(int, tensor.indices))
, contract = 1
)
logger.trace('Before contract. Expr: {}, inputs: {}, {}', expr, result_data, tensor)
result_data = torch.einsum(expr, result_data, tensor.data)
logger.trace("expression {}. Data: {}, -> {}", expr, tensor.data, result_data)
result_indices = tuple(sorted(
set(result_indices + tensor.indices),
key=int, reverse=True
Expand Down Expand Up @@ -188,11 +208,7 @@ def get_sliced_buckets(self, buckets, data_dict, slice_dict):
data = tensor.data
# Works for torch tensors just fine
if not isinstance(data, torch.Tensor):
if self.device == 'gpu' and torch.cuda.is_available():
cuda = torch.device('cuda')
data = torch.from_numpy(data.astype(np.complex128)).to(cuda)
else:
data = torch.from_numpy(data.astype(np.complex128))
data = torch.from_numpy(data.astype(np.complex128)).to(self.device)
else:
data = data.type(torch.complex128)
# slice data
Expand All @@ -206,3 +222,133 @@ def get_sliced_buckets(self, buckets, data_dict, slice_dict):

def get_result_data(self, result):
return torch.permute(result.data, tuple(reversed(range(result.data.ndim))))

class TorchBackendMatm(TorchBackend):

def _get_index_sizes(self, *ixs):
try:
sizes = [ i.size for i in ixs ]
except AttributeError:
sizes = [2] * len(ixs)
return sizes

def _get_index_space_size(self, *ixs):
sizes = self._get_index_sizes(*ixs)
return reduce(np.multiply, sizes, 1)

def pairwise_sum_contract(self, ixa, a, ixb, b, ixout):
out = ixout
common = set(ixa).intersection(set(ixb))
# -- sum indices that are in one tensor only
all_ix = set(ixa+ixb)
sum_ix = all_ix - set(out)
a_sum = sum_ix.intersection(set(ixa) - common)
b_sum = sum_ix.intersection(set(ixb) - common)
#print('ab', ixa, ixb)
#print('all sum', sum_ix, 'a/b_sum', a_sum, b_sum)
if len(a_sum):
a = a.sum(axis=tuple(ixa.index(x) for x in a_sum))
ixa = [x for x in ixa if x not in a_sum]
if len(b_sum):
b = b.sum(axis=tuple(ixb.index(x) for x in b_sum))
ixb = [x for x in ixb if x not in b_sum]
tensors = a, b
# --

ixs = ixa, ixb
common = set(ixs[0]).intersection(set(ixs[1]))

# \sum_k A_{kfm} * B_{kfn} = C_{fmn}
mix = set(ixs[0]) - common
nix = set(ixs[1]) - common
kix = common - set(out)
fix = common - kix
common = list(kix) + list(fix)
a = tensors[0].transpose(*[
list(ixs[0]).index(x) for x in common + list(mix)
])

b = tensors[1].transpose(*[
list(ixs[1]).index(x) for x in common + list(nix)
])

k, f, m, n = [self._get_index_space_size(*ix)
for ix in (kix, fix, mix, nix)
]
a = a.reshape(k, f, m)
b = b.reshape(k, f, n)
c = torch.einsum('kfm, kfn -> fmn', a, b)
if len(out):
#print('out ix', out, 'kfmnix', kix, fix, mix, nix)
c = c.reshape(*self._get_index_sizes(*out))
#print('outix', out, 'res', c.shape, 'kfmn',kix, fix, mix, nix)

current_ord_ = list(fix) + list(mix) + list(nix)
c = c.transpose(*[current_ord_.index(i) for i in out])
return c

def process_bucket(self, bucket, no_sum=False):
bucket.sort(key = lambda x: len(x.indices))
result_indices = bucket[0].indices
result_data = bucket[0].data
width = len(set(bucket[0].indices))

for tensor in bucket[1:-1]:

ixr = list(map(int, result_indices))
ixt = list(map(int, tensor.indices))
result_indices = tuple(sorted(
set(result_indices + tensor.indices),
key=int, reverse=True
)
)
ixout = list(map(int, result_indices))

logger.trace('Before contract. expr: {}, {} ->', ixr, ixt, ixout)
result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout)
#result_data = torch.einsum(expr, result_data, tensor.data)
logger.trace("Data: {}, -> {}", result_data, tensor.data, result_data_new)
result_data = result_data_new

# Merge and sort indices and shapes

size = len(set(tensor.indices))
if size > width:
width = size


if len(bucket)>1:
tensor = bucket[-1]

ixr = list(map(int, result_indices))
ixt = list(map(int, tensor.indices))
result_indices = tuple(sorted(
set(result_indices + tensor.indices),
key=int, reverse=True
))[:-1]
ixout = list(map(int, result_indices))

logger.trace('Before contract. expr: {}, {} ->', ixr, ixt, ixout)
result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout)
#result_data = torch.einsum(expr, result_data, tensor.data)
logger.trace("Data: {}, -> {}", result_data, tensor.data, result_data_new)
result_data = result_data_new
else:
result_data = result_data.sum(axis=-1)



if len(result_indices) > 0:
first_index = result_indices[-1]
result_indices = result_indices[:-1]
tag = first_index.identity
else:
tag = 'f'
result_indices = []

# reduce
result = qtree.optimizer.Tensor(f'E{tag}', result_indices,
data=result_data)
return result


2 changes: 1 addition & 1 deletion qtree
Submodule qtree updated from 7b038d to 16efbb

0 comments on commit de30241

Please sign in to comment.