Skip to content

Commit

Permalink
Separate check of canonical forms in check_form functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jofrevalles committed Nov 14, 2024
1 parent 5d83e8f commit 21146df
Showing 1 changed file with 35 additions and 107 deletions.
142 changes: 35 additions & 107 deletions src/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,133 +98,61 @@ function MPS(::NonCanonical, arrays; order=defaultorder(MPS), check=true)
end

function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check=true)
@assert ndims(arrays[1]) == 2 "First array must have 2 dimensions"
@assert all(==(3) ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions"
@assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions"
issetequal(order, defaultorder(MPS)) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPS)))"))

n = length(arrays)
gen = IndexCounter()
symbols = [nextindex!(gen) for _ in 1:(2n)]

tn = TensorNetwork(
map(enumerate(arrays)) do (i, array)
_order = if i == 1
filter(x -> x != :l, order)
elseif i == n
filter(x -> x != :r, order)
else
order
end

inds = map(_order) do dir
if dir == :o
symbols[i]
elseif dir == :r
symbols[n + mod1(i, n)]
elseif dir == :l
symbols[n + mod1(i - 1, n)]
else
throw(ArgumentError("Invalid direction: $dir"))
end
end
Tensor(array, inds)
end,
)

sitemap = Dict(Site(i) => symbols[i] for i in 1:n)
qtn = Quantum(tn, sitemap)
graph = path_graph(n)
mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))])
lattice = Lattice(mapping, graph)
ansatz = Ansatz(qtn, lattice)
mps = MPS(ansatz, form)
mps = MPS(arrays; form=NonCanonical(), order, check)
mps.form = form

# Check that for site start to orthog_center-1 the tensors are left-canonical
if check
for i in 1:(id(form.orthog_center) - 1)
isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical"))
end

# Check that for site orthog_center+1 to end the tensors are right-canonical
for i in (id(form.orthog_center) + 1):nsites(mps)
isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical"))
end
end
# Check mixed canonical form
check && check_form(form, mps)

return mps
end

function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check=true)
@assert ndims(arrays[1]) == 2 "First array must have 2 dimensions"
@assert all(==(3) ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions"
@assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions"

@assert length(λ) == length(arrays) - 1 "Number of λ tensors must be one less than the number of arrays"
@assert all(==(1) ndims, λ) "All λ tensors must be Vectors"

issetequal(order, defaultorder(MPS)) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPS)))"))
mps = MPS(arrays; form=NonCanonical(), order, check)
mps.form = Canonical()

n = length(arrays)
gen = IndexCounter()
symbols = [nextindex!(gen) for _ in 1:(2n)]
# Create tensors from 'λ'
map(enumerate(λ)) do (i, array)
tensor = Tensor(array, (inds(mps; at=Site(i), dir=:right),))
push!(mps, tensor)
end

# Create tensors from 'arrays'
tensor_list = map(enumerate(arrays)) do (i, array)
_order = if i == 1
filter(x -> x != :l, order)
elseif i == n
filter(x -> x != :r, order)
else
order
end
# Check canonical form by contracting Γ and λ tensors and checking their orthogonality
check && check_form(Canonical(), mps)

inds = map(_order) do dir
if dir == :o
symbols[i]
elseif dir == :r
symbols[n + mod1(i, n)]
elseif dir == :l
symbols[n + mod1(i - 1, n)]
else
throw(ArgumentError("Invalid direction: $dir"))
end
end
Tensor(array, inds)
end
return mps
end

# Create tensors from 'λ'
lambda_tensors = map(enumerate(λ)) do (i, array)
Tensor(array, [symbols[n + mod1(i, n)]])
function check_form(::MixedCanonical, mps::AbstractMPO)
orthog_center = form(mps).orthog_center
for i in 1:nsites(mps)
if i < id(orthog_center) # Check left-canonical tensors
isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical"))
elseif i > id(orthog_center) # Check right-canonical tensors
isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical"))
end
end

tn = TensorNetwork(vcat(tensor_list, lambda_tensors))
sitemap = Dict(Site(i) => symbols[i] for i in 1:n)
qtn = Quantum(tn, sitemap)
graph = path_graph(n)
mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))])
lattice = Lattice(mapping, graph)
ansatz = Ansatz(qtn, lattice)
mps = MPS(ansatz, Canonical())
return true
end

# Check canonical form by contracting Γ and λ tensors and checking their orthogonality
if check
for i in 1:nsites(mps)
if i > 1
isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) ||
throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction."))
end
function check_form(::Canonical, mps::AbstractMPO)
for i in 1:nsites(mps)
if i > 1
isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) ||
throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction."))
end

if i < nsites(mps)
isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left) ||
throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction."))
end
if i < nsites(mps)
isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left) ||
throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction."))
end
end

return mps
return true
end

"""
Expand Down

0 comments on commit 21146df

Please sign in to comment.