Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance simple_update! for MPS in the Canonical form #255

Merged
merged 26 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
24674da
First round of fixes on simple_update
jofrevalles Nov 19, 2024
5f158d2
Fix tests
jofrevalles Nov 20, 2024
99e6533
Add simple_update_2site! for MixedCanonical form
jofrevalles Nov 20, 2024
07a2af9
Format code
jofrevalles Nov 20, 2024
4c28ab9
Renormalize mps in truncation when recanonize kwarg is true
jofrevalles Nov 20, 2024
7612188
Enhance tests
jofrevalles Nov 20, 2024
967d711
Change default recanonize kwarg to false in truncate! function
jofrevalles Nov 20, 2024
cfd2a0f
Refactor normalize functions
jofrevalles Nov 21, 2024
4e6835d
Enhance normalize tests
jofrevalles Nov 21, 2024
b4d678e
Define LinearAlgebra.normalize for AbstractQuantum
jofrevalles Nov 21, 2024
64c77ef
Fix normalize functions for MPS
jofrevalles Nov 21, 2024
52b30c4
Enhance tests
jofrevalles Nov 21, 2024
4db9ccb
Fix normalize for Canonical MPS
jofrevalles Nov 21, 2024
bde48f7
Format code
jofrevalles Nov 21, 2024
c9cf7d7
Update normalization step on evolve
jofrevalles Nov 21, 2024
7232631
Change normalization to all lambdas for Canonical form
jofrevalles Nov 21, 2024
58f5e32
Format code
jofrevalles Nov 21, 2024
9fc52bd
Fix truncate by adding renormalize kwarg
jofrevalles Nov 21, 2024
6b76a71
Small enhancements on normalize! functions
jofrevalles Nov 21, 2024
2948ffd
Enhance tests
jofrevalles Nov 21, 2024
3ffd4e0
Change default kwargs in truncate
jofrevalles Nov 21, 2024
f844ac0
Fix evolve kwargs
jofrevalles Nov 21, 2024
a0b0267
Fix normalize! by putting replace! instead of inplace modification fo…
jofrevalles Nov 21, 2024
5f97d6b
Enhance tests
jofrevalles Nov 21, 2024
03e886b
Fix aesthetic suggestions, improve kwarg definition
jofrevalles Nov 22, 2024
2c46309
Update comment
jofrevalles Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,10 @@ end
Truncate the dimension of the virtual `bond` of a [`Canonical`](@ref) Tensor Network by keeping the `maxdim` largest
**Schmidt coefficients** or those larger than `threshold`, and then recanonizes the Tensor Network if `recanonize` is `true`.
"""
function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, recanonize=true)
function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, recanonize=false)
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false)

recanonize && canonize!(tn)
recanonize && canonize!(tn; normalize=true)
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved

return tn
end
Expand Down Expand Up @@ -387,11 +387,11 @@ function simple_update!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=noth

if nlanes(gate) == 1
return simple_update_1site!(ψ, gate)
elseif nlanes(gate) == 2
return simple_update_2site!(form(ψ), ψ, gate; threshold, maxdim, kwargs...)
else
throw(ArgumentError("Only 1-site and 2-site gates are currently supported"))
end

@assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites"

return simple_update!(form(ψ), ψ, gate; threshold, maxdim, kwargs...)
end

# TODO a lot of problems with merging... maybe we shouldn't merge manually
Expand Down Expand Up @@ -419,9 +419,16 @@ function simple_update_1site!(ψ::AbstractAnsatz, gate)
return contract!(ψ, contracting_index)
end

function simple_update_2site!(
::MixedCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false
)
return simple_update_2site!(NonCanonical(), ψ, gate; threshold, maxdim, renormalize)
end

# TODO remove `renormalize` argument?
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false)
@assert nlanes(gate) == 2 "Only 2-site gates are supported currently"
function simple_update_2site!(
::NonCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false
)
@assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites"

# shallow copy to avoid problems if errors in mid execution
Expand Down Expand Up @@ -455,16 +462,38 @@ function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=noth

# truncate virtual index
if any(!isnothing, (threshold, maxdim))
truncate!(ψ, bond; threshold, maxdim)
renormalize && normalize!(ψ, bond[1])
truncate!(ψ, collect(bond); threshold, maxdim)
renormalize && normalize!(ψ, bond)
end

return ψ
end

# TODO remove `renormalize` argument?
# TODO optimize correctly -> avoid recanonization + use lateral Λs
function simple_update!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, renormalize=false)
simple_update!(NonCanonical(), ψ, gate; threshold, maxdim, renormalize)
return canonize!(ψ)
function simple_update_2site!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, renormalize=false)
# Contract the exterior Λ tensors
sitel, siter = extrema(lanes(gate))
(0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) ||
throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))"))

Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel))
Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1)))

!isnothing(Λᵢ₋₁) && contract!(ψ; between=(Site(id(sitel) - 1), sitel), direction=:right, delete_Λ=false)
!isnothing(Λᵢ₊₁) && contract!(ψ; between=(siter, Site(id(siter) + 1)), direction=:left, delete_Λ=false)

simple_update_2site!(NonCanonical(), ψ, gate; threshold, maxdim, renormalize)

# contract the updated tensors with the inverse of Λᵢ and Λᵢ₊₂, to get the new Γ tensors
U, Vt = tensors(ψ; at=sitel), tensors(ψ; at=siter)
Γᵢ₋₁ =
isnothing(Λᵢ₋₁) ? U : contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)); atol=1e-32)), inds(Λᵢ₋₁)); dims=())
Γᵢ =
isnothing(Λᵢ₊₁) ? Vt : contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)); atol=1e-32)), inds(Λᵢ₊₁)), Vt; dims=())

# Update the tensors in the tensor network
replace!(ψ, tensors(ψ; at=sitel) => Γᵢ₋₁)
replace!(ψ, tensors(ψ; at=siter) => Γᵢ)

return ψ
end
24 changes: 20 additions & 4 deletions src/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ function canonize_site!(ψ::MPS, site::Site; direction::Symbol, method=:qr)
return ψ
end

function canonize!(ψ::AbstractMPO)
function canonize!(ψ::AbstractMPO; normalize=false)
Λ = Tensor[]

# right-to-left QR sweep, get right-canonical tensors
Expand All @@ -495,6 +495,7 @@ function canonize!(ψ::AbstractMPO)

# extract the singular values and contract them with the next tensor
Λᵢ = pop!(ψ, tensors(ψ; between=(Site(i), Site(i + 1))))
normalize && (Λᵢ ./= norm(Λᵢ))
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
Aᵢ₊₁ = tensors(ψ; at=Site(i + 1))
replace!(ψ, Aᵢ₊₁ => contract(Aᵢ₊₁, Λᵢ; dims=()))
push!(Λ, Λᵢ)
Expand Down Expand Up @@ -540,20 +541,35 @@ function mixed_canonize!(tn::AbstractMPO, orthog_center)
return tn
end

LinearAlgebra.normalize(ψ::AbstractMPO, site) = normalize!(copy(ψ), site)
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved

LinearAlgebra.normalize!(ψ::AbstractMPO; kwargs...) = normalize!(form(ψ), ψ; kwargs...)
LinearAlgebra.normalize!(ψ::AbstractMPO, site) = normalize!(form(ψ), ψ; at=site)
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved

function LinearAlgebra.normalize!(::NonCanonical, ψ::AbstractMPO; at=Site(nsites(ψ) ÷ 2))
tensor = tensors(ψ; at)
tensor ./= norm(ψ)
return ψ
end

LinearAlgebra.normalize!(ψ::AbstractMPO, site::Site) = normalize!(mixed_canonize!(ψ, site); at=site)

function LinearAlgebra.normalize!(config::MixedCanonical, ψ::AbstractMPO; at=config.orthog_center)
mixed_canonize!(ψ, at)
normalize!(tensors(ψ; at), 2)
return ψ
end

# TODO function LinearAlgebra.normalize!(::Canonical, ψ::AbstractMPO) end
function LinearAlgebra.normalize!(config::Canonical, ψ::AbstractMPO; at=nothing)
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
if isnothing(at) # Normalize all λ tensors
normalizer = (norm(ψ))^(1 / (nsites(ψ) - 1))
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved

for i in 1:(nsites(ψ) - 1)
λ = tensors(ψ; between=(Site(i), Site(i + 1)))
replace!(ψ, λ => λ ./ normalizer)
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
end
else
λ = tensors(ψ; between=at)
replace!(ψ, λ => λ ./ norm(ψ))
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
end

return ψ
end
2 changes: 2 additions & 0 deletions src/Quantum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,8 @@ function Base.merge!(a::AbstractQuantum, b::AbstractQuantum; reset=true)
return a
end

LinearAlgebra.normalize(ψ::AbstractQuantum; kwargs...) = normalize!(copy(ψ); kwargs...)

function LinearAlgebra.norm(ψ::AbstractQuantum, p::Real=2; kwargs...)
p == 2 || throw(ArgumentError("only L2-norm is implemented yet"))
return LinearAlgebra.norm2(ψ; kwargs...)
Expand Down
55 changes: 46 additions & 9 deletions test/MPS_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ using LinearAlgebra
ψ = rand(MPS; n=5, maxdim=16)
canonize!(ψ)

truncated = truncate(ψ, [site"2", site"3"]; maxdim=2)
truncated = truncate(ψ, [site"2", site"3"]; maxdim=2, recanonize=true)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2
@test Tenet.check_form(truncated)
end

@testset "MixedCanonical" begin
Expand Down Expand Up @@ -144,11 +145,42 @@ using LinearAlgebra
end

@testset "normalize!" begin
using LinearAlgebra: normalize!
using LinearAlgebra: normalize, normalize!

ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])
normalize!(ψ, Site(3))
@test isapprox(norm(ψ), 1.0)
@testset "NonCanonical" begin
ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])

normalized = normalize(ψ)
@test norm(normalized) ≈ 1.0

normalize!(ψ, Site(3))
@test norm(ψ) ≈ 1.0
end

@testset "MixedCanonical" begin
ψ = rand(MPS; n=5, maxdim=16)

# Perturb the state to make it non-normalized
t = tensors(ψ; at=site"3")
replace!(ψ, t => Tensor(rand(size(t)...), inds(t)))

normalized = normalize(ψ)
@test norm(normalized) ≈ 1.0

normalize!(ψ, Site(3))
@test norm(ψ) ≈ 1.0
end

@testset "Canonical" begin
ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])
canonize!(ψ)

normalized = normalize(ψ)
@test norm(normalized) ≈ 1.0

normalize!(ψ, (Site(3), Site(4)))
@test norm(ψ) ≈ 1.0
end
end

@testset "canonize_site!" begin
Expand Down Expand Up @@ -306,11 +338,16 @@ using LinearAlgebra
end

@testset "Canonical" begin
ψ = deepcopy(ψ)
ψ = rand(MPS; n=5, maxdim=20)
ϕ = deepcopy(ψ)
canonize!(ψ)
evolved = evolve!(deepcopy(ψ), gate; threshold=1e-14)
@test isapprox(contract(evolved), contract(ψ))
@test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)])
evolved = evolve!(deepcopy(ψ), gate)

@test Tenet.check_form(evolved)
@test isapprox(contract(evolved), contract(ϕ)) # Identity gate should not change the state

# Ensure that the original MixedCanonical state evolves into the same state as the canonicalized one
@test contract(evolve!(ϕ, gate; threshold=1e-14)) ≈ contract(ψ)
end
end
end
Expand Down
Loading