Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend KrylovKit.eigsolve by allowing the initial guess x₀ to be a Tensor #171

Merged
merged 6 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions ext/TenetKrylovKitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,28 @@ function eigsolve_prehook_tensor_reshape(A::Tensor, left_inds, right_inds)
return Amat, left_sizes, right_sizes
end

function KrylovKit.eigselector(A::Tensor, T::Type; left_inds=Symbol[], right_inds=Symbol[], kwargs...)
Amat, _, _ = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds)
return KrylovKit.eigselector(Amat, T; kwargs...)
function eigsolve_prehook_tensor_reshape(A::Tensor, x₀::Tensor, left_inds, right_inds)
left_inds, right_inds = Tenet.factorinds(A, left_inds, right_inds)

Amat, left_sizes, right_sizes = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds)
prod_left_sizes = prod(left_sizes)

inds(x₀) != left_inds && throw(
ArgumentError(
"The initial guess must have the same left indices as the tensor, but got $(inds(x₀)) and $left_inds."
),
)
prod(size.((x₀,), left_inds)) != prod_left_sizes && throw(
ArgumentError(
"The initial guess must have the same size as the left indices, but got sizes $prod_x₀_sizes and $prod_left_sizes.",
),
)

# Permute and reshape the tensor
x₀ = permutedims(x₀, left_inds)
x₀vec = reshape(parent(x₀), prod_left_sizes)

return Amat, left_sizes, right_sizes, x₀vec
end

function KrylovKit.eigsolve(
Expand All @@ -50,20 +69,6 @@ function KrylovKit.eigsolve(
return vals, Avecs, info
end

function KrylovKit.eigsolve(
f::Tensor, x₀, howmany::Int=1, which::KrylovKit.Selector=:LM; left_inds=Symbol[], right_inds=Symbol[], kwargs...
)
Amat, left_sizes, right_sizes = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds)

# Compute eigenvalues and eigenvectors
vals, vecs, info = KrylovKit.eigsolve(Amat, x₀, howmany, which; kwargs...)

# Tensorify the eigenvectors
Avecs = [Tensor(reshape(vec, left_sizes...), left_inds) for vec in vecs]

return vals, Avecs, info
end

Comment on lines -53 to -66
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why have you removed this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment from above.

"""
KrylovKit.eigsolve(tensor::Tensor; left_inds, right_inds, kwargs...)

Expand Down Expand Up @@ -95,4 +100,25 @@ function KrylovKit.eigsolve(
return vals, Avecs, info
end

function KrylovKit.eigsolve(
A::Tensor,
x₀::Tensor,
howmany::Int,
which::KrylovKit.Selector,
alg::Algorithm;
left_inds=inds(x₀),
right_inds=Symbol[],
kwargs...,
) where {Algorithm<:KrylovKit.Lanczos} # KrylovKit.KrylovAlgorithm}
Amat, left_sizes, right_sizes, x₀vec = eigsolve_prehook_tensor_reshape(A, x₀, left_inds, right_inds)

# Compute eigenvalues and eigenvectors
vals, vecs, info = KrylovKit.eigsolve(Amat, x₀vec, howmany, which, alg; kwargs...)

# Tensorify the eigenvectors
Avecs = [Tensor(reshape(vec, left_sizes...), left_inds) for vec in vecs]

return vals, Avecs, info
end

end
20 changes: 18 additions & 2 deletions test/integration/KrylovKit_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,24 @@
@test parent(reconstructed_tensor) ≈ parent(transpose(reconstructed_tensor_perm))

@testset "Lanczos" begin
vals_lanczos, vecs_lanczos = eigsolve(
tensor, rand(ComplexF64, 4), 1, :SR, Lanczos(; krylovdim=2, tol=1e-16); left_inds=[:i], right_inds=[:j]
@test_throws ArgumentError eigsolve(
tensor,
Tensor(rand(ComplexF64, 4), (:j,)),
1,
:SR,
Lanczos(; krylovdim=2, tol=1e-16);
left_inds=[:i],
right_inds=[:j],
)

vals_lanczos, vecs_lanczos, info = eigsolve(
tensor,
Tensor(rand(ComplexF64, 4), (:i,)),
1,
:SR,
Lanczos(; krylovdim=2, tol=1e-16);
left_inds=[:i],
right_inds=[:j],
)

@test length(vals_lanczos) == 1
Expand Down
Loading