From 278e757a8db16eb0ee80dc6a8076820f4c10dd54 Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Wed, 13 Nov 2024 15:07:56 +0100 Subject: [PATCH 1/9] Add Form-specific constructors in MPS --- src/MPS.jl | 138 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 136 insertions(+), 2 deletions(-) diff --git a/src/MPS.jl b/src/MPS.jl index 2a6b0799..935ebb4c 100644 --- a/src/MPS.jl +++ b/src/MPS.jl @@ -38,6 +38,9 @@ Base.zero(x::T) where {T<:Union{MPS,MPO}} = T(zero(Ansatz(x)), form(x)) defaultorder(::Type{<:AbstractMPS}) = (:o, :l, :r) defaultorder(::Type{<:AbstractMPO}) = (:o, :i, :l, :r) +MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS), form::Form=NonCanonical(), check_canonical_form = true) = MPS(form, arrays; order=order, check_canonical_form=check_canonical_form) +MPS(arrays::Vector{<:AbstractArray}, λ::Vector{<:AbstractArray}; order=defaultorder(MPS), form::Form=Canonical(), check_canonical_form = true) = MPS(form, arrays, λ; order=order, check_canonical_form=check_canonical_form) + """ MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) @@ -47,7 +50,7 @@ Create a [`NonCanonical`](@ref) [`MPS`](@ref) from a vector of arrays. - `order` The order of the indices in the arrays. Defaults to `(:o, :l, :r)`. """ -function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) +function MPS(::NonCanonical, arrays; order=defaultorder(MPS), check_canonical_form = true) @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" @@ -92,6 +95,136 @@ function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) return MPS(ansatz, NonCanonical()) end +function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check_canonical_form = true) + @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" + @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" + @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" + issetequal(order, defaultorder(MPS)) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPS)))")) + + n = length(arrays) + gen = IndexCounter() + symbols = [nextindex!(gen) for _ in 1:(2n)] + + tn = TensorNetwork( + map(enumerate(arrays)) do (i, array) + _order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) + else + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :r + symbols[n + mod1(i, n)] + elseif dir == :l + symbols[n + mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end, + ) + + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + qtn = Quantum(tn, sitemap) + graph = path_graph(n) + mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))]) + lattice = Lattice(mapping, graph) + ansatz = Ansatz(qtn, lattice) + mps = MPS(ansatz, form) + + # Check that for site start to orthog_center-1 the tensors are left-canonical + if check_canonical_form + for i in 1:id(form.orthog_center) - 1 + isisometry(mps, Site(i); dir = :right) || throw(ArgumentError("Tensors are not left-canonical")) + end + + # Check that for site orthog_center+1 to end the tensors are right-canonical + for i in id(form.orthog_center) + 1:nsites(mps) + isisometry(mps, Site(i); dir = :left) || throw(ArgumentError("Tensors are not right-canonical")) + end + end + + return mps +end + +function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check_canonical_form = true) + @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" + @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" + @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" + + @assert length(λ) == length(arrays) - 1 "Number of λ tensors must be one less than the number of arrays" + @assert all(==(1) ∘ ndims, λ) "All λ tensors must be Vectors" + + issetequal(order, defaultorder(MPS)) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPS)))")) + + n = length(arrays) + gen = IndexCounter() + symbols = [nextindex!(gen) for _ in 1:(2n)] + + # Create tensors from 'arrays' + tensor_list = map(enumerate(arrays)) do (i, array) + _order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) + else + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :r + symbols[n + mod1(i, n)] + elseif dir == :l + symbols[n + mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end + + # Create tensors from 'λ' + lambda_tensors = map(enumerate(λ)) do (i, array) + Tensor(array, [symbols[n + mod1(i, n)]]) + end + + tn = TensorNetwork(vcat(tensor_list, lambda_tensors)) + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + qtn = Quantum(tn, sitemap) + graph = path_graph(n) + mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))]) + lattice = Lattice(mapping, graph) + ansatz = Ansatz(qtn, lattice) + mps = MPS(ansatz, Canonical()) + + # Check canonical form by contracting Γ and λ tensors and checking their orthogonality + if check_canonical_form + for i in 1:nsites(mps) + if i > 1 + isisometry(contract(mps; between=(Site(i-1), Site(i)), direction=:right), Site(i); dir=:right) || + throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction.")) + end + + if i < nsites(mps) + isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left) || + throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction.")) + end + end + end + + return mps +end + """ MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO)) @@ -238,7 +371,8 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{MPS}; n, maxdim, eltype=Float arrays[1] = reshape(arrays[1], p, p) arrays[n] = reshape(arrays[n], p, p) - return MPS(arrays; order=(:l, :o, :r)) + return MPS(arrays; order=(:l, :o, :r), form=MixedCanonical(Site(0))) + # return MPS(arrays; order=(:l, :o, :r)) end # TODO different input/output physical dims From 731a57047bec589bc1db63b6278950d1067b8f34 Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Wed, 13 Nov 2024 15:08:17 +0100 Subject: [PATCH 2/9] Enhance description of MixedCanonical Form --- src/Ansatz.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Ansatz.jl b/src/Ansatz.jl index 91354a1a..6c7a1ae7 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -44,6 +44,9 @@ struct NonCanonical <: Form end MixedCanonical [`Form`](@ref) trait representing a [`AbstractAnsatz`](@ref) Tensor Network in mixed-canonical form. + + - The orthogonality center is a [`Site`](@ref) or a vector of [`Site`](@ref)s. The tensors to the +left of the orthogonality center are left-canonical and the tensors to the right are right-canonical. """ struct MixedCanonical <: Form orthog_center::Union{Site,Vector{Site}} From 4300ba0fa365404538a495925c492146e2c258ac Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Wed, 13 Nov 2024 15:29:22 +0100 Subject: [PATCH 3/9] Format code --- src/Ansatz.jl | 2 +- src/MPS.jl | 32 ++++++++++++++++++++++---------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/Ansatz.jl b/src/Ansatz.jl index 6c7a1ae7..c7a96904 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -46,7 +46,7 @@ struct NonCanonical <: Form end [`Form`](@ref) trait representing a [`AbstractAnsatz`](@ref) Tensor Network in mixed-canonical form. - The orthogonality center is a [`Site`](@ref) or a vector of [`Site`](@ref)s. The tensors to the -left of the orthogonality center are left-canonical and the tensors to the right are right-canonical. + left of the orthogonality center are left-canonical and the tensors to the right are right-canonical. """ struct MixedCanonical <: Form orthog_center::Union{Site,Vector{Site}} diff --git a/src/MPS.jl b/src/MPS.jl index 935ebb4c..583991b3 100644 --- a/src/MPS.jl +++ b/src/MPS.jl @@ -38,8 +38,20 @@ Base.zero(x::T) where {T<:Union{MPS,MPO}} = T(zero(Ansatz(x)), form(x)) defaultorder(::Type{<:AbstractMPS}) = (:o, :l, :r) defaultorder(::Type{<:AbstractMPO}) = (:o, :i, :l, :r) -MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS), form::Form=NonCanonical(), check_canonical_form = true) = MPS(form, arrays; order=order, check_canonical_form=check_canonical_form) -MPS(arrays::Vector{<:AbstractArray}, λ::Vector{<:AbstractArray}; order=defaultorder(MPS), form::Form=Canonical(), check_canonical_form = true) = MPS(form, arrays, λ; order=order, check_canonical_form=check_canonical_form) +function MPS( + arrays::Vector{<:AbstractArray}; order=defaultorder(MPS), form::Form=NonCanonical(), check_canonical_form=true +) + return MPS(form, arrays; order=order, check_canonical_form=check_canonical_form) +end +function MPS( + arrays::Vector{<:AbstractArray}, + λ::Vector{<:AbstractArray}; + order=defaultorder(MPS), + form::Form=Canonical(), + check_canonical_form=true, +) + return MPS(form, arrays, λ; order=order, check_canonical_form=check_canonical_form) +end """ MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) @@ -50,7 +62,7 @@ Create a [`NonCanonical`](@ref) [`MPS`](@ref) from a vector of arrays. - `order` The order of the indices in the arrays. Defaults to `(:o, :l, :r)`. """ -function MPS(::NonCanonical, arrays; order=defaultorder(MPS), check_canonical_form = true) +function MPS(::NonCanonical, arrays; order=defaultorder(MPS), check_canonical_form=true) @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" @@ -95,7 +107,7 @@ function MPS(::NonCanonical, arrays; order=defaultorder(MPS), check_canonical_fo return MPS(ansatz, NonCanonical()) end -function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check_canonical_form = true) +function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check_canonical_form=true) @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" @@ -141,20 +153,20 @@ function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check_canoni # Check that for site start to orthog_center-1 the tensors are left-canonical if check_canonical_form - for i in 1:id(form.orthog_center) - 1 - isisometry(mps, Site(i); dir = :right) || throw(ArgumentError("Tensors are not left-canonical")) + for i in 1:(id(form.orthog_center) - 1) + isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical")) end # Check that for site orthog_center+1 to end the tensors are right-canonical - for i in id(form.orthog_center) + 1:nsites(mps) - isisometry(mps, Site(i); dir = :left) || throw(ArgumentError("Tensors are not right-canonical")) + for i in (id(form.orthog_center) + 1):nsites(mps) + isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical")) end end return mps end -function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check_canonical_form = true) +function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check_canonical_form=true) @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" @@ -211,7 +223,7 @@ function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check_canonical_f if check_canonical_form for i in 1:nsites(mps) if i > 1 - isisometry(contract(mps; between=(Site(i-1), Site(i)), direction=:right), Site(i); dir=:right) || + isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) || throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction.")) end From c9d8421962020a82e47de6c5c32bf727edc11221 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:37:09 +0100 Subject: [PATCH 4/9] Apply @mofeing suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/MPS.jl | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/MPS.jl b/src/MPS.jl index 583991b3..abbedde9 100644 --- a/src/MPS.jl +++ b/src/MPS.jl @@ -38,19 +38,9 @@ Base.zero(x::T) where {T<:Union{MPS,MPO}} = T(zero(Ansatz(x)), form(x)) defaultorder(::Type{<:AbstractMPS}) = (:o, :l, :r) defaultorder(::Type{<:AbstractMPO}) = (:o, :i, :l, :r) -function MPS( - arrays::Vector{<:AbstractArray}; order=defaultorder(MPS), form::Form=NonCanonical(), check_canonical_form=true -) - return MPS(form, arrays; order=order, check_canonical_form=check_canonical_form) -end -function MPS( - arrays::Vector{<:AbstractArray}, - λ::Vector{<:AbstractArray}; - order=defaultorder(MPS), - form::Form=Canonical(), - check_canonical_form=true, -) - return MPS(form, arrays, λ; order=order, check_canonical_form=check_canonical_form) +MPS(arrays; form::Form=NonCanonical(), kwargs...) = MPS(form, arrays; kwargs...) +function MPS(arrays, λ; form::Form=Canonical(), kwargs...) + return MPS(form, arrays, λ; kwargs...) end """ From 54ad374136b38f69b60d9f5439c60a6ff35fdc93 Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Thu, 14 Nov 2024 09:41:09 +0100 Subject: [PATCH 5/9] Update code to match Sergio's suggestions --- src/MPS.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/MPS.jl b/src/MPS.jl index abbedde9..377dfb67 100644 --- a/src/MPS.jl +++ b/src/MPS.jl @@ -52,7 +52,7 @@ Create a [`NonCanonical`](@ref) [`MPS`](@ref) from a vector of arrays. - `order` The order of the indices in the arrays. Defaults to `(:o, :l, :r)`. """ -function MPS(::NonCanonical, arrays; order=defaultorder(MPS), check_canonical_form=true) +function MPS(::NonCanonical, arrays; order=defaultorder(MPS), check=true) @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" @@ -97,7 +97,7 @@ function MPS(::NonCanonical, arrays; order=defaultorder(MPS), check_canonical_fo return MPS(ansatz, NonCanonical()) end -function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check_canonical_form=true) +function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check=true) @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" @@ -142,7 +142,7 @@ function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check_canoni mps = MPS(ansatz, form) # Check that for site start to orthog_center-1 the tensors are left-canonical - if check_canonical_form + if check for i in 1:(id(form.orthog_center) - 1) isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical")) end @@ -156,7 +156,7 @@ function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check_canoni return mps end -function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check_canonical_form=true) +function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check=true) @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" @@ -210,7 +210,7 @@ function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check_canonical_f mps = MPS(ansatz, Canonical()) # Check canonical form by contracting Γ and λ tensors and checking their orthogonality - if check_canonical_form + if check for i in 1:nsites(mps) if i > 1 isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) || From 5d83e8f447ebb3745da7979217afa57d2f6f746b Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Thu, 14 Nov 2024 09:49:23 +0100 Subject: [PATCH 6/9] Update rand MPS constructor docstring --- src/MPS.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MPS.jl b/src/MPS.jl index 377dfb67..7eec23da 100644 --- a/src/MPS.jl +++ b/src/MPS.jl @@ -339,7 +339,7 @@ Base.adjoint(tn::T) where {T<:AbstractMPO} = T(adjoint(Ansatz(tn)), form(tn)) """ Base.rand(rng::Random.AbstractRNG, ::Type{MPS}; n, maxdim, eltype=Float64, physdim=2) -Create a random [`MPS`](@ref) Tensor Network. +Create a random [`MPS`](@ref) Tensor Network where all tensors are left-canonical (MixedCanonical(Site(0))). In order to avoid norm explosion issues, the tensors are orthogonalized by QR factorization so its normalized and mixed canonized to the last site. # Keyword Arguments From 21146df07d07ef2597903331ac57f816c0ebe37a Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Thu, 14 Nov 2024 10:08:00 +0100 Subject: [PATCH 7/9] Separate check of canonical forms in check_form functions --- src/MPS.jl | 142 +++++++++++++---------------------------------------- 1 file changed, 35 insertions(+), 107 deletions(-) diff --git a/src/MPS.jl b/src/MPS.jl index 7eec23da..a2788f45 100644 --- a/src/MPS.jl +++ b/src/MPS.jl @@ -98,133 +98,61 @@ function MPS(::NonCanonical, arrays; order=defaultorder(MPS), check=true) end function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check=true) - @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" - @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" - @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" - issetequal(order, defaultorder(MPS)) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPS)))")) - - n = length(arrays) - gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:(2n)] - - tn = TensorNetwork( - map(enumerate(arrays)) do (i, array) - _order = if i == 1 - filter(x -> x != :l, order) - elseif i == n - filter(x -> x != :r, order) - else - order - end - - inds = map(_order) do dir - if dir == :o - symbols[i] - elseif dir == :r - symbols[n + mod1(i, n)] - elseif dir == :l - symbols[n + mod1(i - 1, n)] - else - throw(ArgumentError("Invalid direction: $dir")) - end - end - Tensor(array, inds) - end, - ) - - sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - qtn = Quantum(tn, sitemap) - graph = path_graph(n) - mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))]) - lattice = Lattice(mapping, graph) - ansatz = Ansatz(qtn, lattice) - mps = MPS(ansatz, form) + mps = MPS(arrays; form=NonCanonical(), order, check) + mps.form = form - # Check that for site start to orthog_center-1 the tensors are left-canonical - if check - for i in 1:(id(form.orthog_center) - 1) - isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical")) - end - - # Check that for site orthog_center+1 to end the tensors are right-canonical - for i in (id(form.orthog_center) + 1):nsites(mps) - isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical")) - end - end + # Check mixed canonical form + check && check_form(form, mps) return mps end function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check=true) - @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" - @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" - @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" - @assert length(λ) == length(arrays) - 1 "Number of λ tensors must be one less than the number of arrays" @assert all(==(1) ∘ ndims, λ) "All λ tensors must be Vectors" - issetequal(order, defaultorder(MPS)) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPS)))")) + mps = MPS(arrays; form=NonCanonical(), order, check) + mps.form = Canonical() - n = length(arrays) - gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:(2n)] + # Create tensors from 'λ' + map(enumerate(λ)) do (i, array) + tensor = Tensor(array, (inds(mps; at=Site(i), dir=:right),)) + push!(mps, tensor) + end - # Create tensors from 'arrays' - tensor_list = map(enumerate(arrays)) do (i, array) - _order = if i == 1 - filter(x -> x != :l, order) - elseif i == n - filter(x -> x != :r, order) - else - order - end + # Check canonical form by contracting Γ and λ tensors and checking their orthogonality + check && check_form(Canonical(), mps) - inds = map(_order) do dir - if dir == :o - symbols[i] - elseif dir == :r - symbols[n + mod1(i, n)] - elseif dir == :l - symbols[n + mod1(i - 1, n)] - else - throw(ArgumentError("Invalid direction: $dir")) - end - end - Tensor(array, inds) - end + return mps +end - # Create tensors from 'λ' - lambda_tensors = map(enumerate(λ)) do (i, array) - Tensor(array, [symbols[n + mod1(i, n)]]) +function check_form(::MixedCanonical, mps::AbstractMPO) + orthog_center = form(mps).orthog_center + for i in 1:nsites(mps) + if i < id(orthog_center) # Check left-canonical tensors + isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical")) + elseif i > id(orthog_center) # Check right-canonical tensors + isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical")) + end end - tn = TensorNetwork(vcat(tensor_list, lambda_tensors)) - sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - qtn = Quantum(tn, sitemap) - graph = path_graph(n) - mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))]) - lattice = Lattice(mapping, graph) - ansatz = Ansatz(qtn, lattice) - mps = MPS(ansatz, Canonical()) + return true +end - # Check canonical form by contracting Γ and λ tensors and checking their orthogonality - if check - for i in 1:nsites(mps) - if i > 1 - isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) || - throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction.")) - end +function check_form(::Canonical, mps::AbstractMPO) + for i in 1:nsites(mps) + if i > 1 + isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) || + throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction.")) + end - if i < nsites(mps) - isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left) || - throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction.")) - end + if i < nsites(mps) + isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left) || + throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction.")) end end - return mps + return true end """ From a92aca2ca8216276e986be757c6421853dde6516 Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Fri, 15 Nov 2024 09:03:06 +0100 Subject: [PATCH 8/9] Small changes in constructor functions --- src/MPS.jl | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/MPS.jl b/src/MPS.jl index a2788f45..6d4c3119 100644 --- a/src/MPS.jl +++ b/src/MPS.jl @@ -102,7 +102,7 @@ function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check=true) mps.form = form # Check mixed canonical form - check && check_form(form, mps) + check && check_form(mps) return mps end @@ -121,13 +121,15 @@ function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check=true) end # Check canonical form by contracting Γ and λ tensors and checking their orthogonality - check && check_form(Canonical(), mps) + check && check_form(mps) return mps end -function check_form(::MixedCanonical, mps::AbstractMPO) - orthog_center = form(mps).orthog_center +check_form(mps::AbstractMPO) = check_form(form(mps), mps) + +function check_form(config::MixedCanonical, mps::AbstractMPO) + orthog_center = config.orthog_center for i in 1:nsites(mps) if i < id(orthog_center) # Check left-canonical tensors isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical")) @@ -142,13 +144,13 @@ end function check_form(::Canonical, mps::AbstractMPO) for i in 1:nsites(mps) if i > 1 - isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) || - throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction.")) + !isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) + throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction.")) end - if i < nsites(mps) - isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left) || - throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction.")) + if i < nsites(mps) && + !isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left) + throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction.")) end end @@ -267,8 +269,8 @@ Base.adjoint(tn::T) where {T<:AbstractMPO} = T(adjoint(Ansatz(tn)), form(tn)) """ Base.rand(rng::Random.AbstractRNG, ::Type{MPS}; n, maxdim, eltype=Float64, physdim=2) -Create a random [`MPS`](@ref) Tensor Network where all tensors are left-canonical (MixedCanonical(Site(0))). -In order to avoid norm explosion issues, the tensors are orthogonalized by QR factorization so its normalized and mixed canonized to the last site. +Create a random [`MPS`](@ref) Tensor Network in the MixedCanonical form where all tensors are right-canonical (ortogonality +center at the first site). In order to avoid norm explosion issues, the tensors are orthogonalized by LQ factorization. # Keyword Arguments From b1e73c6e55b1df2307bd09127bc9bae0534cac73 Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Fri, 15 Nov 2024 09:03:56 +0100 Subject: [PATCH 9/9] Fix typo --- src/MPS.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/MPS.jl b/src/MPS.jl index 6d4c3119..4b3b6ed2 100644 --- a/src/MPS.jl +++ b/src/MPS.jl @@ -303,8 +303,7 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{MPS}; n, maxdim, eltype=Float arrays[1] = reshape(arrays[1], p, p) arrays[n] = reshape(arrays[n], p, p) - return MPS(arrays; order=(:l, :o, :r), form=MixedCanonical(Site(0))) - # return MPS(arrays; order=(:l, :o, :r)) + return MPS(arrays; order=(:l, :o, :r), form=MixedCanonical(Site(1))) end # TODO different input/output physical dims