diff --git a/src/Transformations.jl b/src/Transformations.jl index 779a7eef..f49c7fb8 100644 --- a/src/Transformations.jl +++ b/src/Transformations.jl @@ -179,10 +179,13 @@ function transform!(tn::TensorNetwork, config::Truncate) for (dim, index) in enumerate(inds(tensor)) index ∈ skip_inds && continue - zeroslices = iszero.(eachslice(tensor; dims=dim)) - any(zeroslices) || continue + # Use atol to determine small slices + small_slices = [maximum(abs.(s)) < config.atol for s in eachslice(tensor; dims=dim)] + any(small_slices) || continue - slice!(tn, index, count(!, zeroslices) == 1 ? findfirst(!, zeroslices) : findall(!, zeroslices)) + # Keep slices where the maximum absolute value is greater than or equal to atol + slices_to_keep = count(!, small_slices) == 1 ? findfirst(!, small_slices) : findall(!, small_slices) + slice!(tn, index, slices_to_keep) end end diff --git a/test/Transformations_test.jl b/test/Transformations_test.jl index 7e299ace..0864c418 100644 --- a/test/Transformations_test.jl +++ b/test/Transformations_test.jl @@ -232,33 +232,55 @@ using Tenet: Truncate @testset "range" begin + # Create tensor data with small values less than the default atol (1e-12) data = rand(3, 3, 3) - data[:, 1:2, :] .= 0 + data[:, 1:2, :] .= 1e-13 A = Tensor(data, (:i, :j, :k)) B = Tensor(rand(3, 3), (:j, :l)) C = Tensor(rand(3, 3), (:j, :m)) tn = TensorNetwork([A, B, C]) - reduced = transform(tn, Truncate) + # Apply Truncate transformation with default atol (1e-12) + reduced = transform(tn, Truncate()) + + # Test that index :j is removed because all its slices are below atol @test :j ∉ inds(reduced) @test contract(reduced) ≈ contract(tn) + + # Now, apply Truncate with a smaller atol (1e-14) so slices are not truncated + reduced_no_trunc = transform(tn, Truncate(; atol=1e-14)) + + # Test that index :j is still present + @test :j ∈ inds(reduced_no_trunc) + @test contract(reduced_no_trunc) ≈ contract(tn) end @testset "int" begin + # Create tensor data with one slice having small values less than default atol data = rand(3, 3, 3) - data[:, 2, :] .= 0 + data[:, 2, :] .= 1e-13 A = Tensor(data, (:i, :j, :k)) B = Tensor(rand(3, 3), (:j, :l)) C = Tensor(rand(3, 3), (:j, :m)) tn = TensorNetwork([A, B, C]) - reduced = transform(tn, Truncate) + # Apply Truncate transformation with default atol (1e-12) + reduced = transform(tn, Truncate()) + + # Test that size of index :j is reduced by 1 @test size(reduced, :j) == 2 @test contract(reduced) ≈ contract(tn) + + # Now, apply Truncate with a smaller atol (1e-14) so the slice is not truncated + reduced_no_trunc = transform(tn, Truncate(; atol=1e-14)) + + # Test that size of index :j remains the same + @test size(reduced_no_trunc, :j) == 3 + @test contract(reduced_no_trunc) ≈ contract(tn) end end