Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concretize Ansatz type with "lattice"/connectivity information and refactor subtypes on top of it #204

Closed
wants to merge 75 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
9d8f4ef
Prototype `MPS`, `MPO`
mofeing Aug 9, 2024
e48e2c6
Introduce Canonical `Form` trait
mofeing Sep 3, 2024
0382f57
Implement `rand`, `adjoint`, `defaultorder`, `boundary`, `form` for `…
mofeing Sep 12, 2024
d29ff7d
Implement conversion from `Product` to `MPS`, `MPO`
mofeing Sep 12, 2024
0f0e5dd
Implement `hassite` method and alias to `in`
mofeing Sep 14, 2024
555440a
Refactor `Ansatz` into a concrete type
mofeing Sep 15, 2024
425df9e
Fix typos in `Ansatz`
mofeing Sep 16, 2024
d2c8ce4
Use `Graphs.neighbors`
mofeing Sep 16, 2024
e492846
Refactor `Product` on top of new `Ansatz` type
mofeing Sep 16, 2024
d334d13
Implement `copy`, `similar`, `zero` for `Ansatz`
mofeing Sep 16, 2024
86fdd7a
Format code
mofeing Sep 16, 2024
1937f80
Refactor `Dense` on top of new `Ansatz` type
mofeing Sep 16, 2024
ba817a2
Refactor `MPS`, `MPO` on top of new `Ansatz` type
mofeing Sep 16, 2024
3370244
Relax `sites` condition on `Ansatz` construction for `lanes`
mofeing Sep 16, 2024
8c05563
Implement `PEPS`, `PEPO` types
mofeing Sep 16, 2024
c21db5e
Remove some exports
mofeing Sep 16, 2024
c9d351f
Move `Chain` code to `AbstractAnsatz` and `MPS`
mofeing Sep 17, 2024
c76fdb3
Force Graphs to be a strong dependency
mofeing Sep 17, 2024
636b03f
Implement `Graphs.neighbors`, `isneighbor` methods
mofeing Sep 18, 2024
ab497b5
Fix `sites` method for `MPS`
mofeing Sep 18, 2024
0079bfd
Fix `inds` method for `MPS`
mofeing Sep 18, 2024
c933053
Fix typo
mofeing Sep 18, 2024
538e698
Refactor `adapt_structure` method to support additional types
mofeing Sep 18, 2024
23fc8d4
Refactor `Reactant.make_tracer`, `Reactant.create_result` methods on …
mofeing Sep 18, 2024
203c87d
Refactor `ChainRules` methods on top of new types
mofeing Sep 18, 2024
7520a2e
Refactor `ProjectTo` for `Ansatz`
mofeing Sep 18, 2024
64b34b9
Refactor `rand` for `MPS`, `MPO`
mofeing Sep 18, 2024
0c79f0b
Export `canonize_site`, `canonize_site!` methods
mofeing Sep 18, 2024
076ee5b
Refactor `Chain` tests on top of `MPS`, `MPO`
mofeing Sep 18, 2024
5e823c3
Add `Graphs`, `MetaGraphsNext` as test dependencies
mofeing Sep 18, 2024
60ae718
Try using more `@site_str` instead of `Site` in MPS tests
mofeing Sep 18, 2024
cca0cff
Implement some `sites`, `inds` methods for `MPO`
mofeing Sep 18, 2024
24800b8
Try using more `@site_str` in MPO tests
mofeing Sep 18, 2024
6f8a0a0
Fix `tensors(; bond)`
mofeing Sep 18, 2024
955361c
Fix typo in `mixed_canonize!`
mofeing Sep 18, 2024
531002f
Export `isleftcanonical`, `isrightcanonical`
mofeing Sep 18, 2024
efc1c11
Fix `truncate!`
mofeing Sep 18, 2024
08b3d45
Fix `truncate` tests on `MPS`
mofeing Sep 18, 2024
c986ae2
Fix `Dense` constructors
mofeing Sep 18, 2024
f03bb18
Fix `truncate!` extension when using `threshold`
mofeing Sep 18, 2024
7db568d
Refactor some tests of `MPS` to simplify
mofeing Sep 18, 2024
fba30e9
Format code
mofeing Sep 18, 2024
edc0a16
Fix typo in `normalize!` on `MPS` method
mofeing Sep 18, 2024
8dc2795
Fix typo
mofeing Sep 18, 2024
c3e7ca0
Deprecate `isleftcanonical`, `isrightcanonical` in favor of `isisometry`
mofeing Sep 18, 2024
380d7d3
Comment `renormalize` kwarg of `evolve!`
mofeing Sep 18, 2024
3cb2db3
Fix `simple_update!` on single site gates
mofeing Sep 19, 2024
53122c1
Fix `isleftcanonical`, `isrightcanonical` tests on boundary sites
mofeing Sep 19, 2024
daec4d1
Fix `evolve!` calls in tests
mofeing Sep 19, 2024
c8fe5da
Fix indexing problems in `simple_update_1site!`
mofeing Sep 19, 2024
0a5408a
Refactor MPO tests
mofeing Sep 22, 2024
cab5fa4
Some fixes on `simple_update!`
mofeing Sep 22, 2024
2a6ebfd
Prototype tests for `Ansatz`
mofeing Sep 22, 2024
54312c5
Some fixes for `PEPS` constructor
mofeing Sep 26, 2024
9781717
Remove check in `PEPS` constructor
mofeing Sep 26, 2024
5b8da11
Fix reference to lattice in `adapt_structure` for `Ansatz`
mofeing Sep 26, 2024
63edfc2
Stop orthogonalization to index on `mixed_canonize!`
mofeing Sep 26, 2024
7921a49
Aesthetic name fix
mofeing Sep 26, 2024
eb95d07
Stop using `IdDict` on Reactant extension
mofeing Sep 28, 2024
22824d8
Fix `create_result` on `MPS`, `MPO`
mofeing Sep 30, 2024
dfb5c2e
Refactor lattice generation in constructors of `Dense`, `Product`, `M…
mofeing Sep 30, 2024
ef25de5
Fix `make_tracer`, `create_result` from Reactant on `Product`, `Dense`
mofeing Sep 30, 2024
9ffddf3
Implement `rand`, `normalize!`, `overlap` for `Dense` states
mofeing Sep 30, 2024
7cd5fae
Set temporarily a more concrete type of `lattice` in graph to circunv…
Oct 2, 2024
71964b0
Fix pairwise `contract` between `TracedRArray` and `Array`
mofeing Oct 6, 2024
20d7183
Small fix to `promote_to` of `Tensor`
mofeing Oct 6, 2024
4121133
Refactor by importing `Reactant.TracedRArray`
mofeing Oct 6, 2024
4aa2830
Try remove ambiguity on `contract` with `TracedRArray`
mofeing Oct 6, 2024
2579137
Dispatch `det`, `logdet`, `tr` methods to underlying array on matrix …
mofeing Oct 8, 2024
1c65921
Implement Eigendecomposition for `Tensor`
mofeing Oct 8, 2024
3670dfe
Implement `eigen!` for `TensorNetwork`
mofeing Oct 8, 2024
cd937af
small fix
mofeing Oct 8, 2024
fee3b07
Implement an MPS method initializing the tensors to identity (copy-te…
Todorbsc Oct 14, 2024
f39b6f8
Implement `traced_type` for `AbstractTensorNetwork`
mofeing Oct 18, 2024
f5aa926
Format code
mofeing Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KeywordDispatch = "5888135b-5456-5c80-a1b6-c91ef8180460"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Expand All @@ -25,7 +27,6 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Expand Down
10 changes: 8 additions & 2 deletions ext/TenetAdaptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ Adapt.adapt_structure(to, x::Tensor) = Tensor(adapt(to, parent(x)), inds(x))
Adapt.adapt_structure(to, x::TensorNetwork) = TensorNetwork(adapt.(Ref(to), tensors(x)))

Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites)
Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Quantum(x)))
Adapt.adapt_structure(to, x::Chain) = Chain(adapt(to, Quantum(x)), boundary(x))
Adapt.adapt_structure(to, x::Ansatz) = Ansatz(adapt(to, Quantum(x)), Tenet.lattice(x))

Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Ansatz(x)))
Adapt.adapt_structure(to, x::Dense) = Dense(adapt(to, Ansatz(x)))
Adapt.adapt_structure(to, x::MPS) = MPS(adapt(to, Ansatz(x)), form(x))
Adapt.adapt_structure(to, x::MPO) = MPO(adapt(to, Ansatz(x)), form(x))
Adapt.adapt_structure(to, x::PEPS) = PEPS(adapt(to, Ansatz(x)), form(x))
Adapt.adapt_structure(to, x::PEPO) = PEPO(adapt(to, Ansatz(x)), form(x))

end
39 changes: 25 additions & 14 deletions ext/TenetChainRulesCoreExt/frules.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,39 @@
using Tenet: AbstractTensorNetwork, AbstractQuantum

# `Tensor` constructor
ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds) = T(data, inds), T(Δ, inds)

# `TensorNetwork` constructor
ChainRulesCore.frule((_, Δ), ::Type{TensorNetwork}, tensors) = TensorNetwork(tensors), TensorNetworkTangent(Δ)

# `Quantum` constructor
function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites)
return Quantum(x, sites), Tangent{Quantum}(; tn=ẋ, sites=NoTangent())
end

# `Ansatz` constructor
function ChainRulesCore.frule((_, ẋ), ::Type{Ansatz}, x::Quantum, lattice)
return Ansatz(x, lattice), Tangent{Ansatz}(; tn=ẋ, lattice=NoTangent())
end

# `AbstractAnsatz`-subtype constructors
ChainRulesCore.frule((_, ẋ), ::Type{Product}, x::Ansatz) = Product(x), Tangent{Product}(; tn=ẋ)
ChainRulesCore.frule((_, ẋ), ::Type{Dense}, x::Ansatz) = Dense(x, form), Tangent{Dense}(; tn=ẋ)
ChainRulesCore.frule((_, ẋ), ::Type{MPS}, x::Ansatz, form) = MPS(x, form), Tangent{MPS}(; tn=ẋ, lattice=NoTangent())
ChainRulesCore.frule((_, ẋ), ::Type{MPO}, x::Ansatz, form) = MPO(x, form), Tangent{MPO}(; tn=ẋ, lattice=NoTangent())
function ChainRulesCore.frule((_, ẋ), ::Type{PEPS}, x::Ansatz, form)
return PEPS(x, form), Tangent{PEPS}(; tn=ẋ, lattice=NoTangent())
end

# `Base.conj` methods
ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::Tensor) = conj(tn), conj(Δ)

ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::TensorNetwork) = conj(tn), conj(Δ)
ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::AbstractTensorNetwork) = conj(tn), conj(Δ)

# `Base.merge` methods
ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(Base.merge), a::TensorNetwork, b::TensorNetwork) = merge(a, b), merge(ȧ, ḃ)
function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(Base.merge), a::AbstractTensorNetwork, b::AbstractTensorNetwork)
return merge(a, b), merge(ȧ, ḃ)
end

# `contract` methods
function ChainRulesCore.frule((_, ẋ), ::typeof(contract), x::Tensor; kwargs...)
Expand All @@ -22,15 +45,3 @@ function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(contract), a::Tensor, b::T
ċ = contract(ȧ, b; kwargs...) + contract(a, ḃ; kwargs...)
return c, ċ
end

function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites)
y = Quantum(x, sites)
ẏ = Tangent{Quantum}(; tn=ẋ)
return y, ẏ
end

ChainRulesCore.frule((_, ẋ), ::Type{T}, x::Quantum) where {T<:Ansatz} = T(x), Tangent{T}(; super=ẋ)

function ChainRulesCore.frule((_, ẋ, _), ::Type{T}, x::Quantum, boundary) where {T<:Ansatz}
return T(x, boundary), Tangent{T}(; super=ẋ, boundary=NoTangent())
end
7 changes: 2 additions & 5 deletions ext/TenetChainRulesCoreExt/projectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,5 @@ end
ChainRulesCore.ProjectTo(x::Quantum) = ProjectTo{Quantum}(; tn=ProjectTo(TensorNetwork(x)), sites=x.sites)
(projector::ProjectTo{Quantum})(Δ) = Quantum(projector.tn(Δ), projector.sites)

ChainRulesCore.ProjectTo(x::T) where {T<:Ansatz} = ProjectTo{T}(; super=ProjectTo(Quantum(x)))
(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Ansatz} = T(projector.super(Δ.super), Δ.boundary)

# NOTE edge case: `Product` has no `boundary`. should it?
(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Product} = T(projector.super(Δ.super))
ChainRulesCore.ProjectTo(x::Ansatz) = ProjectTo{Ansatz}(; tn=ProjectTo(Quantum(x)), lattice=x.lattice)
(projector::ProjectTo{Ansatz})(Δ) = Ansatz(projector.tn(Δ), Δ.lattice)
59 changes: 32 additions & 27 deletions ext/TenetChainRulesCoreExt/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,38 @@ TensorNetwork_pullback(Δ::TensorNetworkTangent) = (NoTangent(), tensors(Δ))
TensorNetwork_pullback(Δ::AbstractThunk) = TensorNetwork_pullback(unthunk(Δ))
ChainRulesCore.rrule(::Type{TensorNetwork}, tensors) = TensorNetwork(tensors), TensorNetwork_pullback

# `Quantum` constructor
Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
Quantum_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent())
Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback

# `Ansatz` constructor
Ansatz_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Ansatz}, x::Quantum, lattice) = Ansatz(x, lattice), Ansatz_pullback

# `AbstractAnsatz`-subtype constructors
Product_pullback(ȳ) = (NoTangent(), ȳ.tn)
Product_pullback(ȳ::AbstractThunk) = Product_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Product}, x::Ansatz) = Product(x), Product_pullback

Dense_pullback(ȳ) = (NoTangent(), ȳ.tn)
Dense_pullback(ȳ::AbstractThunk) = Dense_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Dense}, x::Ansatz) = Dense(x), Dense_pullback

MPS_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
MPS_pullback(ȳ::AbstractThunk) = MPS_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{MPS}, x::Ansatz, form) = MPS(x, form), MPS_pullback

MPO_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
MPO_pullback(ȳ::AbstractThunk) = MPO_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{MPO}, x::Ansatz, form) = MPO(x, form), MPO_pullback

PEPS_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
PEPS_pullback(ȳ::AbstractThunk) = PEPS_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{PEPS}, x::Ansatz, form) = PEPS(x, form), PEPS_pullback

# `Base.conj` methods
conj_pullback(Δ::Tensor) = (NoTangent(), conj(Δ))
conj_pullback(Δ::Tangent{Tensor}) = (NoTangent(), conj(Δ))
Expand Down Expand Up @@ -93,33 +125,6 @@ function ChainRulesCore.rrule(::typeof(contract), a::Tensor, b::Tensor; kwargs..
return c, contract_pullback
end

Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
Quantum_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent())
Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback

Ansatz_pullback(ȳ) = (NoTangent(), ȳ.super)
Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ))
function ChainRulesCore.rrule(::Type{T}, x::Quantum) where {T<:Ansatz}
y = T(x)
return y, Ansatz_pullback
end

Ansatz_boundary_pullback(ȳ) = (NoTangent(), ȳ.super, NoTangent())
Ansatz_boundary_pullback(ȳ::AbstractThunk) = Ansatz_boundary_pullback(unthunk(ȳ))
function ChainRulesCore.rrule(::Type{T}, x::Quantum, boundary) where {T<:Ansatz}
return T(x, boundary), Ansatz_boundary_pullback
end

Ansatz_from_arrays_pullback(ȳ) = (NoTangent(), NoTangent(), NoTangent(), parent.(tensors(ȳ.super.tn)))
Ansatz_from_arrays_pullback(ȳ::AbstractThunk) = Ansatz_from_arrays_pullback(unthunk(ȳ))
function ChainRulesCore.rrule(
::Type{T}, socket::Tenet.Socket, boundary::Tenet.Boundary, arrays; kwargs...
) where {T<:Ansatz}
y = T(socket, boundary, arrays; kwargs...)
return y, Ansatz_from_arrays_pullback
end

copy_pullback(ȳ) = (NoTangent(), ȳ)
copy_pullback(ȳ::AbstractThunk) = unthunk(ȳ)
function ChainRulesCore.rrule(::typeof(copy), x::Quantum)
Expand Down
31 changes: 23 additions & 8 deletions ext/TenetChainRulesTestUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,42 @@ using Tenet
using ChainRulesCore
using ChainRulesTestUtils
using Random
using Graphs
using MetaGraphsNext

const TensorNetworkTangent = Base.get_extension(Tenet, :TenetChainRulesCoreExt).TensorNetworkTangent

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Vector{T}) where {T<:Tensor}
if isempty(x)
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Vector{T}) where {T<:Tensor}
if isempty(tn)
return Vector{T}()
else
@invoke rand_tangent(rng::AbstractRNG, x::AbstractArray)
@invoke rand_tangent(rng::AbstractRNG, tn::AbstractArray)
end
end

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::TensorNetwork)
return TensorNetworkTangent(Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)])
end

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Quantum)
return Tangent{Quantum}(; tn=rand_tangent(rng, TensorNetwork(x)), sites=NoTangent())
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Quantum)
return Tangent{Quantum}(; tn=rand_tangent(rng, TensorNetwork(tn)), sites=NoTangent())
end

# WARN type-piracy
# NOTE used in `Quantum` constructor
ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::Dict{<:Site,Symbol}) = NoTangent()
# WARN type-piracy, used in `Quantum` constructor
ChainRulesTestUtils.rand_tangent(::AbstractRNG, tn::Dict{<:Site,Symbol}) = NoTangent()

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Ansatz)
return Tangent{Ansatz}(; tn=rand_tangent(rng, Quantum(tn)), lattice=NoTangent())
end

# WARN not really type-piracy but almost, used in `Ansatz` constructor
ChainRulesTestUtils.rand_tangent(::AbstractRNG, tn::T) where {V,T<:MetaGraph{V,SimpleGraph{V},<:Site}} = NoTangent()

# WARN not really type-piracy but almost, used when testing `Ansatz`
function ChainRulesTestUtils.test_approx(
actual::G, expected::G, msg; kwargs...
) where {G<:MetaGraph{Int64,SimpleGraph{Int64},<:Site}}
return actual == expected
end

end
14 changes: 14 additions & 0 deletions ext/TenetFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,18 @@ function FiniteDifferences.to_vec(x::Dict{Vector{Symbol},Tensor})
return x_vec, Dict_from_vec
end

function FiniteDifferences.to_vec(x::Quantum)
x_vec, back = to_vec(TensorNetwork(x))
Quantum_from_vec(v) = Quantum(back(v), copy(x.sites))

return x_vec, Quantum_from_vec
end

function FiniteDifferences.to_vec(x::Ansatz)
x_vec, back = to_vec(Quantum(x))
Ansatz_from_vec(v) = Ansatz(back(v), copy(x.lattice))

return x_vec, Ansatz_from_vec
end

end
4 changes: 2 additions & 2 deletions ext/TenetGraphMakieExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module TenetGraphMakieExt

using Tenet
using GraphMakie
using Graphs
using Makie
const Graphs = GraphMakie.Graphs
using Tenet
using Combinatorics: combinations

"""
Expand Down
73 changes: 51 additions & 22 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,53 @@ module TenetReactantExt
using Tenet
using EinExprs
using Reactant
using Reactant: @reactant_override
using Reactant: @reactant_override, TracedRArray
const MLIR = Reactant.MLIR
const stablehlo = MLIR.Dialects.stablehlo

const Enzyme = Reactant.Enzyme

Reactant.traced_type(::Type{T}, _, _) where {T<:Tenet.AbstractTensorNetwork} = T
Reactant.traced_getfield(x::TensorNetwork, i::Int) = tensors(x)[i]

function Reactant.make_tracer(
seen::IdDict, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs...
seen, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs...
) where {RT<:Tensor}
tracedata = Reactant.make_tracer(seen, parent(prev), Reactant.append_path(path, :data), mode; kwargs...)
return Tensor(tracedata, inds(prev))
end

function Reactant.make_tracer(seen::IdDict, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...)
function Reactant.make_tracer(seen, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetensors = Vector{Tensor}(undef, Tenet.ntensors(prev))
for (i, tensor) in enumerate(tensors(prev))
tracetensors[i] = Reactant.make_tracer(seen, tensor, Reactant.append_path(path, i), mode; kwargs...)
end
return TensorNetwork(tracetensors)
end

Reactant.traced_getfield(x::TensorNetwork, i::Int) = tensors(x)[i]

function Reactant.make_tracer(seen::IdDict, prev::Quantum, path::Tuple, mode::Reactant.TraceMode; kwargs...)
function Reactant.make_tracer(seen, prev::Quantum, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetn = Reactant.make_tracer(seen, TensorNetwork(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return Quantum(tracetn, copy(prev.sites))
end

function Reactant.make_tracer(seen::IdDict, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracequantum = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :super), mode; kwargs...)
return Tenet.Product(tracequantum)
function Reactant.make_tracer(seen, prev::Ansatz, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetn = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return Ansatz(tracetn, copy(Tenet.lattice(prev)))
end

# TODO try rely on generic fallback for ansatzes -> do it when refactoring to MPS/MPO
function Reactant.make_tracer(seen::IdDict, prev::Tenet.Chain, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracequantum = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :super), mode; kwargs...)
return Tenet.Chain(tracequantum, boundary(prev))
# TODO try rely on generic fallback for ansatzes
for A in (Product, Dense)
@eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return $A(tracetn)
end
end

for A in (MPS, MPO)
@eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return $A(tracetn, form(prev))
end
end

function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores)
Expand All @@ -59,10 +69,24 @@ function Reactant.create_result(tocopy::Quantum, @nospecialize(path), result_sto
return :($Quantum($tn, $(copy(tocopy.sites))))
end

# TODO try rely on generic fallback for ansatzes -> do it when refactoring to MPS/MPO
function Reactant.create_result(tocopy::Tenet.Chain, @nospecialize(path), result_stores)
qtn = Reactant.create_result(Quantum(tocopy), Reactant.append_path(path, :super), result_stores)
return :($(Tenet.Chain)($qtn, $(boundary(tocopy))))
function Reactant.create_result(tocopy::Ansatz, @nospecialize(path), result_stores)
tn = Reactant.create_result(Quantum(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($Ansatz($tn, $(copy(Tenet.lattice(tocopy)))))
end

# TODO try rely on generic fallback for ansatzes
for A in (Product, Dense)
@eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A}
tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($A($tn))
end
end

for A in (MPS, MPO)
@eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A}
tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($A($tn, $(Tenet.form(tocopy))))
end
end

function Reactant.push_val!(ad_inputs, x::TensorNetwork, path)
Expand Down Expand Up @@ -124,8 +148,8 @@ end
end

function Tenet.contract(
a::Tensor{Ta,Na,Aa}, b::Tensor{Tb,Nb,Ab}; dims=(∩(inds(a), inds(b))), out=nothing
) where {Ta,Na,Aa<:Reactant.TracedRArray,Tb,Nb,Ab<:Reactant.TracedRArray}
a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb,TracedRArray{Tb,Nb}}; dims=(∩(inds(a), inds(b))), out=nothing
) where {Ta,Na,Tb,Nb}
ia = collect(inds(a))
ib = collect(inds(b))
i = ∩(dims, ia, ib)
Expand Down Expand Up @@ -154,12 +178,12 @@ function Tenet.contract(

result = Reactant.MLIR.IR.result(stablehlo.einsum(op_a, op_b; result_0, einsum_config))

data = Reactant.TracedRArray{T,length(ic)}((), result, rsize)
data = TracedRArray{T,length(ic)}((), result, rsize)
_res = Tensor(data, ic)
return _res
end

function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) where {T,N,A<:Reactant.TracedRArray}
function Tenet.contract(a::Tensor{T,N,TracedRArray{T,N}}; dims=nonunique(inds(a)), out=nothing) where {T,N}
ia = inds(a)
i = ∩(dims, ia)

Expand All @@ -178,8 +202,13 @@ function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing)

result = Reactant.MLIR.IR.result(stablehlo.unary_einsum(operand; result_0, einsum_config))

data = Reactant.TracedRArray{T,length(ic)}((), result, rsize)
data = TracedRArray{T,length(ic)}((), result, rsize)
return Tensor(data, ic)
end

Tenet.contract(a::Tensor, b::Tensor{T,N,TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...)
function Tenet.contract(a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb}; kwargs...) where {Ta,Na,Tb,Nb}
return contract(a, Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, parent(b)), inds(b)); kwargs...)
end

end
Loading