Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Truncate transformation to correclty use atol argument #250

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/Transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,13 @@ function transform!(tn::TensorNetwork, config::Truncate)
for (dim, index) in enumerate(inds(tensor))
index ∈ skip_inds && continue

zeroslices = iszero.(eachslice(tensor; dims=dim))
any(zeroslices) || continue
# Use atol to determine small slices
small_slices = [maximum(abs.(s)) < config.atol for s in eachslice(tensor; dims=dim)]
any(small_slices) || continue

slice!(tn, index, count(!, zeroslices) == 1 ? findfirst(!, zeroslices) : findall(!, zeroslices))
# Keep slices where the maximum absolute value is greater than or equal to atol
slices_to_keep = count(!, small_slices) == 1 ? findfirst(!, small_slices) : findall(!, small_slices)
slice!(tn, index, slices_to_keep)
end
end

Expand Down
30 changes: 26 additions & 4 deletions test/Transformations_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,33 +232,55 @@
using Tenet: Truncate

@testset "range" begin
# Create tensor data with small values less than the default atol (1e-12)
data = rand(3, 3, 3)
data[:, 1:2, :] .= 0
data[:, 1:2, :] .= 1e-13

A = Tensor(data, (:i, :j, :k))
B = Tensor(rand(3, 3), (:j, :l))
C = Tensor(rand(3, 3), (:j, :m))

tn = TensorNetwork([A, B, C])
reduced = transform(tn, Truncate)

# Apply Truncate transformation with default atol (1e-12)
reduced = transform(tn, Truncate())

# Test that index :j is removed because all its slices are below atol
@test :j ∉ inds(reduced)
@test contract(reduced) ≈ contract(tn)

# Now, apply Truncate with a smaller atol (1e-14) so slices are not truncated
reduced_no_trunc = transform(tn, Truncate(; atol=1e-14))

# Test that index :j is still present
@test :j ∈ inds(reduced_no_trunc)
@test contract(reduced_no_trunc) ≈ contract(tn)
end

@testset "int" begin
# Create tensor data with one slice having small values less than default atol
data = rand(3, 3, 3)
data[:, 2, :] .= 0
data[:, 2, :] .= 1e-13

A = Tensor(data, (:i, :j, :k))
B = Tensor(rand(3, 3), (:j, :l))
C = Tensor(rand(3, 3), (:j, :m))

tn = TensorNetwork([A, B, C])
reduced = transform(tn, Truncate)

# Apply Truncate transformation with default atol (1e-12)
reduced = transform(tn, Truncate())

# Test that size of index :j is reduced by 1
@test size(reduced, :j) == 2
@test contract(reduced) ≈ contract(tn)

# Now, apply Truncate with a smaller atol (1e-14) so the slice is not truncated
reduced_no_trunc = transform(tn, Truncate(; atol=1e-14))

# Test that size of index :j remains the same
@test size(reduced_no_trunc, :j) == 3
@test contract(reduced_no_trunc) ≈ contract(tn)
end
end

Expand Down
Loading