Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sketch out conj method #46

Open
wants to merge 15 commits into
base: dev
Choose a base branch
from
52 changes: 39 additions & 13 deletions scratchpad/tn_api/tn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, *args, **kwargs):
# slice not inplace
def slice(self, slice_dict: dict) -> 'TensorNetwork':
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danlkv thanks for the previous comment, i believed i've addressed that in this update

tn = self.copy()

new_edge_list = []
for idx, slice_val in slice_dict.items():
# make sure idx is valid
if idx >= len(tn._edges):
Expand All @@ -75,7 +75,7 @@ def slice(self, slice_dict: dict) -> 'TensorNetwork':
edge_list = list(tn._edges)
edge = edge_list.pop(idx)
# now put the updated edges back on the class
tn._edges = tuple(edge_list)
tn._edges = tuple(edge_list) # TODO @dallon - is this the issue, that i'm getting rid of all of the edges?
# get all tensors indexed by this edge
tensors_to_slice = set(port.tensor_ref for port in edge)
dallonasnes marked this conversation as resolved.
Show resolved Hide resolved
# store slice index and value for each tensor
Expand All @@ -98,7 +98,17 @@ def slice(self, slice_dict: dict) -> 'TensorNetwork':
slice_bounds.append(slice_dict[idx])
except KeyError:
slice_bounds.append(slice(None))
tn._tensors[current_tensor_ref] = tn._tensors[current_tensor_ref][tuple(slice_bounds)]
sliced_tensor = tn._tensors[current_tensor_ref][tuple(slice_bounds)]
tn._tensors[current_tensor_ref] = sliced_tensor

# TODO: this is just a guess - but i am adding the ports from the popped edge back to the list of slices
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danlkv where we left off last time is that calling einsum is failing on a sliced tn. it's likely because slicing the tn removes all edges - i don't think that behavior is intended, so my proposal for a fix is that in this section of logic (starting at line 76) when we pop an edge out, at some point we need to add it back.

i've added some pdb loops so that when we err out, we can repro it in the same run

here's an example error output

aa,ba->ab
[(4, 4, 2), (4, 4, 2)]
((Port(tensor_ref=0, ix=0), Port(tensor_ref=0, ix=1), Port(tensor_ref=1, ix=1)), (Port(tensor_ref=1, ix=0),))
2
operand has more dimensions than subscripts given in einstein sum, but no '...' ellipsis provided to broadcast the extra dimensions.
[1] > /app/scratchpad/tn_api/tn.py(172)contract()
-> keep_going = True
(Pdb++) c
Traceback (most recent call last):
  File "tn_api/tn.py", line 168, in contract
    return np.einsum(einsum_expr, *self._tensors)
  File "<__array_function__ internals>", line 200, in einsum
  File "/usr/local/lib/python3.8/dist-packages/numpy/core/einsumfunc.py", line 1371, in einsum
    return c_einsum(*operands, **kwargs)
ValueError: operand has more dimensions than subscripts given in einstein sum, but no '...' ellipsis provided to broadcast the extra dimensions.

# for port in edge:
# if port.ix in slice_dict and port.tensor_ref == current_tensor_ref:
# new_edge_list.append(Port(tensor_ref=current_tensor_ref, ix=port.ix))

# # update the tensor network's edges with the new edges
# tn._edges = tuple(new_edge_list)


return tn

Expand Down Expand Up @@ -154,11 +164,17 @@ def contract(self, contraction_info: ContractionInfo) -> np.ndarray:
print([t.shape for t in self._tensors])
print(self._edges)
print(len(self._tensors))
import pdb; pdb.set_trace()
try:
return np.einsum(einsum_expr, *self._tensors)
except:
except Exception as e:
print(e)
import pdb; pdb.set_trace()
keep_going = True
while keep_going:
einsum_expr = self._get_einsum_expr(contraction_info)
res = np.einsum(einsum_expr, *self._tensors)
import pdb; pdb.set_trace()


# for reference, see qtensor/contraction_backends/numpy.py -> get_einsum_expr
def _get_einsum_expr(self, contraction_info: ContractionInfo) -> str:
Expand Down Expand Up @@ -266,21 +282,31 @@ def __repr__(self):
dim = 3
tn = TensorNetwork.new_random_cpu(2, dim, 4)
slice_dict = {0: slice(0, 2), 1: slice(1, 3)}
sliced_tn = tn.slice(slice_dict) # TODO: go through debugger here to make sure that certain edges of the same port aren't being skipped
sliced_tn = tn.slice(slice_dict)
print(len(sliced_tn._edges))
import pdb; pdb.set_trace()
# TODO: go through slice method debugger here to make sure that certain edges of the same port aren't being skipped
# TODO: st i can run contract on a sliced tn without it breaking

# Where did I leave off?
# Having trouble verifying tests, perhaps logic is incorrect but it makes sense to me
"""
# TODO: issue is that slicing results in no edges
,-> #einsum expression
[(2, 2, 2), (2, 2, 2)] #tensor shapes
() #edges
2 #tensors count
"""

# can also do "contract all except..." by knowing indices of edges in tn
# generate random indices to contract

random_indices_to_contract = tn._get_random_indices_to_contract(2)
random_indices_to_contract = sliced_tn._get_random_indices_to_contract(2)

contraction_info = ContractionInfo(tuple(random_indices_to_contract))
import pdb; pdb.set_trace()
contracted_tensor = tn.contract(contraction_info)
print("success")
print("starting not sliced tensor")
contracted_tensor_not_slice = tn.contract(contraction_info)
print("finished not sliced tensor")
print("starting sliced tensor")
contracted_tensor = sliced_tn.contract(contraction_info)
print("finished sliced tensor")
import pdb; pdb.set_trace()

"""
Expand Down
Loading