diff --git a/Project.toml b/Project.toml index c3f02be2d..b9ae892c3 100644 --- a/Project.toml +++ b/Project.toml @@ -8,8 +8,10 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188" EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" KeywordDispatch = "5888135b-5456-5c80-a1b6-c91ef8180460" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" @@ -25,7 +27,6 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" -Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7" ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" diff --git a/ext/TenetAdaptExt.jl b/ext/TenetAdaptExt.jl index 6facff448..534355bb8 100644 --- a/ext/TenetAdaptExt.jl +++ b/ext/TenetAdaptExt.jl @@ -7,7 +7,13 @@ Adapt.adapt_structure(to, x::Tensor) = Tensor(adapt(to, parent(x)), inds(x)) Adapt.adapt_structure(to, x::TensorNetwork) = TensorNetwork(adapt.(Ref(to), tensors(x))) Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites) -Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Quantum(x))) -Adapt.adapt_structure(to, x::Chain) = Chain(adapt(to, Quantum(x)), boundary(x)) +Adapt.adapt_structure(to, x::Ansatz) = Ansatz(adapt(to, Quantum(x)), Tenet.lattice(x)) + +Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Ansatz(x))) +Adapt.adapt_structure(to, x::Dense) = Dense(adapt(to, Ansatz(x))) +Adapt.adapt_structure(to, x::MPS) = MPS(adapt(to, Ansatz(x)), form(x)) +Adapt.adapt_structure(to, x::MPO) = MPO(adapt(to, Ansatz(x)), form(x)) +Adapt.adapt_structure(to, x::PEPS) = PEPS(adapt(to, Ansatz(x)), form(x)) +Adapt.adapt_structure(to, x::PEPO) = PEPO(adapt(to, Ansatz(x)), form(x)) end diff --git a/ext/TenetChainRulesCoreExt/frules.jl b/ext/TenetChainRulesCoreExt/frules.jl index 941a723b6..f9a6bc76a 100644 --- a/ext/TenetChainRulesCoreExt/frules.jl +++ b/ext/TenetChainRulesCoreExt/frules.jl @@ -1,16 +1,39 @@ +using Tenet: AbstractTensorNetwork, AbstractQuantum + # `Tensor` constructor ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds) = T(data, inds), T(Δ, inds) # `TensorNetwork` constructor ChainRulesCore.frule((_, Δ), ::Type{TensorNetwork}, tensors) = TensorNetwork(tensors), TensorNetworkTangent(Δ) +# `Quantum` constructor +function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites) + return Quantum(x, sites), Tangent{Quantum}(; tn=ẋ, sites=NoTangent()) +end + +# `Ansatz` constructor +function ChainRulesCore.frule((_, ẋ), ::Type{Ansatz}, x::Quantum, lattice) + return Ansatz(x, lattice), Tangent{Ansatz}(; tn=ẋ, lattice=NoTangent()) +end + +# `AbstractAnsatz`-subtype constructors +ChainRulesCore.frule((_, ẋ), ::Type{Product}, x::Ansatz) = Product(x), Tangent{Product}(; tn=ẋ) +ChainRulesCore.frule((_, ẋ), ::Type{Dense}, x::Ansatz) = Dense(x, form), Tangent{Dense}(; tn=ẋ) +ChainRulesCore.frule((_, ẋ), ::Type{MPS}, x::Ansatz, form) = MPS(x, form), Tangent{MPS}(; tn=ẋ, lattice=NoTangent()) +ChainRulesCore.frule((_, ẋ), ::Type{MPO}, x::Ansatz, form) = MPO(x, form), Tangent{MPO}(; tn=ẋ, lattice=NoTangent()) +function ChainRulesCore.frule((_, ẋ), ::Type{PEPS}, x::Ansatz, form) + return PEPS(x, form), Tangent{PEPS}(; tn=ẋ, lattice=NoTangent()) +end + # `Base.conj` methods ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::Tensor) = conj(tn), conj(Δ) -ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::TensorNetwork) = conj(tn), conj(Δ) +ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::AbstractTensorNetwork) = conj(tn), conj(Δ) # `Base.merge` methods -ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(Base.merge), a::TensorNetwork, b::TensorNetwork) = merge(a, b), merge(ȧ, ḃ) +function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(Base.merge), a::AbstractTensorNetwork, b::AbstractTensorNetwork) + return merge(a, b), merge(ȧ, ḃ) +end # `contract` methods function ChainRulesCore.frule((_, ẋ), ::typeof(contract), x::Tensor; kwargs...) @@ -22,15 +45,3 @@ function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(contract), a::Tensor, b::T ċ = contract(ȧ, b; kwargs...) + contract(a, ḃ; kwargs...) return c, ċ end - -function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites) - y = Quantum(x, sites) - ẏ = Tangent{Quantum}(; tn=ẋ) - return y, ẏ -end - -ChainRulesCore.frule((_, ẋ), ::Type{T}, x::Quantum) where {T<:Ansatz} = T(x), Tangent{T}(; super=ẋ) - -function ChainRulesCore.frule((_, ẋ, _), ::Type{T}, x::Quantum, boundary) where {T<:Ansatz} - return T(x, boundary), Tangent{T}(; super=ẋ, boundary=NoTangent()) -end diff --git a/ext/TenetChainRulesCoreExt/projectors.jl b/ext/TenetChainRulesCoreExt/projectors.jl index acd488fad..299a3aaab 100644 --- a/ext/TenetChainRulesCoreExt/projectors.jl +++ b/ext/TenetChainRulesCoreExt/projectors.jl @@ -36,8 +36,5 @@ end ChainRulesCore.ProjectTo(x::Quantum) = ProjectTo{Quantum}(; tn=ProjectTo(TensorNetwork(x)), sites=x.sites) (projector::ProjectTo{Quantum})(Δ) = Quantum(projector.tn(Δ), projector.sites) -ChainRulesCore.ProjectTo(x::T) where {T<:Ansatz} = ProjectTo{T}(; super=ProjectTo(Quantum(x))) -(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Ansatz} = T(projector.super(Δ.super), Δ.boundary) - -# NOTE edge case: `Product` has no `boundary`. should it? -(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Product} = T(projector.super(Δ.super)) +ChainRulesCore.ProjectTo(x::Ansatz) = ProjectTo{Ansatz}(; tn=ProjectTo(Quantum(x)), lattice=x.lattice) +(projector::ProjectTo{Ansatz})(Δ) = Ansatz(projector.tn(Δ), Δ.lattice) diff --git a/ext/TenetChainRulesCoreExt/rrules.jl b/ext/TenetChainRulesCoreExt/rrules.jl index 9992c8774..f6c89d055 100644 --- a/ext/TenetChainRulesCoreExt/rrules.jl +++ b/ext/TenetChainRulesCoreExt/rrules.jl @@ -9,6 +9,38 @@ TensorNetwork_pullback(Δ::TensorNetworkTangent) = (NoTangent(), tensors(Δ)) TensorNetwork_pullback(Δ::AbstractThunk) = TensorNetwork_pullback(unthunk(Δ)) ChainRulesCore.rrule(::Type{TensorNetwork}, tensors) = TensorNetwork(tensors), TensorNetwork_pullback +# `Quantum` constructor +Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +Quantum_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent()) +Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback + +# `Ansatz` constructor +Ansatz_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{Ansatz}, x::Quantum, lattice) = Ansatz(x, lattice), Ansatz_pullback + +# `AbstractAnsatz`-subtype constructors +Product_pullback(ȳ) = (NoTangent(), ȳ.tn) +Product_pullback(ȳ::AbstractThunk) = Product_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{Product}, x::Ansatz) = Product(x), Product_pullback + +Dense_pullback(ȳ) = (NoTangent(), ȳ.tn) +Dense_pullback(ȳ::AbstractThunk) = Dense_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{Dense}, x::Ansatz) = Dense(x), Dense_pullback + +MPS_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +MPS_pullback(ȳ::AbstractThunk) = MPS_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{MPS}, x::Ansatz, form) = MPS(x, form), MPS_pullback + +MPO_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +MPO_pullback(ȳ::AbstractThunk) = MPO_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{MPO}, x::Ansatz, form) = MPO(x, form), MPO_pullback + +PEPS_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +PEPS_pullback(ȳ::AbstractThunk) = PEPS_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{PEPS}, x::Ansatz, form) = PEPS(x, form), PEPS_pullback + # `Base.conj` methods conj_pullback(Δ::Tensor) = (NoTangent(), conj(Δ)) conj_pullback(Δ::Tangent{Tensor}) = (NoTangent(), conj(Δ)) @@ -93,33 +125,6 @@ function ChainRulesCore.rrule(::typeof(contract), a::Tensor, b::Tensor; kwargs.. return c, contract_pullback end -Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) -Quantum_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent()) -Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ)) -ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback - -Ansatz_pullback(ȳ) = (NoTangent(), ȳ.super) -Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ)) -function ChainRulesCore.rrule(::Type{T}, x::Quantum) where {T<:Ansatz} - y = T(x) - return y, Ansatz_pullback -end - -Ansatz_boundary_pullback(ȳ) = (NoTangent(), ȳ.super, NoTangent()) -Ansatz_boundary_pullback(ȳ::AbstractThunk) = Ansatz_boundary_pullback(unthunk(ȳ)) -function ChainRulesCore.rrule(::Type{T}, x::Quantum, boundary) where {T<:Ansatz} - return T(x, boundary), Ansatz_boundary_pullback -end - -Ansatz_from_arrays_pullback(ȳ) = (NoTangent(), NoTangent(), NoTangent(), parent.(tensors(ȳ.super.tn))) -Ansatz_from_arrays_pullback(ȳ::AbstractThunk) = Ansatz_from_arrays_pullback(unthunk(ȳ)) -function ChainRulesCore.rrule( - ::Type{T}, socket::Tenet.Socket, boundary::Tenet.Boundary, arrays; kwargs... -) where {T<:Ansatz} - y = T(socket, boundary, arrays; kwargs...) - return y, Ansatz_from_arrays_pullback -end - copy_pullback(ȳ) = (NoTangent(), ȳ) copy_pullback(ȳ::AbstractThunk) = unthunk(ȳ) function ChainRulesCore.rrule(::typeof(copy), x::Quantum) diff --git a/ext/TenetChainRulesTestUtilsExt.jl b/ext/TenetChainRulesTestUtilsExt.jl index 1565387e0..ab3decd49 100644 --- a/ext/TenetChainRulesTestUtilsExt.jl +++ b/ext/TenetChainRulesTestUtilsExt.jl @@ -6,14 +6,16 @@ using Tenet using ChainRulesCore using ChainRulesTestUtils using Random +using Graphs +using MetaGraphsNext const TensorNetworkTangent = Base.get_extension(Tenet, :TenetChainRulesCoreExt).TensorNetworkTangent -function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Vector{T}) where {T<:Tensor} - if isempty(x) +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Vector{T}) where {T<:Tensor} + if isempty(tn) return Vector{T}() else - @invoke rand_tangent(rng::AbstractRNG, x::AbstractArray) + @invoke rand_tangent(rng::AbstractRNG, tn::AbstractArray) end end @@ -21,12 +23,25 @@ function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::TensorNetwork) return TensorNetworkTangent(Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)]) end -function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Quantum) - return Tangent{Quantum}(; tn=rand_tangent(rng, TensorNetwork(x)), sites=NoTangent()) +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Quantum) + return Tangent{Quantum}(; tn=rand_tangent(rng, TensorNetwork(tn)), sites=NoTangent()) end -# WARN type-piracy -# NOTE used in `Quantum` constructor -ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::Dict{<:Site,Symbol}) = NoTangent() +# WARN type-piracy, used in `Quantum` constructor +ChainRulesTestUtils.rand_tangent(::AbstractRNG, tn::Dict{<:Site,Symbol}) = NoTangent() + +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Ansatz) + return Tangent{Ansatz}(; tn=rand_tangent(rng, Quantum(tn)), lattice=NoTangent()) +end + +# WARN not really type-piracy but almost, used in `Ansatz` constructor +ChainRulesTestUtils.rand_tangent(::AbstractRNG, tn::T) where {V,T<:MetaGraph{V,SimpleGraph{V},<:Site}} = NoTangent() + +# WARN not really type-piracy but almost, used when testing `Ansatz` +function ChainRulesTestUtils.test_approx( + actual::G, expected::G, msg; kwargs... +) where {G<:MetaGraph{Int64,SimpleGraph{Int64},<:Site}} + return actual == expected +end end diff --git a/ext/TenetFiniteDifferencesExt.jl b/ext/TenetFiniteDifferencesExt.jl index 7b3308428..a0355eed9 100644 --- a/ext/TenetFiniteDifferencesExt.jl +++ b/ext/TenetFiniteDifferencesExt.jl @@ -20,4 +20,18 @@ function FiniteDifferences.to_vec(x::Dict{Vector{Symbol},Tensor}) return x_vec, Dict_from_vec end +function FiniteDifferences.to_vec(x::Quantum) + x_vec, back = to_vec(TensorNetwork(x)) + Quantum_from_vec(v) = Quantum(back(v), copy(x.sites)) + + return x_vec, Quantum_from_vec +end + +function FiniteDifferences.to_vec(x::Ansatz) + x_vec, back = to_vec(Quantum(x)) + Ansatz_from_vec(v) = Ansatz(back(v), copy(x.lattice)) + + return x_vec, Ansatz_from_vec +end + end diff --git a/ext/TenetGraphMakieExt.jl b/ext/TenetGraphMakieExt.jl index e8e121d7e..c013d4df2 100644 --- a/ext/TenetGraphMakieExt.jl +++ b/ext/TenetGraphMakieExt.jl @@ -1,9 +1,9 @@ module TenetGraphMakieExt +using Tenet using GraphMakie +using Graphs using Makie -const Graphs = GraphMakie.Graphs -using Tenet using Combinatorics: combinations """ diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 47d97beaa..62bbf336b 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -3,20 +3,23 @@ module TenetReactantExt using Tenet using EinExprs using Reactant -using Reactant: @reactant_override +using Reactant: @reactant_override, TracedRArray const MLIR = Reactant.MLIR const stablehlo = MLIR.Dialects.stablehlo const Enzyme = Reactant.Enzyme +Reactant.traced_type(::Type{T}, _, _) where {T<:Tenet.AbstractTensorNetwork} = T +Reactant.traced_getfield(x::TensorNetwork, i::Int) = tensors(x)[i] + function Reactant.make_tracer( - seen::IdDict, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs... + seen, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs... ) where {RT<:Tensor} tracedata = Reactant.make_tracer(seen, parent(prev), Reactant.append_path(path, :data), mode; kwargs...) return Tensor(tracedata, inds(prev)) end -function Reactant.make_tracer(seen::IdDict, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...) +function Reactant.make_tracer(seen, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...) tracetensors = Vector{Tensor}(undef, Tenet.ntensors(prev)) for (i, tensor) in enumerate(tensors(prev)) tracetensors[i] = Reactant.make_tracer(seen, tensor, Reactant.append_path(path, i), mode; kwargs...) @@ -24,22 +27,29 @@ function Reactant.make_tracer(seen::IdDict, prev::TensorNetwork, path::Tuple, mo return TensorNetwork(tracetensors) end -Reactant.traced_getfield(x::TensorNetwork, i::Int) = tensors(x)[i] - -function Reactant.make_tracer(seen::IdDict, prev::Quantum, path::Tuple, mode::Reactant.TraceMode; kwargs...) +function Reactant.make_tracer(seen, prev::Quantum, path::Tuple, mode::Reactant.TraceMode; kwargs...) tracetn = Reactant.make_tracer(seen, TensorNetwork(prev), Reactant.append_path(path, :tn), mode; kwargs...) return Quantum(tracetn, copy(prev.sites)) end -function Reactant.make_tracer(seen::IdDict, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...) - tracequantum = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :super), mode; kwargs...) - return Tenet.Product(tracequantum) +function Reactant.make_tracer(seen, prev::Ansatz, path::Tuple, mode::Reactant.TraceMode; kwargs...) + tracetn = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return Ansatz(tracetn, copy(Tenet.lattice(prev))) end -# TODO try rely on generic fallback for ansatzes -> do it when refactoring to MPS/MPO -function Reactant.make_tracer(seen::IdDict, prev::Tenet.Chain, path::Tuple, mode::Reactant.TraceMode; kwargs...) - tracequantum = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :super), mode; kwargs...) - return Tenet.Chain(tracequantum, boundary(prev)) +# TODO try rely on generic fallback for ansatzes +for A in (Product, Dense) + @eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...) + tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return $A(tracetn) + end +end + +for A in (MPS, MPO) + @eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...) + tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return $A(tracetn, form(prev)) + end end function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores) @@ -59,10 +69,24 @@ function Reactant.create_result(tocopy::Quantum, @nospecialize(path), result_sto return :($Quantum($tn, $(copy(tocopy.sites)))) end -# TODO try rely on generic fallback for ansatzes -> do it when refactoring to MPS/MPO -function Reactant.create_result(tocopy::Tenet.Chain, @nospecialize(path), result_stores) - qtn = Reactant.create_result(Quantum(tocopy), Reactant.append_path(path, :super), result_stores) - return :($(Tenet.Chain)($qtn, $(boundary(tocopy)))) +function Reactant.create_result(tocopy::Ansatz, @nospecialize(path), result_stores) + tn = Reactant.create_result(Quantum(tocopy), Reactant.append_path(path, :tn), result_stores) + return :($Ansatz($tn, $(copy(Tenet.lattice(tocopy))))) +end + +# TODO try rely on generic fallback for ansatzes +for A in (Product, Dense) + @eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A} + tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) + return :($A($tn)) + end +end + +for A in (MPS, MPO) + @eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A} + tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) + return :($A($tn, $(Tenet.form(tocopy)))) + end end function Reactant.push_val!(ad_inputs, x::TensorNetwork, path) @@ -124,8 +148,8 @@ end end function Tenet.contract( - a::Tensor{Ta,Na,Aa}, b::Tensor{Tb,Nb,Ab}; dims=(∩(inds(a), inds(b))), out=nothing -) where {Ta,Na,Aa<:Reactant.TracedRArray,Tb,Nb,Ab<:Reactant.TracedRArray} + a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb,TracedRArray{Tb,Nb}}; dims=(∩(inds(a), inds(b))), out=nothing +) where {Ta,Na,Tb,Nb} ia = collect(inds(a)) ib = collect(inds(b)) i = ∩(dims, ia, ib) @@ -154,12 +178,12 @@ function Tenet.contract( result = Reactant.MLIR.IR.result(stablehlo.einsum(op_a, op_b; result_0, einsum_config)) - data = Reactant.TracedRArray{T,length(ic)}((), result, rsize) + data = TracedRArray{T,length(ic)}((), result, rsize) _res = Tensor(data, ic) return _res end -function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) where {T,N,A<:Reactant.TracedRArray} +function Tenet.contract(a::Tensor{T,N,TracedRArray{T,N}}; dims=nonunique(inds(a)), out=nothing) where {T,N} ia = inds(a) i = ∩(dims, ia) @@ -178,8 +202,13 @@ function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) result = Reactant.MLIR.IR.result(stablehlo.unary_einsum(operand; result_0, einsum_config)) - data = Reactant.TracedRArray{T,length(ic)}((), result, rsize) + data = TracedRArray{T,length(ic)}((), result, rsize) return Tensor(data, ic) end +Tenet.contract(a::Tensor, b::Tensor{T,N,TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...) +function Tenet.contract(a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb}; kwargs...) where {Ta,Na,Tb,Nb} + return contract(a, Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, parent(b)), inds(b)); kwargs...) +end + end diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index 84fe9fffb..22ba076b5 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -1,33 +1,78 @@ +using KeywordDispatch using LinearAlgebra +using Graphs +using MetaGraphsNext + +# Traits +abstract type Boundary end +struct Open <: Boundary end +struct Periodic <: Boundary end + +function boundary end + +abstract type Form end +struct NonCanonical <: Form end +struct MixedCanonical <: Form + orthogonality_center::Union{Site,Vector{Site}} +end +struct Canonical <: Form end + +function form end + +struct MissingSchmidtCoefficientsException <: Base.Exception + bond::NTuple{2,Site} +end + +MissingSchmidtCoefficientsException(bond::Vector{<:Site}) = MissingSchmidtCoefficientsException(tuple(bond...)) + +function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) + return print(io, "Can't access the spectrum on bond $(e.bond)") +end + +abstract type AbstractAnsatz <: AbstractQuantum end """ Ansatz -[`AbstractQuantum`](@ref) Tensor Network with a predefined structure. +[`AbstractQuantum`](@ref) Tensor Network with a preserving structure. +""" +struct Ansatz <: AbstractAnsatz + tn::Quantum + lattice::MetaGraph{Int,G,Site{1},Nothing} where {G<:Graphs.AbstractGraph{Int}} + + function Ansatz(tn, lattice) + if !issetequal(lanes(tn), labels(lattice)) + throw(ArgumentError("Sites of the tensor network and the lattice must be equal")) + end + return new(tn, lattice) + end +end -# Notes +Ansatz(tn::Ansatz) = tn +Quantum(tn::AbstractAnsatz) = Ansatz(tn).tn - - Any subtype must define `super::Quantum` field or specialize the `Quantum` method. -""" -abstract type Ansatz <: AbstractQuantum end +Base.copy(tn::Ansatz) = Ansatz(copy(Quantum(tn)), copy(lattice(tn))) +Base.similar(tn::Ansatz) = Ansatz(similar(Quantum(tn)), copy(lattice(tn))) +Base.zero(tn::Ansatz) = Ansatz(zero(Quantum(tn)), copy(lattice(tn))) -# TODO maybe we need to change this? -Quantum(@nospecialize tn::Ansatz) = tn.super +lattice(tn::AbstractAnsatz) = Ansatz(tn).lattice -Base.:(==)(a::Ansatz, b::Ansatz) = Quantum(a) == Quantum(b) -Base.isapprox(a::Ansatz, b::Ansatz; kwargs...) = isapprox(Quantum(a), Quantum(b); kwargs...) +function Base.isapprox(a::AbstractAnsatz, b::AbstractAnsatz; kwargs...) + return ==(latice.((a, b))...) && isapprox(Quantum(a), Quantum(b); kwargs...) +end -alias(::A) where {A} = string(A) -function Base.summary(io::IO, tn::A) where {A<:Ansatz} - return print(io, "$(alias(tn)) (inputs=$(nsites(tn; set=:inputs)), outputs=$(nsites(tn; set=:outputs)))") +Graphs.neighbors(tn::AbstractAnsatz, site::Site) = neighbor_labels(lattice(tn), site) +function isneighbor(tn::AbstractAnsatz, a::Site, b::Site) + lt = lattice(tn) + return has_edge(lt, MetaGraphsNext.code_for(lt, a), MetaGraphsNext.code_for(lt, b)) end -Base.show(io::IO, tn::A) where {A<:Ansatz} = summary(io, tn) -@kwmethod function inds(tn::Ansatz; bond) +@kwmethod function inds(tn::AbstractAnsatz; bond) (site1, site2) = bond @assert site1 ∈ sites(tn) "Site $site1 not found" @assert site2 ∈ sites(tn) "Site $site2 not found" @assert site1 != site2 "Sites must be different" + @assert isneighbor(tn, site1, site2) "Sites must be neighbors" tensor1 = tensors(tn; at=site1) tensor2 = tensors(tn; at=site2) @@ -36,43 +81,280 @@ Base.show(io::IO, tn::A) where {A<:Ansatz} = summary(io, tn) return only(inds(tensor1) ∩ inds(tensor2)) end -@kwmethod function Tenet.tensors(tn::Ansatz; between) - (site1, site2) = between - @assert site1 ∈ sites(tn) "Site $site1 not found" - @assert site2 ∈ sites(tn) "Site $site2 not found" - @assert site1 != site2 "Sites must be different" +@kwmethod function tensors(tn::AbstractAnsatz; bond) + vind = inds(tn; bond) + return only( + tensors(tn, [vind]) do vinds, indices + indices == vinds + end, + ) +end - tensor1 = tensors(tn; at=site1) - tensor2 = tensors(tn; at=site2) +@kwmethod function tensors(tn::AbstractAnsatz; between) + Base.depwarn( + "`tensors(tn; between)` is deprecated, use `tensors(tn; bond)` instead.", + ((Base.Core).Typeof(tensors)).name.mt.name, + ) + return tensors(tn; bond=between) +end - isdisjoint(inds(tensor1), inds(tensor2)) && return nothing +@kwmethod contract!(tn::AbstractAnsatz; bond) = contract!(tn, inds(tn; bond)) + +canonize(tn::AbstractAnsatz, args...; kwargs...) = canonize!(deepcopy(tn), args...; kwargs...) +canonize_site(tn::AbstractAnsatz, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...) + +""" + truncate(tn::AbstractAnsatz, bond; threshold = nothing, maxdim = nothing) + +Like [`truncate!`](@ref), but returns a new tensor network instead of modifying the original one. +""" +truncate(tn::AbstractAnsatz, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...) + +""" + truncate!(tn::AbstractAnsatz, bond; threshold = nothing, maxdim = nothing) + +Truncate the dimension of the virtual `bond`` of an [`Ansatz`](@ref) Tensor Network by keeping only the `maxdim` largest Schmidt coefficients or those larger than`threshold`. - return tn[only(inds(tensor1) ∩ inds(tensor2))] +# Notes + + - Either `threshold` or `maxdim` must be provided. If both are provided, `maxdim` is used. + - The bond must contain the Schmidt coefficients, i.e. a site canonization must be performed before calling `truncate!`. +""" +function truncate!(tn::AbstractAnsatz, bond; threshold=nothing, maxdim=nothing) + @assert isnothing(maxdim) ⊻ isnothing(threshold) "Either `threshold` or `maxdim` must be provided" + + spectrum = parent(tensors(tn; bond)) + vind = inds(tn; bond) + + maxdim = isnothing(maxdim) ? size(tn, vind) : maxdim + + extent = if isnothing(threshold) + 1:maxdim + else + 1:something(findfirst(1:maxdim) do i + abs(spectrum[i]) < threshold + end - 1, maxdim) + end + + slice!(tn, vind, extent) + + return tn end -struct MissingSchmidtCoefficientsException <: Base.Exception - bond::NTuple{2,Site} +function expect(ψ::AbstractAnsatz, observables; bra=copy(ψ)) + ϕ = bra + + # TODO is this ok? + for observable in observables + evolve!(ϕ, observable) + end + + return overlap(ϕ, ψ) end -MissingSchmidtCoefficientsException(bond::Vector{<:Site}) = MissingSchmidtCoefficientsException(tuple(bond...)) +overlap(a::AbstractAnsatz, b::AbstractAnsatz) = contract(merge(a, copy(b)')) -function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) - return print(io, "Can't access the spectrum on bond $(e.bond)") +function evolve!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false) + return simple_update!(ψ, gate; threshold, maxdim, renormalize) end -function LinearAlgebra.norm(ψ::Ansatz, p::Real=2; kwargs...) - p == 2 || throw(ArgumentError("only L2-norm is implemented yet")) +# by popular demand (Stefano, I'm looking at you), I aliased `apply!` to `evolve!` +const apply! = evolve! + +function simple_update!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, kwargs...) + @assert issetequal(adjoint.(sites(gate; set=:inputs)), sites(gate; set=:outputs)) "Inputs of the gate must match outputs" + + if nlanes(gate) == 1 + return simple_update_1site!(ψ, gate) + end + + @assert isneighbor(ψ, lanes(gate)...) "Gate must act on neighboring sites" - return LinearAlgebra.norm2(ψ; kwargs...) + return simple_update!(form(ψ), ψ, gate; kwargs...) end -function LinearAlgebra.norm2(ψ::Ansatz; kwargs...) - return abs(sqrt(only(contract(merge(TensorNetwork(ψ), TensorNetwork(ψ')); kwargs...)))) +# TODO a lot of problems with merging... maybe we shouldn't merge manually +function simple_update_1site!(ψ::AbstractAnsatz, gate) + @assert nlanes(gate) == 1 "Gate must act only on one lane" + @assert ninputs(gate) == 1 "Gate must have only one input" + @assert noutputs(gate) == 1 "Gate must have only one output" + + # shallow copy to avoid problems if errors in mid execution + gate = copy(gate) + resetindex!(gate; init=ninds(ψ)) + + contracting_index = gensym(:tmp) + targetsite = only(sites(gate; set=:inputs))' + + # reindex output of gate to match TN sitemap + replace!(gate, inds(gate; at=only(sites(gate; set=:outputs))) => inds(ψ; at=targetsite)) + + # reindex contracting index + replace!(ψ, inds(ψ; at=targetsite) => contracting_index) + replace!(gate, inds(gate; at=targetsite') => contracting_index) + + # contract gate with TN + merge!(ψ, gate; reset=false) + return contract!(ψ, contracting_index) end -# Traits -abstract type Boundary end -struct Open <: Boundary end -struct Periodic <: Boundary end +# TODO remove `renormalize` argument? +function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false) + @assert nlanes(gate) == 2 "Only 2-site gates are supported currently" + @assert isneighbor(ψ, lanes(gate)...) "Gate must act on neighboring sites" -function boundary end + # shallow copy to avoid problems if errors in mid execution + gate = copy(gate) + resetindex!(gate; init=ninds(ψ)) + @reindex! outputs(gate) => inputs(gate) + + # contract involved sites + bond = (sitel, siter) = extrema(lanes(gate)) + vind = inds(ψ; bond) + linds = filter(==(vind), inds(tensors(ψ; at=sitel))) + rinds = filter(==(vind), inds(tensors(ψ; at=siter))) + contract!(ψ; bond) + + # contract physical inds with gate + merge!(ψ, gate; reset=false) + contract!(ψ, inds(gate; set=:inputs)) + + # decompose using SVD + svd!(ψ; left_inds=linds, right_inds=rinds, virtualind=vind) + + # truncate virtual index + if any(!isnothing, (threshold, maxdim)) + truncate!(ψ, bond; threshold, maxdim) + renormalize && normalize!(ψ, bond[1]) + end + + return ψ +end + +# TODO remove `renormalize` argument? +# TODO refactor code +function simple_update!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, renormalize=false) + @assert nlanes(gate) == 2 "Only 2-site gates are supported currently" + @assert isneighbor(ψ, lanes(gate)...) "Gate must act on neighboring sites" + + # shallow copy to avoid problems if errors in mid execution + gate = copy(gate) + + bond = sitel, siter = minmax(sites(gate; set=:outputs)...) + left_inds::Vector{Symbol} = !isnothing(leftindex(ψ, sitel)) ? [leftindex(ψ, sitel)] : Symbol[] + right_inds::Vector{Symbol} = !isnothing(rightindex(ψ, siter)) ? [rightindex(ψ, siter)] : Symbol[] + + virtualind::Symbol = inds(ψ; bond=bond) + + contract_2sitewf!(ψ, bond) + + # reindex contracting index + contracting_inds = [gensym(:tmp) for _ in sites(gate; set=:inputs)] + replace!( + ψ, + map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index) + inds(ψ; at=site') => contracting_index + end, + ) + replace!( + gate, + map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index) + inds(gate; at=site) => contracting_index + end, + ) + + # replace output indices of the gate for gensym indices + output_inds = [gensym(:out) for _ in sites(gate; set=:outputs)] + replace!( + gate, + map(zip(sites(gate; set=:outputs), output_inds)) do (site, out) + inds(gate; at=site) => out + end, + ) + + # reindex output of gate to match TN sitemap + for site in sites(gate; set=:outputs) + if inds(ψ; at=site) != inds(gate; at=site) + replace!(gate, inds(gate; at=site) => inds(ψ; at=site)) + end + end + + # contract physical inds + merge!(ψ, gate) + contract!(ψ, contracting_inds) + + # decompose using SVD + push!(left_inds, inds(ψ; at=sitel)) + push!(right_inds, inds(ψ; at=siter)) + + unpack_2sitewf!(ψ, bond, left_inds, right_inds, virtualind) + + # truncate virtual index + if any(!isnothing, [threshold, maxdim]) + truncate!(ψ, bond; threshold, maxdim) + renormalize && normalize!(tensors(ψ; between=bond)) + end + + return ψ +end + +# TODO refactor code +""" + contract_2sitewf!(ψ::AbstractAnsatz, bond) + +For a given [`AbstractAnsatz`](@ref) in the canonical form, creates the two-site wave function θ with Λᵢ₋₁Γᵢ₋₁ΛᵢΓᵢΛᵢ₊₁, +where i is the `bond`, and replaces the Γᵢ₋₁ΛᵢΓᵢ tensors with θ. +""" +function contract_2sitewf!(ψ::AbstractAnsatz, bond) + @assert form(ψ) == Canonical() "The tensor network must be in canonical form" + + sitel, siter = bond # TODO Check if bond is valid + (0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) || + throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) + + Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel)) + Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1))) + + !isnothing(Λᵢ₋₁) && contract!(ψ; between=(Site(id(sitel) - 1), sitel), direction=:right, delete_Λ=false) + !isnothing(Λᵢ₊₁) && contract!(ψ; between=(siter, Site(id(siter) + 1)), direction=:left, delete_Λ=false) + + contract!(ψ, inds(ψ; bond=bond)) + + return ψ +end + +# TODO refactor code +""" + unpack_2sitewf!(ψ::AbstractAnsatz, bond) + +For a given [`AbstractAnsatz`](@ref) that contains a two-site wave function θ in a bond, it decomposes θ into the canonical +form: Γᵢ₋₁ΛᵢΓᵢ, where i is the `bond`. +""" +function unpack_2sitewf!(ψ::AbstractAnsatz, bond, left_inds, right_inds, virtualind) + @assert form(ψ) == Canonical() "The tensor network must be in canonical form" + + sitel, siter = bond # TODO Check if bond is valid + (0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) || + throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) + + Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel)) + Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1))) + + # do svd of the θ tensor + θ = tensors(ψ; at=sitel) + U, s, Vt = svd(θ; left_inds, right_inds, virtualind) + + # contract with the inverse of Λᵢ and Λᵢ₊₂ + Γᵢ₋₁ = + isnothing(Λᵢ₋₁) ? U : contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)); atol=1e-32)), inds(Λᵢ₋₁)); dims=()) + Γᵢ = + isnothing(Λᵢ₊₁) ? Vt : contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)); atol=1e-32)), inds(Λᵢ₊₁)), Vt; dims=()) + + delete!(ψ, θ) + + push!(ψ, Γᵢ₋₁) + push!(ψ, s) + push!(ψ, Γᵢ) + + return ψ +end diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl deleted file mode 100644 index bc7332728..000000000 --- a/src/Ansatz/Chain.jl +++ /dev/null @@ -1,749 +0,0 @@ -using LinearAlgebra -using Random - -struct Chain <: Ansatz - super::Quantum - boundary::Boundary -end - -Base.copy(tn::Chain) = Chain(copy(Quantum(tn)), boundary(tn)) - -Base.similar(tn::Chain) = Chain(similar(Quantum(tn)), boundary(tn)) -Base.zero(tn::Chain) = Chain(zero(Quantum(tn)), boundary(tn)) - -boundary(tn::Chain) = tn.boundary - -MPS(arrays) = Chain(State(), Open(), arrays) -pMPS(arrays) = Chain(State(), Periodic(), arrays) -MPO(arrays) = Chain(Operator(), Open(), arrays) -pMPO(arrays) = Chain(Operator(), Periodic(), arrays) - -alias(tn::Chain) = alias(socket(tn), boundary(tn), tn) -alias(::State, ::Open, ::Chain) = "MPS" -alias(::State, ::Periodic, ::Chain) = "pMPS" -alias(::Operator, ::Open, ::Chain) = "MPO" -alias(::Operator, ::Periodic, ::Chain) = "pMPO" - -function Chain(tn::TensorNetwork, sites, args...; kwargs...) - return Chain(Quantum(tn, sites), args...; kwargs...) -end - -defaultorder(::Type{Chain}, ::State) = (:o, :l, :r) -defaultorder(::Type{Chain}, ::Operator) = (:o, :i, :l, :r) - -function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, State())) - @assert all(==(3) ∘ ndims, arrays) "All arrays must have 3 dimensions" - issetequal(order, defaultorder(Chain, State())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))")) - - n = length(arrays) - gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:(2n)] - - _tensors = map(enumerate(arrays)) do (i, array) - inds = map(order) do dir - if dir == :o - symbols[i] - elseif dir == :r - symbols[n + mod1(i, n)] - elseif dir == :l - symbols[n + mod1(i - 1, n)] - else - throw(ArgumentError("Invalid direction: $dir")) - end - end - Tensor(array, inds) - end - - sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - - return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) -end - -function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, State())) - @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" - @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" - @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" - issetequal(order, defaultorder(Chain, State())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))")) - - n = length(arrays) - gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:(2n)] - - _tensors = map(enumerate(arrays)) do (i, array) - _order = if i == 1 - filter(x -> x != :l, order) - elseif i == n - filter(x -> x != :r, order) - else - order - end - - inds = map(_order) do dir - if dir == :o - symbols[i] - elseif dir == :r - symbols[n + mod1(i, n)] - elseif dir == :l - symbols[n + mod1(i - 1, n)] - else - throw(ArgumentError("Invalid direction: $dir")) - end - end - Tensor(array, inds) - end - - sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - - return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) -end - -function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, Operator())) - @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" - issetequal(order, defaultorder(Chain, Operator())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) - - n = length(arrays) - gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:(3n)] - - _tensors = map(enumerate(arrays)) do (i, array) - inds = map(order) do dir - if dir == :o - symbols[i] - elseif dir == :i - symbols[i + n] - elseif dir == :l - symbols[2n + mod1(i - 1, n)] - elseif dir == :r - symbols[2n + mod1(i, n)] - else - throw(ArgumentError("Invalid direction: $dir")) - end - end - Tensor(array, inds) - end - - sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n)) - - return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) -end - -function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, Operator())) - @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" - @assert all(==(4) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 4 dimensions" - @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" - issetequal(order, defaultorder(Chain, Operator())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) - - n = length(arrays) - gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:(3n - 1)] - - _tensors = map(enumerate(arrays)) do (i, array) - _order = if i == 1 - filter(x -> x != :l, order) - elseif i == n - filter(x -> x != :r, order) - else - order - end - - inds = map(_order) do dir - if dir == :o - symbols[i] - elseif dir == :i - symbols[i + n] - elseif dir == :l - symbols[2n + mod1(i - 1, n)] - elseif dir == :r - symbols[2n + mod1(i, n)] - else - throw(ArgumentError("Invalid direction: $dir")) - end - end - Tensor(array, inds) - end - - sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n)) - - return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) -end - -function Base.convert(::Type{Chain}, qtn::Product) - arrs::Vector{Array} = arrays(qtn) - arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) - arrs[end] = reshape(arrs[end], size(arrs[end])..., 1) - map!(@view(arrs[2:(end - 1)]), @view(arrs[2:(end - 1)])) do arr - reshape(arr, size(arr)..., 1, 1) - end - - return Chain(socket(qtn), Open(), arrs) -end - -leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site) -function leftsite(::Open, tn::Chain, site::Site) - return id(site) ∈ range(2, nlanes(tn)) ? Site(id(site) - 1; dual=isdual(site)) : nothing -end -leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn)); dual=isdual(site)) - -rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site) -function rightsite(::Open, tn::Chain, site::Site) - return id(site) ∈ range(1, nlanes(tn) - 1) ? Site(id(site) + 1; dual=isdual(site)) : nothing -end -rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn)); dual=isdual(site)) - -leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site) -leftindex(::Open, tn::Chain, site::Site) = site == site"1" ? nothing : leftindex(Periodic(), tn, site) -leftindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond=(site, leftsite(tn, site))) - -rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site) -function rightindex(::Open, tn::Chain, site::Site) - return site == Site(nlanes(tn); dual=isdual(site)) ? nothing : rightindex(Periodic(), tn, site) -end -rightindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond=(site, rightsite(tn, site))) - -Base.adjoint(chain::Chain) = Chain(adjoint(Quantum(chain)), boundary(chain)) - -struct ChainSampler{B<:Boundary,S<:Socket,NT<:NamedTuple} <: Random.Sampler{Chain} - parameters::NT - - ChainSampler{B,S}(; kwargs...) where {B,S} = new{B,S,typeof(values(kwargs))}(values(kwargs)) -end - -function Base.rand(A::Type{<:Chain}, B::Type{<:Boundary}, S::Type{<:Socket}; kwargs...) - return rand(Random.default_rng(), A, B, S; kwargs...) -end - -function Base.rand(rng::AbstractRNG, ::Type{A}, ::Type{B}, ::Type{S}; kwargs...) where {A<:Chain,B<:Boundary,S<:Socket} - return rand(rng, ChainSampler{B,S}(; kwargs...), B, S) -end - -# TODO let choose the orthogonality center -function Base.rand(rng::Random.AbstractRNG, sampler::ChainSampler, ::Type{Open}, ::Type{State}) - n = sampler.parameters.n - χ = sampler.parameters.χ - p = get(sampler.parameters, :p, 2) - T = get(sampler.parameters, :eltype, Float64) - - arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i - χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 - χl = min(χ, p^(i - 1)) - χr = min(χ, p^i) - - # swap bond dims after mid and handle midpoint for odd-length MPS - (isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr)) - end - - # orthogonalize by QR factorization - F = lq!(rand(rng, T, χl, p * χr)) - - reshape(Matrix(F.Q), χl, p, χr) - end - - # reshape boundary sites - arrays[1] = reshape(arrays[1], p, p) - arrays[n] = reshape(arrays[n], p, p) - - return Chain(State(), Open(), arrays; order=(:l, :o, :r)) -end - -# TODO different input/output physical dims -function Base.rand(rng::Random.AbstractRNG, sampler::ChainSampler, ::Type{Open}, ::Type{Operator}) - n = sampler.parameters.n - χ = sampler.parameters.χ - p = get(sampler.parameters, :p, 2) - T = get(sampler.parameters, :eltype, Float64) - - ip = op = p - - arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i - χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 - χl = min(χ, ip^(i - 1) * op^(i - 1)) - χr = min(χ, ip^i * op^i) - - # swap bond dims after mid and handle midpoint for odd-length MPS - (isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr)) - end - - # orthogonalize by QR factorization - F = lq!(rand(rng, T, χl, ip * op * χr)) - reshape(Matrix(F.Q), χl, ip, op, χr) - end - - # reshape boundary sites - arrays[1] = reshape(arrays[1], p, p, min(χ, ip * op)) - arrays[n] = reshape(arrays[n], min(χ, ip * op), p, p) - - # TODO order might not be the best for performance - return Chain(Operator(), Open(), arrays; order=(:l, :i, :o, :r)) -end - -# """ -# Tenet.contract!(tn::Chain; between=(site1, site2), direction::Symbol = :left, delete_Λ = true) - -# For a given [`Chain`](@ref) tensor network, contracts the singular values Λ between two sites `site1` and `site2`. -# The `direction` keyword argument specifies the direction of the contraction, and the `delete_Λ` keyword argument -# specifies whether to delete the singular values tensor after the contraction. -# """ -@kwmethod contract(tn::Chain; between, direction, delete_Λ) = contract!(copy(tn); between, direction, delete_Λ) -@kwmethod function contract!(tn::Chain; between, direction, delete_Λ) - site1, site2 = between - Λᵢ = tensors(tn; between) - Λᵢ === nothing && return tn - - if direction === :right - Γᵢ₊₁ = tensors(tn; at=site2) - replace!(tn, Γᵢ₊₁ => contract(Γᵢ₊₁, Λᵢ; dims=())) - elseif direction === :left - Γᵢ = tensors(tn; at=site1) - replace!(tn, Γᵢ => contract(Λᵢ, Γᵢ; dims=())) - else - throw(ArgumentError("Unknown direction=:$direction")) - end - - delete_Λ && delete!(TensorNetwork(tn), Λᵢ) - - return tn -end -@kwmethod contract(tn::Chain; between) = contract(tn; between, direction=:left, delete_Λ=true) -@kwmethod contract!(tn::Chain; between) = contract!(tn; between, direction=:left, delete_Λ=true) -@kwmethod contract(tn::Chain; between, direction) = contract(tn; between, direction, delete_Λ=true) -@kwmethod contract!(tn::Chain; between, direction) = contract!(tn; between, direction, delete_Λ=true) - -canonize_site(tn::Chain, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...) -canonize_site!(tn::Chain, args...; kwargs...) = canonize_site!(boundary(tn), tn, args...; kwargs...) - -# NOTE: in method == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex! -function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, method=:qr) - left_inds = Symbol[] - right_inds = Symbol[] - - virtualind = if direction === :left - site == Site(1) && throw(ArgumentError("Cannot right-canonize left-most tensor")) - push!(right_inds, leftindex(tn, site)) - - site == Site(nsites(tn)) || push!(left_inds, rightindex(tn, site)) - push!(left_inds, inds(tn; at=site)) - - only(right_inds) - elseif direction === :right - site == Site(nsites(tn)) && throw(ArgumentError("Cannot left-canonize right-most tensor")) - push!(right_inds, rightindex(tn, site)) - - site == Site(1) || push!(left_inds, leftindex(tn, site)) - push!(left_inds, inds(tn; at=site)) - - only(right_inds) - else - throw(ArgumentError("Unknown direction=:$direction")) - end - - tmpind = gensym(:tmp) - if method === :svd - svd!(TensorNetwork(tn); left_inds, right_inds, virtualind=tmpind) - elseif method === :qr - qr!(TensorNetwork(tn); left_inds, right_inds, virtualind=tmpind) - else - throw(ArgumentError("Unknown factorization method=:$method")) - end - - contract!(tn, virtualind) - replace!(tn, tmpind => virtualind) - - return tn -end - -truncate(tn::Chain, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...) - -""" - truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing) - -Truncate the dimension of the virtual `bond`` of the [`Chain`](@ref) Tensor Network by keeping only the `maxdim` largest Schmidt coefficients or those larger than`threshold`. - -# Notes - - - Either `threshold` or `maxdim` must be provided. If both are provided, `maxdim` is used. - - The bond must contain the Schmidt coefficients, i.e. a site canonization must be performed before calling `truncate!`. -""" -function truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real}=nothing, maxdim::Union{Nothing,Int}=nothing) - # TODO replace for tensors(; between) - vind = rightindex(qtn, bond[1]) - if vind != leftindex(qtn, bond[2]) - throw(ArgumentError("Invalid bond $bond")) - end - - if vind ∉ inds(qtn; set=:hyper) - throw(MissingSchmidtCoefficientsException(bond)) - end - - tensor = TensorNetwork(qtn)[vind] - spectrum = parent(tensor) - - extent = collect( - if !isnothing(maxdim) - 1:min(size(qtn, vind), maxdim) - else - 1:size(qtn, vind) - end, - ) - - # remove 0s from spectrum - if isnothing(threshold) - threshold = 1e-16 - end - - filter!(extent) do i - abs(spectrum[i]) > threshold - end - - slice!(qtn, vind, extent) - - return qtn -end - -function isleftcanonical(qtn::Chain, site; atol::Real=1e-12) - right_ind = rightindex(qtn, site) - tensor = tensors(qtn; at=site) - - # we are at right-most site, we need to add an extra dummy dimension to the tensor - if isnothing(right_ind) - right_ind = gensym(:dummy) - tensor = Tensor(reshape(parent(tensor), size(tensor)..., 1), (inds(tensor)..., right_ind)) - end - - # TODO is replace(conj(A)...) copying too much? - contracted = contract(tensor, replace(conj(tensor), right_ind => gensym(:new_ind))) - n = size(tensor, right_ind) - identity_matrix = Matrix(I, n, n) - - return isapprox(contracted, identity_matrix; atol) -end - -function isrightcanonical(qtn::Chain, site; atol::Real=1e-12) - left_ind = leftindex(qtn, site) - tensor = tensors(qtn; at=site) - - # we are at left-most site, we need to add an extra dummy dimension to the tensor - if isnothing(left_ind) - left_ind = gensym(:dummy) - tensor = Tensor(reshape(parent(tensor), 1, size(tensor)...), (left_ind, inds(tensor)...)) - end - - #TODO is replace(conj(A)...) copying too much? - contracted = contract(tensor, replace(conj(tensor), left_ind => gensym(:new_ind))) - n = size(tensor, left_ind) - identity_matrix = Matrix(I, n, n) - - return isapprox(contracted, identity_matrix; atol) -end - -canonize(tn::Chain, args...; kwargs...) = canonize!(copy(tn), args...; kwargs...) -canonize!(tn::Chain, args...; kwargs...) = canonize!(boundary(tn), tn, args...; kwargs...) - -""" -canonize(boundary::Boundary, tn::Chain) - -Transform a `Chain` tensor network into the canonical form (Vidal form), that is, -we have the singular values matrix Λᵢ between each tensor Γᵢ₋₁ and Γᵢ. -""" -function canonize!(::Open, tn::Chain) - Λ = Tensor[] - - # right-to-left QR sweep, get right-canonical tensors - for i in nsites(tn):-1:2 - canonize_site!(tn, Site(i); direction=:left, method=:qr) - end - - # left-to-right SVD sweep, get left-canonical tensors and singular values without reversing - for i in 1:(nsites(tn) - 1) - canonize_site!(tn, Site(i); direction=:right, method=:svd) - - # extract the singular values and contract them with the next tensor - Λᵢ = pop!(TensorNetwork(tn), tensors(tn; between=(Site(i), Site(i + 1)))) - Aᵢ₊₁ = tensors(tn; at=Site(i + 1)) - replace!(tn, Aᵢ₊₁ => contract(Aᵢ₊₁, Λᵢ; dims=())) - push!(Λ, Λᵢ) - end - - for i in 2:nsites(tn) # tensors at i in "A" form, need to contract (Λᵢ)⁻¹ with A to get Γᵢ - Λᵢ = Λ[i - 1] # singular values start between site 1 and 2 - A = tensors(tn; at=Site(i)) - Γᵢ = contract(A, Tensor(diag(pinv(Diagonal(parent(Λᵢ)); atol=1e-64)), inds(Λᵢ)); dims=()) - replace!(tn, A => Γᵢ) - push!(TensorNetwork(tn), Λᵢ) - end - - return tn -end - -mixed_canonize(tn::Chain, args...; kwargs...) = mixed_canonize!(deepcopy(tn), args...; kwargs...) -mixed_canonize!(tn::Chain, args...; kwargs...) = mixed_canonize!(boundary(tn), tn, args...; kwargs...) - -""" - mixed_canonize!(boundary::Boundary, tn::Chain, center::Site) - -Transform a `Chain` tensor network into the mixed-canonical form, that is, -for i < center the tensors are left-canonical and for i >= center the tensors are right-canonical, -and in the center there is a matrix with singular values. -""" -function mixed_canonize!(::Open, tn::Chain, center::Site) # TODO: center could be a range of sites - # left-to-right QR sweep (left-canonical tensors) - for i in 1:(id(center) - 1) - canonize_site!(tn, Site(i); direction=:right, method=:qr) - end - - # right-to-left QR sweep (right-canonical tensors) - for i in nsites(tn):-1:(id(center) + 1) - canonize_site!(tn, Site(i); direction=:left, method=:qr) - end - - # center SVD sweep to get singular values - canonize_site!(tn, center; direction=:left, method=:svd) - - return tn -end - -""" - LinearAlgebra.normalize!(tn::Chain, center::Site) - -Normalizes the input [`Chain`](@ref) tensor network by transforming it -to mixed-canonized form with the given center site. -""" -function LinearAlgebra.normalize!(tn::Chain, root::Site; p::Real=2) - mixed_canonize!(tn, root) - normalize!(tensors(tn; between=(Site(id(root) - 1), root)), p) - return tn -end - -""" - evolve!(qtn::Chain, gate) - -Applies a local operator `gate` to the [`Chain`](@ref) tensor network. -""" -function evolve!(qtn::Chain, gate::Dense; threshold=nothing, maxdim=nothing, iscanonical=false, renormalize=false) - # check gate is a valid operator - if !(socket(gate) isa Operator) - throw(ArgumentError("Gate must be an operator, but got $(socket(gate))")) - end - - # TODO refactor out to `islane`? - if !issetequal(adjoint.(sites(gate; set=:inputs)), sites(gate; set=:outputs)) - throw( - ArgumentError( - "Gate inputs ($(sites(gate; set=:inputs))) and outputs ($(sites(gate; set=:outputs))) must be the same" - ), - ) - end - - # TODO refactor out to `canconnect`? - if adjoint.(sites(gate; set=:inputs)) ⊈ sites(qtn; set=:outputs) - throw( - ArgumentError("Gate inputs ($(sites(gate; set=:inputs))) must be a subset of the TN sites ($(sites(qtn)))") - ) - end - - if nlanes(gate) == 1 - evolve_1site!(qtn, gate) - elseif nlanes(gate) == 2 - # check gate sites are contiguous - # TODO refactor this out? - gate_inputs = sort!(id.(sites(gate; set=:inputs))) - range = UnitRange(extrema(gate_inputs)...) - - range != gate_inputs && throw(ArgumentError("Gate lanes must be contiguous")) - - # TODO check correctly for periodic boundary conditions - evolve_2site!(qtn, gate; threshold, maxdim, iscanonical, renormalize) - else - # TODO generalize for more than 2 lanes - throw(ArgumentError("Invalid number of lanes $(nlanes(gate)), maximum is 2")) - end - - return qtn -end - -function evolve_1site!(qtn::Chain, gate::Dense) - # shallow copy to avoid problems if errors in mid execution - gate = copy(gate) - resetindex!(gate; init=ninds(qtn)) - - contracting_index = gensym(:tmp) - targetsite = only(sites(gate; set=:inputs))' - - # reindex output of gate to match TN sitemap - replace!(gate, inds(gate; at=only(sites(gate; set=:outputs))) => inds(qtn; at=targetsite)) - - # reindex contracting index - replace!(qtn, inds(qtn; at=targetsite) => contracting_index) - replace!(gate, inds(gate; at=targetsite') => contracting_index) - - # contract gate with TN - merge!(qtn, gate; reset=false) - return contract!(qtn, contracting_index) -end - -# TODO: Maybe rename iscanonical kwarg ? -function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical=false, renormalize=false) - # shallow copy to avoid problems if errors in mid execution - gate = copy(gate) - - bond = sitel, siter = minmax(sites(gate; set=:outputs)...) - left_inds::Vector{Symbol} = !isnothing(leftindex(qtn, sitel)) ? [leftindex(qtn, sitel)] : Symbol[] - right_inds::Vector{Symbol} = !isnothing(rightindex(qtn, siter)) ? [rightindex(qtn, siter)] : Symbol[] - - virtualind::Symbol = inds(qtn; bond=bond) - - iscanonical ? contract_2sitewf!(qtn, bond) : contract!(TensorNetwork(qtn), virtualind) - - # reindex contracting index - contracting_inds = [gensym(:tmp) for _ in sites(gate; set=:inputs)] - replace!( - TensorNetwork(qtn), - map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index) - inds(qtn; at=site') => contracting_index - end, - ) - replace!( - Quantum(gate), - map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index) - inds(gate; at=site) => contracting_index - end, - ) - - # replace output indices of the gate for gensym indices - output_inds = [gensym(:out) for _ in sites(gate; set=:outputs)] - replace!( - Quantum(gate), - map(zip(sites(gate; set=:outputs), output_inds)) do (site, out) - inds(gate; at=site) => out - end, - ) - - # reindex output of gate to match TN sitemap - for site in sites(gate; set=:outputs) - if inds(qtn; at=site) != inds(gate; at=site) - replace!(TensorNetwork(gate), inds(gate; at=site) => inds(qtn; at=site)) - end - end - - # contract physical inds - merge!(TensorNetwork(qtn), TensorNetwork(gate)) - contract!(qtn, contracting_inds) - - # decompose using SVD - push!(left_inds, inds(qtn; at=sitel)) - push!(right_inds, inds(qtn; at=siter)) - - if iscanonical - unpack_2sitewf!(qtn, bond, left_inds, right_inds, virtualind) - else - svd!(TensorNetwork(qtn); left_inds, right_inds, virtualind) - end - # truncate virtual index - if any(!isnothing, [threshold, maxdim]) - truncate!(qtn, bond; threshold, maxdim) - - # renormalize the bond - if renormalize && iscanonical - λ = tensors(qtn; between=bond) - replace!(qtn, λ => normalize(λ)) # TODO this can be replaced by `normalize!(λ)` - elseif renormalize && !iscanonical - normalize!(qtn, bond[1]) - end - end - - return qtn -end - -""" - contract_2sitewf!(ψ::Chain, bond) - -For a given [`Chain`](@ref) in the canonical form, creates the two-site wave function θ with Λᵢ₋₁Γᵢ₋₁ΛᵢΓᵢΛᵢ₊₁, -where i is the `bond`, and replaces the Γᵢ₋₁ΛᵢΓᵢ tensors with θ. -""" -function contract_2sitewf!(ψ::Chain, bond) - # TODO Check if ψ is in canonical form - - sitel, siter = bond # TODO Check if bond is valid - (0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) || - throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) - - Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel)) - Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1))) - - !isnothing(Λᵢ₋₁) && contract!(ψ; between=(Site(id(sitel) - 1), sitel), direction=:right, delete_Λ=false) - !isnothing(Λᵢ₊₁) && contract!(ψ; between=(siter, Site(id(siter) + 1)), direction=:left, delete_Λ=false) - - contract!(ψ, inds(ψ; bond=bond)) - - return ψ -end - -""" - unpack_2sitewf!(ψ::Chain, bond) - -For a given [`Chain`](@ref) that contains a two-site wave function θ in a bond, it decomposes θ into the canonical -form: Γᵢ₋₁ΛᵢΓᵢ, where i is the `bond`. -""" -function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind) - # TODO Check if ψ is in canonical form - - sitel, siter = bond # TODO Check if bond is valid - (0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) || - throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) - - Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel)) - Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1))) - - # do svd of the θ tensor - θ = tensors(ψ; at=sitel) - U, s, Vt = svd(θ; left_inds, right_inds, virtualind) - - # contract with the inverse of Λᵢ and Λᵢ₊₂ - Γᵢ₋₁ = - isnothing(Λᵢ₋₁) ? U : contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)); atol=1e-32)), inds(Λᵢ₋₁)); dims=()) - Γᵢ = - isnothing(Λᵢ₊₁) ? Vt : contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)); atol=1e-32)), inds(Λᵢ₊₁)), Vt; dims=()) - - delete!(TensorNetwork(ψ), θ) - - push!(TensorNetwork(ψ), Γᵢ₋₁) - push!(TensorNetwork(ψ), s) - push!(TensorNetwork(ψ), Γᵢ) - - return ψ -end - -function expect(ψ::Chain, observables) - # contract observable with TN - ϕ = copy(ψ) - for observable in observables - evolve!(ϕ, observable) - end - - # contract evolved TN with adjoint of original TN - tn = merge!(TensorNetwork(ϕ), TensorNetwork(ψ')) - - return contract(tn) -end - -overlap(a::Chain, b::Chain) = overlap(socket(a), a, socket(b), b) - -# TODO fix optimal path -function overlap(::State, a::Chain, ::State, b::Chain) - @assert issetequal(sites(a), sites(b)) "Ansatzes must have the same sites" - - b = copy(b) - b = @reindex! outputs(a) => outputs(b) - - tn = merge(TensorNetwork(a), TensorNetwork(b')) - return contract(tn) -end - -# TODO optimize -overlap(a::Product, b::Chain) = contract(merge(Quantum(a), Quantum(b)')) -overlap(a::Chain, b::Product) = contract(merge(Quantum(a), Quantum(b)')) diff --git a/src/Ansatz/Dense.jl b/src/Ansatz/Dense.jl index a0fa5355f..8bc2b69f4 100644 --- a/src/Ansatz/Dense.jl +++ b/src/Ansatz/Dense.jl @@ -1,15 +1,24 @@ -struct Dense <: Ansatz - super::Quantum +using Combinatorics + +struct Dense <: AbstractAnsatz + tn::Ansatz end +Ansatz(tn::Dense) = tn.tn + +Base.copy(qtn::Dense) = Dense(copy(Ansatz(qtn))) +Base.similar(qtn::Dense) = Dense(similar(Ansatz(qtn))) +Base.zero(qtn::Dense) = Dense(zero(Ansatz(qtn))) + function Dense(::State, array::AbstractArray; sites=Site.(1:ndims(array))) - @assert ndims(array) > 0 + n = ndims(array) + @assert n > 0 @assert all(>(1), size(array)) gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:ndims(array)] + symbols = [nextindex!(gen) for _ in 1:n] sitemap = Dict{Site,Symbol}( - map(sites, 1:ndims(array)) do site, i + map(sites, 1:n) do site, i site => symbols[i] end, ) @@ -18,23 +27,46 @@ function Dense(::State, array::AbstractArray; sites=Site.(1:ndims(array))) tn = TensorNetwork([tensor]) qtn = Quantum(tn, sitemap) - return Dense(qtn) + graph = complete_graph(nlanes(qtn)) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + ansatz = Ansatz(qtn, lattice) + return Dense(ansatz) end function Dense(::Operator, array::AbstractArray; sites) - @assert ndims(array) > 0 + n = ndims(array) + @assert n > 0 @assert all(>(1), size(array)) - @assert length(sites) == ndims(array) + @assert length(sites) == n gen = IndexCounter() - tensor_inds = [nextindex!(gen) for _ in 1:ndims(array)] + tensor_inds = [nextindex!(gen) for _ in 1:n] tensor = Tensor(array, tensor_inds) tn = TensorNetwork([tensor]) sitemap = Dict{Site,Symbol}(map(splat(Pair), zip(sites, tensor_inds))) qtn = Quantum(tn, sitemap) + graph = complete_graph(nlanes(qtn)) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + ansatz = Ansatz(qtn, lattice) + return Dense(ansatz) +end - return Dense(qtn) +function Base.rand(rng::Random.AbstractRNG, ::Type{Dense}, ::State; n, eltype=Float64, physdim=2) + array = rand(rng, eltype, fill(physdim, n)...) + normalize!(array) + return Dense(State(), array; sites=Site.(1:n)) end -Base.copy(qtn::Dense) = Dense(copy(Quantum(qtn))) +function LinearAlgebra.normalize!(ψ::Dense) + normalize!(only(arrays(ψ))) + return ψ +end + +function overlap(ϕ::Dense, ψ::Dense) + @assert lanes(ϕ) == lanes(ψ) + @assert socket(ϕ) == State() && socket(ψ) == State() + ψ = copy(ψ) + @reindex! outputs(ϕ) => outputs(ψ) + return contract(only(tensors(ϕ)), only(tensors(ψ))) +end diff --git a/src/Ansatz/Grid.jl b/src/Ansatz/Grid.jl deleted file mode 100644 index ef59f3e84..000000000 --- a/src/Ansatz/Grid.jl +++ /dev/null @@ -1,180 +0,0 @@ -struct Grid <: Ansatz - super::Quantum - boundary::Boundary -end - -Base.copy(tn::Grid) = Grid(copy(Quantum(tn)), boundary(tn)) - -boundary(tn::Grid) = tn.boundary - -PEPS(arrays) = Grid(State(), Open(), arrays) -pPEPS(arrays) = Grid(State(), Periodic(), arrays) -PEPO(arrays) = Grid(Operator(), Open(), arrays) -pPEPO(arrays) = Grid(Operator(), Periodic(), arrays) - -alias(tn::Grid) = alias(socket(tn), boundary(tn), tn) -alias(::State, ::Open, ::Grid) = "PEPS" -alias(::State, ::Periodic, ::Grid) = "pPEPS" -alias(::Operator, ::Open, ::Grid) = "PEPO" -alias(::Operator, ::Periodic, ::Grid) = "pPEPO" - -function Grid(::State, ::Periodic, arrays::Matrix{<:AbstractArray}) - @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" - - m, n = size(arrays) - gen = IndexCounter() - pinds = map(_ -> nextindex!(gen), arrays) - hvinds = map(_ -> nextindex!(gen), arrays) - vvinds = map(_ -> nextindex!(gen), arrays) - - _tensors = map(eachindex(IndexCartesian(), arrays)) do I - i, j = Tuple(I) - - array = arrays[i, j] - pind = pinds[i, j] - up, down = hvinds[i, j], hvinds[mod1(i + 1, m), j] - left, right = vvinds[i, j], vvinds[i, mod1(j + 1, n)] - - # TODO customize order - Tensor(array, [pind, up, down, left, right]) - end - - sitemap = Dict(Site(i, j) => pinds[i, j] for i in 1:m, j in 1:n) - - return Grid(Quantum(TensorNetwork(_tensors), sitemap), Periodic()) -end - -function Grid(::State, ::Open, arrays::Matrix{<:AbstractArray}) - m, n = size(arrays) - - predicate = all(eachindex(arrays)) do I - i, j = Tuple(I) - array = arrays[i, j] - - N = ndims(array) - 1 - (i == 1 || i == m) && (N -= 1) - (j == 1 || j == n) && (N -= 1) - - N > 0 - end - - if !predicate - throw(DimensionMismatch()) - end - - gen = IndexCounter() - pinds = map(_ -> nextindex!(gen), arrays) - vvinds = [nextindex!(gen) for _ in 1:(m - 1), _ in 1:n] - hvinds = [nextindex!(gen) for _ in 1:m, _ in 1:(n - 1)] - - _tensors = map(eachindex(IndexCartesian(), arrays)) do I - i, j = Tuple(I) - - array = arrays[i, j] - pind = pinds[i, j] - up = i == 1 ? missing : vvinds[i - 1, j] - down = i == m ? missing : vvinds[i, j] - left = j == 1 ? missing : hvinds[i, j - 1] - right = j == n ? missing : hvinds[i, j] - - # TODO customize order - Tensor(array, collect(skipmissing([pind, up, down, left, right]))) - end - - sitemap = Dict(Site(i, j) => pinds[i, j] for i in 1:m, j in 1:n) - - return Grid(Quantum(TensorNetwork(_tensors), sitemap), Open()) -end - -function Grid(::Operator, ::Periodic, arrays::Matrix{<:AbstractArray}) - @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" - - m, n = size(arrays) - gen = IndexCounter() - ipinds = map(_ -> nextindex!(gen), arrays) - opinds = map(_ -> nextindex!(gen), arrays) - hvinds = map(_ -> nextindex!(gen), arrays) - vvinds = map(_ -> nextindex!(gen), arrays) - - _tensors = map(eachindex(IndexCartesian(), arrays)) do I - i, j = Tuple(I) - - array = arrays[i, j] - ipind, opind = ipinds[i, j], opinds[i, j] - up, down = hvinds[i, j], hvinds[mod1(i + 1, m), j] - left, right = vvinds[i, j], vvinds[i, mod1(j + 1, n)] - - # TODO customize order - Tensor(array, [ipind, opind, up, down, left, right]) - end - - sitemap = Dict( - flatten([ - (Site(i, j; dual=true) => ipinds[i, j] for i in 1:m, j in 1:n), - (Site(i, j) => opinds[i, j] for i in 1:m, j in 1:n), - ]), - ) - - return Grid(Quantum(TensorNetwork(_tensors), sitemap), Periodic()) -end - -function Grid(::Operator, ::Open, arrays::Matrix{<:AbstractArray}) - m, n = size(arrays) - - predicate = all(eachindex(IndexCartesian(), arrays)) do I - i, j = Tuple(I) - array = arrays[i, j] - - N = ndims(array) - 2 - (i == 1 || i == m) && (N -= 1) - (j == 1 || j == n) && (N -= 1) - - N > 0 - end - - if !predicate - throw(DimensionMismatch()) - end - - gen = IndexCounter() - ipinds = map(_ -> nextindex!(gen), arrays) - opinds = map(_ -> nextindex!(gen), arrays) - vvinds = [nextindex!(gen) for _ in 1:(m - 1), _ in 1:n] - hvinds = [nextindex!(gen) for _ in 1:m, _ in 1:(n - 1)] - - _tensors = map(eachindex(IndexCartesian(), arrays)) do I - i, j = Tuple(I) - - array = arrays[i, j] - ipind = ipinds[i, j] - opind = opinds[i, j] - up = i == 1 ? missing : vvinds[i - 1, j] - down = i == m ? missing : vvinds[i, j] - left = j == 1 ? missing : hvinds[i, j - 1] - right = j == n ? missing : hvinds[i, j] - - # TODO customize order - Tensor(array, collect(skipmissing([ipind, opind, up, down, left, right]))) - end - - sitemap = Dict( - flatten([ - (Site(i, j; dual=true) => ipinds[i, j] for i in 1:m, j in 1:n), - (Site(i, j) => opinds[i, j] for i in 1:m, j in 1:n), - ]), - ) - - return Grid(Quantum(TensorNetwork(_tensors), sitemap), Open()) -end - -function LinearAlgebra.transpose!(qtn::Grid) - old = Quantum(qtn).sites - new = Dict(Site(reverse(id(site)); dual=isdual(site)) => ind for (site, ind) in old) - - empty!(old) - merge!(old, new) - - return qtn -end - -Base.transpose(qtn::Grid) = LinearAlgebra.transpose!(copy(qtn)) diff --git a/src/Ansatz/MPO.jl b/src/Ansatz/MPO.jl new file mode 100644 index 000000000..ddbe82a1d --- /dev/null +++ b/src/Ansatz/MPO.jl @@ -0,0 +1,135 @@ +using Random + +abstract type AbstractMPO <: AbstractAnsatz end + +struct MPO <: AbstractAnsatz + tn::Ansatz + form::Form +end + +Ansatz(tn::MPO) = tn.tn + +Base.copy(x::MPO) = MPO(copy(Ansatz(x)), form(x)) +Base.similar(x::MPO) = MPO(similar(Ansatz(x)), form(x)) +Base.zero(x::MPO) = MPO(zero(Ansatz(x)), form(x)) + +defaultorder(::Type{MPO}) = (:o, :i, :l, :r) +boundary(::MPO) = Open() +form(tn::MPO) = tn.form + +function MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO)) + @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" + @assert all(==(4) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 4 dimensions" + @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" + issetequal(order, defaultorder(MPO)) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPO)))")) + + n = length(arrays) + gen = IndexCounter() + symbols = [nextindex!(gen) for _ in 1:(3n - 1)] + + tn = TensorNetwork( + map(enumerate(arrays)) do (i, array) + _order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) + else + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :i + symbols[i + n] + elseif dir == :l + symbols[2n + mod1(i - 1, n)] + elseif dir == :r + symbols[2n + mod1(i, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end, + ) + + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n)) + qtn = Quantum(tn, sitemap) + graph = path_graph(n) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + ansatz = Ansatz(qtn, lattice) + return MPO(ansatz, NonCanonical()) +end + +function Base.convert(::Type{MPO}, tn::Product) + @assert socket(tn) == Operator() + + arrs::Vector{Array} = arrays(tn) + arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) + arrs[end] = reshape(arrs[end], size(arrs[end])..., 1) + map!(@view(arrs[2:(end - 1)]), @view(arrs[2:(end - 1)])) do arr + reshape(arr, size(arr)..., 1, 1) + end + + return MPO(arrs) +end + +Base.adjoint(tn::MPO) = MPO(adjoint(Ansatz(tn)), form(tn)) + +# TODO different input/output physical dims +# TODO let choose the orthogonality center +function Base.rand(rng::Random.AbstractRNG, ::Type{MPO}; n, maxdim, eltype=Float64, physdim=2) + T = eltype + ip = op = physdim + χ = maxdim + + arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i + χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 + χl = min(χ, ip^(i - 1) * op^(i - 1)) + χr = min(χ, ip^i * op^i) + + # swap bond dims after mid and handle midpoint for odd-length MPS + (isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr)) + end + + # orthogonalize by QR factorization + F = lq!(rand(rng, T, χl, ip * op * χr)) + reshape(Matrix(F.Q), χl, ip, op, χr) + end + + # reshape boundary sites + arrays[1] = reshape(arrays[1], ip, op, min(χ, ip * op)) + arrays[n] = reshape(arrays[n], min(χ, ip * op), ip, op) + + # TODO order might not be the best for performance + return MPO(arrays; order=(:l, :i, :o, :r)) +end + +# TODO change it to `lanes`? +# TODO refactor common code with `MPS` +function sites(ψ::MPO, site::Site; dir) + if dir === :left + return site <= site"1" ? nothing : Site(id(site) - 1) + elseif dir === :right + return site >= Site(nlanes(ψ)) ? nothing : Site(id(site) + 1) + else + throw(ArgumentError("Unknown direction for MPO = :$dir")) + end +end + +@kwmethod function inds(ψ::MPO; at, dir) + if dir === :left && at == site"1" + return nothing + elseif dir === :right && at == Site(nlanes(ψ); dual=isdual(at)) + return nothing + elseif dir ∈ (:left, :right) + return inds(ψ; bond=(at, sites(ψ, at; dir))) + else + throw(ArgumentError("Unknown direction for MPO = :$dir")) + end +end + +function evolve!(ψ::MPS, op::MPO; threshold=nothing, maxdim=nothing, renormalize=false) end diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl new file mode 100644 index 000000000..28f83f7f6 --- /dev/null +++ b/src/Ansatz/MPS.jl @@ -0,0 +1,327 @@ +using Random +using LinearAlgebra + +abstract type AbstractMPS <: AbstractAnsatz end + +mutable struct MPS <: AbstractMPS + const tn::Ansatz + form::Form +end + +Ansatz(tn::MPS) = tn.tn + +Base.copy(x::MPS) = MPS(copy(Ansatz(x)), form(x)) +Base.similar(x::MPS) = MPS(similar(Ansatz(x)), form(x)) +Base.zero(x::MPS) = MPS(zero(Ansatz(x)), form(x)) + +defaultorder(::Type{MPS}) = (:o, :l, :r) +boundary(::MPS) = Open() +form(tn::MPS) = tn.form + +function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) + @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" + @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" + @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" + issetequal(order, defaultorder(MPS)) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPS)))")) + + n = length(arrays) + gen = IndexCounter() + symbols = [nextindex!(gen) for _ in 1:(2n)] + + tn = TensorNetwork( + map(enumerate(arrays)) do (i, array) + _order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) + else + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :r + symbols[n + mod1(i, n)] + elseif dir == :l + symbols[n + mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end, + ) + + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + qtn = Quantum(tn, sitemap) + graph = path_graph(n) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + ansatz = Ansatz(qtn, lattice) + return MPS(ansatz, NonCanonical()) +end + +""" + Base.identity(::Type{MPS}, n::Integer; physdim=2, maxdim=physdim^(n ÷ 2)) + +Returns an [`MPS`](@ref) of `n` sites whose tensors are initialized to COPY-tensors. + +# Keyword Arguments + + - `physdim` The physical or output dimension of each site. Defaults to 2. + - `maxdim` The maximum bond dimension. Defaults to `physdim^(n ÷ 2)`. +""" +function Base.identity(::Type{MPS}, n::Integer; physdim=2, maxdim=physdim^(n ÷ 2)) + # Create bond dimensions until the middle of the MPS considering maxdim + virtualdims = min.(maxdim, physdim .^ (1:(n ÷ 2))) + + # Complete the bond dimensions of the other half of the MPS + virtualdims = vcat(virtualdims, virtualdims[(isodd(n) ? end : end - 1):-1:1]) + + # Create each site dimensions in default order (:o, :l, :r) + arraysdims = [[physdim, virtualdims[1]]] + append!(arraysdims, [[physdim, virtualdims[i], virtualdims[i + 1]] for i in 1:(length(virtualdims) - 1)]) + push!(arraysdims, [physdim, virtualdims[end]]) + + # Create the MPS with copy-tensors according to the tensors dimensions + return MPS( + map(arraysdims) do arrdims + arr = zeros(ComplexF64, arrdims...) + deltas = [fill(i, length(arrdims)) for i in 1:physdim] + broadcast(delta -> arr[delta...] = 1.0, deltas) + arr + end, + ) +end + +function Base.convert(::Type{MPS}, tn::Product) + @assert socket(tn) == State() + + arrs::Vector{Array} = arrays(tn) + arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) + arrs[end] = reshape(arrs[end], size(arrs[end])..., 1) + map!(@view(arrs[2:(end - 1)]), @view(arrs[2:(end - 1)])) do arr + reshape(arr, size(arr)..., 1, 1) + end + + return MPS(arrs) +end + +Base.adjoint(tn::MPS) = MPS(adjoint(Ansatz(tn)), form(tn)) + +# TODO different input/output physical dims +# TODO let choose the orthogonality center +function Base.rand(rng::Random.AbstractRNG, ::Type{MPS}; n, maxdim, eltype=Float64, physdim=2) + p = physdim + T = eltype + χ = maxdim + + arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i + χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 + χl = min(χ, p^(i - 1)) + χr = min(χ, p^i) + + # swap bond dims after mid and handle midpoint for odd-length MPS + (isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr)) + end + + # orthogonalize by QR factorization + F = lq!(rand(rng, T, χl, p * χr)) + + reshape(Matrix(F.Q), χl, p, χr) + end + + # reshape boundary sites + arrays[1] = reshape(arrays[1], p, p) + arrays[n] = reshape(arrays[n], p, p) + + return MPS(arrays; order=(:l, :o, :r)) +end + +# TODO deprecate contract(; between) and generalize it to AbstractAnsatz +""" + Tenet.contract!(tn::MPS; between=(site1, site2), direction::Symbol = :left, delete_Λ = true) + +For a given [`MPS`](@ref) tensor network, contracts the singular values Λ between two sites `site1` and `site2`. +The `direction` keyword argument specifies the direction of the contraction, and the `delete_Λ` keyword argument +specifies whether to delete the singular values tensor after the contraction. +""" +@kwmethod contract(tn::MPS; between, direction, delete_Λ) = contract!(copy(tn); between, direction, delete_Λ) +@kwmethod function contract!(tn::MPS; between, direction, delete_Λ) + site1, site2 = between + Λᵢ = tensors(tn; between) + Λᵢ === nothing && return tn + + if direction === :right + Γᵢ₊₁ = tensors(tn; at=site2) + replace!(tn, Γᵢ₊₁ => contract(Γᵢ₊₁, Λᵢ; dims=())) + elseif direction === :left + Γᵢ = tensors(tn; at=site1) + replace!(tn, Γᵢ => contract(Λᵢ, Γᵢ; dims=())) + else + throw(ArgumentError("Unknown direction=:$direction")) + end + + delete_Λ && delete!(TensorNetwork(tn), Λᵢ) + + return tn +end +@kwmethod contract(tn::MPS; between) = contract(tn; between, direction=:left, delete_Λ=true) +@kwmethod contract!(tn::MPS; between) = contract!(tn; between, direction=:left, delete_Λ=true) +@kwmethod contract(tn::MPS; between, direction) = contract(tn; between, direction, delete_Λ=true) +@kwmethod contract!(tn::MPS; between, direction) = contract!(tn; between, direction, delete_Λ=true) + +# TODO rename it to `lanes`? +function sites(ψ::MPS, site::Site; dir) + if dir === :left + return site == site"1" ? nothing : Site(id(site) - 1) + elseif dir === :right + return site == Site(nsites(ψ)) ? nothing : Site(id(site) + 1) + else + throw(ArgumentError("Unknown direction for MPS = :$dir")) + end +end + +@kwmethod function inds(ψ::MPS; at, dir) + if dir === :left && at == site"1" + return nothing + elseif dir === :right && at == Site(nlanes(ψ); dual=isdual(at)) + return nothing + elseif dir ∈ (:left, :right) + return inds(ψ; bond=(at, sites(ψ, at; dir))) + else + throw(ArgumentError("Unknown direction for MPS = :$dir")) + end +end + +function isisometry(ψ::MPS, site; dir, atol::Real=1e-12) + tensor = tensors(ψ; at=site) + dirind = inds(ψ; at=site, dir) + + if isnothing(dirind) + @show parent(contract(tensor, conj(tensor))) + return isapprox(parent(contract(tensor, conj(tensor))), fill(true); atol) + end + + inda, indb = gensym(:a), gensym(:b) + a = replace(tensor, dirind => inda) + b = replace(conj(tensor), dirind => indb) + + n = size(tensor, dirind) + contracted = contract(a, b; out=[inda, indb]) + + return isapprox(contracted, I(n); atol) +end + +@deprecate isleftcanonical(ψ::MPS, site; atol::Real=1e-12) isisometry(ψ, site; dir=:right, atol) +@deprecate isrightcanonical(ψ::MPS, site; atol::Real=1e-12) isisometry(ψ, site; dir=:left, atol) + +# NOTE: in method == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex! +function canonize_site!(ψ::MPS, site::Site; direction::Symbol, method=:qr) + left_inds = Symbol[] + right_inds = Symbol[] + + virtualind = if direction === :left + site == Site(1) && throw(ArgumentError("Cannot right-canonize left-most tensor")) + push!(right_inds, inds(ψ; at=site, dir=:left)) + + site == Site(nsites(ψ)) || push!(left_inds, inds(ψ; at=site, dir=:right)) + push!(left_inds, inds(ψ; at=site)) + + only(right_inds) + elseif direction === :right + site == Site(nsites(ψ)) && throw(ArgumentError("Cannot left-canonize right-most tensor")) + push!(right_inds, inds(ψ; at=site, dir=:right)) + + site == Site(1) || push!(left_inds, inds(ψ; at=site, dir=:left)) + push!(left_inds, inds(ψ; at=site)) + + only(right_inds) + else + throw(ArgumentError("Unknown direction=:$direction")) + end + + tmpind = gensym(:tmp) + if method === :svd + svd!(ψ; left_inds, right_inds, virtualind=tmpind) + elseif method === :qr + qr!(ψ; left_inds, right_inds, virtualind=tmpind) + else + throw(ArgumentError("Unknown factorization method=:$method")) + end + + contract!(ψ, virtualind) + replace!(ψ, tmpind => virtualind) + + return ψ +end + +""" + canonize!(tn::MPS) + +Transform a [`MPS`](@ref) tensor network into the canonical form (Vidal form); i.e. the singular values matrix Λᵢ between each tensor Γᵢ₋₁ and Γᵢ. +""" +function canonize!(ψ::MPS) + Λ = Tensor[] + + # right-to-left QR sweep, get right-canonical tensors + for i in nsites(ψ):-1:2 + canonize_site!(ψ, Site(i); direction=:left, method=:qr) + end + + # left-to-right SVD sweep, get left-canonical tensors and singular values without reversing + for i in 1:(nsites(ψ) - 1) + canonize_site!(ψ, Site(i); direction=:right, method=:svd) + + # extract the singular values and contract them with the next tensor + Λᵢ = pop!(ψ, tensors(ψ; between=(Site(i), Site(i + 1)))) + Aᵢ₊₁ = tensors(ψ; at=Site(i + 1)) + replace!(ψ, Aᵢ₊₁ => contract(Aᵢ₊₁, Λᵢ; dims=())) + push!(Λ, Λᵢ) + end + + for i in 2:nsites(ψ) # tensors at i in "A" form, need to contract (Λᵢ)⁻¹ with A to get Γᵢ + Λᵢ = Λ[i - 1] # singular values start between site 1 and 2 + A = tensors(ψ; at=Site(i)) + Γᵢ = contract(A, Tensor(diag(pinv(Diagonal(parent(Λᵢ)); atol=1e-64)), inds(Λᵢ)); dims=()) + replace!(ψ, A => Γᵢ) + push!(ψ, Λᵢ) + end + + return ψ +end + +mixed_canonize(tn::MPS, args...; kwargs...) = mixed_canonize!(deepcopy(tn), args...; kwargs...) + +# TODO mixed_canonize! at bond +""" + mixed_canonize!(tn::MPS, orthog_center) + +Transform a [`MPS`](@ref) tensor network into the mixed-canonical form, that is, +for `i < orthog_center` the tensors are left-canonical and for `i >= orthog_center` the tensors are right-canonical, +and in the `orthog_center` there is a matrix with singular values. +""" +function mixed_canonize!(tn::MPS, orthog_center) + # left-to-right QR sweep (left-canonical tensors) + for i in 1:(id(orthog_center) - 1) + canonize_site!(tn, Site(i); direction=:right, method=:qr) + end + + # right-to-left QR sweep (right-canonical tensors) + for i in nsites(tn):-1:(id(orthog_center) + 1) + canonize_site!(tn, Site(i); direction=:left, method=:qr) + end + + # center SVD sweep to get singular values + # canonize_site!(tn, orthog_center; direction=:left, method=:svd) + + return tn +end + +# TODO normalize! methods +function LinearAlgebra.normalize!(ψ::MPS, orthog_center=site"1") + mixed_canonize!(ψ, orthog_center) + normalize!(tensors(ψ; at=orthog_center), 2) + return ψ +end diff --git a/src/Ansatz/PEPO.jl b/src/Ansatz/PEPO.jl new file mode 100644 index 000000000..313a1851c --- /dev/null +++ b/src/Ansatz/PEPO.jl @@ -0,0 +1,102 @@ +abstract type AbstractPEPO <: AbstractAnsatz end + +struct PEPO <: AbstractPEPO + tn::Ansatz + form::Form +end + +Ansatz(tn::PEPO) = tn.tn + +Base.copy(x::PEPO) = PEPO(copy(Ansatz(x)), form(x)) +Base.similar(x::PEPO) = PEPO(similar(Ansatz(x)), form(x)) +Base.zero(x::PEPO) = PEPO(zero(Ansatz(x)), form(x)) + +defaultorder(::Type{PEPO}) = (:o, :i, :l, :r, :u, :d) +boundary(::PEPO) = Open() +form(tn::PEPO) = tn.form + +# TODO periodic boundary conditions +# TODO non-square lattice +function PEPO(arrays::Matrix{<:AbstractArray}; order=defaultorder(PEPO)) + @assert ndims(arrays[1, 1]) == 4 "Array at (1,1) must have 4 dimensions" + @assert ndims(arrays[1, end]) == 4 "Array at (1,end) must have 4 dimensions" + @assert ndims(arrays[end, 1]) == 4 "Array at (end,1) must have 4 dimensions" + @assert ndims(arrays[end, end]) == 4 "Array at (end,end) must have 4 dimensions" + @assert all( + ==(5) ∘ ndims, + Iterators.flatten([ + arrays[1, 2:(end - 1)], arrays[end, 2:(end - 1)], arrays[2:(end - 1), 1], arrays[2:(end - 1), end] + ]), + ) "Arrays at boundaries must have 5 dimensions" + @assert all(==(6) ∘ ndims, arrays[2:(end - 1), 2:(end - 1)]) "Inner arrays must have 6 dimensions" + issetequal(order, defaultorder(PEPO)) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(PEPO)))")) + + m, n = size(arrays) + + predicate = all(eachindex(IndexCartesian(), arrays)) do I + i, j = Tuple(I) + array = arrays[i, j] + + N = ndims(array) - 2 + (i == 1 || i == m) && (N -= 1) + (j == 1 || j == n) && (N -= 1) + + N > 0 + end + + if !predicate + throw(DimensionMismatch()) + end + + gen = IndexCounter() + ipinds = map(_ -> nextindex!(gen), arrays) + opinds = map(_ -> nextindex!(gen), arrays) + vvinds = [nextindex!(gen) for _ in 1:(m - 1), _ in 1:n] + hvinds = [nextindex!(gen) for _ in 1:m, _ in 1:(n - 1)] + + _tensors = map(eachindex(IndexCartesian(), arrays)) do I + i, j = Tuple(I) + + array = arrays[i, j] + ipind = ipinds[i, j] + opind = opinds[i, j] + up = i == 1 ? missing : vvinds[i - 1, j] + down = i == m ? missing : vvinds[i, j] + left = j == 1 ? missing : hvinds[i, j - 1] + right = j == n ? missing : hvinds[i, j] + + # TODO customize order + Tensor(array, collect(skipmissing([ipind, opind, up, down, left, right]))) + end + + sitemap = Dict( + flatten([ + (Site(i, j; dual=true) => ipinds[i, j] for i in 1:m, j in 1:n), + (Site(i, j) => opinds[i, j] for i in 1:m, j in 1:n), + ]), + ) + + qtn = Quatum(tn, sitemap) + graph = grid((m, n)) + # TODO fix this + lattice = MetaGraph(graph, Site.(vertices(graph)) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + ansatz = Ansatz(qtn, lattice) + return PEPO(ansatz, NonCanonical()) +end + +function Base.convert(::Type{PEPO}, tn::Product) + @assert socket(tn) == State() + + # TODO fix this + arrs::Matrix{<:AbstractArray} = arrays(tn) + arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) + arrs[end] = reshape(arrs[end], size(arrs[end])..., 1) + map!(@view(arrs[2:(end - 1)]), @view(arrs[2:(end - 1)])) do arr + reshape(arr, size(arr)..., 1, 1) + end + + return PEPO(arrs) +end + +Base.adjoint(tn::PEPO) = PEPO(adjoint(Ansatz(tn)), form(tn)) diff --git a/src/Ansatz/PEPS.jl b/src/Ansatz/PEPS.jl new file mode 100644 index 000000000..ef262b1d3 --- /dev/null +++ b/src/Ansatz/PEPS.jl @@ -0,0 +1,96 @@ +abstract type AbstractPEPS <: AbstractAnsatz end + +struct PEPS <: AbstractPEPS + tn::Ansatz + form::Form +end + +Ansatz(tn::PEPS) = tn.tn + +Base.copy(x::PEPS) = PEPS(copy(Ansatz(x)), form(x)) +Base.similar(x::PEPS) = PEPS(similar(Ansatz(x)), form(x)) +Base.zero(x::PEPS) = PEPS(zero(Ansatz(x)), form(x)) + +defaultorder(::Type{PEPS}) = (:o, :l, :r, :u, :d) +boundary(::PEPS) = Open() +form(tn::PEPS) = tn.form + +# TODO periodic boundary conditions +# TODO non-square lattice +function PEPS(arrays::Matrix{<:AbstractArray}; order=defaultorder(PEPS)) + @assert ndims(arrays[1, 1]) == 3 "Array at (1,1) must have 3 dimensions" + @assert ndims(arrays[1, end]) == 3 "Array at (1,end) must have 3 dimensions" + @assert ndims(arrays[end, 1]) == 3 "Array at (end,1) must have 3 dimensions" + @assert ndims(arrays[end, end]) == 3 "Array at (end,end) must have 3 dimensions" + @assert all( + ==(4) ∘ ndims, + Iterators.flatten([ + arrays[1, 2:(end - 1)], arrays[end, 2:(end - 1)], arrays[2:(end - 1), 1], arrays[2:(end - 1), end] + ]), + ) "Arrays at boundaries must have 4 dimensions" + @assert all(==(5) ∘ ndims, arrays[2:(end - 1), 2:(end - 1)]) "Inner arrays must have 5 dimensions" + issetequal(order, defaultorder(PEPS)) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(PEPS)))")) + + m, n = size(arrays) + + # predicate = all(eachindex(IndexCartesian(), arrays)) do I + # i, j = Tuple(I) + # array = arrays[i, j] + + # N = ndims(array) - 1 + # (i == 1 || i == m) && (N -= 1) + # (j == 1 || j == n) && (N -= 1) + + # N > 0 + # end + + # if !predicate + # throw(DimensionMismatch()) + # end + + gen = IndexCounter() + pinds = map(_ -> nextindex!(gen), arrays) + vvinds = [nextindex!(gen) for _ in 1:(m - 1), _ in 1:n] + hvinds = [nextindex!(gen) for _ in 1:m, _ in 1:(n - 1)] + + tn = TensorNetwork( + map(eachindex(IndexCartesian(), arrays)) do I + i, j = Tuple(I) + + array = arrays[i, j] + pind = pinds[i, j] + up = i == 1 ? missing : vvinds[i - 1, j] + down = i == m ? missing : vvinds[i, j] + left = j == 1 ? missing : hvinds[i, j - 1] + right = j == n ? missing : hvinds[i, j] + + # TODO customize order + Tensor(array, collect(skipmissing([pind, up, down, left, right]))) + end, + ) + + sitemap = Dict(Site(i, j) => pinds[i, j] for i in 1:m, j in 1:n) + qtn = Quatum(tn, sitemap) + graph = grid((m, n)) + # TODO fix this + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + ansatz = Ansatz(qtn, lattice) + return PEPS(ansatz, NonCanonical()) +end + +function Base.convert(::Type{PEPS}, tn::Product) + @assert socket(tn) == State() + + # TODO fix this + arrs::Matrix{<:AbstractArray} = arrays(tn) + arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) + arrs[end] = reshape(arrs[end], size(arrs[end])..., 1) + map!(@view(arrs[2:(end - 1)]), @view(arrs[2:(end - 1)])) do arr + reshape(arr, size(arr)..., 1, 1) + end + + return PEPS(arrs) +end + +Base.adjoint(tn::PEPS) = PEPS(adjoint(Ansatz(tn)), form(tn)) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index ab77d3be6..6412006d6 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -1,35 +1,34 @@ using LinearAlgebra +using Graphs +using MetaGraphsNext -struct Product <: Ansatz - super::Quantum +struct Product <: AbstractAnsatz + tn::Ansatz end -Base.copy(x::Product) = Product(copy(Quantum(x))) +Ansatz(tn::Product) = tn.tn -Base.similar(x::Product) = Product(similar(Quantum(x))) -Base.zero(x::Product) = Product(zero(Quantum(x))) +Base.copy(x::Product) = Product(copy(Ansatz(x))) +Base.similar(x::Product) = Product(similar(Ansatz(x))) +Base.zero(x::Product) = Product(zero(Ansatz(x))) -function Product(tn::TensorNetwork, sites) - @assert isempty(inds(tn; set=:inner)) "Product ansatz must not have inner indices" - return Product(Quantum(tn, sites)) -end - -Product(arrays::Vector{<:AbstractVector}) = Product(State(), Open(), arrays) -Product(arrays::Vector{<:AbstractMatrix}) = Product(Operator(), Open(), arrays) - -function Product(::State, ::Open, arrays) +function Product(arrays::Vector{<:AbstractVector}) + n = length(arrays) gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:length(arrays)] + symbols = [nextindex!(gen) for _ in 1:n] _tensors = map(enumerate(arrays)) do (i, array) Tensor(array, [symbols[i]]) end - sitemap = Dict(Site(i) => symbols[i] for i in 1:length(arrays)) - - return Product(TensorNetwork(_tensors), sitemap) + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + qtn = Quantum(TensorNetwork(_tensors), sitemap) + graph = Graph(n) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + ansatz = Ansatz(qtn, lattice) + return Product(ansatz) end -function Product(::Operator, ::Open, arrays) +function Product(arrays::Vector{<:AbstractMatrix}) n = length(arrays) gen = IndexCounter() symbols = [nextindex!(gen) for _ in 1:(2 * length(arrays))] @@ -38,18 +37,19 @@ function Product(::Operator, ::Open, arrays) end sitemap = merge!(Dict(Site(i; dual=true) => symbols[i] for i in 1:n), Dict(Site(i) => symbols[i + n] for i in 1:n)) - - return Product(TensorNetwork(_tensors), sitemap) + qtn = Quantum(TensorNetwork(_tensors), sitemap) + graph = Graph(n) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + ansatz = Ansatz(qtn, lattice) + return Product(ansatz) end function Base.zeros(::Type{Product}, n::Integer; p::Int=2, eltype=Bool) - return Product(State(), Open(), fill(append!([one(eltype)], collect(Iterators.repeated(zero(eltype), p - 1))), n)) + return Product(fill(append!([one(eltype)], collect(Iterators.repeated(zero(eltype), p - 1))), n)) end function Base.ones(::Type{Product}, n::Integer; p::Int=2, eltype=Bool) - return Product( - State(), Open(), fill(append!([zero(eltype), one(eltype)], collect(Iterators.repeated(zero(eltype), p - 2))), n) - ) + return Product(fill(append!([zero(eltype), one(eltype)], collect(Iterators.repeated(zero(eltype), p - 2))), n)) end LinearAlgebra.norm(tn::Product, p::Real=2) = LinearAlgebra.norm(socket(tn), tn, p) diff --git a/src/Numerics.jl b/src/Numerics.jl index 0e9a7c02c..7cdd81a2a 100644 --- a/src/Numerics.jl +++ b/src/Numerics.jl @@ -119,6 +119,73 @@ function factorinds(tensor, left_inds, right_inds) return left_inds, right_inds end +# TODO is this an `AbstractTensorNetwork`? +# TODO add fancier `show` method +struct TensorEigen{T,V,Nᵣ,S<:AbstractVector{V},U<:AbstractArray{T,Nᵣ}} <: Factorization{T} + values::Tensor{V,1,S} + vectors::Tensor{T,Nᵣ,U} + right_inds::Vector{Symbol} +end + +function Base.getproperty(obj::TensorEigen, name::Symbol) + if name === :U + return obj.vectors + elseif name === :Λ + return obj.values + elseif name ∈ [:Uinv, :U⁻¹] + U = reshape(parent(obj.vectors), prod(size(obj.vectors)[1:(end - 1)]), size(obj.vectors)[end]) + Uinv = inv(U) + return Tensor(Uinv, [only(inds(obj.values)), obj.right_inds...]) + end + return getfield(obj, name) +end + +function Base.inv(F::TensorEigen) + U = reshape(parent(F.vectors), prod(size(F.vectors)[1:(end - 1)]), size(F.vectors)[end]) + left_inds = inds(F.vectors)[1:(end - 1)] + return Tensor(U * inv(Diagonal(F.values)) / U, [left_inds..., F.right_inds...]) +end +LinearAlgebra.det(x::TensorEigen) = prod(x.values) + +Base.iterate(x::TensorEigen) = (x.values, :vectors) +Base.iterate(x::TensorEigen, state) = state == :vectors ? (x.vectors, nothing) : nothing + +LinearAlgebra.eigen(t::Tensor{<:Any,2}; kwargs...) = @invoke eigen(t::Tensor; left_inds=(first(inds(t)),), kwargs...) +function LinearAlgebra.eigen(tensor::Tensor; left_inds=(), right_inds=(), virtualind=Symbol(uuid4()), kwargs...) + left_inds, right_inds = factorinds(tensor, left_inds, right_inds) + + virtualind ∉ inds(tensor) || + throw(ArgumentError("new virtual bond name ($virtualind) cannot be already be present")) + + # permute array + left_sizes = map(Base.Fix1(size, tensor), left_inds) + right_sizes = map(Base.Fix1(size, tensor), right_inds) + tensor = permutedims(tensor, [left_inds..., right_inds...]) + data = reshape(parent(tensor), prod(left_sizes), prod(right_sizes)) + + # compute eigendecomposition + Λ, U = eigen(data; kwargs...) + + # tensorify results + Λ = Tensor(Λ, [virtualind]) + U = Tensor(reshape(U, left_sizes..., size(U, 2)), [left_inds..., virtualind]) + + return TensorEigen(Λ, U, right_inds) +end + +# TODO document when it returns a `Tensor` and when returns an `Array` +LinearAlgebra.eigvals(t::Tensor{<:Any,2}; kwargs...) = eigvals(parent(t); kwargs...) +function LinearAlgebra.eigvals(tensor::Tensor; left_inds=(), right_inds=(), kwargs...) + F = eigen(tensor; left_inds, right_inds, kwargs...) + return parent(F.values) +end + +LinearAlgebra.eigvecs(t::Tensor{<:Any,2}; kwargs...) = eigvecs(parent(t); kwargs...) +function LinearAlgebra.eigvecs(tensor::Tensor; left_inds=(), right_inds=(), kwargs...) + F = eigen(tensor; left_inds, right_inds, kwargs...) + return F.vectors +end + LinearAlgebra.svd(t::Tensor{<:Any,2}; kwargs...) = Base.@invoke svd(t::Tensor; left_inds=(first(inds(t)),), kwargs...) """ diff --git a/src/Quantum.jl b/src/Quantum.jl index c5298ef64..f3f459d65 100644 --- a/src/Quantum.jl +++ b/src/Quantum.jl @@ -291,6 +291,9 @@ function rmsite!(tn::AbstractQuantum, site) return delete!(tn.sites, site) end +hassite(tn::AbstractQuantum, site) = haskey(Quantum(tn).sites, site) +Base.in(site::Site, tn::AbstractQuantum) = hassite(tn, site) + @kwmethod function sites(tn::AbstractQuantum; set) tn = Quantum(tn) if set === :all diff --git a/src/Tenet.jl b/src/Tenet.jl index 46e13c920..7effdbf68 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -26,6 +26,7 @@ include("Ansatz/Ansatz.jl") export Ansatz export socket, Scalar, State, Operator export boundary, Open, Periodic +export form include("Ansatz/Product.jl") export Product @@ -33,18 +34,22 @@ export Product include("Ansatz/Dense.jl") export Dense -include("Ansatz/Chain.jl") -export Chain -export MPS, pMPS, MPO, pMPO -export leftindex, rightindex, isleftcanonical, isrightcanonical -export canonize_site, canonize_site!, truncate! -export canonize, canonize!, mixed_canonize, mixed_canonize! +include("Ansatz/MPS.jl") +export MPS -include("Ansatz/Grid.jl") -export Grid -export PEPS, pPEPS, PEPO, pPEPO +include("Ansatz/MPO.jl") +export MPO -export evolve!, expect, overlap +include("Ansatz/PEPS.jl") +export PEPS + +include("Ansatz/PEPO.jl") +export PEPO + +# `truncate` not exported because it clashes with `Base.truncate` +export canonize_site, canonize_site!, canonize, canonize!, mixed_canonize, mixed_canonize! +export isisometry, isleftcanonical, isrightcanonical +export evolve!, expect, overlap, truncate! # reexports from EinExprs export einexpr, inds diff --git a/src/Tensor.jl b/src/Tensor.jl index bc45fcef4..45b191520 100644 --- a/src/Tensor.jl +++ b/src/Tensor.jl @@ -243,3 +243,7 @@ function __expand_repeat(array, axis, size) end LinearAlgebra.opnorm(x::Tensor, p::Real) = opnorm(parent(x), p) + +LinearAlgebra.det(x::Tensor{T,2}) where {T} = det(parent(x)) +LinearAlgebra.logdet(x::Tensor{T,2}) where {T} = logdet(parent(x)) +LinearAlgebra.tr(x::Tensor{T,2}) where {T} = tr(parent(x)) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index fbdddc81e..a7f47110f 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -6,6 +6,7 @@ using LinearAlgebra using ScopedValues using Serialization using KeywordDispatch +using Graphs mutable struct CachedField{T} isvalid::Bool @@ -181,7 +182,8 @@ end return tensors(!isdisjoint, TensorNetwork(tn), intersects) end -function tensors(selector, tn::TensorNetwork, is::AbstractVecOrTuple{Symbol}) +function tensors(selector, tn::AbstractTensorNetwork, is::AbstractVecOrTuple{Symbol}) + tn = TensorNetwork(tn) return filter(Base.Fix1(selector, is) ∘ inds, tn.indexmap[first(is)]) end @@ -475,7 +477,7 @@ Base.merge!(self::TensorNetwork, other::TensorNetwork) = append!(self, tensors(o Base.merge!(self::TensorNetwork, others::TensorNetwork...) = foldl(merge!, others; init=self) Base.merge(self::AbstractTensorNetwork, others::AbstractTensorNetwork...) = merge!(copy(self), others...) -function neighbors(tn::AbstractTensorNetwork, tensor::Tensor; open::Bool=true) +function Graphs.neighbors(tn::AbstractTensorNetwork, tensor::Tensor; open::Bool=true) @assert tensor ∈ tn "Tensor not found in TensorNetwork" tensors = mapreduce(∪, inds(tensor)) do index Tenet.tensors(tn; intersects=index) @@ -484,7 +486,7 @@ function neighbors(tn::AbstractTensorNetwork, tensor::Tensor; open::Bool=true) return tensors end -function neighbors(tn::AbstractTensorNetwork, i::Symbol; open::Bool=true) +function Graphs.neighbors(tn::AbstractTensorNetwork, i::Symbol; open::Bool=true) @assert i ∈ tn "Index $i not found in TensorNetwork" tensors = mapreduce(inds, ∪, Tenet.tensors(tn; intersects=i)) # open && filter!(x -> x !== i, tensors) @@ -641,6 +643,13 @@ contract(t::Tensor, tn::AbstractTensorNetwork; kwargs...) = contract(tn, t; kwar return contract(intermediates...; dims=suminds(path)) end +function LinearAlgebra.eigen!(tn::AbstractTensorNetwork; left_inds=Symbol[], right_inds=Symbol[], kwargs...) + tensor = tn[left_inds ∪ right_inds...] + (; U, Λ, U⁻¹) = eigen(tensor; left_inds, right_inds, kwargs...) + replace!(tn, tensor => TensorNetwork([U, Λ, U⁻¹])) + return tn +end + function LinearAlgebra.svd!(tn::AbstractTensorNetwork; left_inds=Symbol[], right_inds=Symbol[], kwargs...) tensor = tn[left_inds ∪ right_inds...] U, s, Vt = svd(tensor; left_inds, right_inds, kwargs...) @@ -726,6 +735,10 @@ function Base.rand(::Type{TensorNetwork}, n::Integer, regularity::Integer; kwarg return rand(Random.default_rng(), TensorNetwork, n, regularity; kwargs...) end +function Base.rand(::Type{T}, args...; kwargs...) where {T<:AbstractTensorNetwork} + return rand(Random.default_rng(), T, args...; kwargs...) +end + function Serialization.serialize(s::AbstractSerializer, obj::TensorNetwork) Serialization.writetag(s.io, Serialization.OBJECT_TAG) return serialize(s, tensors(obj)) diff --git a/test/Ansatz_test.jl b/test/Ansatz_test.jl new file mode 100644 index 000000000..7c9218851 --- /dev/null +++ b/test/Ansatz_test.jl @@ -0,0 +1 @@ +@testset "Ansatz" begin end diff --git a/test/Chain_test.jl b/test/Chain_test.jl deleted file mode 100644 index d2b946896..000000000 --- a/test/Chain_test.jl +++ /dev/null @@ -1,436 +0,0 @@ -@testset "Chain ansatz" begin - @testset "Periodic boundary" begin - @testset "State" begin - qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) - @test socket(qtn) == State() - @test nsites(qtn; set=:inputs) == 0 - @test nsites(qtn; set=:outputs) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3"]) - @test boundary(qtn) == Periodic() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing - - arrays = [rand(2, 1, 4), rand(2, 4, 3), rand(2, 3, 1)] - qtn = Chain(State(), Periodic(), arrays) # Default order (:o, :l, :r) - - @test size(tensors(qtn; at=Site(1))) == (2, 1, 4) - @test size(tensors(qtn; at=Site(2))) == (2, 4, 3) - @test size(tensors(qtn; at=Site(3))) == (2, 3, 1) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - - arrays = [permutedims(array, (3, 1, 2)) for array in arrays] # now we have (:r, :o, :l) - qtn = Chain(State(), Periodic(), arrays; order=[:r, :o, :l]) - - @test size(tensors(qtn; at=Site(1))) == (4, 2, 1) - @test size(tensors(qtn; at=Site(2))) == (3, 2, 4) - @test size(tensors(qtn; at=Site(3))) == (1, 2, 3) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - - for i in 1:nsites(qtn) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - end - end - - @testset "Operator" begin - qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3]) - @test socket(qtn) == Operator() - @test nsites(qtn; set=:inputs) == 3 - @test nsites(qtn; set=:outputs) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) - @test boundary(qtn) == Periodic() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing - - arrays = [rand(2, 4, 1, 3), rand(2, 4, 3, 6), rand(2, 4, 6, 1)] # Default order (:o, :i, :l, :r) - qtn = Chain(Operator(), Periodic(), arrays) - - @test size(tensors(qtn; at=Site(1))) == (2, 4, 1, 3) - @test size(tensors(qtn; at=Site(2))) == (2, 4, 3, 6) - @test size(tensors(qtn; at=Site(3))) == (2, 4, 6, 1) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - - for i in 1:length(arrays) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - @test size(qtn, inds(qtn; at=Site(i; dual=true))) == 4 - end - - arrays = [permutedims(array, (4, 1, 3, 2)) for array in arrays] # now we have (:r, :o, :l, :i) - qtn = Chain(Operator(), Periodic(), arrays; order=[:r, :o, :l, :i]) - - @test size(tensors(qtn; at=Site(1))) == (3, 2, 1, 4) - @test size(tensors(qtn; at=Site(2))) == (6, 2, 3, 4) - @test size(tensors(qtn; at=Site(3))) == (1, 2, 6, 4) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) !== nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing - - for i in 1:length(arrays) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - @test size(qtn, inds(qtn; at=Site(i; dual=true))) == 4 - end - end - end - - @testset "Open boundary" begin - @testset "State" begin - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - @test socket(qtn) == State() - @test nsites(qtn; set=:inputs) == 0 - @test nsites(qtn; set=:outputs) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3"]) - @test boundary(qtn) == Open() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing - - arrays = [rand(2, 1), rand(2, 1, 3), rand(2, 3)] - qtn = Chain(State(), Open(), arrays) # Default order (:o, :l, :r) - - @test size(tensors(qtn; at=Site(1))) == (2, 1) - @test size(tensors(qtn; at=Site(2))) == (2, 1, 3) - @test size(tensors(qtn; at=Site(3))) == (2, 3) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - - arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l) - qtn = Chain(State(), Open(), arrays; order=[:r, :o, :l]) - - @test size(tensors(qtn; at=Site(1))) == (1, 2) - @test size(tensors(qtn; at=Site(2))) == (3, 2, 1) - @test size(tensors(qtn; at=Site(3))) == (2, 3) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing - - for i in 1:nsites(qtn) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - end - end - @testset "Operator" begin - qtn = Chain(Operator(), Open(), [rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) - @test socket(qtn) == Operator() - @test nsites(qtn; set=:inputs) == 3 - @test nsites(qtn; set=:outputs) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) - @test boundary(qtn) == Open() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing - - arrays = [rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)] # Default order (:o :i, :l, :r) - qtn = Chain(Operator(), Open(), arrays) - - @test size(tensors(qtn; at=Site(1))) == (2, 4, 1) - @test size(tensors(qtn; at=Site(2))) == (2, 4, 1, 3) - @test size(tensors(qtn; at=Site(3))) == (2, 4, 3) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing - - for i in 1:length(arrays) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - @test size(qtn, inds(qtn; at=Site(i; dual=true))) == 4 - end - - arrays = [ - permutedims(arrays[1], (3, 1, 2)), - permutedims(arrays[2], (4, 1, 3, 2)), - permutedims(arrays[3], (1, 3, 2)), - ] # now we have (:r, :o, :l, :i) - qtn = Chain(Operator(), Open(), arrays; order=[:r, :o, :l, :i]) - - @test size(tensors(qtn; at=Site(1))) == (1, 2, 4) - @test size(tensors(qtn; at=Site(2))) == (3, 2, 1, 4) - @test size(tensors(qtn; at=Site(3))) == (2, 3, 4) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing - - for i in 1:length(arrays) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - @test size(qtn, inds(qtn; at=Site(i; dual=true))) == 4 - end - end - end - - @testset "Site" begin - using Tenet: leftsite, rightsite - qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) - - @test leftsite(qtn, Site(1)) == Site(3) - @test leftsite(qtn, Site(2)) == Site(1) - @test leftsite(qtn, Site(3)) == Site(2) - - @test rightsite(qtn, Site(1)) == Site(2) - @test rightsite(qtn, Site(2)) == Site(3) - @test rightsite(qtn, Site(3)) == Site(1) - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - - @test isnothing(leftsite(qtn, Site(1))) - @test isnothing(rightsite(qtn, Site(3))) - - @test leftsite(qtn, Site(2)) == Site(1) - @test leftsite(qtn, Site(3)) == Site(2) - - @test rightsite(qtn, Site(2)) == Site(3) - @test rightsite(qtn, Site(1)) == Site(2) - end - - @testset "truncate" begin - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - canonize_site!(qtn, Site(2); direction=:right, method=:svd) - - @test_throws Tenet.MissingSchmidtCoefficientsException truncate!(qtn, [Site(1), Site(2)]; maxdim=1) - # @test_throws ArgumentError truncate!(qtn, [Site(2), Site(3)]) - - truncated = Tenet.truncate(qtn, [Site(2), Site(3)]; maxdim=1) - @test size(truncated, rightindex(truncated, Site(2))) == 1 - @test size(truncated, leftindex(truncated, Site(3))) == 1 - - singular_values = tensors(qtn; between=(Site(2), Site(3))) - truncated = Tenet.truncate(qtn, [Site(2), Site(3)]; threshold=singular_values[2] + 0.1) - @test size(truncated, rightindex(truncated, Site(2))) == 1 - @test size(truncated, leftindex(truncated, Site(3))) == 1 - end - - @testset "rand" begin - using LinearAlgebra: norm - - @testset "State" begin - n = 8 - χ = 10 - - qtn = rand(Chain, Open, State; n, p=2, χ) - @test socket(qtn) == State() - @test nsites(qtn; set=:inputs) == 0 - @test nsites(qtn; set=:outputs) == n - @test issetequal(sites(qtn), map(Site, 1:n)) - @test boundary(qtn) == Open() - @test isapprox(norm(qtn), 1.0) - @test maximum(last, size(qtn)) <= χ - end - - @testset "Operator" begin - n = 8 - χ = 10 - - qtn = rand(Chain, Open, Operator; n, p=2, χ) - @test socket(qtn) == Operator() - @test nsites(qtn; set=:inputs) == n - @test nsites(qtn; set=:outputs) == n - @test issetequal(sites(qtn), vcat(map(Site, 1:n), map(adjoint ∘ Site, 1:n))) - @test boundary(qtn) == Open() - @test isapprox(norm(qtn), 1.0) - @test maximum(last, size(qtn)) <= χ - end - end - - @testset "Canonization" begin - using Tenet - - @testset "contract" begin - qtn = rand(Chain, Open, State; n=5, p=2, χ=20) - let canonized = canonize(qtn) - @test_throws ArgumentError contract!(canonized; between=(Site(1), Site(2)), direction=:dummy) - end - - canonized = canonize(qtn) - - for i in 1:4 - contract_some = contract(canonized; between=(Site(i), Site(i + 1))) - Bᵢ = tensors(contract_some; at=Site(i)) - - @test isapprox(contract(contract_some), contract(qtn)) - @test_throws ArgumentError tensors(contract_some; between=(Site(i), Site(i + 1))) - - @test isrightcanonical(contract_some, Site(i)) - @test isleftcanonical( - contract(canonized; between=(Site(i), Site(i + 1)), direction=:right), Site(i + 1) - ) - - Γᵢ = tensors(canonized; at=Site(i)) - Λᵢ₊₁ = tensors(canonized; between=(Site(i), Site(i + 1))) - @test Bᵢ ≈ contract(Γᵢ, Λᵢ₊₁; dims=()) - end - end - - @testset "canonize_site" begin - qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)]) - - @test_throws ArgumentError canonize_site!(qtn, Site(1); direction=:left) - @test_throws ArgumentError canonize_site!(qtn, Site(3); direction=:right) - - for method in [:qr, :svd] - canonized = canonize_site(qtn, site"1"; direction=:right, method=method) - @test isleftcanonical(canonized, site"1") - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - - canonized = canonize_site(qtn, site"2"; direction=:right, method=method) - @test isleftcanonical(canonized, site"2") - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - - canonized = canonize_site(qtn, site"2"; direction=:left, method=method) - @test isrightcanonical(canonized, site"2") - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - - canonized = canonize_site(qtn, site"3"; direction=:left, method=method) - @test isrightcanonical(canonized, site"3") - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - end - - # Ensure that svd creates a new tensor - @test length(tensors(canonize_site(qtn, Site(2); direction=:left, method=:svd))) == 4 - end - - @testset "canonize" begin - using Tenet: isleftcanonical, isrightcanonical - - qtn = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) - canonized = canonize(qtn) - - @test length(tensors(canonized)) == 9 # 5 tensors + 4 singular values vectors - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - @test isapprox(norm(qtn), norm(canonized)) - - # Extract the singular values between each adjacent pair of sites in the canonized chain - Λ = [tensors(canonized; between=(Site(i), Site(i + 1))) for i in 1:4] - @test map(λ -> sum(abs2, λ), Λ) ≈ ones(length(Λ)) * norm(canonized)^2 - - for i in 1:5 - canonized = canonize(qtn) - - if i == 1 - @test isleftcanonical(canonized, Site(i)) - elseif i == 5 # in the limits of the chain, we get the norm of the state - contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) - tensor = tensors(canonized; at=Site(i)) - replace!(canonized, tensor => tensor / norm(canonized)) - @test isleftcanonical(canonized, Site(i)) - else - contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) - @test isleftcanonical(canonized, Site(i)) - end - end - - for i in 1:5 - canonized = canonize(qtn) - - if i == 1 # in the limits of the chain, we get the norm of the state - contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) - tensor = tensors(canonized; at=Site(i)) - replace!(canonized, tensor => tensor / norm(canonized)) - @test isrightcanonical(canonized, Site(i)) - elseif i == 5 - @test isrightcanonical(canonized, Site(i)) - else - contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) - @test isrightcanonical(canonized, Site(i)) - end - end - end - - @testset "mixed_canonize" begin - qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) - canonized = mixed_canonize(qtn, Site(3)) - - @test length(tensors(canonized)) == length(tensors(qtn)) + 1 - - @test isleftcanonical(canonized, Site(1)) - @test isleftcanonical(canonized, Site(2)) - @test isrightcanonical(canonized, Site(3)) - @test isrightcanonical(canonized, Site(4)) - @test isrightcanonical(canonized, Site(5)) - - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - end - end - - @test begin - qtn = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) - normalize!(qtn, Site(3)) - isapprox(norm(qtn), 1.0) - end - - @testset "adjoint" begin - qtn = rand(Chain, Open, State; n=5, p=2, χ=10) - adjoint_qtn = adjoint(qtn) - - for i in 1:nsites(qtn) - i < nsites(qtn) && - @test rightindex(adjoint_qtn, Site(i; dual=true)) == Symbol(String(rightindex(qtn, Site(i))) * "'") - i > 1 && @test leftindex(adjoint_qtn, Site(i; dual=true)) == Symbol(String(leftindex(qtn, Site(i))) * "'") - end - - @test isapprox(contract(qtn), contract(adjoint_qtn)) - end - - @testset "evolve!" begin - @testset "one site" begin - i = 2 - mat = reshape(LinearAlgebra.I(2), 2, 2) - gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(i; dual=true)]) - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) - - @testset "canonical form" begin - canonized = canonize(qtn) - - evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14) - @test isapprox(contract(evolved), contract(canonized)) - @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) - @test isapprox(contract(evolved), contract(qtn)) - end - - @testset "arbitrary chain" begin - evolved = evolve!(deepcopy(qtn), gate; threshold=1e-14, iscanonical=false) - @test length(tensors(evolved)) == 5 - @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2)]) - @test isapprox(contract(evolved), contract(qtn)) - end - end - - @testset "two sites" begin - i, j = 2, 3 - mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) - gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) - - @testset "canonical form" begin - canonized = canonize(qtn) - - evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14) - @test isapprox(contract(evolved), contract(canonized)) - @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) - @test isapprox(contract(evolved), contract(qtn)) - end - - @testset "arbitrary chain" begin - evolved = evolve!(deepcopy(qtn), gate; threshold=1e-14, iscanonical=false) - @test length(tensors(evolved)) == 5 - @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2,), (2, 2, 2), (2, 2, 2), (2, 2)]) - @test isapprox(contract(evolved), contract(qtn)) - end - end - end - - @testset "expect" begin - i, j = 2, 3 - mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) - gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) - - @test isapprox(expect(qtn, [gate]), norm(qtn)^2) - end -end diff --git a/test/MPO_test.jl b/test/MPO_test.jl new file mode 100644 index 000000000..1f7332aca --- /dev/null +++ b/test/MPO_test.jl @@ -0,0 +1,77 @@ +@testset "MPO" begin + H = MPO([rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) + @test socket(H) == Operator() + @test nsites(H; set=:inputs) == 3 + @test nsites(H; set=:outputs) == 3 + @test issetequal(sites(H), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) + @test boundary(H) == Open() + @test inds(H; at=site"1", dir=:left) == inds(H; at=site"3", dir=:right) == nothing + + # Default order (:o :i, :l, :r) + H = MPO([rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)]) + + @test size(tensors(H; at=site"1")) == (2, 4, 1) + @test size(tensors(H; at=site"2")) == (2, 4, 1, 3) + @test size(tensors(H; at=site"3")) == (2, 4, 3) + + @test inds(H; at=site"1", dir=:left) == inds(H; at=site"3", dir=:right) === nothing + @test inds(H; at=site"2", dir=:left) == inds(H; at=site"1", dir=:right) !== nothing + @test inds(H; at=site"3", dir=:left) == inds(H; at=site"2", dir=:right) !== nothing + + for i in 1:length(arrays) + @test size(H, inds(H; at=Site(i))) == 2 + @test size(H, inds(H; at=Site(i; dual=true))) == 4 + end + + # now we have (:r, :o, :l, :i) + H = MPO( + [ + permutedims(arrays(H)[1], (3, 1, 2)), + permutedims(arrays(H)[2], (4, 1, 3, 2)), + permutedims(arrays(H)[3], (1, 3, 2)), + ]; + order=[:r, :o, :l, :i], + ) + + @test size(tensors(H; at=site"1")) == (1, 2, 4) + @test size(tensors(H; at=site"2")) == (3, 2, 1, 4) + @test size(tensors(H; at=site"3")) == (2, 3, 4) + + @test inds(H; at=site"1", dir=:left) == inds(H; at=site"3", dir=:right) === nothing + @test inds(H; at=site"2", dir=:left) == inds(H; at=site"1", dir=:right) !== nothing + @test inds(H; at=site"3", dir=:left) == inds(H; at=site"2", dir=:right) !== nothing + + for i in 1:length(arrays) + @test size(H, inds(H; at=Site(i))) == 2 + @test size(H, inds(H; at=Site(i; dual=true))) == 4 + end + + @testset "Site" begin + H = MPO([rand(2, 2, 2), rand(2, 2, 2, 2), rand(2, 2, 2)]) + + @test isnothing(sites(H, site"1"; dir=:left)) + @test isnothing(sites(H, site"3"; dir=:right)) + + @test sites(H, site"2"; dir=:left) == site"1" + @test sites(H, site"3"; dir=:left) == site"2" + + @test sites(H, site"2"; dir=:right) == site"3" + @test sites(H, site"1"; dir=:right) == site"2" + end + + @testset "norm" begin + using LinearAlgebra: norm + + n = 8 + χ = 10 + H = rand(MPO; n, maxdim=χ) + + @test socket(H) == Operator() + @test nsites(H; set=:inputs) == n + @test nsites(H; set=:outputs) == n + @test issetequal(sites(H), vcat(map(Site, 1:n), map(adjoint ∘ Site, 1:n))) + @test boundary(H) == Open() + @test isapprox(norm(H), 1.0) + @test maximum(last, size(H)) <= χ + end +end diff --git a/test/MPS_test.jl b/test/MPS_test.jl new file mode 100644 index 000000000..9ac2f7350 --- /dev/null +++ b/test/MPS_test.jl @@ -0,0 +1,299 @@ +@testset "MPS" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + @test socket(ψ) == State() + @test nsites(ψ; set=:inputs) == 0 + @test nsites(ψ; set=:outputs) == 3 + @test issetequal(sites(ψ), [site"1", site"2", site"3"]) + @test boundary(ψ) == Open() + @test inds(ψ; at=site"1", dir=:left) == inds(ψ; at=site"3", dir=:right) == nothing + + arrays = [rand(2, 1), rand(2, 1, 3), rand(2, 3)] + ψ = MPS(arrays) # Default order (:o, :l, :r) + @test size(tensors(ψ; at=site"1")) == (2, 1) + @test size(tensors(ψ; at=site"2")) == (2, 1, 3) + @test size(tensors(ψ; at=site"3")) == (2, 3) + @test inds(ψ; at=site"1", dir=:left) == inds(ψ; at=site"3", dir=:right) === nothing + @test inds(ψ; at=site"2", dir=:left) == inds(ψ; at=site"1", dir=:right) + @test inds(ψ; at=site"3", dir=:left) == inds(ψ; at=site"2", dir=:right) + + arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l) + ψ = MPS(arrays; order=[:r, :o, :l]) + @test size(tensors(ψ; at=site"1")) == (1, 2) + @test size(tensors(ψ; at=site"2")) == (3, 2, 1) + @test size(tensors(ψ; at=site"3")) == (2, 3) + @test inds(ψ; at=site"1", dir=:left) == inds(ψ; at=site"3", dir=:right) === nothing + @test inds(ψ; at=site"2", dir=:left) == inds(ψ; at=site"1", dir=:right) !== nothing + @test inds(ψ; at=site"3", dir=:left) == inds(ψ; at=site"2", dir=:right) !== nothing + @test all(i -> size(ψ, inds(ψ; at=Site(i))) == 2, 1:nsites(ψ)) + + @testset "Base.identity" begin + nsites_cases = [6, 7, 6, 7] + physdim_cases = [3, 2, 3, 2] + maxdim_cases = [nothing, nothing, 9, 4] # nothing means default + expected_tensorsizes_cases = [ + [(3, 3), (3, 3, 9), (3, 9, 27), (3, 27, 9), (3, 9, 3), (3, 3)], + [(2, 2), (2, 2, 4), (2, 4, 8), (2, 8, 8), (2, 8, 4), (2, 4, 2), (2, 2)], + [(3, 3), (3, 3, 9), (3, 9, 9), (3, 9, 9), (3, 9, 3), (3, 3)], + [(2, 2), (2, 2, 4), (2, 4, 4), (2, 4, 4), (2, 4, 4), (2, 4, 2), (2, 2)], + ] + + for (nsites, physdim, expected_tensorsizes, maxdim) in + zip(nsites_cases, physdim_cases, expected_tensorsizes_cases, maxdim_cases) + ψ = if isnothing(maxdim) + identity(MPS, nsites; physdim=physdim) + else + identity(MPS, nsites; physdim=physdim, maxdim=maxdim) + end + + # Test the tensor dimensions + obtained_tensorsizes = size.(tensors(ψ)) + @test obtained_tensorsizes == expected_tensorsizes + + # Test whether all tensors are the identity + alltns = tensors(ψ) + + # - Test extreme tensors (2D) equal identity + diagonal_2D = [fill(i, 2) for i in 1:physdim] + @test all(delta -> alltns[1][delta...] == 1, diagonal_2D) + @test sum(alltns[1]) == physdim + @test all(delta -> alltns[end][delta...] == 1, diagonal_2D) + @test sum(alltns[end]) == physdim + + # - Test bulk tensors (3D) equal identity + diagonal_3D = [fill(i, 3) for i in 1:physdim] + @test all(tns -> all(delta -> tns[delta...] == 1, diagonal_3D), alltns[2:(end - 1)]) + @test all(tns -> sum(tns) == physdim, alltns[2:(end - 1)]) + + # Test whether the contraction gives the identity + contracted_ψ = contract(ψ) + diagonal_nsitesD = [fill(i, nsites) for i in 1:physdim] + @test all(delta -> contracted_ψ[delta...] == 1, diagonal_nsitesD) + @test sum(contracted_ψ) == physdim + end + end + + @testset "Site" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + + @test isnothing(sites(ψ, site"1"; dir=:left)) + @test isnothing(sites(ψ, site"3"; dir=:right)) + + @test sites(ψ, site"2"; dir=:left) == site"1" + @test sites(ψ, site"3"; dir=:left) == site"2" + + @test sites(ψ, site"2"; dir=:right) == site"3" + @test sites(ψ, site"1"; dir=:right) == site"2" + end + + @testset "adjoint" begin + ψ = rand(MPS; n=3, maxdim=2, eltype=ComplexF64) + @test socket(ψ') == State(; dual=true) + @test isapprox(contract(ψ), conj(contract(ψ'))) + end + + @testset "truncate" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + canonize_site!(ψ, Site(2); direction=:right, method=:svd) + + # @test_throws Tenet.MissingSchmidtCoefficientsException truncate!(ψ, [site"1", site"2"]; maxdim=1) + @test_throws ArgumentError truncate!(ψ, [site"1", site"2"]; maxdim=1) + + truncated = Tenet.truncate(ψ, [site"2", site"3"]; maxdim=1) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 + + singular_values = tensors(ψ; between=(site"2", site"3")) + truncated = Tenet.truncate(ψ, [site"2", site"3"]; threshold=singular_values[2] + 0.1) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 + end + + @testset "norm" begin + using LinearAlgebra: norm + + n = 8 + χ = 10 + ψ = rand(MPS; n, maxdim=χ) + + @test socket(ψ) == State() + @test nsites(ψ; set=:inputs) == 0 + @test nsites(ψ; set=:outputs) == n + @test issetequal(sites(ψ), map(Site, 1:n)) + @test boundary(ψ) == Open() + @test isapprox(norm(ψ), 1.0) + @test maximum(last, size(ψ)) <= χ + end + + @testset "normalize!" begin + using LinearAlgebra: normalize! + + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + normalize!(ψ, Site(3)) + @test isapprox(norm(ψ), 1.0) + end + + @testset "canonize_site!" begin + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4)]) + + @test_throws ArgumentError canonize_site!(ψ, Site(1); direction=:left) + @test_throws ArgumentError canonize_site!(ψ, Site(3); direction=:right) + + for method in [:qr, :svd] + canonized = canonize_site(ψ, site"1"; direction=:right, method=method) + @test isleftcanonical(canonized, site"1") + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + + canonized = canonize_site(ψ, site"2"; direction=:right, method=method) + @test isleftcanonical(canonized, site"2") + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + + canonized = canonize_site(ψ, site"2"; direction=:left, method=method) + @test isrightcanonical(canonized, site"2") + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + + canonized = canonize_site(ψ, site"3"; direction=:left, method=method) + @test isrightcanonical(canonized, site"3") + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + end + + # Ensure that svd creates a new tensor + @test length(tensors(canonize_site(ψ, Site(2); direction=:left, method=:svd))) == 4 + end + + @testset "canonize!" begin + using Tenet: isleftcanonical, isrightcanonical + + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = canonize(ψ) + + @test length(tensors(canonized)) == 9 # 5 tensors + 4 singular values vectors + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + @test isapprox(norm(ψ), norm(canonized)) + + # Extract the singular values between each adjacent pair of sites in the canonized chain + Λ = [tensors(canonized; between=(Site(i), Site(i + 1))) for i in 1:4] + @test map(λ -> sum(abs2, λ), Λ) ≈ ones(length(Λ)) * norm(canonized)^2 + + for i in 1:5 + canonized = canonize(ψ) + + if i == 1 + @test isleftcanonical(canonized, Site(i)) + elseif i == 5 # in the limits of the chain, we get the norm of the state + normalize!(tensors(canonized; bond=(Site(i - 1), Site(i)))) + contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) + @test isleftcanonical(canonized, Site(i)) + else + contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) + @test isleftcanonical(canonized, Site(i)) + end + end + + for i in 1:5 + canonized = canonize(ψ) + + if i == 1 # in the limits of the chain, we get the norm of the state + normalize!(tensors(canonized; bond=(Site(i), Site(i + 1)))) + contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) + @test isrightcanonical(canonized, Site(i)) + elseif i == 5 + @test isrightcanonical(canonized, Site(i)) + else + contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) + @test isrightcanonical(canonized, Site(i)) + end + end + end + + @testset "mixed_canonize!" begin + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = mixed_canonize(ψ, site"3") + + @test length(tensors(canonized)) == length(tensors(ψ)) + 1 + + @test isleftcanonical(canonized, site"1") + @test isleftcanonical(canonized, site"2") + @test isrightcanonical(canonized, site"3") + @test isrightcanonical(canonized, site"4") + @test isrightcanonical(canonized, site"5") + + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + end + + @testset "expect" begin + i, j = 2, 3 + mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) + gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + + @test isapprox(expect(ψ, [gate]), norm(ψ)^2) + end + + @testset "evolve!" begin + @testset "one site" begin + i = 2 + mat = reshape(LinearAlgebra.I(2), 2, 2) + gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(i; dual=true)]) + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + + @testset "canonical form" begin + canonized = canonize(ψ) + evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14) + @test isapprox(contract(evolved), contract(canonized)) + @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) + @test isapprox(contract(evolved), contract(ψ)) + end + + @testset "arbitrary chain" begin + evolved = evolve!(deepcopy(ψ), gate; threshold=1e-14) + @test length(tensors(evolved)) == 5 + @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2)]) + @test isapprox(contract(evolved), contract(ψ)) + end + end + + @testset "two sites" begin + mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) + gate = Dense(Tenet.Operator(), mat; sites=[site"2", site"3", site"2'", site"3'"]) + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + + @testset "canonical form" begin + canonized = canonize(ψ) + evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14) + @test isapprox(contract(evolved), contract(canonized)) + @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) + @test isapprox(contract(evolved), contract(ψ)) + end + + @testset "arbitrary chain" begin + evolved = evolve!(deepcopy(ψ), gate; threshold=1e-14) + @test length(tensors(evolved)) == 5 + @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2,), (2, 2, 2), (2, 2, 2), (2, 2)]) + @test isapprox(contract(evolved), contract(ψ)) + end + end + end + + # TODO rename when method is renamed + @testset "contract between" begin + ψ = rand(MPS; n=5, maxdim=20) + let canonized = canonize(ψ) + @test_throws ArgumentError contract!(canonized; between=(site"1", site"2"), direction=:dummy) + end + + canonized = canonize(ψ) + + for i in 1:4 + contract_some = contract(canonized; between=(Site(i), Site(i + 1))) + Bᵢ = tensors(contract_some; at=Site(i)) + + @test isapprox(contract(contract_some), contract(ψ)) + @test_throws ArgumentError tensors(contract_some; between=(Site(i), Site(i + 1))) + + @test isrightcanonical(contract_some, Site(i)) + @test isleftcanonical(contract(canonized; between=(Site(i), Site(i + 1)), direction=:right), Site(i + 1)) + + Γᵢ = tensors(canonized; at=Site(i)) + Λᵢ₊₁ = tensors(canonized; between=(Site(i), Site(i + 1))) + @test Bᵢ ≈ contract(Γᵢ, Λᵢ₊₁; dims=()) + end + end +end diff --git a/test/Project.toml b/test/Project.toml index f8588675a..366561448 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,10 +8,12 @@ Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" +MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a" OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" Permutations = "2ae35dd2-176d-5d53-8349-f30d82d94d4f" diff --git a/test/integration/ChainRules_test.jl b/test/integration/ChainRules_test.jl index 79b763867..e6219005d 100644 --- a/test/integration/ChainRules_test.jl +++ b/test/integration/ChainRules_test.jl @@ -1,6 +1,8 @@ @testset "ChainRules" begin using Tenet: Tensor, contract using ChainRulesTestUtils + using Graphs + using MetaGraphsNext @testset "Tensor" begin test_frule(Tensor, ones(), Symbol[]) @@ -190,30 +192,36 @@ end @testset "Ansatz" begin - @testset "Product" begin - tn = TensorNetwork([Tensor(ones(2), [:i]), Tensor(ones(2), [:j]), Tensor(ones(2), [:k])]) - qtn = Quantum(tn, Dict([site"1" => :i, site"2" => :j, site"3" => :k])) + tn = Quantum(TensorNetwork([Tensor(ones(2), [:i])]), Dict{Site,Symbol}(site"1" => :i)) + lattice = MetaGraph(Graph(1), Pair{Site,Nothing}[site"1" => nothing], Pair{Tuple{Site,Site},Nothing}[]) + test_frule(Ansatz, tn, lattice) + test_rrule(Ansatz, tn, lattice) + end - test_frule(Product, qtn) - test_rrule(Product, qtn) - end + @testset "Product" begin + tn = Product([ones(2), ones(2), ones(2)]) - @testset "Chain" begin - tn = Chain(State(), Open(), [ones(2, 2), ones(2, 2, 2), ones(2, 2)]) - # test_frule(Chain, Quantum(tn), Open()) - test_rrule(Chain, Quantum(tn), Open()) + test_frule(Product, Ansatz(tn)) + test_rrule(Product, Ansatz(tn)) + end - tn = Chain(State(), Periodic(), [ones(2, 2, 2), ones(2, 2, 2), ones(2, 2, 2)]) - # test_frule(Chain, Quantum(tn), Periodic()) - test_rrule(Chain, Quantum(tn), Periodic()) + @testset "MPS" begin + tn = MPS([ones(2, 2), ones(2, 2, 2), ones(2, 2)]) + # test_frule(MPS, Ansatz(tn), form(tn)) + test_rrule(MPS, Ansatz(tn), form(tn)) - tn = Chain(Operator(), Open(), [ones(2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2)]) - # test_frule(Chain, Quantum(tn), Open()) - test_rrule(Chain, Quantum(tn), Open()) + # TODO reenable periodic MPS + # tn = MPS([ones(2, 2, 2), ones(2, 2, 2), ones(2, 2, 2)]) + # test_frule(Chain, Quantum(tn), Periodic()) + # test_rrule(Chain, Quantum(tn), Periodic()) - tn = Chain(Operator(), Periodic(), [ones(2, 2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2, 2)]) - # test_frule(Chain, Quantum(tn), Periodic()) - test_rrule(Chain, Quantum(tn), Periodic()) - end + tn = MPO([ones(2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2)]) + # test_frule(MPO, Ansatz(tn), form(tn)) + test_rrule(MPO, Ansatz(tn), form(tn)) + + # TODO reenable periodic MPO + # tn = Chain(Operator(), Periodic(), [ones(2, 2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2, 2)]) + # test_frule(Chain, Quantum(tn), Periodic()) + # test_rrule(Chain, Quantum(tn), Periodic()) end end diff --git a/test/runtests.jl b/test/runtests.jl index 6a9bc8e71..4183a75c9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,8 +10,10 @@ using OMEinsum include("Transformations_test.jl") include("Site_test.jl") include("Quantum_test.jl") + include("Ansatz_test.jl") include("Product_test.jl") - include("Chain_test.jl") + include("MPS_test.jl") + include("MPO_test.jl") end # CI hangs on these tests for some unknown reason on Julia 1.9