diff --git a/qtensor/compression/Compressor.py b/qtensor/compression/Compressor.py index 407e8410..39e2aef4 100644 --- a/qtensor/compression/Compressor.py +++ b/qtensor/compression/Compressor.py @@ -176,9 +176,8 @@ def free_compressed(self, ptr): #return import ctypes, cupy #cmp_bytes, num_elements_eff, shape, dtype, _ = ptr - cmp_t_real, cmp_t_imag, shape, dtype = ptr - del cmp_t_real - del cmp_t_imag + cmp_t, shape, dtype = ptr + del cmp_t torch.cuda.empty_cache() return print(f"Freeing compressed data {num_elements_eff}") @@ -194,31 +193,34 @@ def free_compressed(self, ptr): def compress(self, data): isCupy, num_elements_eff = _get_data_info(data) dtype = data.dtype + shape = data.shape # convert cupy to torch - data_imag = torch.as_tensor(data.imag, device='cuda').contiguous() - data_real = torch.as_tensor(data.real, device='cuda').contiguous() - print(f"cuszp Compressing {type(data)}") + # TODO: cast to one array of double the number of elements + torch_data = torch.tensor(data, device='cuda') + data_view = torch.view_as_real(torch_data) + #print(f"cuszp Compressing {type(data)}") #cmp_bytes, outSize_ptr = cuszp_device_compress(data, self.r2r_error, num_elements_eff, self.r2r_threshold) - cmp_t_real = cuszp.compress(data_real, self.r2r_error, 'rel') - cmp_t_imag = cuszp.compress(data_imag, self.r2r_error, 'rel') - return (cmp_t_real, cmp_t_imag, data.shape, dtype) + cmp_t = cuszp.compress(data_view, self.r2r_error, 'rel') + return (cmp_t, shape, dtype) # return (cmp_bytes, num_elements_eff, isCuPy, data.shape, dtype, outSize_ptr.contents.value) def compress_size(self, ptr): #return ptr[4] - return ptr[0].nbytes + ptr[1].nbytes + return ptr[0].nbytes def decompress(self, obj): import cupy #cmp_bytes, num_elements_eff, shape, dtype, cmpsize = obj #decompressed_ptr = cuszp_device_decompress(num_elements_eff, cmp_bytes, cmpsize, self, dtype) - cmp_t_real, cmp_t_imag, shape, dtype = obj + cmp_t, shape, dtype = obj num_elements_decompressed = 1 for s in shape: num_elements_decompressed *= s - decomp_t_real = cuszp.decompress(cmp_t_real, num_elements_decompressed, cmp_t_real.nbytes, self.r2r_error, 'rel') - decomp_t_imag = cuszp.decompress(cmp_t_imag, num_elements_decompressed, cmp_t_imag.nbytes, self.r2r_error, 'rel') - decomp_t = decomp_t_real + 1j * decomp_t_imag + # Number of elements is twice because the shape is for complex numbers + num_elements_decompressed *= 2 + decomp_t_float = cuszp.decompress(cmp_t, num_elements_decompressed, cmp_t.nbytes, self.r2r_error, 'rel') + decomp_t_float = decomp_t_float.view(decomp_t_float.shape[0]//2, 2) + decomp_t = torch.view_as_complex(decomp_t_float) arr_cp = cupy.asarray(decomp_t) arr = cupy.reshape(arr_cp, shape) return arr