Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contraction problems in OMEinsum for MPS evolved with different eltype Arrays #257

Open
jofrevalles opened this issue Nov 21, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@jofrevalles
Copy link
Member

Summary

After using the evolve! function on ψ (and I suspect that this can happen in other operations too), a MethodError occurs when we perform norm(ψ) due to a mismatch in array types (ComplexF64 vs. Float64) when OMEinsum attempts to perform a batched_gemm! operation.

julia> ψ = rand(MPS; n=6, maxdim=20); canonize!(ψ)
MPS (inputs=0, outputs=6)

julia> function id_gate(i, j)
               mat = Array{ComplexF64}(reshape(LinearAlgebra.I(4), 2, 2, 2, 2))
               Quantum(mat, [Site(i), Site(j), Site(i, dual=true), Site(j, dual=true)])
end
id_gate (generic function with 1 method)

julia> evolved = evolve!(deepcopy(ψ), id_gate(1, 2); maxdim=2, renormalize=false)
MPS (inputs=0, outputs=6)

julia> norm(evolved)
ERROR: MethodError: no method matching batched_gemm!(::Char, ::Char, ::ComplexF64, ::Array{ComplexF64, 3}, ::Array{Float64, 3}, ::ComplexF64, ::Array{ComplexF64, 3})
The function `batched_gemm!` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  batched_gemm!(::AbstractChar, ::AbstractChar, ::ComplexF64, ::AbstractArray{ComplexF64, 3}, ::AbstractArray{ComplexF64, 3}, ::ComplexF64, ::AbstractArray{ComplexF64, 3})
   @ BatchedRoutines ~/.julia/packages/BatchedRoutines/nbUQX/src/blas.jl:112
  batched_gemm!(::AbstractChar, ::AbstractChar, ::Float64, ::AbstractArray{Float64, 3}, ::AbstractArray{Float64, 3}, ::Float64, ::AbstractArray{Float64, 3})
   @ BatchedRoutines ~/.julia/packages/BatchedRoutines/nbUQX/src/blas.jl:112
  batched_gemm!(::AbstractChar, ::AbstractChar, ::ComplexF32, ::AbstractArray{ComplexF32, 3}, ::AbstractArray{ComplexF32, 3}, ::ComplexF32, ::AbstractArray{ComplexF32, 3})
   @ BatchedRoutines ~/.julia/packages/BatchedRoutines/nbUQX/src/blas.jl:112
  ...

Stacktrace:
  [1] _batched_gemm!(C1::Char, C2::Char, alpha::Bool, A::Array{ComplexF64, 3}, B::Array{Float64, 3}, beta::Bool, C::Array{ComplexF64, 3})
    @ OMEinsum ~/.julia/packages/OMEinsum/FCR12/src/utils.jl:165
  [2] binary_einsum!(::OMEinsum.SimpleBinaryRule{('i', 'j', 'l'), ('j', 'k', 'l'), ('i', 'k', 'l')}, x1::Array{ComplexF64, 3}, x2::Array{Float64, 3}, y::Array{ComplexF64, 3}, sx::Bool, sy::Bool)
    @ OMEinsum ~/.julia/packages/OMEinsum/FCR12/src/binaryrules.jl:138
  [3] einsum!(ixs::Tuple{Vector{Symbol}, Vector{Symbol}}, iy::Vector{Symbol}, xs::Tuple{Any, Any}, y::Any, sx::Bool, sy::Bool, size_dict::Dict{Symbol, Int64})
    @ OMEinsum ~/.julia/packages/OMEinsum/FCR12/src/einsum.jl:118
  [4] contract!(c::Tensor{ComplexF64, 5, Array{ComplexF64, 5}}, a::Tensor{ComplexF64, 5, Array{ComplexF64, 5}}, b::Tensor{Float64, 3, Array{Float64, 3}})
    @ Tenet ~/git/Tenet.jl/src/Numerics.jl:82
...
 [76] norm
    @ ~/git/Tenet.jl/src/Quantum.jl:413 [inlined]
 [77] norm::MPS)
    @ Tenet ~/git/Tenet.jl/src/Quantum.jl:413

This error can be easily circumvented changing the eltype of ψ to ComplexF64, but it is an error that we should consider. Also, I don't know if this is our fault or maybe this has to do with some bug in OMEinsum, but as far as I know the contraction between different eltypes should be supported.

@jofrevalles jofrevalles added the bug Something isn't working label Nov 21, 2024
@jofrevalles
Copy link
Member Author

jofrevalles commented Nov 21, 2024

Okay @mofeing see that this also errors:

julia> ψ = rand(MPS; n=6, maxdim=20); canonize!(ψ)
MPS (inputs=0, outputs=6)

julia> t = tensors(ψ, at=Site(2))
2×2×4 Tensor{Float64, 3, Array{Float64, 3}}: ...

julia> replace!(ψ, t => Tensor(Array{ComplexF64}(parent(t)), inds(t)))
TensorNetwork (#tensors=11, #inds=11)

julia> norm(ψ)
ERROR: MethodError: no method matching batched_gemm!(::Char, ::Char, ::ComplexF64, ::Array{ComplexF64, 3}, ::Array{Float64, 3}, ::ComplexF64, ::Array{ComplexF64, 3})
The function `batched_gemm!` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  batched_gemm!(::AbstractChar, ::AbstractChar, ::ComplexF64, ::AbstractArray{ComplexF64, 3}, ::AbstractArray{ComplexF64, 3}, ::ComplexF64, ::AbstractArray{ComplexF64, 3})
   @ BatchedRoutines ~/.julia/packages/BatchedRoutines/nbUQX/src/blas.jl:112
  batched_gemm!(::AbstractChar, ::AbstractChar, ::Float64, ::AbstractArray{Float64, 3}, ::AbstractArray{Float64, 3}, ::Float64, ::AbstractArray{Float64, 3})
   @ BatchedRoutines ~/.julia/packages/BatchedRoutines/nbUQX/src/blas.jl:112
  batched_gemm!(::AbstractChar, ::AbstractChar, ::ComplexF32, ::AbstractArray{ComplexF32, 3}, ::AbstractArray{ComplexF32, 3}, ::ComplexF32, ::AbstractArray{ComplexF32, 3})
   @ BatchedRoutines ~/.julia/packages/BatchedRoutines/nbUQX/src/blas.jl:112
  ...

Stacktrace:
  [1] _batched_gemm!(C1::Char, C2::Char, alpha::Bool, A::Array{ComplexF64, 3}, B::Array{Float64, 3}, beta::Bool, C::Array{ComplexF64, 3})
    @ OMEinsum ~/.julia/packages/OMEinsum/FCR12/src/utils.jl:165
  [2] binary_einsum!(::OMEinsum.SimpleBinaryRule{('i', 'j', 'l'), ('j', 'k', 'l'), ('i', 'k', 'l')}, x1::Array{ComplexF64, 3}, x2::Array{Float64, 3}, y::Array{ComplexF64, 3}, sx::Bool, sy::Bool)
    @ OMEinsum ~/.julia/packages/OMEinsum/FCR12/src/binaryrules.jl:138
  [3] einsum!(ixs::Tuple{Vector{Symbol}, Vector{Symbol}}, iy::Vector{Symbol}, xs::Tuple{Any, Any}, y::Any, sx::Bool, sy::Bool, size_dict::Dict{Symbol, Int64})
    @ OMEinsum ~/.julia/packages/OMEinsum/FCR12/src/einsum.jl:118
  [4] contract!(c::Tensor{ComplexF64, 5, Array{ComplexF64, 5}}, a::Tensor{ComplexF64, 5, Array{ComplexF64, 5}}, b::Tensor{Float64, 3, Array{Float64, 3}})
    @ Tenet ~/git/Tenet.jl/src/Numerics.jl:82
  [5] contract(a::Tensor{ComplexF64, 5, Array{ComplexF64, 5}}, b::Tensor{Float64, 3, Array{Float64, 3}}; dims::Vector{Symbol}, out::Nothing)
    @ Tenet ~/git/Tenet.jl/src/Numerics.jl:51
  [6] kwcall(::@NamedTuple{path::EinExprs.EinExpr{Symbol}}, ::typeof(contract), tn::MPS)
    @ Tenet ~/git/Tenet.jl/src/TensorNetwork.jl:702
  [7] contract(tn::MPS; kwargs::@Kwargs{path::EinExprs.EinExpr{Symbol}})
    @ Tenet ~/.julia/packages/KeywordDispatch/HY9PO/src/KeywordDispatch.jl:177
  [8] contract
    @ ~/.julia/packages/KeywordDispatch/HY9PO/src/KeywordDispatch.jl:175 [inlined]
  [9] #133
    @ ~/git/Tenet.jl/src/TensorNetwork.jl:701 [inlined]
 [10] iterate
    @ ./generator.jl:48 [inlined]
 [11] _collect(c::Vector{EinExprs.EinExpr{Symbol}}, itr::Base.Generator{Vector{EinExprs.EinExpr{Symbol}}, Tenet.var"#133#134"{MPS}}, ::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base ./array.jl:800
 [12] collect_similar
    @ ./array.jl:709 [inlined]
 [13] map
    @ ./abstractarray.jl:3371 [inlined]
 [14] kwcall(::@NamedTuple{path::EinExprs.EinExpr{Symbol}}, ::typeof(contract), tn::MPS)
    @ Tenet ~/git/Tenet.jl/src/TensorNetwork.jl:701
--- the above 8 lines are repeated 3 more times ---
 [39] contract(tn::MPS; kwargs::@Kwargs{path::EinExprs.EinExpr{Symbol}})
    @ Tenet ~/.julia/packages/KeywordDispatch/HY9PO/src/KeywordDispatch.jl:177
 [40] contract
    @ ~/.julia/packages/KeywordDispatch/HY9PO/src/KeywordDispatch.jl:175 [inlined]
 [41] #133
    @ ~/git/Tenet.jl/src/TensorNetwork.jl:701 [inlined]
 [42] iterate
    @ ./generator.jl:48 [inlined]
 [43] collect_to!(dest::Vector{Tensor{Float64, 2, Matrix{Float64}}}, itr::Base.Generator{Vector{EinExprs.EinExpr{Symbol}}, Tenet.var"#133#134"{MPS}}, offs::Int64, st::Int64)
    @ Base ./array.jl:838
 [44] collect_to_with_first!(dest::Vector{Tensor{Float64, 2, Matrix{Float64}}}, v1::Tensor{Float64, 2, Matrix{Float64}}, itr::Base.Generator{Vector{EinExprs.EinExpr{Symbol}}, Tenet.var"#133#134"{MPS}}, st::Int64)
    @ Base ./array.jl:816
 [45] _collect(c::Vector{EinExprs.EinExpr{Symbol}}, itr::Base.Generator{Vector{EinExprs.EinExpr{Symbol}}, Tenet.var"#133#134"{MPS}}, ::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base ./array.jl:810
 [46] collect_similar
    @ ./array.jl:709 [inlined]
 [47] map
    @ ./abstractarray.jl:3371 [inlined]
 [48] kwcall(::@NamedTuple{path::EinExprs.EinExpr{Symbol}}, ::typeof(contract), tn::MPS)
    @ Tenet ~/git/Tenet.jl/src/TensorNetwork.jl:701
 [49] contract(tn::MPS; kwargs::@Kwargs{path::EinExprs.EinExpr{Symbol}})
    @ Tenet ~/.julia/packages/KeywordDispatch/HY9PO/src/KeywordDispatch.jl:177
 [50] contract
    @ ~/.julia/packages/KeywordDispatch/HY9PO/src/KeywordDispatch.jl:175 [inlined]
 [51] #133
    @ ~/git/Tenet.jl/src/TensorNetwork.jl:701 [inlined]
 [52] iterate
    @ ./generator.jl:48 [inlined]
 [53] _collect(c::Vector{EinExprs.EinExpr{Symbol}}, itr::Base.Generator{Vector{EinExprs.EinExpr{Symbol}}, Tenet.var"#133#134"{MPS}}, ::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base ./array.jl:800
 [54] collect_similar
    @ ./array.jl:709 [inlined]
 [55] map
    @ ./abstractarray.jl:3371 [inlined]
 [56] kwcall(::@NamedTuple{path::EinExprs.EinExpr{Symbol}}, ::typeof(contract), tn::MPS)
    @ Tenet ~/git/Tenet.jl/src/TensorNetwork.jl:701
 [57] contract(tn::MPS; kwargs::@Kwargs{path::EinExprs.EinExpr{Symbol}})
    @ Tenet ~/.julia/packages/KeywordDispatch/HY9PO/src/KeywordDispatch.jl:177
 [58] contract
    @ ~/.julia/packages/KeywordDispatch/HY9PO/src/KeywordDispatch.jl:175 [inlined]
 [59] #133
    @ ~/git/Tenet.jl/src/TensorNetwork.jl:701 [inlined]
 [60] iterate
    @ ./generator.jl:48 [inlined]
 [61] collect_to!(dest::Vector{Tensor{Float64, 1, Vector{Float64}}}, itr::Base.Generator{Vector{EinExprs.EinExpr{Symbol}}, Tenet.var"#133#134"{MPS}}, offs::Int64, st::Int64)
    @ Base ./array.jl:838
 [62] collect_to_with_first!(dest::Vector{Tensor{Float64, 1, Vector{Float64}}}, v1::Tensor{Float64, 1, Vector{Float64}}, itr::Base.Generator{Vector{EinExprs.EinExpr{Symbol}}, Tenet.var"#133#134"{MPS}}, st::Int64)
    @ Base ./array.jl:816
 [63] _collect(c::Vector{EinExprs.EinExpr{Symbol}}, itr::Base.Generator{Vector{EinExprs.EinExpr{Symbol}}, Tenet.var"#133#134"{MPS}}, ::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base ./array.jl:810
 [64] collect_similar
    @ ./array.jl:709 [inlined]
 [65] map
    @ ./abstractarray.jl:3371 [inlined]
 [66] kwcall(::@NamedTuple{path::EinExprs.SizedEinExpr{Symbol}}, ::typeof(contract), tn::MPS)
    @ Tenet ~/git/Tenet.jl/src/TensorNetwork.jl:701
 [67] contract(tn::MPS; kwargs::@Kwargs{path::EinExprs.SizedEinExpr{Symbol}})
    @ Tenet ~/.julia/packages/KeywordDispatch/HY9PO/src/KeywordDispatch.jl:177
 [68] contract
    @ ~/.julia/packages/KeywordDispatch/HY9PO/src/KeywordDispatch.jl:175 [inlined]
 [69] #contract#123
    @ ~/.julia/packages/KeywordDispatch/HY9PO/src/KeywordDispatch.jl:177 [inlined]
 [70] contract
    @ ~/.julia/packages/KeywordDispatch/HY9PO/src/KeywordDispatch.jl:175 [inlined]
 [71] norm2(::State, ψ::MPS; kwargs::@Kwargs{})
    @ Tenet ~/git/Tenet.jl/src/Quantum.jl:421
 [72] norm2(::State, ψ::MPS)
    @ Tenet ~/git/Tenet.jl/src/Quantum.jl:420
 [73] norm2::MPS; kwargs::@Kwargs{})
    @ Tenet ~/git/Tenet.jl/src/Quantum.jl:418
 [74] norm2
    @ ~/git/Tenet.jl/src/Quantum.jl:418 [inlined]
 [75] #norm#243
    @ ~/git/Tenet.jl/src/Quantum.jl:415 [inlined]
 [76] norm
    @ ~/git/Tenet.jl/src/Quantum.jl:413 [inlined]
 [77] norm::MPS)
    @ Tenet ~/git/Tenet.jl/src/Quantum.jl:413

But not this:

julia> ψ = rand(MPS; n=6, maxdim=20)
MPS (inputs=0, outputs=6)

julia> t = tensors(ψ, at=Site(2))
2×2×4 Tensor{Float64, 3, Array{Float64, 3}}: ....

julia> replace!(ψ, t => Tensor(Array{ComplexF64}(parent(t)), inds(t)))
TensorNetwork (#tensors=6, #inds=11)

julia> norm(ψ)
0.9999999999999999

So it has to do with the Canonical form

@mofeing
Copy link
Member

mofeing commented Nov 21, 2024

This error can be easily circumvented changing the eltype of ψ to ComplexF64, but it is an error that we should consider.

you can promote the eltypes of both tensors on contract but I thought we already did that. check out "Conversion and Promotion" section in Julia documentation.

So it has to do with the Canonical form

mmm maybe because lambdas are real and gammas are complex?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants