Skip to content

Commit

Permalink
Add evolve! for evolution of an MPS with an MPO (#264)
Browse files Browse the repository at this point in the history
* Add evolve! function for an MPS with an MPO

* Add tests for MPS-MPO evolution

* Remove unnecessary @show

* Add comment

* Fix kwarg handling on truncate! function

* Fix normalize! for Canonical form and small fixes on evolve! with MPO

* Extend evolve!(mps, mpo) tests

* Format code

* Remove unnecessary kwarg

* Refactor code so it is easier to extend for other canonical forms

* Fix comment

* Enhance tests

* Add reset_index kwarg

* Update tests

* Format code

* Remove unnecessary Quantum functions, add docstring for truncate_sweep!

* Format code

* Remove stale function
  • Loading branch information
jofrevalles committed Dec 2, 2024
1 parent 1cad26e commit 1bde3fc
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 14 deletions.
10 changes: 5 additions & 5 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,21 +325,21 @@ function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim,
return tn
end

function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, normalize=false)
function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; kwargs...)
# move orthogonality center to bond
mixed_canonize!(tn, bond)

return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=true, normalize)
return truncate!(NonCanonical(), tn, bond; compute_local_svd=true, kwargs...)
end

"""
truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, canonize=true)
truncate!(::Canonical, tn::AbstractAnsatz, bond; canonize=true, kwargs...)
Truncate the dimension of the virtual `bond` of a [`Canonical`](@ref) Tensor Network by keeping the `maxdim` largest
**Schmidt coefficients** or those larger than `threshold`, and then canonizes the Tensor Network if `canonize` is `true`.
"""
function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, canonize=false, normalize=false)
truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false, normalize)
function truncate!(::Canonical, tn::AbstractAnsatz, bond; canonize=true, kwargs...)
truncate!(NonCanonical(), tn, bond; compute_local_svd=false, kwargs...)

canonize && canonize!(tn)

Expand Down
147 changes: 138 additions & 9 deletions src/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ end
Check if the tensors in the mps are in the proper [`Form`](@ref).
"""
check_form(mps::AbstractMPO) = check_form(form(mps), mps)
check_form(mps::AbstractMPO; kwargs...) = check_form(form(mps), mps; kwargs...)

function check_form(config::MixedCanonical, mps::AbstractMPO)
function check_form(config::MixedCanonical, mps::AbstractMPO; atol=1e-12)
orthog_center = config.orthog_center

left, right = if orthog_center isa Site
Expand All @@ -144,23 +144,24 @@ function check_form(config::MixedCanonical, mps::AbstractMPO)

for i in 1:nsites(mps)
if i < left # Check left-canonical tensors
isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical"))
isisometry(mps, Site(i); dir=:right, atol) || throw(ArgumentError("Tensors are not left-canonical"))
elseif i > right # Check right-canonical tensors
isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical"))
isisometry(mps, Site(i); dir=:left, atol) || throw(ArgumentError("Tensors are not right-canonical"))
end
end

return true
end

function check_form(::Canonical, mps::AbstractMPO)
function check_form(::Canonical, mps::AbstractMPO; atol=1e-12)
for i in 1:nsites(mps)
if i > 1 && !isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right)
if i > 1 &&
!isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right, atol)
throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction."))
end

if i < nsites(mps) &&
!isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left)
!isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left, atol)
throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction."))
end
end
Expand Down Expand Up @@ -541,6 +542,133 @@ function mixed_canonize!(tn::AbstractMPO, orthog_center)
return tn
end

"""
evolve!(ψ::AbstractAnsatz, mpo::AbstractMPO; threshold=nothing, maxdim=nothing, normalize=true, reset_index=true)
Evolve the [`AbstractAnsatz`](@ref) `ψ` with the [`AbstractMPO`](@ref) `mpo` along the output indices of `ψ`.
If `threshold` or `maxdim` are not `nothing`, the tensors are truncated after each sweep at the proper value, and the
bond is normalized if `normalize=true`. If `reset_index=true`, the indices of the `ψ` are reset to the original ones.
"""
function evolve!(
ψ::AbstractAnsatz, mpo::AbstractMPO; threshold=nothing, maxdim=nothing, normalize=true, reset_index=true
)
original_sites = copy(Quantum(ψ).sites)
evolve!(form(ψ), ψ, mpo; threshold, maxdim, normalize)

if reset_index
resetindex!(ψ; init=ninds(TensorNetwork(ψ)) + 1)

replacements = [inds(ψ; at=site) => original_sites[site] for site in keys(original_sites)]
replace!(ψ, replacements)
end

return ψ
end

function evolve!(::NonCanonical, ψ::AbstractAnsatz, mpo::AbstractMPO; threshold, maxdim, normalize, kwargs...)
L = nsites(ψ)
Tenet.@reindex! outputs(ψ) => inputs(mpo)

right_inds = [inds(ψ; at=Site(i), dir=:right) for i in 1:(L - 1)]

for i in 1:L
contract_ind = inds(ψ; at=Site(i))
push!(ψ, tensors(mpo; at=Site(i)))
contract!(ψ, contract_ind)
merge!(Quantum(ψ).sites, Dict(Site(i) => inds(mpo; at=Site(i))))
end

# Group the parallel bond indices
for i in 1:(L - 1)
groupinds!(ψ, right_inds[i])
end

if !isnothing(threshold) || !isnothing(maxdim)
truncate_sweep!(form(ψ), ψ; threshold, maxdim, normalize)
else
normalize && normalize!(ψ)
end

return ψ
end

function evolve!(::MixedCanonical, ψ::AbstractAnsatz, mpo::AbstractMPO; normalize, kwargs...)
initial_form = form(ψ)
mixed_canonize!(ψ, Site(nsites(ψ))) # We convert all the tensors to left-canonical form

evolve!(NonCanonical(), ψ, mpo; normalize, kwargs...)

mixed_canonize!(ψ, initial_form.orthog_center)

return ψ
end

function evolve!(::Canonical, ψ::AbstractAnsatz, mpo::AbstractMPO; threshold, maxdim, normalize, kwargs...)
# We first join the λs to the Γs to get MixedCanonical(Site(1)) form
for i in 1:(nsites(ψ) - 1)
contract!(ψ; between=(Site(i), Site(i + 1)), direction=:right)
end

evolve!(NonCanonical(), ψ, mpo; threshold=nothing, maxdim=nothing, normalize=false, kwargs...) # set maxdim and threshold to nothing so we truncate from Canonical form

if !isnothing(threshold) || !isnothing(maxdim)
truncate_sweep!(Canonical(), ψ; threshold, maxdim, normalize)
else
normalize && canonize!(ψ; normalize)
end

return ψ
end

"""
truncate_sweep!
Do a right-to-left QR sweep on the [`AbstractMPO`](@ref) `ψ` and then left-to-right SVD sweep and truncate the tensors
according to the `threshold` or `maxdim` values. The bond is normalized if `normalize=true`.
"""
function truncate_sweep! end

function truncate_sweep!(::NonCanonical, ψ::AbstractMPO; threshold, maxdim, normalize)
for i in nsites(ψ):-1:2
canonize_site!(ψ, Site(i); direction=:left, method=:qr)
end

# left-to-right SVD sweep, get left-canonical tensors and singular values and truncate
for i in 1:(nsites(ψ) - 1)
canonize_site!(ψ, Site(i); direction=:right, method=:svd)

(!isnothing(threshold) || !isnothing(maxdim)) &&
truncate!(ψ, [Site(i), Site(i + 1)]; threshold, maxdim, normalize, compute_local_svd=false)

contract!(ψ; between=(Site(i), Site(i + 1)), direction=:right)
end

ψ.form = MixedCanonical(Site(nsites(ψ)))

return ψ
end

function truncate_sweep!(::MixedCanonical, ψ::AbstractMPO; threshold, maxdim, normalize)
truncate_sweep!(NonCanonical(), ψ; threshold, maxdim, normalize)
end

function truncate_sweep!(::Canonical, ψ::AbstractMPO; threshold, maxdim, normalize)
for i in nsites(ψ):-1:2
canonize_site!(ψ, Site(i); direction=:left, method=:qr)
end

# left-to-right SVD sweep, get left-canonical tensors and singular values and truncate
for i in 1:(nsites(ψ) - 1)
canonize_site!(ψ, Site(i); direction=:right, method=:svd)
(!isnothing(threshold) || !isnothing(maxdim)) &&
truncate!(ψ, [Site(i), Site(i + 1)]; threshold, maxdim, normalize, compute_local_svd=false)
end

canonize!(ψ)

return ψ
end

LinearAlgebra.normalize!::AbstractMPO; kwargs...) = normalize!(form(ψ), ψ; kwargs...)
LinearAlgebra.normalize!::AbstractMPO, at::Site) = normalize!(form(ψ), ψ; at)
LinearAlgebra.normalize!::AbstractMPO, bond::Base.AbstractVecOrTuple{Site}) = normalize!(form(ψ), ψ; bond)
Expand All @@ -564,14 +692,15 @@ function LinearAlgebra.normalize!(config::MixedCanonical, ψ::AbstractMPO; at=co
end

function LinearAlgebra.normalize!(config::Canonical, ψ::AbstractMPO; bond=nothing)
old_norm = norm(ψ)
if isnothing(bond) # Normalize all λ tensors
for i in 1:(nsites(ψ) - 1)
λ = tensors(ψ; between=(Site(i), Site(i + 1)))
replace!(ψ, λ => λ ./ norm(λ)^(1 / (nsites(ψ) - 1)))
replace!(ψ, λ => λ ./ old_norm^(1 / (nsites(ψ) - 1)))
end
else
λ = tensors(ψ; between=bond)
replace!(ψ, λ => λ ./ norm(λ))
replace!(ψ, λ => λ ./ old_norm)
end

return ψ
Expand Down
55 changes: 55 additions & 0 deletions test/MPS_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,61 @@ using LinearAlgebra
@test_throws ArgumentError Tenet.check_form(evolved)
end
end

@testset "MPO evolution" begin
ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)])
normalize!(ψ)
mpo = rand(MPO; n=5, maxdim=8)

ϕ_1 = deepcopy(ψ)
ϕ_2 = deepcopy(ψ)
ϕ_3 = deepcopy(ψ)

@testset "NonCanonical" begin
evolve!(ϕ_1, mpo)
@test length(tensors(ϕ_1)) == 5
@test norm(ϕ_1) 1.0

evolved = evolve!(deepcopy(ψ), mpo; maxdim=3)
@test all(x -> x 3, vcat([collect(t) for t in vec(size.(tensors(evolved)))]...))
@test norm(evolved) 1.0
end

@testset "Canonical" begin
canonize!(ϕ_2)
evolve!(ϕ_2, mpo)
@test length(tensors(ϕ_2)) == 5 + 4
@test form(ϕ_2) == Canonical()
@test Tenet.check_form(ϕ_2)

evolved = evolve!(deepcopy(canonize!(ψ)), mpo; maxdim=3)
@test all(x -> x 3, vcat([collect(t) for t in vec(size.(tensors(evolved)))]...))
@test form(evolved) == Canonical()
@test Tenet.check_form(evolved)
end

@testset "MixedCanonical" begin
mixed_canonize!(ϕ_3, site"3")
evolve!(ϕ_3, mpo)
@test length(tensors(ϕ_3)) == 5
@test form(ϕ_3) == MixedCanonical(Site(3))
@test norm(ϕ_3) 1.0
@test Tenet.check_form(ϕ_3)

evolved = evolve!(deepcopy(mixed_canonize!(ψ, site"3")), mpo; maxdim=3)
@test all(x -> x 3, vcat([collect(t) for t in vec(size.(tensors(evolved)))]...))
@test form(evolved) == MixedCanonical(Site(3))
@test norm(evolved) 1.0
@test Tenet.check_form(evolved)
end

t1 = contract(ϕ_1)
t2 = contract(ϕ_2)
t3 = contract(ϕ_3)

@test t1 t2 t3
@test only(overlap(ϕ_1, ϕ_2)) only(overlap(ϕ_1, ϕ_3)) only(overlap(ϕ_2, ϕ_3)) 1.0
end
end

# TODO rename when method is renamed
Expand Down

0 comments on commit 1bde3fc

Please sign in to comment.