-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sketch out conj method #46
base: dev
Are you sure you want to change the base?
Changes from 1 commit
eeba989
69d0341
0d74462
b0645a2
d7e24c2
1aafd05
3533289
ea2a43c
0c91a8f
dbeee3d
1d7271b
0026283
b48ea18
efbab98
1631afb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,10 @@ | |
import math | ||
from dataclasses import dataclass | ||
from typing import TypeVar, Generic, Iterable | ||
from qtree import np_framework | ||
|
||
class Array(np.ndarray): | ||
shape: tuple[int] | ||
shape: tuple | ||
|
||
D = TypeVar('D') # tensor data type (numpy, torch, etc.) | ||
|
||
|
@@ -52,16 +53,21 @@ class Port: | |
class TensorNetwork(TensorNetworkIFC[np.ndarray]): | ||
tensors: Iterable[np.ndarray] | ||
shape: tuple | ||
edges: tuple[tuple[Port]] | ||
edges: tuple | ||
|
||
def __init__(self, *args, **kwargs): | ||
self._tensors = [] | ||
self._edges = tuple() | ||
self.shape = tuple() | ||
self.buckets = [] | ||
self.data_dict = {} | ||
|
||
# slice not inplace | ||
def slice(self, slice_dict: dict) -> 'TensorNetwork': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @danlkv thanks for the previous comment, i believed i've addressed that in this update |
||
... | ||
tn = self.copy() | ||
sliced_buckets = np_framework.get_sliced_np_buckets(self.buckets, self.data_dict, slice_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i imagine |
||
tn.buckets = sliced_buckets | ||
return tn | ||
|
||
def copy(self): | ||
new = TensorNetwork() | ||
|
@@ -153,4 +159,6 @@ def __repr__(self): | |
|
||
|
||
|
||
|
||
if __name__ == "__main__": | ||
tn = TensorNetwork.new_random_cpu(2, 3, 4) | ||
import pdb; pdb.set_trace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed all instances of
tuple[...
because was getting error that types can't have[