diff --git a/scikit_tt/tensor_train.py b/scikit_tt/tensor_train.py index 88d75be..45e0fb0 100644 --- a/scikit_tt/tensor_train.py +++ b/scikit_tt/tensor_train.py @@ -1704,14 +1704,12 @@ def pinv(self, return p_inv - def diag(self, t: 'TT', diag_list: List) -> 'TT': + def diag(self, diag_list: List) -> 'TT': """ Construction of diagonal MPO from MPS. Parameters ---------- - t: TT - MPS decomposition diag_list: list cores to diagonalize @@ -1721,7 +1719,7 @@ def diag(self, t: 'TT', diag_list: List) -> 'TT': MPO decomposition """ - cores = t.cores.copy() + cores = self.cores.copy() for i in diag_list: r1, d, _, r2 = cores[i].shape