Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dallonasnes committed Oct 30, 2023
1 parent b0645a2 commit d7e24c2
Showing 1 changed file with 41 additions and 12 deletions.
53 changes: 41 additions & 12 deletions scratchpad/tn_api/tn.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,29 +58,58 @@ def __init__(self, *args, **kwargs):
self._tensors = []
self._edges = tuple()
self.shape = tuple()
self._indices = {}

# slice not inplace
def slice(self, slice_dict: dict) -> 'TensorNetwork':
tn = self.copy()
sliced_tns = []
for tensor in tn._tensors:
slice_bounds = []
for idx in range(tensor.ndim):
try:
slice_bounds.append(slice_dict[idx])
except KeyError:
slice_bounds.append(slice(None))

sliced_tns.append(tensor[tuple(slice_bounds)])

tn._tensors = sliced_tns

# hydrate the index map of edge index to tensor and dimension (via port object)
for idx, edge in enumerate(tn._edges):
for port in edge:
if idx in tn._indices:
tn._indices[idx].append(port)
else:
tn._indices[idx] = [port]

for idx, slice_val in slice_dict.items():
# don't care about indices that are not in TN
if idx not in tn._indices:
continue

edge = tn._edges.pop(idx)
# get all tensors indexed by this edge
tensors_to_slice = set(port.tensor_ref for port in edge)
# store slice index and value for each tensor
local_slices_dict = {}
for current_tensor_ref in tensors_to_slice:
slice_dict = {}
# get all ports for the current tensor
current_tensor_ref_ports = [port for port in edge where port.tensor_ref == current_tensor_ref]
for current_port in current_tensor_ref_ports:
slice_dict[current_port.ix] = slice_val
# store the slice params for this tensor in the local dict
local_slices_dict[current_tensor_ref] = slice_dict

# now use the local slice dict to slice for each tensor
for current_tensor_ref, slice_dict in local_slices_dict.items():
slice_bounds = []
current_tensor = tn._tensors[current_tensor_ref]
for idx in range(current_tensor.ndim):
try:
slice_bounds.append(slice_dict[idx])
except KeyError:
slice_bounds.append(slice(None))
tn._tensors[current_tensor_ref] = tn._tensors[current_tensor_ref][tuple(slice_bounds)]

return tn

def copy(self):
new = TensorNetwork()
new._tensors = self._tensors
new._edges = self._edges
new.shape = self.shape
new.indices = self.indices
return new

def add(self, other: "TensorNetwork | np.ndarray"):
Expand Down

0 comments on commit d7e24c2

Please sign in to comment.