Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Form-specific constructors for MPS #248

Merged
merged 9 commits into from
Nov 15, 2024
3 changes: 3 additions & 0 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ struct NonCanonical <: Form end
MixedCanonical

[`Form`](@ref) trait representing a [`AbstractAnsatz`](@ref) Tensor Network in mixed-canonical form.

- The orthogonality center is a [`Site`](@ref) or a vector of [`Site`](@ref)s. The tensors to the
left of the orthogonality center are left-canonical and the tensors to the right are right-canonical.
"""
struct MixedCanonical <: Form
orthog_center::Union{Site,Vector{Site}}
Expand Down
73 changes: 69 additions & 4 deletions src/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ Base.zero(x::T) where {T<:Union{MPS,MPO}} = T(zero(Ansatz(x)), form(x))
defaultorder(::Type{<:AbstractMPS}) = (:o, :l, :r)
defaultorder(::Type{<:AbstractMPO}) = (:o, :i, :l, :r)

MPS(arrays; form::Form=NonCanonical(), kwargs...) = MPS(form, arrays; kwargs...)
function MPS(arrays, λ; form::Form=Canonical(), kwargs...)
return MPS(form, arrays, λ; kwargs...)
end

"""
MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS))

Expand All @@ -47,7 +52,7 @@ Create a [`NonCanonical`](@ref) [`MPS`](@ref) from a vector of arrays.

- `order` The order of the indices in the arrays. Defaults to `(:o, :l, :r)`.
"""
function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS))
function MPS(::NonCanonical, arrays; order=defaultorder(MPS), check=true)
@assert ndims(arrays[1]) == 2 "First array must have 2 dimensions"
@assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions"
@assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions"
Expand Down Expand Up @@ -92,6 +97,66 @@ function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS))
return MPS(ansatz, NonCanonical())
end

function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check=true)
mps = MPS(arrays; form=NonCanonical(), order, check)
mps.form = form

# Check mixed canonical form
check && check_form(mps)

return mps
end

function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check=true)
@assert length(λ) == length(arrays) - 1 "Number of λ tensors must be one less than the number of arrays"
@assert all(==(1) ∘ ndims, λ) "All λ tensors must be Vectors"

mps = MPS(arrays; form=NonCanonical(), order, check)
mps.form = Canonical()

# Create tensors from 'λ'
map(enumerate(λ)) do (i, array)
tensor = Tensor(array, (inds(mps; at=Site(i), dir=:right),))
push!(mps, tensor)
end

# Check canonical form by contracting Γ and λ tensors and checking their orthogonality
check && check_form(mps)

return mps
end

check_form(mps::AbstractMPO) = check_form(form(mps), mps)

function check_form(config::MixedCanonical, mps::AbstractMPO)
orthog_center = config.orthog_center
for i in 1:nsites(mps)
if i < id(orthog_center) # Check left-canonical tensors
isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical"))
elseif i > id(orthog_center) # Check right-canonical tensors
isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical"))
end
end

return true
end

function check_form(::Canonical, mps::AbstractMPO)
for i in 1:nsites(mps)
if i > 1
!isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right)
throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction."))
end

if i < nsites(mps) &&
!isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left)
throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction."))
end
end

return true
end

"""
MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO))

Expand Down Expand Up @@ -204,8 +269,8 @@ Base.adjoint(tn::T) where {T<:AbstractMPO} = T(adjoint(Ansatz(tn)), form(tn))
"""
Base.rand(rng::Random.AbstractRNG, ::Type{MPS}; n, maxdim, eltype=Float64, physdim=2)

Create a random [`MPS`](@ref) Tensor Network.
In order to avoid norm explosion issues, the tensors are orthogonalized by QR factorization so its normalized and mixed canonized to the last site.
Create a random [`MPS`](@ref) Tensor Network in the MixedCanonical form where all tensors are right-canonical (ortogonality
center at the first site). In order to avoid norm explosion issues, the tensors are orthogonalized by LQ factorization.

# Keyword Arguments

Expand Down Expand Up @@ -238,7 +303,7 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{MPS}; n, maxdim, eltype=Float
arrays[1] = reshape(arrays[1], p, p)
arrays[n] = reshape(arrays[n], p, p)

return MPS(arrays; order=(:l, :o, :r))
return MPS(arrays; order=(:l, :o, :r), form=MixedCanonical(Site(1)))
end

# TODO different input/output physical dims
Expand Down
Loading