From 45862140e2ab6fd08f53b0b1084266420550101a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Poisot?= Date: Sat, 14 Oct 2023 22:13:46 -0400 Subject: [PATCH 1/7] feat(sdt)?: MultivariateStats + StatsAPI --- Project.toml | 4 ++ src/SpeciesDistributionToolkit.jl | 8 ++- .../makie.jl => external/Makie.jl} | 0 src/external/MultivariateStats.jl | 49 +++++++++++++++++++ 4 files changed, 60 insertions(+), 1 deletion(-) rename src/{integrations/makie.jl => external/Makie.jl} (100%) create mode 100644 src/external/MultivariateStats.jl diff --git a/Project.toml b/Project.toml index dc8664de3..6de3d746c 100644 --- a/Project.toml +++ b/Project.toml @@ -5,16 +5,20 @@ version = "0.0.10" [deps] ArchGDAL = "c9ce4bd3-c3d5-55b8-8973-c0e20141b8c3" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Fauxcurrences = "a2d61402-033a-4ca9-aef4-652d70cf7c9c" GBIF = "ee291a33-5a6c-5552-a3c8-0f29a1181037" GDAL = "add2ef01-049f-52c4-9ee2-e494f65e021a" MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" +MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Phylopic = "c889285c-44aa-4473-b1e1-56f5d4e3ccf5" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" SimpleSDMDatasets = "2c7d61d0-5c73-410d-85b2-d2e7fbbdcefa" SimpleSDMLayers = "2c645270-77db-11e9-22c3-0f302a89c64c" +StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] diff --git a/src/SpeciesDistributionToolkit.jl b/src/SpeciesDistributionToolkit.jl index d2efb1161..f80441195 100644 --- a/src/SpeciesDistributionToolkit.jl +++ b/src/SpeciesDistributionToolkit.jl @@ -12,6 +12,9 @@ using MakieCore import StatsBase import OffsetArrays +import MultivariateStats +import StatsAPI + # We make ample use of re-export using Reexport @@ -32,7 +35,10 @@ include("integrations/gbif_layers.jl") include("integrations/gbif_phylopic.jl") # Plotting -include("integrations/makie.jl") +include("external/Makie.jl") + +# Plotting +include("external/MultivariateStats.jl") # Functions for IO include("io/geotiff.jl") diff --git a/src/integrations/makie.jl b/src/external/Makie.jl similarity index 100% rename from src/integrations/makie.jl rename to src/external/Makie.jl diff --git a/src/external/MultivariateStats.jl b/src/external/MultivariateStats.jl new file mode 100644 index 000000000..744dcf9c6 --- /dev/null +++ b/src/external/MultivariateStats.jl @@ -0,0 +1,49 @@ +function _layers_to_matrix(X) + Y = zeros(SimpleSDMLayers._inner_type(X[1]), (length(X), length(X[1]))) + for i in axes(X, 1) + Y[i,:] .= values(X[i]) + end + return Y +end + +# PCA + +function StatsAPI.fit(::Type{MultivariateStats.PCA}, X::Vector{T}; kwargs...) where {T <: SimpleSDMLayers.SimpleSDMLayer} + @assert SimpleSDMLayers._layers_are_compatible(X) + Y = _layers_to_matrix(X) + return StatsAPI.fit(MultivariateStats.PCA, Y; kwargs...) +end + +function StatsAPI.predict(M::MultivariateStats.PCA, X::Vector{T}) where {T <: SimpleSDMLayers.SimpleSDMLayer} + @assert SimpleSDMLayers._layers_are_compatible(X) + Y = _layers_to_matrix(X) + D = StatsAPI.predict(M, Y) + O = [similar(X[1]) for i in 1:MultivariateStats.outdim(M)] + for i in axes(O, 1) + for (j,k) in enumerate(keys(O[i])) + O[i][k] = D[i,j] + end + end + return O +end + +# Whitening + +function StatsAPI.fit(::Type{MultivariateStats.Whitening}, X::Vector{T}; n::Int=1_000, kwargs...) where {T <: SimpleSDMLayers.SimpleSDMLayer} + @assert SimpleSDMLayers._layers_are_compatible(X) + Y = _layers_to_matrix(X) + return StatsAPI.fit(MultivariateStats.Whitening, Y; kwargs...) +end + +function MultivariateStats.transform(W::MultivariateStats.Whitening, X::Vector{T}) where {T <: SimpleSDMLayers.SimpleSDMLayer} + @assert SimpleSDMLayers._layers_are_compatible(X) + Y = _layers_to_matrix(X) + D = MultivariateStats.transform(W, Y) + O = [similar(X[1]) for i in 1:length(X)] + for i in axes(O, 1) + for (j,k) in enumerate(keys(O[i])) + O[i][k] = D[i,j] + end + end + return O +end From eb60b16c27c85dc4ad05da9929c6219a4926b093 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Poisot?= Date: Sun, 15 Oct 2023 09:45:52 -0400 Subject: [PATCH 2/7] dependencies(sdt)!: move to an extension model for multivariate stats --- ext/MultivariateExtension.jl | 72 +++++++++++++++++++++++++++++++ src/SpeciesDistributionToolkit.jl | 6 --- src/external/MultivariateStats.jl | 49 --------------------- 3 files changed, 72 insertions(+), 55 deletions(-) create mode 100644 ext/MultivariateExtension.jl delete mode 100644 src/external/MultivariateStats.jl diff --git a/ext/MultivariateExtension.jl b/ext/MultivariateExtension.jl new file mode 100644 index 000000000..f4b53e753 --- /dev/null +++ b/ext/MultivariateExtension.jl @@ -0,0 +1,72 @@ +module MultivariateExtension + +using SpeciesDistributionToolkit +using MultivariateStats +using StatsAPI + +function _layers_to_matrix(X) + Y = zeros(SimpleSDMLayers._inner_type(X[1]), (length(X), length(X[1]))) + for i in axes(X, 1) + Y[i, :] .= values(X[i]) + end + return Y +end + +# PCA + +function StatsAPI.fit( + ::Type{MultivariateStats.PCA}, + X::Vector{T}; + kwargs..., +) where {T <: SimpleSDMLayers.SimpleSDMLayer} + @assert SimpleSDMLayers._layers_are_compatible(X) + Y = _layers_to_matrix(X) + return StatsAPI.fit(MultivariateStats.PCA, Y; kwargs...) +end + +function StatsAPI.predict( + M::MultivariateStats.PCA, + X::Vector{T}, +) where {T <: SimpleSDMLayers.SimpleSDMLayer} + @assert SimpleSDMLayers._layers_are_compatible(X) + Y = _layers_to_matrix(X) + D = StatsAPI.predict(M, Y) + O = [similar(X[1], eltype(D)) for i in 1:MultivariateStats.outdim(M)] + for i in axes(O, 1) + for (j, k) in enumerate(keys(O[i])) + O[i][k] = D[i, j] + end + end + return O +end + +# Whitening + +function StatsAPI.fit( + ::Type{MultivariateStats.Whitening}, + X::Vector{T}; + n::Int = 1_000, + kwargs..., +) where {T <: SimpleSDMLayers.SimpleSDMLayer} + @assert SimpleSDMLayers._layers_are_compatible(X) + Y = _layers_to_matrix(X) + return StatsAPI.fit(MultivariateStats.Whitening, Y; kwargs...) +end + +function MultivariateStats.transform( + W::MultivariateStats.Whitening, + X::Vector{T}, +) where {T <: SimpleSDMLayers.SimpleSDMLayer} + @assert SimpleSDMLayers._layers_are_compatible(X) + Y = _layers_to_matrix(X) + D = MultivariateStats.transform(W, Y) + O = [similar(X[1], eltype(D)) for i in 1:length(X)] + for i in axes(O, 1) + for (j, k) in enumerate(keys(O[i])) + O[i][k] = D[i, j] + end + end + return O +end + +end \ No newline at end of file diff --git a/src/SpeciesDistributionToolkit.jl b/src/SpeciesDistributionToolkit.jl index f80441195..13b7e70f9 100644 --- a/src/SpeciesDistributionToolkit.jl +++ b/src/SpeciesDistributionToolkit.jl @@ -12,9 +12,6 @@ using MakieCore import StatsBase import OffsetArrays -import MultivariateStats -import StatsAPI - # We make ample use of re-export using Reexport @@ -37,9 +34,6 @@ include("integrations/gbif_phylopic.jl") # Plotting include("external/Makie.jl") -# Plotting -include("external/MultivariateStats.jl") - # Functions for IO include("io/geotiff.jl") include("io/ascii.jl") diff --git a/src/external/MultivariateStats.jl b/src/external/MultivariateStats.jl deleted file mode 100644 index 744dcf9c6..000000000 --- a/src/external/MultivariateStats.jl +++ /dev/null @@ -1,49 +0,0 @@ -function _layers_to_matrix(X) - Y = zeros(SimpleSDMLayers._inner_type(X[1]), (length(X), length(X[1]))) - for i in axes(X, 1) - Y[i,:] .= values(X[i]) - end - return Y -end - -# PCA - -function StatsAPI.fit(::Type{MultivariateStats.PCA}, X::Vector{T}; kwargs...) where {T <: SimpleSDMLayers.SimpleSDMLayer} - @assert SimpleSDMLayers._layers_are_compatible(X) - Y = _layers_to_matrix(X) - return StatsAPI.fit(MultivariateStats.PCA, Y; kwargs...) -end - -function StatsAPI.predict(M::MultivariateStats.PCA, X::Vector{T}) where {T <: SimpleSDMLayers.SimpleSDMLayer} - @assert SimpleSDMLayers._layers_are_compatible(X) - Y = _layers_to_matrix(X) - D = StatsAPI.predict(M, Y) - O = [similar(X[1]) for i in 1:MultivariateStats.outdim(M)] - for i in axes(O, 1) - for (j,k) in enumerate(keys(O[i])) - O[i][k] = D[i,j] - end - end - return O -end - -# Whitening - -function StatsAPI.fit(::Type{MultivariateStats.Whitening}, X::Vector{T}; n::Int=1_000, kwargs...) where {T <: SimpleSDMLayers.SimpleSDMLayer} - @assert SimpleSDMLayers._layers_are_compatible(X) - Y = _layers_to_matrix(X) - return StatsAPI.fit(MultivariateStats.Whitening, Y; kwargs...) -end - -function MultivariateStats.transform(W::MultivariateStats.Whitening, X::Vector{T}) where {T <: SimpleSDMLayers.SimpleSDMLayer} - @assert SimpleSDMLayers._layers_are_compatible(X) - Y = _layers_to_matrix(X) - D = MultivariateStats.transform(W, Y) - O = [similar(X[1]) for i in 1:length(X)] - for i in axes(O, 1) - for (j,k) in enumerate(keys(O[i])) - O[i][k] = D[i,j] - end - end - return O -end From cff61e83f1f987c7500beb30a8d67a6f6b0b06ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Poisot?= Date: Sun, 15 Oct 2023 09:55:34 -0400 Subject: [PATCH 3/7] dependencies(sdt): MakieCore --- Project.toml | 13 ++++++++----- src/SpeciesDistributionToolkit.jl | 4 ++-- src/{external/Makie.jl => integrations/makie.jl} | 1 - 3 files changed, 10 insertions(+), 8 deletions(-) rename src/{external/Makie.jl => integrations/makie.jl} (94%) diff --git a/Project.toml b/Project.toml index 6de3d746c..8d93ddc5b 100644 --- a/Project.toml +++ b/Project.toml @@ -5,22 +5,25 @@ version = "0.0.10" [deps] ArchGDAL = "c9ce4bd3-c3d5-55b8-8973-c0e20141b8c3" -CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Fauxcurrences = "a2d61402-033a-4ca9-aef4-652d70cf7c9c" GBIF = "ee291a33-5a6c-5552-a3c8-0f29a1181037" GDAL = "add2ef01-049f-52c4-9ee2-e494f65e021a" MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" -MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Phylopic = "c889285c-44aa-4473-b1e1-56f5d4e3ccf5" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" SimpleSDMDatasets = "2c7d61d0-5c73-410d-85b2-d2e7fbbdcefa" SimpleSDMLayers = "2c645270-77db-11e9-22c3-0f302a89c64c" -StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +[weakdeps] +MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" +StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" + +[extensions] +MultivariateExtension = ["MultivariateStats", "StatsAPI"] + [compat] ArchGDAL = "0.9, 0.10" Distances = "0.10" @@ -34,4 +37,4 @@ Reexport = "1.2" SimpleSDMDatasets = "0.1" SimpleSDMLayers = "0.9" StatsBase = "0.33, 0.34" -julia = "1.8" +julia = "1.9" diff --git a/src/SpeciesDistributionToolkit.jl b/src/SpeciesDistributionToolkit.jl index 13b7e70f9..8a0dad7ee 100644 --- a/src/SpeciesDistributionToolkit.jl +++ b/src/SpeciesDistributionToolkit.jl @@ -31,8 +31,8 @@ include("integrations/gbif_layers.jl") # GBIF and Phylopic integration include("integrations/gbif_phylopic.jl") -# Plotting -include("external/Makie.jl") +# Makie recipes +include("integrations/makie.jl") # Functions for IO include("io/geotiff.jl") diff --git a/src/external/Makie.jl b/src/integrations/makie.jl similarity index 94% rename from src/external/Makie.jl rename to src/integrations/makie.jl index 7e0920f0c..ddc28c4b3 100644 --- a/src/external/Makie.jl +++ b/src/integrations/makie.jl @@ -1,4 +1,3 @@ -# Function to turn a layer into something (Geo)Makie can use function sprinkle(layer::T) where {T <: SimpleSDMLayer} return ( longitudes(layer), From df19a8a9538035d418be05c21663222806fe6a16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Poisot?= Date: Sun, 15 Oct 2023 11:09:05 -0400 Subject: [PATCH 4/7] semver(sdt)!: v0.1.0 NEW MINOR RELEASE --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8d93ddc5b..44481d8f0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SpeciesDistributionToolkit" uuid = "72b53823-5c0b-4575-ad0e-8e97227ad13b" authors = ["Timothée Poisot "] -version = "0.0.10" +version = "0.1.0" [deps] ArchGDAL = "c9ce4bd3-c3d5-55b8-8973-c0e20141b8c3" From fa43ee639fc89af19e1dd22e6fd996d273da8d42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Poisot?= Date: Sun, 15 Oct 2023 11:33:36 -0400 Subject: [PATCH 5/7] feat(sdt): Array for vector of layers --- ext/MultivariateExtension.jl | 25 ++++--------------------- src/SpeciesDistributionToolkit.jl | 3 +++ src/quality_of_life.jl | 13 +++++++++++++ 3 files changed, 20 insertions(+), 21 deletions(-) create mode 100644 src/quality_of_life.jl diff --git a/ext/MultivariateExtension.jl b/ext/MultivariateExtension.jl index f4b53e753..bea8c5f58 100644 --- a/ext/MultivariateExtension.jl +++ b/ext/MultivariateExtension.jl @@ -4,24 +4,13 @@ using SpeciesDistributionToolkit using MultivariateStats using StatsAPI -function _layers_to_matrix(X) - Y = zeros(SimpleSDMLayers._inner_type(X[1]), (length(X), length(X[1]))) - for i in axes(X, 1) - Y[i, :] .= values(X[i]) - end - return Y -end - -# PCA - function StatsAPI.fit( ::Type{MultivariateStats.PCA}, X::Vector{T}; kwargs..., ) where {T <: SimpleSDMLayers.SimpleSDMLayer} @assert SimpleSDMLayers._layers_are_compatible(X) - Y = _layers_to_matrix(X) - return StatsAPI.fit(MultivariateStats.PCA, Y; kwargs...) + return StatsAPI.fit(MultivariateStats.PCA, Array(X); kwargs...) end function StatsAPI.predict( @@ -29,8 +18,7 @@ function StatsAPI.predict( X::Vector{T}, ) where {T <: SimpleSDMLayers.SimpleSDMLayer} @assert SimpleSDMLayers._layers_are_compatible(X) - Y = _layers_to_matrix(X) - D = StatsAPI.predict(M, Y) + D = StatsAPI.predict(M, Array(X)) O = [similar(X[1], eltype(D)) for i in 1:MultivariateStats.outdim(M)] for i in axes(O, 1) for (j, k) in enumerate(keys(O[i])) @@ -40,17 +28,13 @@ function StatsAPI.predict( return O end -# Whitening - function StatsAPI.fit( ::Type{MultivariateStats.Whitening}, X::Vector{T}; - n::Int = 1_000, kwargs..., ) where {T <: SimpleSDMLayers.SimpleSDMLayer} @assert SimpleSDMLayers._layers_are_compatible(X) - Y = _layers_to_matrix(X) - return StatsAPI.fit(MultivariateStats.Whitening, Y; kwargs...) + return StatsAPI.fit(MultivariateStats.Whitening, Array(X); kwargs...) end function MultivariateStats.transform( @@ -58,8 +42,7 @@ function MultivariateStats.transform( X::Vector{T}, ) where {T <: SimpleSDMLayers.SimpleSDMLayer} @assert SimpleSDMLayers._layers_are_compatible(X) - Y = _layers_to_matrix(X) - D = MultivariateStats.transform(W, Y) + D = MultivariateStats.transform(W, Array(X)) O = [similar(X[1], eltype(D)) for i in 1:length(X)] for i in axes(O, 1) for (j, k) in enumerate(keys(O[i])) diff --git a/src/SpeciesDistributionToolkit.jl b/src/SpeciesDistributionToolkit.jl index 8a0dad7ee..a5851e8af 100644 --- a/src/SpeciesDistributionToolkit.jl +++ b/src/SpeciesDistributionToolkit.jl @@ -22,6 +22,9 @@ using Reexport @reexport using Fauxcurrences @reexport using Phylopic +# Quality of life functions +include("quality_of_life.jl") + # SimpleSDMLayers to wrap everything together include("integrations/datasets_layers.jl") diff --git a/src/quality_of_life.jl b/src/quality_of_life.jl new file mode 100644 index 000000000..77403259a --- /dev/null +++ b/src/quality_of_life.jl @@ -0,0 +1,13 @@ +Base.keys(layers::Vector{T}) where {T <: SimpleSDMLayers.SimpleSDMLayer} = reduce(intersect, keys.(layers)) + +function Base.Array(layers::Vector{T}) where {T <: SimpleSDMLayers.SimpleSDMLayer} + k = keys(layers) # List of keys commons to all layers + R = SimpleSDMLayers._inner_type(first(layers)) + X = Matrix{R}(undef, length(layers), length(k)) + for i in eachindex(k) + for j in eachindex(layers) + X[j,i] = layers[j][k[i]] + end + end + return X +end \ No newline at end of file From e6b88cda6c2b06bd9598bae7a198755d420433d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Poisot?= Date: Sun, 15 Oct 2023 11:49:39 -0400 Subject: [PATCH 6/7] feat(sdt): auto-generate the code for most MultivariateStats --- ext/MultivariateExtension.jl | 86 ++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 39 deletions(-) diff --git a/ext/MultivariateExtension.jl b/ext/MultivariateExtension.jl index bea8c5f58..b52e10d38 100644 --- a/ext/MultivariateExtension.jl +++ b/ext/MultivariateExtension.jl @@ -4,52 +4,60 @@ using SpeciesDistributionToolkit using MultivariateStats using StatsAPI -function StatsAPI.fit( - ::Type{MultivariateStats.PCA}, - X::Vector{T}; - kwargs..., -) where {T <: SimpleSDMLayers.SimpleSDMLayer} - @assert SimpleSDMLayers._layers_are_compatible(X) - return StatsAPI.fit(MultivariateStats.PCA, Array(X); kwargs...) -end +# These types have a fit method +types_to_fit = [:PCA, :PPCA, :KernelPCA, :Whitening, :MDS, :MetricMDS] + +# These types have a transform method +types_to_transform = [:Whitening] -function StatsAPI.predict( - M::MultivariateStats.PCA, - X::Vector{T}, -) where {T <: SimpleSDMLayers.SimpleSDMLayer} - @assert SimpleSDMLayers._layers_are_compatible(X) - D = StatsAPI.predict(M, Array(X)) - O = [similar(X[1], eltype(D)) for i in 1:MultivariateStats.outdim(M)] - for i in axes(O, 1) - for (j, k) in enumerate(keys(O[i])) - O[i][k] = D[i, j] +# These types have a predict method +types_to_predict = [:PCA, :PPCA, :KernelPCA, :MDS, :MetricMDS] + +for tf in types_to_fit + eval( + quote + function StatsAPI.fit(::Type{MultivariateStats.$tf}, X::Vector{T}; kwargs...) where {T <: SimpleSDMLayers.SimpleSDMLayer} + @assert SimpleSDMLayers._layers_are_compatible(X) + return StatsAPI.fit(MultivariateStats.$tf, Array(X); kwargs...) + end end - end - return O + ) end -function StatsAPI.fit( - ::Type{MultivariateStats.Whitening}, - X::Vector{T}; - kwargs..., -) where {T <: SimpleSDMLayers.SimpleSDMLayer} - @assert SimpleSDMLayers._layers_are_compatible(X) - return StatsAPI.fit(MultivariateStats.Whitening, Array(X); kwargs...) +for tf in types_to_transform + eval( + quote + function MultivariateStats.transform(f::MultivariateStats.$tf, X::Vector{T}; kwargs...) where {T <: SimpleSDMLayers.SimpleSDMLayer} + @assert SimpleSDMLayers._layers_are_compatible(X) + D = MultivariateStats.transform(f, Array(X); kwargs...) + O = [similar(X[1], eltype(D)) for i in 1:length(X)] + for i in axes(O, 1) + for (j, k) in enumerate(keys(O[i])) + O[i][k] = D[i, j] + end + end + return O + end + end + ) end -function MultivariateStats.transform( - W::MultivariateStats.Whitening, - X::Vector{T}, -) where {T <: SimpleSDMLayers.SimpleSDMLayer} - @assert SimpleSDMLayers._layers_are_compatible(X) - D = MultivariateStats.transform(W, Array(X)) - O = [similar(X[1], eltype(D)) for i in 1:length(X)] - for i in axes(O, 1) - for (j, k) in enumerate(keys(O[i])) - O[i][k] = D[i, j] +for tf in types_to_predict + eval( + quote + function StatsAPI.predict(f::MultivariateStats.$tf, X::Vector{T}; kwargs...) where {T <: SimpleSDMLayers.SimpleSDMLayer} + @assert SimpleSDMLayers._layers_are_compatible(X) + D = StatsAPI.predict(f, Array(X); kwargs...) + O = [similar(X[1], eltype(D)) for i in 1:MultivariateStats.outdim(M)] + for i in axes(O, 1) + for (j, k) in enumerate(keys(O[i])) + O[i][k] = D[i, j] + end + end + return O + end end - end - return O + ) end end \ No newline at end of file From 7ae5b8ffbe786e62d3f43a21a5c217cb439b4e32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Poisot?= Date: Sun, 15 Oct 2023 11:55:30 -0400 Subject: [PATCH 7/7] perf(sdt): Array for a vector of layes Closes #212 --- src/quality_of_life.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/quality_of_life.jl b/src/quality_of_life.jl index 77403259a..e779c19c0 100644 --- a/src/quality_of_life.jl +++ b/src/quality_of_life.jl @@ -1,7 +1,6 @@ -Base.keys(layers::Vector{T}) where {T <: SimpleSDMLayers.SimpleSDMLayer} = reduce(intersect, keys.(layers)) - function Base.Array(layers::Vector{T}) where {T <: SimpleSDMLayers.SimpleSDMLayer} - k = keys(layers) # List of keys commons to all layers + @assert SimpleSDMLayers._layers_are_compatible(layers) + k = reduce(intersect, [findall(!isnothing, grid(layer)) for layer in layers]) R = SimpleSDMLayers._inner_type(first(layers)) X = Matrix{R}(undef, length(layers), length(k)) for i in eachindex(k)