Skip to content

Commit

Permalink
fix torch slicing: ignore data_dict if data is present
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuri Alexeev committed Feb 16, 2024
1 parent e9b6a9f commit a3d228b
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion qtensor/contraction_backends/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def slice_torch_tensor(data:np.ndarray, indices_in, indices_out, slice_dict):
indices_sliced = [
i for sl, i in zip(slice_bounds, indices_in) if not isinstance(sl, int)
]
print(f'indicies_in {indices_in}, slice_dict {slice_dict}, bounds {slice_bounds}, slicedix {indices_sliced}, sshape {s_data.shape}')
indices_sized = [v.copy(size=size) for v, size in zip(indices_sliced, s_data.shape)]
indices_out = [v for v in indices_out if not isinstance(slice_dict.get(v, None), int)]
assert len(indices_sized) == len(s_data.shape)
Expand Down Expand Up @@ -181,7 +182,10 @@ def get_sliced_buckets(self, buckets, data_dict, slice_dict):
# get data
# sort tensor dimensions
out_indices = list(sorted(tensor.indices, key=int, reverse=True))
data = data_dict[tensor.data_key]
if tensor.data is None:
data = data_dict[tensor.data_key]
else:
data = tensor.data
# Works for torch tensors just fine
if not isinstance(data, torch.Tensor):
if self.device == 'gpu' and torch.cuda.is_available():
Expand Down

0 comments on commit a3d228b

Please sign in to comment.