From bb6949f45a1253b8f05824aa1e602a96906557b8 Mon Sep 17 00:00:00 2001 From: Dan Lykov Date: Fri, 4 Oct 2024 16:47:13 +0000 Subject: [PATCH] fix bug when bucket contains 2+ scalar tensors only --- qtensor/compression/Compressor.py | 4 ++-- qtensor/contraction_backends/compression.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/qtensor/compression/Compressor.py b/qtensor/compression/Compressor.py index 26d42a94..9195a63d 100644 --- a/qtensor/compression/Compressor.py +++ b/qtensor/compression/Compressor.py @@ -150,7 +150,7 @@ def __init__(self, r2r_error=1e-3, r2r_threshold=1e-3): def free_decompressed(self): import cupy - print("Cleanup", len(self.decompressed_own)) + #print("Cleanup", len(self.decompressed_own)) for x in self.decompressed_own: del x cupy.get_default_memory_pool().free_all_blocks() @@ -192,7 +192,7 @@ def __init__(self, r2r_error=1e-3, r2r_threshold=1e-3): def free_decompressed(self): import cupy - print("Cleanup", len(self.decompressed_own)) + #print("Cleanup", len(self.decompressed_own)) for x in self.decompressed_own: #print(x) #if x == None: diff --git a/qtensor/contraction_backends/compression.py b/qtensor/contraction_backends/compression.py index 03059371..9a0789e4 100644 --- a/qtensor/contraction_backends/compression.py +++ b/qtensor/contraction_backends/compression.py @@ -55,6 +55,10 @@ def process_bucket(self, bucket, no_sum=False): ) if len(bucket)>1: t = bucket[-1] + if len(t.indices) == 0: + print(f"Scalar tensor {t}, {t.data}") + accum = accum * t + return accum total_ixs = sorted( set().union(*[t.indices, accum.indices]) , key=int, reverse=True