Skip to content

Commit

Permalink
fix in ortho routines
Browse files Browse the repository at this point in the history
  • Loading branch information
PGelss authored Jul 17, 2024
1 parent 11e3d42 commit a2c9d58
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions scikit_tt/tensor_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,15 +1141,16 @@ def ortho_left(self, start_index: int=0,

# check for correct max_rank argument and set max_ranks
max_rank_tf = True
if not isinstance(max_rank, list) and ((isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty):
if (isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty:
max_ranks = [1] + [max_rank for _ in range(self.order-1)] + [1]
else:
if len(max_rank) == self.order+1:
for i in range(self.order+1):
if not ((isinstance(max_rank[i], (int, np.int32, np.int64)) and max_rank[i] > 0) or max_rank[i] == np.infty):
max_rank_tf = False
if max_rank_tf:
max_ranks = max_rank
if isinstance(max_rank, list):
if len(max_rank) == self.order+1:
for i in range(self.order+1):
if not ((isinstance(max_rank[i], (int, np.int32, np.int64)) and max_rank[i] > 0) or max_rank[i] == np.infty):
max_rank_tf = False
if max_rank_tf:
max_ranks = max_rank

if max_rank_tf:

Expand Down Expand Up @@ -1193,7 +1194,7 @@ def ortho_left(self, start_index: int=0,
return self

else:
raise ValueError('Maximum rank must be a positive integer.')
raise ValueError('Maximum rank(s) must be positive integers.')

else:
raise ValueError('Threshold must be greater or equal 0.')
Expand Down Expand Up @@ -1244,15 +1245,16 @@ def ortho_right(self, start_index: Optional[int]=None,

# check for correct max_rank argument and set max_ranks
max_rank_tf = True
if not isinstance(max_rank, list) and ((isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty):
if (isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty:
max_ranks = [1] + [max_rank for _ in range(self.order-1)] + [1]
else:
if len(max_rank) == self.order+1:
for i in range(self.order+1):
if not ((isinstance(max_rank[i], (int, np.int32, np.int64)) and max_rank[i] > 0) or max_rank[i] == np.infty):
max_rank_tf = False
if max_rank_tf:
max_ranks = max_rank
if isinstance(max_rank, list):
if len(max_rank) == self.order+1:
for i in range(self.order+1):
if not ((isinstance(max_rank[i], (int, np.int32, np.int64)) and max_rank[i] > 0) or max_rank[i] == np.infty):
max_rank_tf = False
if max_rank_tf:
max_ranks = max_rank

if max_rank_tf:

Expand Down Expand Up @@ -1298,7 +1300,7 @@ def ortho_right(self, start_index: Optional[int]=None,
return self

else:
raise ValueError('Maximum rank must be a positive integer.')
raise ValueError('Maximum rank(s) must be positive integers.')

else:
raise ValueError('Threshold must be greater or equal 0.')
Expand Down

0 comments on commit a2c9d58

Please sign in to comment.