-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sketch out conj method #46
base: dev
Are you sure you want to change the base?
Changes from 1 commit
eeba989
69d0341
0d74462
b0645a2
d7e24c2
1aafd05
3533289
ea2a43c
0c91a8f
dbeee3d
1d7271b
0026283
b48ea18
efbab98
1631afb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,7 +66,7 @@ def __init__(self, *args, **kwargs): | |
# slice not inplace | ||
def slice(self, slice_dict: dict) -> 'TensorNetwork': | ||
tn = self.copy() | ||
|
||
new_edge_list = [] | ||
for idx, slice_val in slice_dict.items(): | ||
# make sure idx is valid | ||
if idx >= len(tn._edges): | ||
|
@@ -75,7 +75,7 @@ def slice(self, slice_dict: dict) -> 'TensorNetwork': | |
edge_list = list(tn._edges) | ||
edge = edge_list.pop(idx) | ||
# now put the updated edges back on the class | ||
tn._edges = tuple(edge_list) | ||
tn._edges = tuple(edge_list) # TODO @dallon - is this the issue, that i'm getting rid of all of the edges? | ||
# get all tensors indexed by this edge | ||
tensors_to_slice = set(port.tensor_ref for port in edge) | ||
dallonasnes marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# store slice index and value for each tensor | ||
|
@@ -98,7 +98,17 @@ def slice(self, slice_dict: dict) -> 'TensorNetwork': | |
slice_bounds.append(slice_dict[idx]) | ||
except KeyError: | ||
slice_bounds.append(slice(None)) | ||
tn._tensors[current_tensor_ref] = tn._tensors[current_tensor_ref][tuple(slice_bounds)] | ||
sliced_tensor = tn._tensors[current_tensor_ref][tuple(slice_bounds)] | ||
tn._tensors[current_tensor_ref] = sliced_tensor | ||
|
||
# TODO: this is just a guess - but i am adding the ports from the popped edge back to the list of slices | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @danlkv where we left off last time is that calling einsum is failing on a sliced tn. it's likely because slicing the tn removes all edges - i don't think that behavior is intended, so my proposal for a fix is that in this section of logic (starting at line 76) when we pop an edge out, at some point we need to add it back. i've added some pdb loops so that when we err out, we can repro it in the same run here's an example error output
|
||
# for port in edge: | ||
# if port.ix in slice_dict and port.tensor_ref == current_tensor_ref: | ||
# new_edge_list.append(Port(tensor_ref=current_tensor_ref, ix=port.ix)) | ||
|
||
# # update the tensor network's edges with the new edges | ||
# tn._edges = tuple(new_edge_list) | ||
|
||
|
||
return tn | ||
|
||
|
@@ -154,11 +164,17 @@ def contract(self, contraction_info: ContractionInfo) -> np.ndarray: | |
print([t.shape for t in self._tensors]) | ||
print(self._edges) | ||
print(len(self._tensors)) | ||
import pdb; pdb.set_trace() | ||
try: | ||
return np.einsum(einsum_expr, *self._tensors) | ||
except: | ||
except Exception as e: | ||
print(e) | ||
import pdb; pdb.set_trace() | ||
keep_going = True | ||
while keep_going: | ||
einsum_expr = self._get_einsum_expr(contraction_info) | ||
res = np.einsum(einsum_expr, *self._tensors) | ||
import pdb; pdb.set_trace() | ||
|
||
|
||
# for reference, see qtensor/contraction_backends/numpy.py -> get_einsum_expr | ||
def _get_einsum_expr(self, contraction_info: ContractionInfo) -> str: | ||
|
@@ -266,21 +282,31 @@ def __repr__(self): | |
dim = 3 | ||
tn = TensorNetwork.new_random_cpu(2, dim, 4) | ||
slice_dict = {0: slice(0, 2), 1: slice(1, 3)} | ||
sliced_tn = tn.slice(slice_dict) # TODO: go through debugger here to make sure that certain edges of the same port aren't being skipped | ||
sliced_tn = tn.slice(slice_dict) | ||
print(len(sliced_tn._edges)) | ||
import pdb; pdb.set_trace() | ||
# TODO: go through slice method debugger here to make sure that certain edges of the same port aren't being skipped | ||
# TODO: st i can run contract on a sliced tn without it breaking | ||
|
||
# Where did I leave off? | ||
# Having trouble verifying tests, perhaps logic is incorrect but it makes sense to me | ||
""" | ||
# TODO: issue is that slicing results in no edges | ||
,-> #einsum expression | ||
[(2, 2, 2), (2, 2, 2)] #tensor shapes | ||
() #edges | ||
2 #tensors count | ||
""" | ||
|
||
# can also do "contract all except..." by knowing indices of edges in tn | ||
# generate random indices to contract | ||
|
||
random_indices_to_contract = tn._get_random_indices_to_contract(2) | ||
random_indices_to_contract = sliced_tn._get_random_indices_to_contract(2) | ||
|
||
contraction_info = ContractionInfo(tuple(random_indices_to_contract)) | ||
import pdb; pdb.set_trace() | ||
contracted_tensor = tn.contract(contraction_info) | ||
print("success") | ||
print("starting not sliced tensor") | ||
contracted_tensor_not_slice = tn.contract(contraction_info) | ||
print("finished not sliced tensor") | ||
print("starting sliced tensor") | ||
contracted_tensor = sliced_tn.contract(contraction_info) | ||
print("finished sliced tensor") | ||
import pdb; pdb.set_trace() | ||
|
||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danlkv thanks for the previous comment, i believed i've addressed that in this update