Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dallonasnes committed Nov 13, 2023
1 parent 1d7271b commit 0026283
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions scratchpad/tn_api/tn.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,25 +134,20 @@ def add(self, other: "TensorNetwork | np.ndarray"):
# contract to produce a new tensor
def contract(self, contraction_info: ContractionInfo) -> np.ndarray:
einsum_expr = self._get_einsum_expr(contraction_info)
print(einsum_expr)
print(len(self._tensors))
import pdb; pdb.set_trace()
return np.einsum(einsum_expr, self._tensors)


# what can I understand next?
# if i setup this contract method to work for QTensorNet, it would be pretty similar becuase it also uses np.einsum
# other types can also be supported, such as QTensor and Quimb
# those would require setting up the optimize method to build the buckets
# TODO is read through and understand what it would take to set each of those up


def _get_random_indices_to_contract(self, count=2):
import random
tn_copy = self.copy()
indices_to_contract = []
counter = 0
edges_with_indices = [idx for idx, port in enumerate(list(tn_copy._edges))]

while counter < count:
while counter < count and len(edges_with_indices) > 0:
random_element = random.choice(edges_with_indices)
edges_with_indices.remove(random_element)
indices_to_contract.append(random_element)
Expand All @@ -161,8 +156,7 @@ def _get_random_indices_to_contract(self, count=2):
return sorted(indices_to_contract)


# based on implementation in
# qtensor/contraction_backends/numpy.py -> get_einsum_expr
# for reference, see qtensor/contraction_backends/numpy.py -> get_einsum_expr
def _get_einsum_expr(self, contraction_info: ContractionInfo) -> str:
# mapping from tensor index to a tuple of edges that preserves ordering
# st can lookup tix -> tuple(idx of edges) # this iterable needs to be sorted by of port.ix
Expand Down Expand Up @@ -199,6 +193,9 @@ def _get_einsum_expr(self, contraction_info: ContractionInfo) -> str:
substr += edge_to_char[edge_idx]
substrs_to_join.append(substr)

for ix in contraction_info.result_indices:
if ix not in edge_to_char:
raise ValueError("result expects invalid indices")
expr = ','.join(substrs_to_join) + '->' + ''.join(edge_to_char[ix] for ix in contraction_info.result_indices)
return expr

Expand Down Expand Up @@ -270,7 +267,7 @@ def __repr__(self):
# generate random indices to contract

random_indices_to_contract = tn._get_random_indices_to_contract(2)
# random_indices_to_contract = (np.random.randint(0, len(tn._edges)), np.random.randint(0, len(tn._edges)), np.random.randint(0, len(tn._edges)),)

contraction_info = ContractionInfo(tuple(random_indices_to_contract))

contracted_tensor = tn.contract(contraction_info)
Expand Down

0 comments on commit 0026283

Please sign in to comment.