Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sketch out conj method #46

Open
wants to merge 15 commits into
base: dev
Choose a base branch
from
19 changes: 14 additions & 5 deletions scratchpad/tn_api/tn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import math
from dataclasses import dataclass
from typing import TypeVar, Generic, Iterable
from qtree import np_framework

class Array(np.ndarray):
shape: tuple
Expand Down Expand Up @@ -59,14 +58,22 @@ def __init__(self, *args, **kwargs):
self._tensors = []
self._edges = tuple()
self.shape = tuple()
self.buckets = []
self.data_dict = {}

# slice not inplace
def slice(self, slice_dict: dict) -> 'TensorNetwork':
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danlkv thanks for the previous comment, i believed i've addressed that in this update

tn = self.copy()
sliced_buckets = np_framework.get_sliced_np_buckets(self.buckets, self.data_dict, slice_dict)
tn.buckets = sliced_buckets
sliced_tns = []
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danlkv i think this implementation is better, wdyt?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is much better. However, we need a way to refer to tensor indices globally. For example, what would happen if using this code we sliced a tensor with 3 tensors of shape (2, 2) and slice dict {1: 0}? The answer is that all tensors will be sliced along the 2nd dimension, but we want a behavior that is global to all TN

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, you have to consider how the edges change as well.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dallonasnes Here's a sketch of what can be a solution.
First, let me clear out something about the edges structure. Edges in the edge array represent delta-tensors, which make all indices that are connected to them the same. For example for expression T_{abc}\delta_{abd} B_{edg}C_{mp} the equivalent expression is T_{xxc } B_{exg} C_{mp}. Note that you can have a one-to-one correspondence between each delta and an index (in the example index x).

Let's assume that all indices in our TN are represented by a delta tensor and its corresponding tuple in the edges attribute (This may seem excessive, but we can find overcome the overhead later). Now, a slice dict may refer to the element in the edges attribute by using just index of the edge.

So the slice dict is {index in edge dict: index value}

The algorithm would be
1 For each index, value pair of slice dict

  1. Pop the edge from edges attr using the index
  2. Get all tensors that are indexed by this edge (in the example above that would be T and B). Use Port.tensor_ix for that.
  3. Using the information in ports in edges, construct the local (wrt each tensor) slicing dict for each tensor
  4. Slice and update the tensor.

Note that after slicing the edge disappears.

for tensor in tn._tensors:
slice_bounds = []
for idx in range(tensor.ndim):
try:
slice_bounds.append(slice_dict[idx])
except KeyError:
slice_bounds.append(slice(None))

sliced_tns.append(tensor[tuple(slice_bounds)])

tn._tensors = sliced_tns
return tn

def copy(self):
Expand Down Expand Up @@ -161,4 +168,6 @@ def __repr__(self):

if __name__ == "__main__":
tn = TensorNetwork.new_random_cpu(2, 3, 4)
slice_dict = {0: slice(0, 2), 1: slice(1, 3)}
sliced_tn = tn.slice(slice_dict)
import pdb; pdb.set_trace()
Loading