Skip to content

Commit

Permalink
update in slicing history shape
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Lykov committed Dec 1, 2023
1 parent 957ddb5 commit 5f5a15c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 4 additions & 1 deletion qtensor/contraction_backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ def slice_numpy_tensor(data:np.ndarray, indices_in, indices_out, slice_dict):
i for sl, i in zip(slice_bounds, indices_in) if not isinstance(sl, int)
]
indices_sized = [v.copy(size=size) for v, size in zip(indices_sliced, s_data.shape)]
#print("indices_sized", indices_sized)
#print("Slice bounds", slice_bounds)
#print("Slice dict", slice_dict)
#print("data shape, sliced data shape", data.shape, s_data.shape)
indices_out = [v for v in indices_out if not isinstance(slice_dict.get(v, None), int)]
assert len(indices_sized) == len(s_data.shape)
assert len(indices_sliced) == len(s_data.shape)
st_data = permute_np_tensor_data(s_data, indices_sliced, indices_out)
return st_data, indices_out
Expand Down
4 changes: 2 additions & 2 deletions qtensor/optimisation/Optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def _split_graph(self, p_graph, max_tw):
peo_ints = self.peo_ints
tw = self.treewidth
self._slice_hist = []
self._slice_hist.append([0, tw])
self._slice_hist.append([0, tw, peo_ints])
log.info('Treewidth: {}', tw)
log.info('Target treewidth: {}', max_tw)
result = []
Expand Down Expand Up @@ -354,7 +354,7 @@ def _split_graph(self, p_graph, max_tw):

peo_ints, path = self._update_peo_after_slice(p_graph, result)
tw = max(path)
self._slice_hist.append([pv_cnt, tw])
self._slice_hist.append([pv_cnt, tw, peo_ints])
delta = tw - max_tw

return peo_ints, result
Expand Down

0 comments on commit 5f5a15c

Please sign in to comment.