Skip to content

Commit

Permalink
fix Array-TracedRArray contraction and force dense representation…
Browse files Browse the repository at this point in the history
… of Yao gates (#258)

* try fix reordering of sites

* try fix layout of multi-qubit gates in Yao to Tenet conversion

* Remove permutedims & add collect for array building

* minor fixes on Reactant pkg extension

* implement `copy` for some structs

* fix column- to row-major layout conversion on `Reactant.promote_to` call

* remove dead code

---------

Co-authored-by: Todorbsc <[email protected]>
  • Loading branch information
2 people authored and jofrevalles committed Dec 2, 2024
1 parent 8d53a34 commit 1cad26e
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 28 deletions.
44 changes: 18 additions & 26 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function Reactant.make_tracer(
seen, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs...
) where {RT<:Tensor}
tracedata = Reactant.make_tracer(seen, parent(prev), Reactant.append_path(path, :data), mode; kwargs...)
return Tensor(tracedata, inds(prev))
return Tensor(tracedata, copy(inds(prev)))
end

function Reactant.make_tracer(seen, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...)
Expand Down Expand Up @@ -42,16 +42,16 @@ function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reac
return Tenet.Product(tracetn)
end

for A in (MPS, MPO)
@eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return $A(tracetn, form(prev))
end
function Reactant.make_tracer(
seen, prev::A, path::Tuple, mode::Reactant.TraceMode; kwargs...
) where {A<:Tenet.AbstractMPO}
tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return A(tracetn, copy(form(prev)))
end

function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores)
data = Reactant.create_result(parent(tocopy), Reactant.append_path(path, :data), result_stores)
return :($Tensor($data, $(inds(tocopy))))
return :($Tensor($data, $(copy(inds(tocopy)))))
end

function Reactant.create_result(tocopy::TensorNetwork, @nospecialize(path), result_stores)
Expand All @@ -77,26 +77,11 @@ function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), resu
return :($(Tenet.Product)($tn))
end

for A in (MPS, MPO)
@eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A}
tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($A($tn, $(Tenet.form(tocopy))))
end
function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:Tenet.AbstractMPO}
tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($A($tn, $(Tenet.form(tocopy))))
end

# TODO try rely on generic fallback for ansatzes
# function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores)
# tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
# return :($(Tenet.Product)($tn))
# end

# for A in (MPS, MPO)
# @eval function Reactant.create_result(tocopy::$A, @nospecialize(path), result_stores)
# tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
# return :($A($tn, form(tocopy)))
# end
# end

function Reactant.push_val!(ad_inputs, x::TensorNetwork, path)
@assert length(path) == 2
@assert path[2] === :data
Expand Down Expand Up @@ -216,7 +201,14 @@ end

Tenet.contract(a::Tensor, b::Tensor{T,N,TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...)
function Tenet.contract(a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb}; kwargs...) where {Ta,Na,Tb,Nb}
return contract(a, Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, parent(b)), inds(b)); kwargs...)
# TODO change to `Ops.constant` when Ops PR lands in Reactant
# apparently `promote_to` doesn't do the transpostion for converting from column-major (Julia) to row-major layout (MLIR)
# currently, we call permutedims manually
return contract(
a,
Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, permutedims(parent(b), collect(Nb:-1:1))), inds(b));
kwargs...,
)
end

end
4 changes: 2 additions & 2 deletions ext/TenetYaoBlocksExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ function Tenet.Quantum(circuit::AbstractBlock)
end

# NOTE `YaoBlocks.mat` on m-site qubits still returns the operator on the full Hilbert space
m = length(occupied_locs(gate))
operator = if gate isa YaoBlocks.ControlBlock
m = length(occupied_locs(gate))
control((1:(m - 1))..., m => content(gate))(m)
else
content(gate)
end
array = reshape(mat(operator), fill(nlevel(operator), 2 * nqubits(operator))...)
array = reshape(collect(mat(operator)), fill(nlevel(operator), 2 * nqubits(operator))...)

inds = (x -> collect(Iterators.flatten(zip(x...))))(
map(occupied_locs(gate)) do l
Expand Down
4 changes: 4 additions & 0 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Abstract type representing the canonical form trait of a [`AbstractAnsatz`](@ref
"""
abstract type Form end

Base.copy(x::Form) = x

"""
NonCanonical
Expand All @@ -52,6 +54,8 @@ struct MixedCanonical <: Form
orthog_center::Union{Site,Vector{<:Site}}
end

Base.copy(x::MixedCanonical) = MixedCanonical(copy(x.orthog_center))

"""
Canonical
Expand Down
2 changes: 2 additions & 0 deletions src/Site.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ end
Site(id::Int; kwargs...) = Site((id,); kwargs...)
Site(id::Vararg{Int,N}; kwargs...) where {N} = Site(id; kwargs...)

Base.copy(x::Site) = x

id(site::Site{1}) = only(site.id)
id(site::Site) = site.id

Expand Down

0 comments on commit 1cad26e

Please sign in to comment.