-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sketch out conj method #46
base: dev
Are you sure you want to change the base?
Changes from 1 commit
eeba989
69d0341
0d74462
b0645a2
d7e24c2
1aafd05
3533289
ea2a43c
0c91a8f
dbeee3d
1d7271b
0026283
b48ea18
efbab98
1631afb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,6 @@ | |
import math | ||
from dataclasses import dataclass | ||
from typing import TypeVar, Generic, Iterable | ||
from qtree import np_framework | ||
|
||
class Array(np.ndarray): | ||
shape: tuple | ||
|
@@ -59,14 +58,22 @@ def __init__(self, *args, **kwargs): | |
self._tensors = [] | ||
self._edges = tuple() | ||
self.shape = tuple() | ||
self.buckets = [] | ||
self.data_dict = {} | ||
|
||
# slice not inplace | ||
def slice(self, slice_dict: dict) -> 'TensorNetwork': | ||
tn = self.copy() | ||
sliced_buckets = np_framework.get_sliced_np_buckets(self.buckets, self.data_dict, slice_dict) | ||
tn.buckets = sliced_buckets | ||
sliced_tns = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @danlkv i think this implementation is better, wdyt? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is much better. However, we need a way to refer to tensor indices globally. For example, what would happen if using this code we sliced a tensor with 3 tensors of shape (2, 2) and slice dict {1: 0}? The answer is that all tensors will be sliced along the 2nd dimension, but we want a behavior that is global to all TN There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In addition, you have to consider how the edges change as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dallonasnes Here's a sketch of what can be a solution. Let's assume that all indices in our TN are represented by a delta tensor and its corresponding tuple in the edges attribute (This may seem excessive, but we can find overcome the overhead later). Now, a slice dict may refer to the element in the edges attribute by using just index of the edge. So the slice dict is {index in edge dict: index value} The algorithm would be
Note that after slicing the edge disappears. |
||
for tensor in tn._tensors: | ||
slice_bounds = [] | ||
for idx in range(tensor.ndim): | ||
try: | ||
slice_bounds.append(slice_dict[idx]) | ||
except KeyError: | ||
slice_bounds.append(slice(None)) | ||
|
||
sliced_tns.append(tensor[tuple(slice_bounds)]) | ||
|
||
tn._tensors = sliced_tns | ||
return tn | ||
|
||
def copy(self): | ||
|
@@ -161,4 +168,6 @@ def __repr__(self): | |
|
||
if __name__ == "__main__": | ||
tn = TensorNetwork.new_random_cpu(2, 3, 4) | ||
slice_dict = {0: slice(0, 2), 1: slice(1, 3)} | ||
sliced_tn = tn.slice(slice_dict) | ||
import pdb; pdb.set_trace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danlkv thanks for the previous comment, i believed i've addressed that in this update