Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extension to convert ITensorMPS.MPS to Tenet.MPS and viceversa #251

Merged
merged 8 commits into from
Nov 18, 2024
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7"
ITensorMPS = "0d1a4710-d33b-49a5-8f18-73bdf49b47e2"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Expand All @@ -44,6 +45,7 @@ TenetChainRulesTestUtilsExt = ["ChainRulesCore", "ChainRulesTestUtils"]
TenetDaggerExt = "Dagger"
TenetFiniteDifferencesExt = "FiniteDifferences"
TenetGraphMakieExt = ["GraphMakie", "Makie"]
TenetITensorMPSExt = ["ITensors, "ITensorMPS"]
TenetITensorNetworksExt = "ITensorNetworks"
TenetITensorsExt = "ITensors"
TenetKrylovKitExt = ["KrylovKit"]
Expand All @@ -66,6 +68,7 @@ EinExprs = "0.5, 0.6"
FiniteDifferences = "0.12"
GraphMakie = "0.4,0.5"
Graphs = "1.7"
ITensorMPS = "0.2.6"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm if we set the patch version, will we have problems when v0.2.7 lands?

like, will it force it to use v0.2.6 or will it be compatible (but not included) up to v0.3.0?

in principle we should only mark major and minor versions, not patch version (unless a patch is completely required for it to work)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmm I don't know, I just put the version since we also had that for ITensors and ITensorNetworks. Do you really think this is a problem?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should try to just set the major.minor version. If you need a higher patch version, you should just call Pkg.update() then.

ITensorNetworks = "0.11"
ITensors = "0.6"
KeywordDispatch = "0.3"
Expand Down
104 changes: 104 additions & 0 deletions ext/TenetITensorMPSExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
module TenetITensorMPSExt

using Tenet
using ITensors: ITensor, Index, dim
using ITensorMPS
using Tenet: MPS, tensors, form, inds

# Convert an AbstractMPS to an ITensor MPS
function Base.convert(::Type{ITensorMPS.MPS}, mps::Tenet.AbstractMPS)
@assert form(mps) isa MixedCanonical "Currently only MixedCanonical MPS conversion is supported"

ortho_center = form(mps).orthog_center

itensors = ITensor[]
for (i, t) in enumerate(tensors(mps))
t = Tenet.permutedims(
t,
Vector{Symbol}(
filter!(
!isnothing,
[inds(mps; at=Site(i)), inds(mps; at=Site(i), dir=:left), inds(mps; at=Site(i), dir=:right)],
),
),
)

site_index = Index(size(mps, inds(mps; at=Site(i))), "Site,n=$i")
if i == 1
link_size = size(mps, inds(mps; at=Site(1), dir=:right))
link_indices = [Index(link_size, "Link,l=1")]
else
# Take index from previous tensor as the left link index
prev_ind = ITensors.inds(itensors[end])[end]

if i < length(tensors(mps))
next_link_size = size(mps, inds(mps; at=Site(i), dir=:right))
next_ind = Index(next_link_size, "Link,l=$(i)")
link_indices = [prev_ind, next_ind]
else
link_indices = [prev_ind]
end
end
all_indices = (site_index, link_indices...)

it = ITensor(parent(t), all_indices...)
push!(itensors, it)
end

itensors_mps = ITensorMPS.MPS(itensors)

# Set llim and rlim based on the orthogonality center
if isa(ortho_center, Site)
n = Tenet.id(ortho_center)

itensors_mps.llim = n - 1
itensors_mps.rlim = n + 1
elseif isa(ortho_center, Vector{Site})
ids = Tenet.id.(ortho_center)

# For multiple orthogonality centers, set llim and rlim accordingly
itensors_mps.llim = minimum(ids) - 1
itensors_mps.rlim = maximum(ids) + 1
end

return itensors_mps
end

# Convert an ITensor MPS to an MPS
function Base.convert(::Type{MPS}, itensors_mps::ITensorMPS.MPS)
llim = itensors_mps.llim
rlim = itensors_mps.rlim

# Extract site and link indices
sites = siteinds(itensors_mps)
links = linkinds(itensors_mps)

tensors_vec = []
first_ten = array(itensors_mps[1], sites[1], links[1])
push!(tensors_vec, first_ten)

# Extract the bulk tensors
for j in 2:(length(itensors_mps) - 1)
ten = array(itensors_mps[j], sites[j], links[j - 1], links[j]) # Indices are ordered as (site index, left link, right link)
push!(tensors_vec, ten)
end
last_ten = array(itensors_mps[end], sites[end], links[end])
push!(tensors_vec, last_ten)

mps = Tenet.MPS(tensors_vec)

# Map llim and rlim to your MPS's orthogonality center(s)
mps_form = if llim + 1 == rlim - 1
Tenet.MixedCanonical(Tenet.Site(; n=llim + 1))
elseif llim + 1 < rlim - 1
Tenet.MixedCanonical([Tenet.Site(j) for j in (llim + 1):(rlim - 1)])
else
Tenet.NonCanonical()
end

mps.form = mps_form

return mps
end

end
40 changes: 40 additions & 0 deletions test/integration/ITensorMPS_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
@testset "ITensorMPS" begin
using ITensorMPS
using ITensors: ITensor, Index, dim, dims
using Tenet: MPS, tensors, form

# Tenet to ITensorMPS conversion
tenet_mps = rand(MPS; n=5, maxdim=30)
itensor_mps = convert(ITensorMPS.MPS, tenet_mps)

@test length(tensors(tenet_mps)) == length(ITensors.tensors(itensor_mps))

for (t1, t2) in zip(tensors(tenet_mps), ITensors.tensors(itensor_mps))
@test issetequal(size(t1), dims(t2))
end

@test itensor_mps.llim == Tenet.id(form(tenet_mps).orthog_center) - 1
@test itensor_mps.rlim == Tenet.id(form(tenet_mps).orthog_center) + 1

contracted = Tenet.contract(tenet_mps)
permuted = permutedims(contracted, [inds(tenet_mps; at=Site(i)) for i in 1:length(tensors(tenet_mps))])
@test isapprox(parent(permuted), Array(ITensorMPS.contract(itensor_mps).tensor))

# ITensorMPS to Tenet conversion
itensor_mps = ITensorMPS.random_mps(siteinds(4, 5); linkdims=7)
tenet_mps = convert(MPS, itensor_mps)

@test length(ITensors.tensors(itensor_mps)) == length(tensors(tenet_mps))

for (t1, t2) in zip(ITensors.tensors(itensor_mps), tensors(tenet_mps))
@test issetequal(dims(t1), size(t2))
end

@test form(tenet_mps) isa MixedCanonical
@test form(tenet_mps).orthog_center == Site(itensor_mps.llim + 1)
@test form(tenet_mps).orthog_center == Site(itensor_mps.rlim - 1)

contracted = Tenet.contract(tenet_mps)
permuted = permutedims(contracted, [inds(tenet_mps; at=Site(i)) for i in 1:length(tensors(tenet_mps))])
@test isapprox(parent(permuted), Array(ITensorMPS.contract(itensor_mps).tensor))
end
Loading