diff --git a/Project.toml b/Project.toml index 43556fce..9b32422b 100644 --- a/Project.toml +++ b/Project.toml @@ -8,9 +8,11 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" UnsafeAtomics = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/ext/LinearAlgebraExt.jl b/ext/LinearAlgebraExt.jl index 17f3e5e6..adff3179 100644 --- a/ext/LinearAlgebraExt.jl +++ b/ext/LinearAlgebraExt.jl @@ -1,11 +1,7 @@ module LinearAlgebraExt using KernelAbstractions: KernelAbstractions -if isdefined(Base, :get_extension) - using LinearAlgebra: Tridiagonal, Diagonal -else - using ..LinearAlgebra: Tridiagonal, Diagonal -end +using LinearAlgebra: Tridiagonal, Diagonal KernelAbstractions.get_backend(A::Diagonal) = KernelAbstractions.get_backend(A.diag) KernelAbstractions.get_backend(A::Tridiagonal) = KernelAbstractions.get_backend(A.d) diff --git a/ext/SparseArraysExt.jl b/ext/SparseArraysExt.jl index 4d8979b1..64a74f07 100644 --- a/ext/SparseArraysExt.jl +++ b/ext/SparseArraysExt.jl @@ -1,11 +1,7 @@ module SparseArraysExt using KernelAbstractions: KernelAbstractions -if isdefined(Base, :get_extension) - using SparseArrays: AbstractSparseArray, rowvals -else - using ..SparseArrays: AbstractSparseArray, rowvals -end +using SparseArrays: AbstractSparseArray, rowvals function KernelAbstractions.get_backend(A::AbstractSparseArray) return KernelAbstractions.get_backend(rowvals(A)) diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 954730fd..741ac12f 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -769,4 +769,9 @@ end end end +if !isdefined(Base, :get_extension) + include("../ext/LinearAlgebraExt.jl") + include("../ext/SparseArraysExt.jl") +end + end #module