diff --git a/Project.toml b/Project.toml index ace5f347..4a23524c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJFlux" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" authors = ["Anthony D. Blaom ", "Ayush Shridhar "] -version = "0.5.1" +version = "0.6.0" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -33,6 +33,7 @@ julia = "1.9" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -42,4 +43,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"] +test = ["CUDA", "cuDNN", "LinearAlgebra", "Logging", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"] diff --git a/README.md b/README.md index 8702c8e4..949f616b 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,12 @@ -# MLJFlux +
+ image + +An interface to the Flux deep learning models for the [MLJ](https://github.com/alan-turing-institute/MLJ.jl) machine learning framework + +
+ +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/MLJFlux.jl/dev/) -An interface to the Flux deep learning models for the -[MLJ](https://github.com/alan-turing-institute/MLJ.jl) machine -learning framework. | Branch | Julia | CPU CI | GPU CI | Coverage | | -------- | ----- | ------ | ----- | -------- | @@ -21,7 +25,6 @@ learning framework. [coveralls-img-dev]: https://coveralls.io/repos/github/alan-turing-institute/MLJFlux.jl/badge.svg?branch=dev "Code Coverage" [coveralls-url]: https://github.com/FluxML/MLJFlux.jl/actions/workflows/ci.yml -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/MLJFlux.jl/dev/) ## Code Snippet @@ -56,10 +59,10 @@ Wrap in "iteration controls": ```julia stop_conditions = [ Step(1), # Apply controls every epoch - NumberLimit(1000), # Don't train for more than 100 steps - Patience(4), # Stop after 5 iterations of deteriation in validation loss - NumberSinceBest(5), # Or if the best loss occurred 9 iterations ago - TimeLimit(30/60), # Or if 30 minutes passed + NumberLimit(1000), # Don't train for more than 1000 steps + Patience(4), # Stop after 4 iterations of deteriation in validation loss + NumberSinceBest(5), # Or if the best loss occurred 5 iterations ago + TimeLimit(30/60), # Or if 30 minutes has passed ] validation_losses = [] diff --git a/docs/src/extended_examples/MNIST/notebook.jl b/docs/src/extended_examples/MNIST/notebook.jl index 448f50ee..810fae5f 100644 --- a/docs/src/extended_examples/MNIST/notebook.jl +++ b/docs/src/extended_examples/MNIST/notebook.jl @@ -3,12 +3,18 @@ # This tutorial is available as a Jupyter notebook or julia script # [here](https://github.com/FluxML/MLJFlux.jl/tree/dev/docs/src/extended_examples/MNIST). -using Pkg #!md -const DIR = @__DIR__ #!md -Pkg.activate(DIR) #!md -Pkg.instantiate() #!md +# The following code block assumes the current directory contains `Manifest.toml` and +# `Project.toml` files tested for this demonstration, available +# [here](https://github.com/FluxML/MLJFlux.jl/tree/dev/docs/src/extended_examples/MNIST). +# Otherwise, you can try running `using Pkg; Pkg.activate(temp=true)` instead, and +# manually add the relevant packages to the temporary environment created. + +using Pkg +const DIR = @__DIR__ +Pkg.activate(DIR) +Pkg.instantiate() -# **Julia version** is assumed to be 1.10.* +# **Julia version** is assumed to be ≥ 1.10** using MLJ using Flux diff --git a/docs/src/extended_examples/spam_detection/notebook.jl b/docs/src/extended_examples/spam_detection/notebook.jl index 3d712ebf..e855a1a6 100644 --- a/docs/src/extended_examples/spam_detection/notebook.jl +++ b/docs/src/extended_examples/spam_detection/notebook.jl @@ -10,9 +10,15 @@ # **Warning.** This demo includes some non-idiomatic use of MLJ to allow use of the # Flux.jl `Embedding` layer. It is not recommended for MLJ beginners. -using Pkg #!md -Pkg.activate(@__DIR__); #!md -Pkg.instantiate(); #!md +# The following code block assumes the current directory contains `Manifest.toml` and +# `Project.toml` files tested for this demonstration, available +# [here](https://github.com/FluxML/MLJFlux.jl/tree/dev/docs/src/extended_examples/spam_detection). +# Otherwise, you can try running `using Pkg; Pkg.activate(temp=true)` instead, and +# manually add the relevant packages to the temporary environment created. + +using Pkg +Pkg.activate(@__DIR__); +Pkg.instantiate(); # ### Basic Imports using MLJ diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index b90ac223..bbc0b669 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -1,7 +1,6 @@ module MLJFlux export CUDALibs, CPU1 - import Flux using MLJModelInterface using MLJModelInterface.ScientificTypesBase @@ -17,8 +16,10 @@ import Metalhead import Optimisers include("utilities.jl") -const MMI=MLJModelInterface +const MMI = MLJModelInterface +include("encoders.jl") +include("entity_embedding.jl") include("builders.jl") include("metalhead.jl") include("types.jl") @@ -26,6 +27,8 @@ include("core.jl") include("regressor.jl") include("classifier.jl") include("image.jl") +include("fit_utils.jl") +include("entity_embedding_utils.jl") include("mlj_model_interface.jl") export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor @@ -33,6 +36,4 @@ export NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, ImageClassifier export CUDALibs, CPU1 include("deprecated.jl") - - -end #module +end # module diff --git a/src/classifier.jl b/src/classifier.jl index 145eb019..036602cc 100644 --- a/src/classifier.jl +++ b/src/classifier.jl @@ -5,7 +5,6 @@ A private method that returns the shape of the input and output of the model for given data `X` and `y`. - """ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y) X = X isa Matrix ? Tables.table(X) : X @@ -14,6 +13,7 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y) n_input = Tables.schema(X).names |> length return (n_input, n_output) end +is_embedding_enabled(::NeuralNetworkClassifier) = true # builds the end-to-end Flux chain needed, given the `model` and `shape`: MLJFlux.build( @@ -29,24 +29,28 @@ MLJFlux.fitresult( model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, chain, y, -) = (chain, MLJModelInterface.classes(y[1])) + ordinal_mappings = nothing, + embedding_matrices = nothing, +) = (chain, MLJModelInterface.classes(y[1]), ordinal_mappings, embedding_matrices) function MLJModelInterface.predict( model::NeuralNetworkClassifier, fitresult, Xnew, - ) - chain, levels = fitresult +) + chain, levels, ordinal_mappings, _ = fitresult + Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) # what if Xnew is a matrix X = reformat(Xnew) probs = vcat([chain(tomat(X[:, i]))' for i in 1:size(X, 2)]...) return MLJModelInterface.UnivariateFinite(levels, probs) end + MLJModelInterface.metadata_model( NeuralNetworkClassifier, - input_scitype=Union{AbstractMatrix{Continuous},Table(Continuous)}, - target_scitype=AbstractVector{<:Finite}, - load_path="MLJFlux.NeuralNetworkClassifier", + input_scitype = Union{AbstractMatrix{Continuous}, Table(Continuous, Finite)}, + target_scitype = AbstractVector{<:Finite}, + load_path = "MLJFlux.NeuralNetworkClassifier", ) #### Binary Classifier @@ -56,13 +60,15 @@ function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y) n_input = Tables.schema(X).names |> length return (n_input, 1) # n_output is always 1 for a binary classifier end +is_embedding_enabled(::NeuralNetworkBinaryClassifier) = true function MLJModelInterface.predict( model::NeuralNetworkBinaryClassifier, fitresult, Xnew, - ) - chain, levels = fitresult +) + chain, levels, ordinal_mappings, _ = fitresult + Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) X = reformat(Xnew) probs = vec(chain(X)) return MLJModelInterface.UnivariateFinite(levels, probs; augment = true) @@ -70,7 +76,7 @@ end MLJModelInterface.metadata_model( NeuralNetworkBinaryClassifier, - input_scitype=Union{AbstractMatrix{Continuous},Table(Continuous)}, - target_scitype=AbstractVector{<:Finite{2}}, - load_path="MLJFlux.NeuralNetworkBinaryClassifier", + input_scitype = Union{AbstractMatrix{Continuous}, Table(Continuous, Finite)}, + target_scitype = AbstractVector{<:Finite{2}}, + load_path = "MLJFlux.NeuralNetworkBinaryClassifier", ) diff --git a/src/core.jl b/src/core.jl index 938dea7f..f866ee51 100644 --- a/src/core.jl +++ b/src/core.jl @@ -24,6 +24,8 @@ end y, ) -> updated_chain, updated_optimiser_state, training_loss +**Private method.** + Update the parameters of a Flux `chain`, where: - `model` is typically an `MLJFluxModel` instance, but could be any object such that @@ -77,6 +79,8 @@ end y, ) -> (updated_chain, updated_optimiser_state, history) +**Private method.** + Optimize a Flux model `chain`, where `(yhat, y) -> loss(yhat, y)` is the loss function inferred from the `model`. Typically, `model` will be an `MLJFluxModel` instance, but it could be any object such that `model.loss` is a Flux.jl loss function. @@ -162,6 +166,8 @@ end """ gpu_isdead() +**Private method.** + Returns `true` if `acceleration=CUDALibs()` option is unavailable, and false otherwise. @@ -171,6 +177,8 @@ gpu_isdead() = Flux.gpu([1.0,]) isa Array """ nrows(X) +**Private method.** + Find the number of rows of `X`, where `X` is an `AbstractVector or Tables.jl table. """ @@ -268,15 +276,22 @@ input `X` and target `y` in the form required by by `model.batch_size`.) """ -function collate(model, X, y) +function collate(model, X, y, verbosity) row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size) - Xmatrix = reformat(X) + Xmatrix = _f32(reformat(X), verbosity) ymatrix = reformat(y) return [_get(Xmatrix, b) for b in row_batches], [_get(ymatrix, b) for b in row_batches] end -function collate(model::NeuralNetworkBinaryClassifier, X, y) +function collate(model::NeuralNetworkBinaryClassifier, X, y, verbosity) row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size) - Xmatrix = reformat(X) + Xmatrix = _f32(reformat(X), verbosity) yvec = (y .== classes(y)[2])' # convert to boolean return [_get(Xmatrix, b) for b in row_batches], [_get(yvec, b) for b in row_batches] end + +_f32(x::AbstractArray{Float32}, verbosity) = x +function _f32(x::AbstractArray, verbosity) + verbosity > 0 && @info "MLJFlux: converting input data to Float32" + return Float32.(x) +end + diff --git a/src/encoders.jl b/src/encoders.jl new file mode 100644 index 00000000..f3961428 --- /dev/null +++ b/src/encoders.jl @@ -0,0 +1,152 @@ +""" +File containing ordinal encoder and entity embedding encoder. Borrows code from the MLJTransforms package. +""" + +### Ordinal Encoder +""" +**Private Method** + +Fits an ordinal encoder to the table `X`, using only the columns with indices in `featinds`. + +Returns a dictionary mapping each column index to a dictionary mapping each level in that column to an integer. +""" +function ordinal_encoder_fit(X; featinds) + # 1. Define mapping per column per level dictionary + mapping_matrix = Dict() + + # 2. Use feature mapper to compute the mapping of each level in each column + for i in featinds + feat_col = Tables.getcolumn(Tables.columns(X), i) + feat_levels = levels(feat_col) + # Check if feat levels is already ordinal encoded in which case we skip + (Set([Float32(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue + # Compute the dict using the given feature_mapper function + mapping_matrix[i] = + Dict{eltype(feat_levels), Float32}( + value => Float32(index) for (index, value) in enumerate(feat_levels) + ) + end + return mapping_matrix +end + +""" +**Private Method** + +Checks that all levels in `test_levels` are also in `train_levels`. If not, throws an error. +""" +function check_unkown_levels(train_levels, test_levels) + # test levels must be a subset of train levels + if !issubset(test_levels, train_levels) + # get the levels in test that are not in train + lost_levels = setdiff(test_levels, train_levels) + error( + "While transforming, found novel levels for the column: $(lost_levels) that were not seen while training.", + ) + end +end + +""" +**Private Method** + +Transforms the table `X` using the ordinal encoder defined by `mapping_matrix`. + +Returns a new table with the same column names as `X`, but with categorical columns replaced by integer columns. +""" +function ordinal_encoder_transform(X, mapping_matrix) + isnothing(mapping_matrix) && return X + isempty(mapping_matrix) && return X + feat_names = Tables.schema(X).names + numfeats = length(feat_names) + new_feats = [] + for ind in 1:numfeats + col = Tables.getcolumn(Tables.columns(X), ind) + + # Create the transformation function for each column + if ind in keys(mapping_matrix) + train_levels = keys(mapping_matrix[ind]) + test_levels = levels(col) + check_unkown_levels(train_levels, test_levels) + level2scalar = mapping_matrix[ind] + new_col = recode(unwrap.(col), level2scalar...) + push!(new_feats, new_col) + else + push!(new_feats, col) + end + end + + transformed_X = NamedTuple{tuple(feat_names...)}(tuple(new_feats)...) + # Attempt to preserve table type + transformed_X = Tables.materializer(X)(transformed_X) + return transformed_X +end + +""" +**Private Method** + +Combine ordinal_encoder_fit and ordinal_encoder_transform and return both X and ordinal_mappings +""" +function ordinal_encoder_fit_transform(X; featinds) + ordinal_mappings = ordinal_encoder_fit(X; featinds = featinds) + return ordinal_encoder_transform(X, ordinal_mappings), ordinal_mappings +end + + + +## Entity Embedding Encoder (assuming precomputed weights) +""" +**Private method.** + +Function to generate new feature names: feat_name_0, feat_name_1,..., feat_name_n +""" +function generate_new_feat_names(feat_name, num_inds, existing_names) + conflict = true # will be kept true as long as there is a conflict + count = 1 # number of conflicts+1 = number of underscores + + new_column_names = [] + while conflict + suffix = repeat("_", count) + new_column_names = [Symbol("$(feat_name)$(suffix)$i") for i in 1:num_inds] + conflict = any(name -> name in existing_names, new_column_names) + count += 1 + end + return new_column_names +end + + +""" +Given X and a dict of mapping_matrices that map each categorical column to a matrix, use the matrix to transform +each level in each categorical columns using the columns of the matrix. + +This is used with the embedding matrices of the entity embedding layer in entity enabled models to implement entity embeddings. +""" +function embedding_transform(X, mapping_matrices) + (isempty(mapping_matrices)) && return X + feat_names = Tables.schema(X).names + new_feat_names = Symbol[] + new_cols = [] + for feat_name in feat_names + col = Tables.getcolumn(Tables.columns(X), feat_name) + # Create the transformation function for each column + if feat_name in keys(mapping_matrices) + level2vector = mapping_matrices[feat_name] + new_multi_col = map(x -> level2vector[:, Int.(unwrap(x))], col) + new_multi_col = [col for col in eachrow(hcat(new_multi_col...))] + push!(new_cols, new_multi_col...) + feat_names_with_inds = generate_new_feat_names( + feat_name, + size(level2vector, 1), + feat_names, + ) + push!(new_feat_names, feat_names_with_inds...) + else + # Not to be transformed => left as is + push!(new_feat_names, feat_name) + push!(new_cols, col) + end + end + + transformed_X = NamedTuple{tuple(new_feat_names...)}(tuple(new_cols)...) + # Attempt to preserve table type + transformed_X = Tables.materializer(X)(transformed_X) + return transformed_X +end diff --git a/src/entity_embedding.jl b/src/entity_embedding.jl index b87a79ba..313e3e6d 100644 --- a/src/entity_embedding.jl +++ b/src/entity_embedding.jl @@ -1,50 +1,73 @@ -# This is just some experimental code -# to implement EntityEmbeddings for purely -# categorical features +""" +A layer that implements entity embedding layers as presented in 'Entity Embeddings of + Categorical Variables by Cheng Guo, Felix Berkhahn'. Expects a matrix of dimensions (numfeats, batchsize) + and applies entity embeddings to each specified categorical feature. Other features will be left as is. +# Arguments +- `entityprops`: a vector of named tuples each of the form `(index=..., levels=..., newdim=...)` to + specify the feature index, the number of levels and the desired embeddings dimensionality for selected features of the input. +- `numfeats`: the number of features in the input. -using Flux +# Example +```julia +# Prepare a batch of four features where the 2nd and the 4th are categorical +batch = [ + 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1; + 1 2 3 4 5 6 7 8 9 10; + 0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1 + 1 1 2 2 1 1 2 2 1 1; +] -mutable struct EmbeddingMatrix - e - levels - - function EmbeddingMatrix(levels; dim=4) - if dim <= 0 - dimension = div(length(levels), 2) - else - dimension = min(length(levels), dim) # Dummy function for now - end - return new(Dense(length(levels), dimension), levels), dimension - end +entityprops = [ + (index=2, levels=10, newdim=2), + (index=4, levels=2, newdim=1) +] +numfeats = 4 +# Run it through the categorical embedding layer +embedder = EntityEmbedder(entityprops, 4) +julia> output = embedder(batch) +5×10 Matrix{Float64}: + 0.2 0.3 0.4 0.5 … 0.8 0.9 1.0 1.1 + -1.27129 -0.417667 -1.40326 -0.695701 0.371741 1.69952 -1.40034 -2.04078 + -0.166796 0.657619 -0.659249 -0.337757 -0.717179 -0.0176273 -1.2817 -0.0372752 + 0.9 0.1 0.4 0.5 0.8 0.9 1.0 1.1 + -0.847354 -0.847354 -1.66261 -1.66261 -1.66261 -1.66261 -0.847354 -0.847354 +``` +""" # 1. Define layer struct to hold parameters +struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer} + embedders::A1 + modifiers::A2 # applied on the input before passing it to the embedder + numfeats::I end -Flux.@treelike EmbeddingMatrix - -function (embed::EmbeddingMatrix)(ip) - return embed.e(Flux.onehot(ip, embed.levels)) -end +# 2. Define the forward pass (i.e., calling an instance of the layer) +(m::EntityEmbedder)(x) = + (vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)) -mutable struct EntityEmbedding - embeddingmatrix +# 3. Define the constructor which initializes the parameters and returns the instance +function EntityEmbedder(entityprops, numfeats; init = Flux.randn32) + embedders = [] + modifiers = [] + # Setup entityprops + cat_inds = [entityprop.index for entityprop in entityprops] + levels_per_feat = [entityprop.levels for entityprop in entityprops] + newdims = [entityprop.newdim for entityprop in entityprops] - function EntityEmbedding(a...) - return new(a) + c = 1 + for i in 1:numfeats + if i in cat_inds + push!(embedders, Flux.Embedding(levels_per_feat[c] => newdims[c], init = init)) + push!(modifiers, (x, i) -> Int.(x[i, :])) + c += 1 + else + push!(embedders, feat -> feat) + push!(modifiers, (x, i) -> x[i:i, :]) + end end -end -Flux.@treelike EntityEmbedding - - -# ip is an array of tuples -function (embed::EntityEmbedding)(ip) - return hcat((vcat((embed.embeddingmatrix[i](ip[idx][i]) for i=1:length(ip[idx]))...) for idx =1:length(ip))...) + EntityEmbedder(embedders, modifiers, numfeats) end - -# Q1. How should this be called in the API? -# nn = NeuralNetworkClassifier(builder=builder, optimiser = .., embeddingdimension = 5) -# -# -# +# 4. Register it as layer with Flux +Flux.@layer EntityEmbedder \ No newline at end of file diff --git a/src/entity_embedding_utils.jl b/src/entity_embedding_utils.jl new file mode 100644 index 00000000..21e77eda --- /dev/null +++ b/src/entity_embedding_utils.jl @@ -0,0 +1,126 @@ +""" +A file containing functions or constants used in the `fit` and `update` methods in `mlj_model_interface.jl` for setups supporting entity embeddings +""" +is_embedding_enabled(model) = false + +# function to set default new embedding dimension +function set_default_new_embedding_dim(numlevels) + # Set default to the minimum of num_levels-1 and 10 + return min(numlevels - 1, 10) +end + +MISMATCH_INDS(wrong_feats) = + "Features $(join(wrong_feats, ", ")) were specified in embedding_dims hyperparameter but were not recognized as categorical variables because their scitypes are not `Multiclass` or `OrderedFactor`." +function check_mismatch_in_cat_feats(featnames, cat_inds, specified_featinds) + wrong_feats = [featnames[i] for i in specified_featinds if !(i in cat_inds)] + length(wrong_feats) > 0 && throw(ArgumentError(MISMATCH_INDS(wrong_feats))) +end + +# function to set new embedding dimensions +function set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims) + specified_featnames = keys(embedding_dims) + specified_featinds = + [i for i in 1:length(featnames) if featnames[i] in specified_featnames] + check_mismatch_in_cat_feats(featnames, cat_inds, specified_featinds) + catind2numlevels = Dict(zip(cat_inds, num_levels)) + # for each value of embedding dim if float then multiply it by the number of levels + for featname in specified_featnames + if embedding_dims[featname] isa AbstractFloat + embedding_dims[featname] = ceil( + Int, + embedding_dims[featname] * + catind2numlevels[findfirst(x -> x == featname, featnames)], + ) + end + end + newdims = [ + (cat_ind in specified_featinds) ? embedding_dims[featnames[cat_ind]] : + set_default_new_embedding_dim(num_levels[i]) for + (i, cat_ind) in enumerate(cat_inds) + ] + return newdims +end + + +""" +**Private Method** + +Returns the indices of the categorical columns in the table `X`. +""" +function get_cat_inds(X) + # if input is a matrix; conclude no categorical columns + Tables.istable(X) || return Int[] + Xcol = Tables.columns(X) + types = [ + scitype(Tables.getcolumn(Xcol, name)[1]) for + name in Tables.schema(Xcol).names + ] + cat_inds = findall(x -> x <: Finite, types) + return cat_inds +end + +""" +**Private Method** + +Returns the number of levels in each categorical column in the table `X`. +""" +function get_num_levels(X, cat_inds) + num_levels = [] + for i in cat_inds + num_levels = + push!(num_levels, length(levels(Tables.getcolumn(Tables.columns(X), i)))) + end + return num_levels +end + +# A function to prepare the inputs for entity embeddings layer +function prepare_entityembs(X, featnames, cat_inds, embedding_dims) + # 1. Construct entityprops + numfeats = length(featnames) + num_levels = get_num_levels(X, cat_inds) + newdims = set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims) + entityprops = [ + (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i]) for + i in eachindex(cat_inds) + ] + # 2. Compute entityemb_output_dim + sum_newdims = length(newdims) == 0 ? 0 : sum(newdims) + entityemb_output_dim = sum_newdims + numfeats - length(cat_inds) + return entityprops, entityemb_output_dim +end + +# A function to construct model chain including entity embeddings as the first layer +function construct_model_chain_with_entityembs( + model, + rng, + shape, + move, + entityprops, + entityemb_output_dim, +) + chain = try + Flux.Chain( + EntityEmbedder(entityprops, shape[1]; init = Flux.glorot_uniform(rng)), + build(model, rng, (entityemb_output_dim, shape[2])), + ) |> move + catch ex + @error ERR_BUILDER + rethrow() + end + return chain +end + + +# A function that given a model chain, returns a dictionary of embedding matrices +function get_embedding_matrices(chain, cat_inds, featnames) + embedder_layer = chain.layers[1] + embedding_matrices = Dict{Symbol, Matrix{Float32}}() + for cat_ind in cat_inds + featname = featnames[cat_ind] + matrix = Flux.params(embedder_layer.embedders[cat_ind])[1] + embedding_matrices[featname] = matrix + end + return embedding_matrices +end + + diff --git a/src/fit_utils.jl b/src/fit_utils.jl new file mode 100644 index 00000000..b2062791 --- /dev/null +++ b/src/fit_utils.jl @@ -0,0 +1,68 @@ +""" +A file containing functions used in the `fit` and `update` methods in `mlj_model_interface.jl` +""" + +# Converts input to table if it's a matrix +convert_to_table(X) = X isa Matrix ? Tables.table(X) : X + + +# Construct model chain and throws error if it fails +function construct_model_chain(model, rng, shape, move) + chain = try + build(model, rng, shape) |> move + catch ex + @error ERR_BUILDER + rethrow() + end + return chain +end + +# Test whether constructed chain works else throws error +function test_chain_works(x, chain) + try + chain(x) + catch ex + @error ERR_BUILDER + throw(ex) + end +end + +# Models implement L1/L2 regularization by chaining the chosen optimiser with weight/sign +# decay. Note that the weight/sign decay must be scaled down by the number of batches to +# ensure penalization over an epoch does not scale with the choice of batch size; see +# https://github.com/FluxML/MLJFlux.jl/issues/213. + +function regularized_optimiser(model, nbatches) + model.lambda == 0 && return model.optimiser + λ_L1 = model.alpha * model.lambda + λ_L2 = (1 - model.alpha) * model.lambda + λ_sign = λ_L1 / nbatches + λ_weight = 2 * λ_L2 / nbatches + + # recall components in an optimiser chain are executed from left to right: + if model.alpha == 0 + return Optimisers.OptimiserChain( + Optimisers.WeightDecay(λ_weight), + model.optimiser, + ) + elseif model.alpha == 1 + return Optimisers.OptimiserChain( + Optimisers.SignDecay(λ_sign), + model.optimiser, + ) + else + return Optimisers.OptimiserChain( + Optimisers.SignDecay(λ_sign), + Optimisers.WeightDecay(λ_weight), + model.optimiser, + ) + end +end + +# Prepares optimiser for training +function prepare_optimiser(data, model, chain) + nbatches = length(data[2]) + regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) + optimiser_state = Optimisers.setup(regularized_optimiser, chain) + return regularized_optimiser, optimiser_state +end \ No newline at end of file diff --git a/src/image.jl b/src/image.jl index dc8d5637..5af4a033 100644 --- a/src/image.jl +++ b/src/image.jl @@ -10,12 +10,14 @@ function shape(model::ImageClassifier, X, y) end return (n_input, n_output, n_channels) end +is_embedding_enabled(::ImageClassifier) = false + build(model::ImageClassifier, rng, shape) = Flux.Chain(build(model.builder, rng, shape...), model.finaliser) -fitresult(model::ImageClassifier, chain, y) = +fitresult(model::ImageClassifier, chain, y, ::Any, ::Any) = (chain, MLJModelInterface.classes(y[1])) function MLJModelInterface.predict(model::ImageClassifier, fitresult, Xnew) diff --git a/src/mlj_model_interface.jl b/src/mlj_model_interface.jl index aa9850c4..f3d2c9c6 100644 --- a/src/mlj_model_interface.jl +++ b/src/mlj_model_interface.jl @@ -8,8 +8,8 @@ MLJModelInterface.deep_properties(::Type{<:MLJFluxModel}) = # # CLEAN METHOD const ERR_BAD_OPTIMISER = ArgumentError( - "Flux.jl optimiser detected. Only optimisers from Optimisers.jl are supported. "* - "For example, use `optimiser=Optimisers.Momentum()` after `import Optimisers`. " + "Flux.jl optimiser detected. Only optimisers from Optimisers.jl are supported. " * + "For example, use `optimiser=Optimisers.Momentum()` after `import Optimisers`. ", ) function MLJModelInterface.clean!(model::MLJFluxModel) @@ -19,8 +19,8 @@ function MLJModelInterface.clean!(model::MLJFluxModel) model.lambda = 0 end if model.alpha < 0 || model.alpha > 1 - warning *= "Need alpha in the interval `[0, 1]`. "* - "Resetting `alpha = 0`. " + warning *= "Need alpha in the interval `[0, 1]`. " * + "Resetting `alpha = 0`. " model.alpha = 0 end if model.epochs < 0 @@ -32,7 +32,8 @@ function MLJModelInterface.clean!(model::MLJFluxModel) model.batch_size = 1 end if model.acceleration isa CUDALibs && gpu_isdead() - warning *= "`acceleration isa CUDALibs` "* + warning *= + "`acceleration isa CUDALibs` " * "but no CUDA device (GPU) currently live. " end if !(model.acceleration isa CUDALibs || model.acceleration isa CPU1) @@ -40,9 +41,10 @@ function MLJModelInterface.clean!(model::MLJFluxModel) model.acceleration = CPU1() end if model.acceleration isa CUDALibs && model.rng isa Integer - warning *= "Specifying an RNG seed when "* - "`acceleration isa CUDALibs()` may fail for layers depending "* - "on an RNG during training, such as `Dropout`. Consider using "* + warning *= + "Specifying an RNG seed when " * + "`acceleration isa CUDALibs()` may fail for layers depending " * + "on an RNG during training, such as `Dropout`. Consider using " * " `Random.default_rng()` instead. `" end # TODO: This could be removed in next breaking release (0.6.0): @@ -54,73 +56,63 @@ end # # FIT AND UPDATE -const ERR_BUILDER = - "Builder does not appear to build an architecture compatible with supplied data. " +const ERR_BUILDER = "Builder does not appear to build an architecture compatible with supplied data. " true_rng(model) = model.rng isa Integer ? Random.Xoshiro(model.rng) : model.rng -# Models implement L1/L2 regularization by chaining the chosen optimiser with weight/sign -# decay. Note that the weight/sign decay must be scaled down by the number of batches to -# ensure penalization over an epoch does not scale with the choice of batch size; see -# https://github.com/FluxML/MLJFlux.jl/issues/213. - -function regularized_optimiser(model, nbatches) - model.lambda == 0 && return model.optimiser - λ_L1 = model.alpha*model.lambda - λ_L2 = (1 - model.alpha)*model.lambda - λ_sign = λ_L1/nbatches - λ_weight = 2*λ_L2/nbatches - - # recall components in an optimiser chain are executed from left to right: - if model.alpha == 0 - return Optimisers.OptimiserChain( - Optimisers.WeightDecay(λ_weight), - model.optimiser, - ) - elseif model.alpha == 1 - return Optimisers.OptimiserChain( - Optimisers.SignDecay(λ_sign), - model.optimiser, - ) - else return Optimisers.OptimiserChain( - Optimisers.SignDecay(λ_sign), - Optimisers.WeightDecay(λ_weight), - model.optimiser, - ) - end -end function MLJModelInterface.fit(model::MLJFluxModel, - verbosity, - X, - y) - - move = Mover(model.acceleration) - + verbosity, + X, + y) + # GPU and rng related variables + move = MLJFlux.Mover(model.acceleration) rng = true_rng(model) + + # Get input properties shape = MLJFlux.shape(model, X, y) + cat_inds = MLJFlux.get_cat_inds(X) + pure_continuous_input = isempty(cat_inds) + + # Decide whether to enable entity embeddings (e.g., ImageClassifier won't) + enable_entity_embs = MLJFlux.is_embedding_enabled(model) && !pure_continuous_input - chain = try - build(model, rng, shape) |> move - catch ex - @error ERR_BUILDER - rethrow() + # Prepare entity embeddings inputs and encode X if entity embeddings enabled + featnames = [] + if enable_entity_embs + X = MLJFlux.convert_to_table(X) + featnames = Tables.schema(X).names end - data = move.(collate(model, X, y)) - x = data[1][1] + # entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i]) + # for each categorical feature + default_embedding_dims = enable_entity_embs ? model.embedding_dims : Dict{Symbol, Real}() + entityprops, entityemb_output_dim = + MLJFlux.prepare_entityembs(X, featnames, cat_inds, default_embedding_dims) + X, ordinal_mappings = MLJFlux.ordinal_encoder_fit_transform(X; featinds = cat_inds) + + ## Construct model chain + chain = + (!enable_entity_embs) ? construct_model_chain(model, rng, shape, move) : + MLJFlux.construct_model_chain_with_entityembs( + model, + rng, + shape, + move, + entityprops, + entityemb_output_dim, + ) - try - chain(x) - catch ex - @error ERR_BUILDER - throw(ex) - end + # Format data as needed by Flux and move to GPU + data = move.(MLJFlux.collate(model, X, y, verbosity)) - nbatches = length(data[2]) - regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) - optimiser_state = Optimisers.setup(regularized_optimiser, chain) + # Test chain works (as it may be custom) + x = data[1][1] + test_chain_works(x, chain) + # Train model with Flux + regularized_optimiser, optimiser_state = + prepare_optimiser(data, model, chain) chain, optimiser_state, history = train( model, chain, @@ -132,6 +124,10 @@ function MLJModelInterface.fit(model::MLJFluxModel, data[2], ) + # Extract embedding matrices + embedding_matrices = get_embedding_matrices(chain, cat_inds, featnames) + + # Prepare cache for potential warm restarts cache = ( deepcopy(model), data, @@ -141,31 +137,62 @@ function MLJModelInterface.fit(model::MLJFluxModel, optimiser_state, deepcopy(rng), move, + entityprops, + entityemb_output_dim, + ordinal_mappings, + featnames, ) - fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) - report = (training_losses=history, ) + # Prepare fitresult + fitresult = + MLJFlux.fitresult(model, Flux.cpu(chain), y, ordinal_mappings, embedding_matrices) + + # Prepare report + report = (training_losses = history,) return fitresult, cache, report end function MLJModelInterface.update(model::MLJFluxModel, - verbosity, - old_fitresult, - old_cache, - X, - y) - - old_model, data, old_history, shape, regularized_optimiser, - optimiser_state, rng, move = old_cache + verbosity, + old_fitresult, + old_cache, + X, + y) + # Decide whether to enable entity embeddings (e.g., ImageClassifier won't) + cat_inds = get_cat_inds(X) + pure_continuous_input = (length(cat_inds) == 0) + enable_entity_embs = is_embedding_enabled(model) && !pure_continuous_input + + # Unpack cache from previous fit + old_model, + data, + old_history, + shape, + regularized_optimiser, + optimiser_state, + rng, + move, + entityprops, + entityemb_output_dim, + ordinal_mappings, + featnames = old_cache + cat_inds = [prop.index for prop in entityprops] + + # Extract chain old_chain = old_fitresult[1] - optimiser_flag = model.optimiser_changes_trigger_retraining && + # Decide whether optimiser should trigger retraining from scratch + optimiser_flag = + model.optimiser_changes_trigger_retraining && model.optimiser != old_model.optimiser - keep_chain = !optimiser_flag && model.epochs >= old_model.epochs && + # Decide whether to retrain from scratch + keep_chain = + !optimiser_flag && model.epochs >= old_model.epochs && MLJModelInterface.is_same_except(model, old_model, :optimiser, :epochs) + # Use old chain if not retraining from scratch or reconstruct and prepare to retrain if keep_chain chain = move(old_chain) epochs = model.epochs - old_model.epochs @@ -173,15 +200,29 @@ function MLJModelInterface.update(model::MLJFluxModel, else move = Mover(model.acceleration) rng = true_rng(model) - chain = build(model, rng, shape) |> move + X = convert_to_table(X) + X = ordinal_encoder_transform(X, ordinal_mappings) + if enable_entity_embs + chain = + construct_model_chain_with_entityembs( + model, + rng, + shape, + move, + entityprops, + entityemb_output_dim, + ) + else + chain = construct_model_chain(model, rng, shape, move) + end # reset `optimiser_state`: - data = move.(collate(model, X, y)) - nbatches = length(data[2]) - regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) - optimiser_state = Optimisers.setup(regularized_optimiser, chain) + data = move.(collate(model, X, y, verbosity)) + regularized_optimiser, optimiser_state = + prepare_optimiser(data, model, chain) epochs = model.epochs end + # Train model with Flux chain, optimiser_state, history = train( model, chain, @@ -192,12 +233,17 @@ function MLJModelInterface.update(model::MLJFluxModel, data[1], data[2], ) + + # Properly set history if keep_chain # note: history[1] = old_history[end] history = vcat(old_history[1:end-1], history) end - fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) + # Extract embedding matrices + embedding_matrices = get_embedding_matrices(chain, cat_inds, featnames) + + # Prepare cache, fitresult, and report cache = ( deepcopy(model), data, @@ -207,15 +253,33 @@ function MLJModelInterface.update(model::MLJFluxModel, optimiser_state, deepcopy(rng), move, + entityprops, entityemb_output_dim, ordinal_mappings, featnames, ) - report = (training_losses=history, ) + fitresult = + MLJFlux.fitresult(model, Flux.cpu(chain), y, ordinal_mappings, embedding_matrices) + report = (training_losses = history,) return fitresult, cache, report end + +# Transformer for entity-enabled models +function MLJModelInterface.transform( + transformer::MLJFluxModel, + fitresult, + Xnew, +) + # if it doesn't have the property its not an entity-enabled model + is_embedding_enabled(transformer) || return Xnew + ordinal_mappings, embedding_matrices = fitresult[3:4] + Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) + Xnew_transf = embedding_transform(Xnew, embedding_matrices) + return Xnew_transf +end + MLJModelInterface.fitted_params(::MLJFluxModel, fitresult) = - (chain=fitresult[1],) + (chain = fitresult[1],) # # SUPPORT FOR MLJ ITERATION API diff --git a/src/regressor.jl b/src/regressor.jl index 222560b7..1fd9bcbc 100644 --- a/src/regressor.jl +++ b/src/regressor.jl @@ -11,25 +11,31 @@ function shape(model::NeuralNetworkRegressor, X, y) n_ouput = 1 return (n_input, 1) end +is_embedding_enabled(::NeuralNetworkRegressor) = true + build(model::NeuralNetworkRegressor, rng, shape) = build(model.builder, rng, shape...) -fitresult(model::NeuralNetworkRegressor, chain, y) = (chain, nothing) +fitresult(model::NeuralNetworkRegressor, chain, y, ordinal_mappings=nothing, embedding_matrices=nothing) = + (chain, nothing, ordinal_mappings, embedding_matrices) + + function MLJModelInterface.predict(model::NeuralNetworkRegressor, fitresult, Xnew) - chain = fitresult[1] + chain, ordinal_mappings = fitresult[1], fitresult[3] + Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) Xnew_ = reformat(Xnew) return [chain(values.(tomat(Xnew_[:, i])))[1] for i in 1:size(Xnew_, 2)] end MLJModelInterface.metadata_model(NeuralNetworkRegressor, - input=Union{AbstractMatrix{Continuous},Table(Continuous)}, - target=AbstractVector{<:Continuous}, - path="MLJFlux.NeuralNetworkRegressor") + input = Union{AbstractMatrix{Continuous}, Table(Continuous,Finite)}, + target = AbstractVector{<:Continuous}, + path = "MLJFlux.NeuralNetworkRegressor") # # MULTITARGET NEURAL NETWORK REGRESSOR @@ -42,34 +48,42 @@ ncols(X) = Tables.columns(X) |> Tables.columnnames |> length A private method that returns the shape of the input and output of the model for given data `X` and `y`. - """ shape(model::MultitargetNeuralNetworkRegressor, X, y) = (ncols(X), ncols(y)) +is_embedding_enabled(::MultitargetNeuralNetworkRegressor) = true build(model::MultitargetNeuralNetworkRegressor, rng, shape) = build(model.builder, rng, shape...) -function fitresult(model::MultitargetNeuralNetworkRegressor, chain, y) +function fitresult( + model::MultitargetNeuralNetworkRegressor, + chain, + y, + ordinal_mappings=nothing, + embedding_matrices=nothing, +) if y isa Matrix target_column_names = nothing else target_column_names = Tables.schema(y).names end - return (chain, target_column_names) + return (chain, target_column_names, ordinal_mappings, embedding_matrices) end function MLJModelInterface.predict(model::MultitargetNeuralNetworkRegressor, fitresult, Xnew) - chain, target_column_names = fitresult + chain, target_column_names, ordinal_mappings, _ = fitresult + Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) X = reformat(Xnew) ypred = [chain(values.(tomat(X[:, i]))) for i in 1:size(X, 2)] - output = isnothing(target_column_names) ? permutedims(reduce(hcat, ypred)) : - MLJModelInterface.table(reduce(hcat, ypred)', names=target_column_names) + output = + isnothing(target_column_names) ? permutedims(reduce(hcat, ypred)) : + MLJModelInterface.table(reduce(hcat, ypred)', names = target_column_names) return output end MLJModelInterface.metadata_model(MultitargetNeuralNetworkRegressor, - input=Union{AbstractMatrix{Continuous},Table(Continuous)}, - target=Union{AbstractMatrix{Continuous}, Table(Continuous)}, - path="MLJFlux.MultitargetNeuralNetworkRegressor") + input = Union{AbstractMatrix{Continuous}, Table(Continuous,Finite)}, + target = Union{AbstractMatrix{Continuous}, Table(Continuous)}, + path = "MLJFlux.MultitargetNeuralNetworkRegressor") diff --git a/src/types.jl b/src/types.jl index e7bb880d..bf641451 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,142 +1,230 @@ abstract type MLJFluxProbabilistic <: MLJModelInterface.Probabilistic end abstract type MLJFluxDeterministic <: MLJModelInterface.Deterministic end -const MLJFluxModel = Union{MLJFluxProbabilistic,MLJFluxDeterministic} - -for Model in [:NeuralNetworkClassifier, :NeuralNetworkBinaryClassifier, :ImageClassifier] - - # default settings that are not equal across models - default_builder_ex = - Model == :ImageClassifier ? :(image_builder(VGGHack)) : Short() - default_finaliser = - Model == :NeuralNetworkBinaryClassifier ? Flux.σ : Flux.softmax - default_loss = - Model == :NeuralNetworkBinaryClassifier ? Flux.binarycrossentropy : Flux.crossentropy - - quote - mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic - builder::B - finaliser::F - optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl - loss::L # can be called as in `loss(yhat, y)` - epochs::Int # number of epochs - batch_size::Int # size of a batch - lambda::Float64 # regularization strength - alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) - rng::Union{AbstractRNG,Int64} - optimiser_changes_trigger_retraining::Bool - acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` - end - - function $Model( - ;builder::B=$default_builder_ex, - finaliser::F=$default_finaliser, - optimiser::O=Optimisers.Adam(), - loss::L=$default_loss, - epochs=10, - batch_size=1, - lambda=0, - alpha=0, - rng=Random.default_rng(), - optimiser_changes_trigger_retraining=false, - acceleration=CPU1(), - ) where {B,F,O,L} - - model = $Model{B,F,O,L}( - builder, - finaliser, - optimiser, - loss, - epochs, - batch_size, - lambda, - alpha, - rng, - optimiser_changes_trigger_retraining, - acceleration, - ) - - message = clean!(model) - isempty(message) || @warn message - - return model - end - - end |> eval +const MLJFluxModel = Union{MLJFluxProbabilistic, MLJFluxDeterministic} + +for Model in [:NeuralNetworkClassifier, :NeuralNetworkBinaryClassifier] + + # default settings that are not equal across models + default_finaliser = + Model == :NeuralNetworkBinaryClassifier ? Flux.σ : Flux.softmax + default_loss = + Model == :NeuralNetworkBinaryClassifier ? Flux.binarycrossentropy : + Flux.crossentropy + + quote + mutable struct $Model{B, F, O, L} <: MLJFluxProbabilistic + builder::B + finaliser::F + optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl + loss::L # can be called as in `loss(yhat, y)` + epochs::Int # number of epochs + batch_size::Int # size of a batch + lambda::Float64 # regularization strength + alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) + rng::Union{AbstractRNG, Int64} + optimiser_changes_trigger_retraining::Bool + acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` + embedding_dims::Dict{Symbol, Real} + end + + function $Model( + ; builder::B = Short(), + finaliser::F = $default_finaliser, + optimiser::O = Optimisers.Adam(), + loss::L = $default_loss, + epochs = 10, + batch_size = 1, + lambda = 0, + alpha = 0, + rng = Random.default_rng(), + optimiser_changes_trigger_retraining = false, + acceleration = CPU1(), + embedding_dims = Dict{Symbol, Real}(), + ) where {B, F, O, L} + + model = $Model{B, F, O, L}( + builder, + finaliser, + optimiser, + loss, + epochs, + batch_size, + lambda, + alpha, + rng, + optimiser_changes_trigger_retraining, + acceleration, + embedding_dims, + ) + + message = clean!(model) + isempty(message) || @warn message + + return model + end + + end |> eval end - for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor] - quote - mutable struct $Model{B,O,L} <: MLJFluxDeterministic - builder::B - optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl - loss::L # can be called as in `loss(yhat, y)` - epochs::Int # number of epochs - batch_size::Int # size of a batch - lambda::Float64 # regularization strength - alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) - rng::Union{AbstractRNG,Integer} - optimiser_changes_trigger_retraining::Bool - acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` - end - - function $Model( - ; builder::B=Linear(), - optimiser::O=Optimisers.Adam(), - loss::L=Flux.mse, - epochs=10, - batch_size=1, - lambda=0, - alpha=0, - rng=Random.default_rng(), - optimiser_changes_trigger_retraining=false, - acceleration=CPU1(), - ) where {B,O,L} - - model = $Model{B,O,L}( - builder, - optimiser, - loss, - epochs, - batch_size, - lambda, - alpha, - rng, - optimiser_changes_trigger_retraining, - acceleration, - ) - - message = clean!(model) - isempty(message) || @warn message - - return model - end - - end |> eval + quote + mutable struct $Model{B, O, L} <: MLJFluxDeterministic + builder::B + optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl + loss::L # can be called as in `loss(yhat, y)` + epochs::Int # number of epochs + batch_size::Int # size of a batch + lambda::Float64 # regularization strength + alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) + rng::Union{AbstractRNG, Integer} + optimiser_changes_trigger_retraining::Bool + acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` + embedding_dims::Dict{Symbol, Real} + end + + function $Model( + ; builder::B = Linear(), + optimiser::O = Optimisers.Adam(), + loss::L = Flux.mse, + epochs = 10, + batch_size = 1, + lambda = 0, + alpha = 0, + rng = Random.default_rng(), + optimiser_changes_trigger_retraining = false, + acceleration = CPU1(), + embedding_dims = Dict{Symbol, Real}(), + ) where {B, O, L} + + model = $Model{B, O, L}( + builder, + optimiser, + loss, + epochs, + batch_size, + lambda, + alpha, + rng, + optimiser_changes_trigger_retraining, + acceleration, + embedding_dims, + ) + + message = clean!(model) + isempty(message) || @warn message + + return model + end + + end |> eval end const Regressor = - Union{NeuralNetworkRegressor,MultitargetNeuralNetworkRegressor} + Union{NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor} + + +# Separately define the ImageClassifier +mutable struct ImageClassifier{B, F, O, L} <: MLJFluxProbabilistic + builder::B + finaliser::F + optimiser::O + loss::L + epochs::Int + batch_size::Int + lambda::Float64 + alpha::Float64 + rng::Union{AbstractRNG, Int64} + optimiser_changes_trigger_retraining::Bool + acceleration::AbstractResource +end + +function ImageClassifier( + ; builder::B = image_builder(VGGHack), + finaliser::F = Flux.softmax, + optimiser::O = Optimisers.Adam(), + loss::L = Flux.crossentropy, + epochs = 10, + batch_size = 1, + lambda = 0, + alpha = 0, + rng = Random.default_rng(), + optimiser_changes_trigger_retraining = false, + acceleration = CPU1(), +) where {B, F, O, L} + model = ImageClassifier{B, F, O, L}( + builder, + finaliser, + optimiser, + loss, + epochs, + batch_size, + lambda, + alpha, + rng, + optimiser_changes_trigger_retraining, + acceleration, + ) + + message = clean!(model) + isempty(message) || @warn message + + return model +end + MMI.metadata_pkg.( - ( - NeuralNetworkRegressor, - MultitargetNeuralNetworkRegressor, - NeuralNetworkClassifier, - ImageClassifier, - NeuralNetworkBinaryClassifier, - ), - name="MLJFlux", - uuid="094fc8d1-fd35-5302-93ea-dabda2abf845", - url="https://github.com/alan-turing-institute/MLJFlux.jl", - julia=true, - license="MIT", + ( + NeuralNetworkRegressor, + MultitargetNeuralNetworkRegressor, + NeuralNetworkClassifier, + ImageClassifier, + NeuralNetworkBinaryClassifier, + ), + name = "MLJFlux", + uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845", + url = "https://github.com/alan-turing-institute/MLJFlux.jl", + julia = true, + license = "MIT", ) +const MODELSUPPORTDOC = """ +In addition to features with `Continuous` scientific element type, this model supports +categorical features in the input table. If present, such features are embedded into dense +vectors by the use of an additional `EntityEmbedder` layer after the input, as described in +Entity Embeddings of Categorical Variables by Cheng Guo, Felix Berkhahn arXiv, 2016. +""" + +const XDOC = """ + +- `X` provides input features and is either: (i) a `Matrix` with `Continuous` element + scitype (typically `Float32`); or (ii) a table of input features (eg, a `DataFrame`) + whose columns have `Continuous`, `Multiclass` or `OrderedFactor` element scitype; check + column scitypes with `schema(X)`. If any `Multiclass` or `OrderedFactor` features + appear, the constructed network will use an `EntityEmbedder` layer to transform + them into dense vectors. If `X` is a `Matrix`, it is assumed that columns correspond to + features and rows corresponding to observations. + +""" + +const EMBDOC = """ +- `embedding_dims`: a `Dict` whose keys are names of categorical features, given as + symbols, and whose values are numbers representing the desired dimensionality of the + entity embeddings of such features: an integer value of `7`, say, sets the embedding + dimensionality to `7`; a float value of `0.5`, say, sets the embedding dimensionality to + `ceil(0.5 * c)`, where `c` is the number of feature levels. Unspecified feature + dimensionality defaults to `min(c - 1, 10)`. +""" + +const TRANSFORMDOC = """ +- `transform(mach, Xnew)`: Assuming `Xnew` has the same schema as `X`, transform the + categorical features of `Xnew` into dense `Continuous` vectors using the + `MLJFlux.EntityEmbedder` layer present in the network. Does nothing in case the model + was trained on an input `X` that lacks categorical features. +""" # # DOCSTRINGS @@ -149,6 +237,8 @@ given a table of `Continuous` features. Users provide a recipe for constructing the network, based on properties of the data that is encountered, by specifying an appropriate `builder`. See MLJFlux documentation for more on builders. +$MODELSUPPORTDOC + # Training data In MLJ or MLJBase, bind an instance `model` to data with @@ -157,8 +247,7 @@ In MLJ or MLJBase, bind an instance `model` to data with Here: -- `X` is either a `Matrix` or any table of input features (eg, a `DataFrame`) whose columns are of scitype - `Continuous`; check column scitypes with `schema(X)`. If `X` is a `Matrix`, it is assumed to have columns corresponding to features and rows corresponding to observations. +$XDOC - `y` is the target, which can be any `AbstractVector` whose element scitype is `Multiclass` or `OrderedFactor`; check the scitype with `scitype(y)` @@ -227,6 +316,7 @@ Train the machine with `fit!(mach, rows=...)`. - `finaliser=Flux.softmax`: The final activation function of the neural network (applied after the network defined by `builder`). Defaults to `Flux.softmax`. +$EMBDOC # Operations @@ -236,6 +326,7 @@ Train the machine with `fit!(mach, rows=...)`. - `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions returned above. +$TRANSFORMDOC # Fitted parameters @@ -338,6 +429,8 @@ given a table of `Continuous` features. Users provide a recipe for constructing the network, based on properties of the data that is encountered, by specifying an appropriate `builder`. See MLJFlux documentation for more on builders. +$MODELSUPPORTDOC + # Training data In MLJ or MLJBase, bind an instance `model` to data with @@ -346,9 +439,7 @@ In MLJ or MLJBase, bind an instance `model` to data with Here: -- `X` is either a `Matrix` or any table of input features (eg, a `DataFrame`) whose columns are of scitype - `Continuous`; check column scitypes with `schema(X)`. If `X` is a `Matrix`, - it is assumed to have columns corresponding to features and rows corresponding to observations. +$XDOC - `y` is the target, which can be any `AbstractVector` whose element scitype is `Multiclass{2}` or `OrderedFactor{2}`; check the scitype with `scitype(y)` @@ -418,6 +509,7 @@ Train the machine with `fit!(mach, rows=...)`. - `finaliser=Flux.σ`: The final activation function of the neural network (applied after the network defined by `builder`). Defaults to `Flux.σ`. +$EMBDOC # Operations @@ -427,6 +519,8 @@ Train the machine with `fit!(mach, rows=...)`. - `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions returned above. +$TRANSFORMDOC + # Fitted parameters @@ -791,6 +885,7 @@ predict a `Continuous` target, given a table of `Continuous` features. Users pro recipe for constructing the network, based on properties of the data that is encountered, by specifying an appropriate `builder`. See MLJFlux documentation for more on builders. +$MODELSUPPORTDOC # Training data @@ -800,8 +895,8 @@ In MLJ or MLJBase, bind an instance `model` to data with Here: -- `X` is either a `Matrix` or any table of input features (eg, a `DataFrame`) whose columns are of scitype - `Continuous`; check column scitypes with `schema(X)`. If `X` is a `Matrix`, it is assumed to have columns corresponding to features and rows corresponding to observations. +$XDOC + - `y` is the target, which can be any `AbstractVector` whose element scitype is `Continuous`; check the scitype with `scitype(y)` @@ -859,12 +954,14 @@ Train the machine with `fit!(mach, rows=...)`. - `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CUDALibs()`. +$EMBDOC # Operations - `predict(mach, Xnew)`: return predictions of the target given new features `Xnew`, which should have the same scitype as `X` above. +$TRANSFORMDOC # Fitted parameters @@ -1020,6 +1117,8 @@ table of `Continuous` features. Users provide a recipe for constructing the netw on properties of the data that is encountered, by specifying an appropriate `builder`. See MLJFlux documentation for more on builders. +$MODELSUPPORTDOC + # Training data In MLJ or MLJBase, bind an instance `model` to data with @@ -1028,10 +1127,7 @@ In MLJ or MLJBase, bind an instance `model` to data with Here: -- `X` is either a `Matrix` or any table of input features (eg, a `DataFrame`) whose - columns are of scitype `Continuous`; check column scitypes with `schema(X)`. If `X` is a - `Matrix`, it is assumed to have columns corresponding to features and rows corresponding - to observations. +$XDOC - `y` is the target, which can be any table or matrix of output targets whose element scitype is `Continuous`; check column scitypes with `schema(y)`. If `y` is a `Matrix`, @@ -1090,6 +1186,7 @@ Here: - `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CUDALibs()`. +$EMBDOC # Operations @@ -1097,6 +1194,7 @@ Here: features `Xnew` having the same scitype as `X` above. Predictions are deterministic. +$TRANSFORMDOC # Fitted parameters @@ -1190,16 +1288,14 @@ With the learning rate fixed, we can now compute a CV estimate of the performanc all data bound to `mach`) and compare this with performance on the test set: ```julia -# custom MLJ loss: -multi_loss(yhat, y) = l2(MLJ.matrix(yhat), MLJ.matrix(y)) # CV estimate, based on `(X, y)`: -evaluate!(mach, resampling=CV(nfolds=5), measure=multi_loss) +evaluate!(mach, resampling=CV(nfolds=5), measure=multitarget_l2) # loss for `(Xtest, test)`: fit!(mach) # trains on all data `(X, y)` yhat = predict(mach, Xtest) -multi_loss(yhat, ytest) +multitarget_l2(yhat, ytest) ``` See also diff --git a/test/classifier.jl b/test/classifier.jl index ce167b39..2dbafef5 100644 --- a/test/classifier.jl +++ b/test/classifier.jl @@ -2,14 +2,25 @@ seed!(1234) N = 300 -X = MLJBase.table(rand(Float32, N, 4)); -ycont = 2*X.x1 - X.x3 + 0.1*rand(N) +Xm = MLJBase.table(randn(Float32, N, 5)); # purely numeric +X = (; Tables.columntable(Xm)..., + Column1 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)), + Column2 = categorical(repeat(['a', 'b', 'c', 'd', 'e'], Int(N / 5))), + Column3 = categorical(repeat(["b", "c", "d", "f", "f"], Int(N / 5)), ordered = true), + Column4 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)), + Column5 = randn(Float32, N), + Column6 = categorical( + repeat(["group1", "group1", "group2", "group2", "group3"], Int(N / 5)), + ), +) + +ycont = 2 * X.x1 - X.x3 + 0.1 * rand(Float32, N) m, M = minimum(ycont), maximum(ycont) -_, a, b, _ = range(m, stop=M, length=4) |> collect +_, a, b, _ = range(m, stop = M, length = 4) |> collect y = map(ycont) do η - if η < 0.9*a + if η < 0.9 * a 'a' - elseif η < 1.1*b + elseif η < 1.1 * b 'b' else 'c' @@ -20,7 +31,7 @@ end |> categorical; # builer instead of the default `Short()` because `Dropout()` in `Short()` does not appear # to behave the same on GPU as on a CPU, even when we use `default_rng()` for both. -builder = MLJFlux.MLP(hidden=(8,)) +builder = MLJFlux.MLP(hidden = (8,)) optimiser = Optimisers.Adam(0.03) losses = [] @@ -30,32 +41,42 @@ losses = [] # Table input: @testset "Table input" begin basictest(MLJFlux.NeuralNetworkClassifier, - X, - y, - builder, - optimiser, - 0.85, - accel) + X, + y, + builder, + optimiser, + 0.85, + accel) + end + + @testset "Table input numerical" begin + basictest(MLJFlux.NeuralNetworkClassifier, + Xm, + y, + builder, + optimiser, + 0.85, + accel) end # Matrix input: @testset "Matrix input" begin basictest(MLJFlux.NeuralNetworkClassifier, - matrix(X), - y, - builder, - optimiser, - 0.85, - accel) + matrix(Xm), + y, + builder, + optimiser, + 0.85, + accel) end train, test = MLJBase.partition(1:N, 0.7) # baseline loss (predict constant probability distribution): dict = StatsBase.countmap(y[train]) - prob_given_class = Dict{CategoricalArrays.CategoricalValue,Float64}() + prob_given_class = Dict{CategoricalArrays.CategoricalValue, Float64}() for (k, v) in dict - prob_given_class[k] = dict[k]/length(train) + prob_given_class[k] = dict[k] / length(train) end dist = MLJBase.UnivariateFinite(prob_given_class) loss_baseline = @@ -66,36 +87,37 @@ losses = [] # (GPUs only support `default_rng`): rng = Random.default_rng() seed!(rng, 123) - model = MLJFlux.NeuralNetworkClassifier(epochs=50, - builder=builder, - optimiser=optimiser, - acceleration=accel, - batch_size=10, - rng=rng) - @time mach = fit!(machine(model, X, y), rows=train, verbosity=0) + model = MLJFlux.NeuralNetworkClassifier(epochs = 50, + builder = builder, + optimiser = optimiser, + acceleration = accel, + batch_size = 10, + rng = rng) + @time mach = fit!(machine(model, X, y), rows = train, verbosity = 0) first_last_training_loss = MLJBase.report(mach)[1][[1, end]] push!(losses, first_last_training_loss[2]) - yhat = MLJBase.predict(mach, rows=test); - @test StatisticalMeasures.cross_entropy(yhat, y[test]) < 0.95*loss_baseline + yhat = MLJBase.predict(mach, rows = test) + @test StatisticalMeasures.cross_entropy(yhat, y[test]) < 0.95 * loss_baseline optimisertest(MLJFlux.NeuralNetworkClassifier, - X, - y, - builder, - optimiser, - accel) + X, + y, + builder, + optimiser, + accel) end # check different resources (CPU1, CUDALibs, etc)) give about the same loss: reference = losses[1] -@test all(x->abs(x - reference)/reference < 1e-5, losses[2:end]) +println("losses for each resource: $losses") +@test all(x -> abs(x - reference) / reference < 0.03, losses[2:end]) # # NEURAL NETWORK BINARY CLASSIFIER @testset "NeuralNetworkBinaryClassifier constructor" begin - model = NeuralNetworkBinaryClassifier() + model = MLJFlux.NeuralNetworkBinaryClassifier() @test model.loss == Flux.binarycrossentropy @test model.builder isa MLJFlux.Short @test model.finaliser == Flux.σ @@ -104,18 +126,18 @@ end seed!(1234) N = 300 X = MLJBase.table(rand(Float32, N, 4)); -ycont = 2*X.x1 - X.x3 + 0.1*rand(N) +ycont = Float32.(2 * X.x1 - X.x3 + 0.1 * rand(N)) m, M = minimum(ycont), maximum(ycont) -_, a, _ = range(m, stop=M, length=3) |> collect +_, a, _ = range(m, stop = M, length = 3) |> collect y = map(ycont) do η - if η < 0.9*a + if η < 0.9 * a 'a' else 'b' end end |> categorical; -builder = MLJFlux.MLP(hidden=(8,)) +builder = MLJFlux.MLP(hidden = (8,)) optimiser = Optimisers.Adam(0.03) @testset_accelerated "NeuralNetworkBinaryClassifier" accel begin @@ -150,9 +172,9 @@ optimiser = Optimisers.Adam(0.03) # baseline loss (predict constant probability distribution): dict = StatsBase.countmap(y[train]) - prob_given_class = Dict{CategoricalArrays.CategoricalValue,Float64}() + prob_given_class = Dict{CategoricalArrays.CategoricalValue, Float64}() for (k, v) in dict - prob_given_class[k] = dict[k]/length(train) + prob_given_class[k] = dict[k] / length(train) end dist = MLJBase.UnivariateFinite(prob_given_class) loss_baseline = @@ -164,17 +186,17 @@ optimiser = Optimisers.Adam(0.03) rng = Random.default_rng() seed!(rng, 123) model = MLJFlux.NeuralNetworkBinaryClassifier( - epochs=50, - builder=builder, - optimiser=optimiser, - acceleration=accel, - batch_size=10, - rng=rng, + epochs = 50, + builder = builder, + optimiser = optimiser, + acceleration = accel, + batch_size = 10, + rng = rng, ) - @time mach = fit!(machine(model, X, y), rows=train, verbosity=0) + @time mach = fit!(machine(model, X, y), rows = train, verbosity = 0) first_last_training_loss = MLJBase.report(mach)[1][[1, end]] - yhat = MLJBase.predict(mach, rows=test); - @test StatisticalMeasures.cross_entropy(yhat, y[test]) < 0.95*loss_baseline + yhat = MLJBase.predict(mach, rows = test) + @test StatisticalMeasures.cross_entropy(yhat, y[test]) < 0.95 * loss_baseline end diff --git a/test/core.jl b/test/core.jl index 4bfc400b..94eb1e1f 100644 --- a/test/core.jl +++ b/test/core.jl @@ -14,23 +14,28 @@ rowvec(y::Vector) = reshape(y, 1, length(y)) end @testset "collate" begin - # NeuralNetworRegressor: - Xmatrix = broadcast(x->round(x, sigdigits=2), rand(stable_rng, 10, 3)) + Xmatrix = broadcast(x->round(x, sigdigits=2), rand(stable_rng, Float32, 10, 3)) + Xmat_f64 = Float64.(Xmatrix) # convert to a column table: X = MLJBase.table(Xmatrix) + X_64 = MLJBase.table(Xmat_f64) + # NeuralNetworRegressor: y = rand(stable_rng, Float32, 10) model = MLJFlux.NeuralNetworkRegressor() model.batch_size= 3 - @test MLJFlux.collate(model, X, y) == + @test MLJFlux.collate(model, X, y, 1) == MLJFlux.collate(model, X_64, y, 1) == ([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]], rowvec.([y[1:3], y[4:6], y[7:9], y[10:10]])) + @test_logs (:info,) MLJFlux.collate(model, X_64, y, 1) + @test_logs min_level=Logging.Info MLJFlux.collate(model, X, y, 1) + @test_logs min_level=Logging.Info MLJFlux.collate(model, X, y, 0) # NeuralNetworClassifier: y = categorical(['a', 'b', 'a', 'a', 'b', 'a', 'a', 'a', 'b', 'a']) model = MLJFlux.NeuralNetworkClassifier() model.batch_size = 3 - data = MLJFlux.collate(model, X, y) + data = MLJFlux.collate(model, X, y, 1) @test data == ([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]], @@ -42,13 +47,13 @@ end y = MLJBase.table(ymatrix) # a rowaccess table model = MLJFlux.NeuralNetworkRegressor() model.batch_size= 3 - @test MLJFlux.collate(model, X, y) == + @test MLJFlux.collate(model, X, y, 1) == ([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]], rowvec.([ymatrix'[:,1:3], ymatrix'[:,4:6], ymatrix'[:,7:9], ymatrix'[:,10:10]])) y = Tables.columntable(y) # try a columnaccess table - @test MLJFlux.collate(model, X, y) == + @test MLJFlux.collate(model, X, y, 1) == ([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]], rowvec.([ymatrix'[:,1:3], ymatrix'[:,4:6], ymatrix'[:,7:9], ymatrix'[:,10:10]])) @@ -58,7 +63,7 @@ end y = categorical(['a', 'b', 'a', 'a', 'b', 'a', 'a', 'a', 'b', 'a']) model = MLJFlux.ImageClassifier(batch_size=2) - data = MLJFlux.collate(model, Xmatrix, y) + data = MLJFlux.collate(model, Xmatrix, y, 1) @test first.(data) == (Float32.(cat(Xmatrix[1], Xmatrix[2], dims=4)), rowvec.([1 0;0 1])) diff --git a/test/encoders.jl b/test/encoders.jl new file mode 100644 index 00000000..50b93b60 --- /dev/null +++ b/test/encoders.jl @@ -0,0 +1,72 @@ + + +@testset "ordinal encoder" begin + X = ( + Column1 = [1.0, 2.0, 3.0, 4.0, 5.0], + Column2 = categorical(['a', 'b', 'c', 'd', 'e']), + Column3 = categorical(["b", "c", "d"]), + Column4 = [1.0, 2.0, 3.0, 4.0, 5.0], + ) + map = MLJFlux.ordinal_encoder_fit(X; featinds = [2, 3]) + Xenc = MLJFlux.ordinal_encoder_transform(X, map) + @test map[2] == Dict('a' => 1, 'b' => 2, 'c' => 3, 'd' => 4, 'e' => 5) + @test map[3] == Dict("b" => 1, "c" => 2, "d" => 3) + @test Xenc.Column1 == [1.0, 2.0, 3.0, 4.0, 5.0] + @test Xenc.Column2 == Float32.([1.0, 2.0, 3.0, 4.0, 5.0]) + @test Xenc.Column3 == Float32.([1, 2, 3]) + @test Xenc.Column4 == [1.0, 2.0, 3.0, 4.0, 5.0] + + X = coerce(X, :Column1 => Multiclass) + map = MLJFlux.ordinal_encoder_fit(X; featinds = [1, 2, 3]) + @test !haskey(map, 1) # already encoded + + @test Xenc == MLJFlux.ordinal_encoder_fit_transform(X; featinds = [2, 3])[1] +end + +@testset "Generate New feature names Function Tests" begin + # Test 1: No initial conflicts + @testset "No Initial Conflicts" begin + existing_names = [] + names = MLJFlux.generate_new_feat_names("feat", 3, existing_names) + @test names == [Symbol("feat_1"), Symbol("feat_2"), Symbol("feat_3")] + end + + # Test 2: Handle initial conflict by adding underscores + @testset "Initial Conflict Resolution" begin + existing_names = [Symbol("feat_1"), Symbol("feat_2"), Symbol("feat_3")] + names = MLJFlux.generate_new_feat_names("feat", 3, existing_names) + @test names == [Symbol("feat__1"), Symbol("feat__2"), Symbol("feat__3")] + end +end + + +@testset "embedding_transform works" begin + X = ( + Column1 = [1.0, 2.0, 3.0, 4.0, 5.0], + Column2 = categorical(['a', 'b', 'c', 'd', 'e']), + Column3 = categorical(["b", "c", "d", "f", "f"]), + Column4 = [1.0, 2.0, 3.0, 4.0, 5.0], + ) + mapping_matrices = Dict( + :Column2 => [ + 1 0.5 0.7 4 5 + 0.4 2 3 0.9 0.2 + 0.1 0.6 0.8 0.3 0.4 + ], + :Column3 => [ + 1 0.5 0.7 4 + 0.4 2 3 0.9 + ], + ) + X, _ = MLJFlux.ordinal_encoder_fit_transform(X; featinds = [2, 3]) + Xenc = MLJFlux.embedding_transform(X, mapping_matrices) + @test Xenc == ( + Column1 = [1.0, 2.0, 3.0, 4.0, 5.0], + Column2_1 = [1.0, 0.5, 0.7, 4.0, 5.0], + Column2_2 = [0.4, 2.0, 3.0, 0.9, 0.2], + Column2_3 = [0.1, 0.6, 0.8, 0.3, 0.4], + Column3_1 = [1.0, 0.5, 0.7, 4.0, 4.0], + Column3_2 = [0.4, 2.0, 3.0, 0.9, 0.9], + Column4 = [1.0, 2.0, 3.0, 4.0, 5.0], + ) +end diff --git a/test/entity_embedding.jl b/test/entity_embedding.jl new file mode 100644 index 00000000..da0c89ba --- /dev/null +++ b/test/entity_embedding.jl @@ -0,0 +1,223 @@ +""" +See more functional tests in entity_embedding_utils.jl and mlj_model_interface.jl +""" +batch = Float32.([ + 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1; + 1 2 3 4 5 6 7 8 9 10; + 0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1; + 1 1 2 2 1 1 2 2 1 1 +]) + + +entityprops = [ + (index = 2, levels = 10, newdim = 2), + (index = 4, levels = 2, newdim = 1), +] + + +@testset "Feedforward with Entity Embedder Works" begin + ### Option 1: Use EntityEmbedder + entityprops = [ + (index = 2, levels = 10, newdim = 5), + (index = 4, levels = 2, newdim = 2), + ] + + embedder = MLJFlux.EntityEmbedder(entityprops, 4) + + output = embedder(batch) + + ### Option 2: Manual feedforward + x1 = batch[1:1, :] + z2 = Int.(batch[2, :]) + x3 = batch[3:3, :] + z4 = Int.(batch[4, :]) + + # extract matrices from categorical embedder + EE1 = Flux.params(embedder.embedders[2])[1] # (newdim, levels) = (5, 10) + EE2 = Flux.params(embedder.embedders[4])[1] # (newdim, levels) = (2, 2) + + ## One-hot encoding + z2_hot = Flux.onehotbatch(z2, levels(z2)) + z4_hot = Flux.onehotbatch(z4, levels(z4)) + + function feedforward(x1, z2_hot, x3, z4_hot) + f_z2 = EE1 * z2_hot + f_z4 = EE2 * z4_hot + return vcat([x1, f_z2, x3, f_z4]...) + end + + real_output = feedforward(x1, z2_hot, x3, z4_hot) + @test output ≈ real_output +end + + +@testset "Feedforward and Backward Pass with Entity Embedder Works" begin + y_batch_reg = [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0] # Regression + y_batch_cls = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # Classification + y_batch_cls_o = Flux.onehotbatch(y_batch_cls, 1:10) + + losses = [Flux.crossentropy, Flux.mse] + targets = [y_batch_cls_o, y_batch_reg] + finalizer = [softmax, relu] + + for ind in 1:2 + ### Option 1: Feedforward with EntityEmbedder in the network + entityprops = [ + (index = 2, levels = 10, newdim = 5), + (index = 4, levels = 2, newdim = 2), + ] + + cat_model = Chain( + MLJFlux.EntityEmbedder(entityprops, 4), + Dense(9 => (ind == 1) ? 10 : 1), + finalizer[ind], + ) + + EE1_before = Flux.params(cat_model.layers[1].embedders[2])[1] + EE2_before = Flux.params(cat_model.layers[1].embedders[4])[1] + W_before = Flux.params(cat_model.layers[2])[1] + + ### Test with obvious equivalent feedforward + x1 = batch[1:1, :] + z2 = Int.(batch[2, :]) + x3 = batch[3:3, :] + z4 = Int.(batch[4, :]) + + z2_hot = Flux.onehotbatch(z2, levels(z2)) + z4_hot = Flux.onehotbatch(z4, levels(z4)) + + ### Option 2: Manual feedforward + function feedforward(x1, z2_hot, x3, z4_hot, W, EE1, EE2) + f_z2 = EE1 * z2_hot + f_z4 = EE2 * z4_hot + return finalizer[ind](W * vcat([x1, f_z2, x3, f_z4]...)) + end + + struct ObviousNetwork + W::Any + EE1::Any + EE2::Any + end + + (m::ObviousNetwork)(x1, z2_hot, x3, z4_hot) = + feedforward(x1, z2_hot, x3, z4_hot, m.W, m.EE1, m.EE2) + Flux.@layer ObviousNetwork + + W_before_cp, EE1_before_cp, EE2_before_cp = + deepcopy(W_before), deepcopy(EE1_before), deepcopy(EE2_before) + net = ObviousNetwork(W_before_cp, EE1_before_cp, EE2_before_cp) + + @test feedforward(x1, z2_hot, x3, z4_hot, W_before, EE1_before, EE2_before) ≈ + cat_model(batch) + + ## Option 1: Backward with EntityEmbedder in the network + loss, grads = Flux.withgradient(cat_model) do m + y_pred_cls = m(batch) + losses[ind](y_pred_cls, targets[ind]) + end + optim = Flux.setup(Flux.Adam(10), cat_model) + new_params = Flux.update!(optim, cat_model, grads[1]) + + EE1_after = Flux.params(new_params[1].layers[1].embedders[2].weight)[1] + EE2_after = Flux.params(new_params[1].layers[1].embedders[4].weight)[1] + W_after = Flux.params(new_params[1].layers[2].weight)[1] + + ## Option 2: Backward with ObviousNetwork + loss, grads = Flux.withgradient(net) do m + y_pred_cls = m(x1, z2_hot, x3, z4_hot) + losses[ind](y_pred_cls, targets[ind]) + end + + optim = Flux.setup(Flux.Adam(10), net) + z = Flux.update!(optim, net, grads[1]) + EE1_after_cp = Flux.params(z[1].EE1)[1] + EE2_after_cp = Flux.params(z[1].EE2)[1] + W_after_cp = Flux.params(z[1].W)[1] + @test EE1_after_cp ≈ EE1_after + @test EE2_after_cp ≈ EE2_after + @test W_after_cp ≈ W_after + end +end + + +@testset "Transparent when no categorical variables" begin + entityprops = [] + numfeats = 4 + embedder = MLJFlux.EntityEmbedder(entityprops, 4) + output = embedder(batch) + @test output ≈ batch + @test eltype(output) == Float32 +end + + +@testset "get_embedding_matrices works and has the right dimensions" begin + models = [ + MLJFlux.NeuralNetworkBinaryClassifier, + MLJFlux.NeuralNetworkClassifier, + MLJFlux.NeuralNetworkRegressor, + MLJFlux.MultitargetNeuralNetworkRegressor, + ] + + X = ( + Column1 = [1.0, 2.0, 3.0, 4.0, 5.0], + Column2 = categorical(['a', 'b', 'c', 'd', 'e']), + Column3 = categorical(["b", "c", "d", "f", "f"], ordered = true), + Column4 = [1.0, 2.0, 3.0, 4.0, 5.0], + Column5 = randn(5), + Column6 = categorical(["group1", "group1", "group2", "group2", "group3"]), + ) + + y = categorical([0, 1, 0, 1, 1]) + yreg = [0.1, -0.3, 0.2, 0.8, 0.9] + ys = [y, y, yreg, yreg] + + embedding_dims = [ + Dict(:Column2 => 0.5, :Column3 => 2, :Column6 => 0.1), + Dict(:Column2 => 1, :Column3 => 4), + Dict(), + ] + expected_dims = [ + [(3, 5), (2, 4), (1, 3)], + [(1, 5), (4, 4), (2, 3)], + [(4, 5), (3, 4), (2, 3)], + ] + + size([ + 1 2 + 3 4 + ]) + + for j in eachindex(embedding_dims) + for i in eachindex(models) + clf = models[1]( + builder = MLJFlux.Short(n_hidden = 5, dropout = 0.2), + optimiser = Optimisers.Adam(0.01), + batch_size = 8, + epochs = 100, + acceleration = CUDALibs(), + optimiser_changes_trigger_retraining = true, + embedding_dims = embedding_dims[3], + ) + + mach = machine(clf, X, ys[1]) + + fit!(mach, verbosity = 0) + + mapping_matrices = MLJFlux.get_embedding_matrices( + fitted_params(mach).chain, + [2, 3, 6], + [:Column1, :Column2, :Column3, :Column4, :Column5, :Column6], + ) + + embedder_layer = fitted_params(mach).chain.layers[1] + # get_embedding_matrices work + @test mapping_matrices[:Column2] == Flux.params(embedder_layer.embedders[2])[1] + @test mapping_matrices[:Column3] == Flux.params(embedder_layer.embedders[3])[1] + @test mapping_matrices[:Column6] == Flux.params(embedder_layer.embedders[6])[1] + # dimensionalities are correct + @test size(mapping_matrices[:Column2]) == expected_dims[3][1] + @test size(mapping_matrices[:Column3]) == expected_dims[3][2] + @test size(mapping_matrices[:Column6]) == expected_dims[3][3] + end + end +end diff --git a/test/entity_embedding_utils.jl b/test/entity_embedding_utils.jl new file mode 100644 index 00000000..56983a0f --- /dev/null +++ b/test/entity_embedding_utils.jl @@ -0,0 +1,114 @@ + + + +@testset "set_default_new_embedding_dim" begin + @test MLJFlux.set_default_new_embedding_dim(15) == 10 + @test MLJFlux.set_default_new_embedding_dim(9) == 8 +end + +@testset "check_mismatch_in_cat_feats" begin + # Test with no mismatch + featnames = [:a, :b, :c] + cat_inds = [1, 3] + specified_featinds = [1, 3] + @test !MLJFlux.check_mismatch_in_cat_feats(featnames, cat_inds, specified_featinds) + + # Test with mismatch + featnames = [:a, :b, :c] + cat_inds = [1, 3] + specified_featinds = [1, 2, 3] + @test_throws ArgumentError MLJFlux.check_mismatch_in_cat_feats( + featnames, + cat_inds, + specified_featinds, + ) + + # Test with empty specified_featinds + featnames = [:a, :b, :c] + cat_inds = [1, 3] + specified_featinds = [] + @test !MLJFlux.check_mismatch_in_cat_feats(featnames, cat_inds, specified_featinds) + + # Test with empty cat_inds + featnames = [:a, :b, :c] + cat_inds = [] + specified_featinds = [1, 2] + @test_throws ArgumentError MLJFlux.check_mismatch_in_cat_feats( + featnames, + cat_inds, + specified_featinds, + ) +end + +@testset "Testing set_new_embedding_dims" begin + # Test case 1: Correct calculation of embedding dimensions when specified as floats + featnames = ["color", "size", "type"] + cat_inds = [1, 2] + num_levels = [3, 5] + embedding_dims = Dict("color" => 0.5, "size" => 2) + + result = MLJFlux.set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims) + @test result == [2, 2] # Expected to be ceil(1.5) = 2 for "color", and exact 2 for "size" + + # Test case 2: Handling of unspecified dimensions with defaults + embedding_dims = Dict("color" => 0.5) # "size" is not specified + result = MLJFlux.set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims) + @test result == [2, MLJFlux.set_default_new_embedding_dim(5)] + + # Test case 3: All embedding dimensions are unspecified, default for all + embedding_dims = Dict() + result = MLJFlux.set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims) + @test result == [ + MLJFlux.set_default_new_embedding_dim(3), + MLJFlux.set_default_new_embedding_dim(5), + ] # Default dimensions for both +end + +@testset "test get_cat_inds" begin + X = ( + C1 = [1.0, 2.0, 3.0, 4.0, 5.0], + C2 = ['a', 'b', 'c', 'd', 'e'], + C3 = ["b", "c", "d", "e", "f"], + C4 = [1.0, 2.0, 3.0, 4.0, 5.0], + ) + X = coerce(X, :C1 => OrderedFactor, :C2 => Multiclass, :C3 => Multiclass) + @test MLJFlux.get_cat_inds(X) == [1, 2, 3] +end + +@testset "Number of levels" begin + X = ( + C1 = [1.0, 2.0, 3.0, 4.0, 5.0], + C2 = ['a', 'b', 'c', 'd', 'e'], + C3 = ["b", "c", "d", "f", "f"], + C4 = [1.0, 2.0, 3.0, 4.0, 5.0], + ) + @test MLJFlux.get_num_levels(X, [2]) == [5] + @test MLJFlux.get_num_levels(X, [2, 3]) == [5, 4] +end + + +@testset "Testing prepare_entityembs" begin + X = ( + Column1 = [1.0, 2.0, 3.0, 4.0, 5.0], + Column2 = categorical(['a', 'b', 'c', 'd', 'e']), + Column3 = categorical(["b", "c", "d"]), + Column4 = [1.0, 2.0, 3.0, 4.0, 5.0], + ) + + featnames = [:Column1, :Column2, :Column3, :Column4] + cat_inds = [2, 3] # Assuming categorical columns are 2 and 3 + embedding_dims = Dict(:Column2 => 3, :Column3 => 2) + + entityprops_expected = [ + (index = 2, levels = 5, newdim = 3), + (index = 3, levels = 3, newdim = 2), + ] + output_dim_expected = 3 + 2 + 4 - 2 # Total embedding dims + non-categorical features + + entityprops, entityemb_output_dim = + MLJFlux.prepare_entityembs(X, featnames, cat_inds, embedding_dims) + + @test entityprops == entityprops_expected + @test entityemb_output_dim == output_dim_expected +end + diff --git a/test/mlj_model_interface.jl b/test/mlj_model_interface.jl index 522b059e..79849908 100644 --- a/test/mlj_model_interface.jl +++ b/test/mlj_model_interface.jl @@ -76,8 +76,8 @@ end nobservations = 12 Xuser = rand(Float32, nobservations, 3) yuser = rand(Float32, nobservations) - alpha = rand(rng) - lambda = rand(rng) + alpha = rand(rng, Float32) + lambda = rand(rng, Float32) optimiser = Optimisers.Momentum() builder = MLJFlux.Linear() epochs = 1 # don't change this @@ -94,7 +94,7 @@ end # (2) manually train for one epoch explicitly adding a loss penalty: chain = MLJFlux.build(builder, StableRNG(123), 3, 1); penalty = Penalizer(lambda, alpha); # defined in test_utils.jl - X, y = MLJFlux.collate(model, Xuser, yuser); + X, y = MLJFlux.collate(model, Xuser, yuser, 0); loss = model.loss; n_batches = div(nobservations, batch_size) optimiser_state = Optimisers.setup(optimiser, chain); @@ -121,6 +121,7 @@ end # integration test: X, y = MLJBase.make_regression(10) X = Float32.(MLJBase.Tables.matrix(X)) |> MLJBase.Tables.table + y = Float32.(y) mach = MLJBase.machine(model, X, y) MLJBase.fit!(mach, verbosity=0) losses = MLJBase.training_losses(mach) @@ -148,11 +149,178 @@ end builder = LisasBuilder(10), ) - X, y = @load_boston + X = Tables.table(rand(Float32, 75, 2)) + y = rand(Float32, 75) @test_logs( (:error, MLJFlux.ERR_BUILDER), @test_throws UndefVarError(:Chains) MLJBase.fit(model, 0, X, y) ) end + +@testset "layer does not exist for continuous input and transform does nothing" begin + models = [ + MLJFlux.NeuralNetworkBinaryClassifier, + MLJFlux.NeuralNetworkClassifier, + MLJFlux.NeuralNetworkRegressor, + MLJFlux.MultitargetNeuralNetworkRegressor, + ] + # table case + X1 = ( + Column1 = Float32[1.0, 2.0, 3.0, 4.0, 5.0], + Column4 = Float32[1.0, 2.0, 3.0, 4.0, 5.0], + Column5 = randn(Float32, 5), + ) + # matrix case + X2 = rand(Float32, 5, 5) + Xs = [X1, X2] + + y = categorical([0, 1, 0, 1, 1]) + yreg = Float32[0.1, -0.3, 0.2, 0.8, 0.9] + ys = [y, y, yreg, yreg] + for j in eachindex(Xs) + for i in eachindex(models) + clf = models[1]( + builder = MLJFlux.Short(n_hidden = 5, dropout = 0.2, σ = relu), + optimiser = Optimisers.Adam(0.01), + batch_size = 8, + epochs = 100, + acceleration = CUDALibs(), + optimiser_changes_trigger_retraining = true, + ) + + mach = machine(clf, Xs[j], ys[1]) + + fit!(mach, verbosity = 0) + + @test typeof(fitted_params(mach).chain.layers[1][1]) == + typeof(Dense(3 => 5, relu)) + + @test transform(mach, Xs[j]) == Xs[j] + end + end +end + +@testset "transform works properly" begin + # In this test we assumed that get_embedding_weights works + # properly which has been tested. + models = [ + MLJFlux.NeuralNetworkBinaryClassifier, + MLJFlux.NeuralNetworkClassifier, + MLJFlux.NeuralNetworkRegressor, + MLJFlux.MultitargetNeuralNetworkRegressor, + ] + + X = ( + Column1 = Float32[1.0, 2.0, 3.0, 4.0, 5.0], + Column2 = categorical(['a', 'b', 'c', 'd', 'e']), + Column3 = Float32[1.0, 2.0, 3.0, 4.0, 5.0], + Column4 = randn(Float32, 5), + Column5 = categorical(["group1", "group1", "group2", "group2", "group3"]), + ) + + y = categorical([0, 1, 0, 1, 1]) + yreg = Float32[0.1, -0.3, 0.2, 0.8, 0.9] + ys = [y, y, yreg, yreg] + + for i in eachindex(models) + clf = models[1]( + builder = MLJFlux.Short(n_hidden = 5, dropout = 0.2), + optimiser = Optimisers.Adam(0.01), + batch_size = 8, + epochs = 100, + acceleration = CUDALibs(), + optimiser_changes_trigger_retraining = true, + embedding_dims = Dict(:Column2 => 4, :Column5 => 2), + ) + + mach = machine(clf, X, ys[1]) + fit!(mach, verbosity = 0) + Xenc = transform(mach, X) + mat_col2 = + hcat( + [ + collect(Xenc.Column2_1), + collect(Xenc.Column2_2), + collect(Xenc.Column2_3), + collect(Xenc.Column2_4), + ]..., + )' + mat_col5 = hcat( + [ + collect(Xenc.Column5_1), + collect(Xenc.Column5_2), + ]..., + )'[:, [1, 3, 5]] + + mapping_matrices = MLJFlux.get_embedding_matrices( + fitted_params(mach).chain, + [2, 5], + [:Column1, :Column2, :Column3, :Column4, :Column5], + ) + mat_col2_golden = mapping_matrices[:Column2] + mat_col5_golden = mapping_matrices[:Column5] + @test mat_col2 == mat_col2_golden + @test mat_col5 == mat_col5_golden + end +end + +@testset "fit, refit and predict work tests" begin + models = [ + MLJFlux.NeuralNetworkBinaryClassifier, + MLJFlux.NeuralNetworkClassifier, + MLJFlux.NeuralNetworkRegressor, + MLJFlux.MultitargetNeuralNetworkRegressor, + ] + + X = ( + Column1 = Float32[1.0, 2.0, 3.0, 4.0, 5.0], + Column2 = categorical(['a', 'b', 'c', 'd', 'e']), + Column3 = Float32[1.0, 2.0, 3.0, 4.0, 5.0], + Column4 = randn(Float32, 5), + Column5 = categorical(["group1", "group1", "group2", "group2", "group3"]), + ) + + y = categorical([0, 1, 0, 1, 1]) + yreg = Float32[0.1, -0.3, 0.2, 0.8, 0.9] + ys = [y, y, yreg, yreg] + + for i in eachindex(models) + clf = models[1]( + builder = MLJFlux.Short(n_hidden = 5, dropout = 0.2), + optimiser = Optimisers.Adam(0.01), + batch_size = 8, + epochs = 2, + acceleration = CUDALibs(), + optimiser_changes_trigger_retraining = true, + embedding_dims = Dict(:Column2 => 4, :Column5 => 2), + ) + + mach = machine(clf, X, ys[1]) + @test_throws MLJBase.NotTrainedError mapping_matrices = + MLJFlux.get_embedding_matrices( + fitted_params(mach).chain, + [2, 5], + [:Column1, :Column2, :Column3, :Column4, :Column5], + ) + fit!(mach, verbosity = 0) + mapping_matrices_fit = MLJFlux.get_embedding_matrices( + fitted_params(mach).chain, + [2, 5], + [:Column1, :Column2, :Column3, :Column4, :Column5], + ) + clf.epochs = clf.epochs + 3 + clf.optimiser = Optimisers.Adam(clf.optimiser.eta / 2) + fit!(mach, verbosity = 0) + mapping_matrices_double_fit = MLJFlux.get_embedding_matrices( + fitted_params(mach).chain, + [2, 5], + [:Column1, :Column2, :Column3, :Column4, :Column5], + ) + @test mapping_matrices_fit != mapping_matrices_double_fit + # Try model prediction + Xpred = predict(mach, X) + end +end + true diff --git a/test/regressor.jl b/test/regressor.jl index d17a3760..0b9c5852 100644 --- a/test/regressor.jl +++ b/test/regressor.jl @@ -1,13 +1,23 @@ Random.seed!(123) N = 200 -X = MLJBase.table(randn(Float32, N, 5)); +Xm = MLJBase.table(randn(Float32, N, 5)); # purely numeric +X = (; Tables.columntable(Xm)..., + Column1 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)), + Column2 = categorical(repeat(['a', 'b', 'c', 'd', 'e'], Int(N / 5))), + Column3 = categorical(repeat(["b", "c", "d", "f", "f"], Int(N / 5)), ordered = true), + Column4 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)), + Column5 = randn(Float32, N), + Column6 = categorical( + repeat(["group1", "group1", "group2", "group2", "group3"], Int(N / 5)), + ), +) -builder = MLJFlux.Short(σ=identity) +builder = MLJFlux.Short(σ = identity) optimiser = Optimisers.Adam() Random.seed!(123) -y = 1 .+ X.x1 - X.x2 .- 2X.x4 + X.x5 +y = Float32(1) .+ X.x1 - X.x2 .- 2X.x4 + X.x5 train, test = MLJBase.partition(1:N, 0.7) @testset_accelerated "NeuralNetworkRegressor" accel begin @@ -25,11 +35,12 @@ train, test = MLJBase.partition(1:N, 0.7) ) end + # Matrix input: @testset "Matrix input" begin @test basictest( MLJFlux.NeuralNetworkRegressor, - matrix(X), + matrix(Xm), y, builder, optimiser, @@ -42,16 +53,16 @@ train, test = MLJBase.partition(1:N, 0.7) # (GPUs only support `default_rng` when there's `Dropout`): rng = Random.default_rng() seed!(rng, 123) - model = MLJFlux.NeuralNetworkRegressor(builder=builder, - acceleration=accel, - rng=rng) + model = MLJFlux.NeuralNetworkRegressor(builder = builder, + acceleration = accel, + rng = rng) @time fitresult, _, rpt = fit(model, 0, MLJBase.selectrows(X, train), y[train]) first_last_training_loss = rpt[1][[1, end]] -# @show first_last_training_loss + # @show first_last_training_loss yhat = predict(model, fitresult, selectrows(X, test)) truth = y[test] - goal = 0.9*model.loss(truth .- mean(truth), 0) + goal = 0.9 * model.loss(truth .- mean(truth), 0) @test model.loss(yhat, truth) < goal end @@ -73,11 +84,23 @@ y = MLJBase.table(ymatrix); accel, ) end + + @testset "Table input numerical" begin + @test basictest( + MLJFlux.MultitargetNeuralNetworkRegressor, + Xm, + y, + builder, + optimiser, + 1.0, + accel, + ) + end # Matrix input: @testset "Matrix input" begin @test basictest( MLJFlux.MultitargetNeuralNetworkRegressor, - matrix(X), + matrix(Xm), ymatrix, builder, optimiser, @@ -91,16 +114,16 @@ y = MLJBase.table(ymatrix); rng = Random.default_rng() seed!(rng, 123) model = MLJFlux.MultitargetNeuralNetworkRegressor( - acceleration=accel, - builder=builder, - rng=rng, + acceleration = accel, + builder = builder, + rng = rng, ) @time fitresult, _, rpt = fit(model, 0, MLJBase.selectrows(X, train), selectrows(y, train)) first_last_training_loss = rpt[1][[1, end]] yhat = predict(model, fitresult, selectrows(X, test)) - truth = ymatrix[test,:] - goal = 0.85*model.loss(truth .- mean(truth), 0) + truth = ymatrix[test, :] + goal = 0.85 * model.loss(truth .- mean(truth), 0) @test model.loss(Tables.matrix(yhat), truth) < goal end diff --git a/test/runtests.jl b/test/runtests.jl index b7b11d66..a6ed78ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Tables using MLJBase import MLJFlux using CategoricalArrays +using MLJBase using ColorTypes using Flux import Random @@ -13,6 +14,7 @@ using StableRNGs using CUDA, cuDNN import StatisticalMeasures import Optimisers +import Logging using ComputationalResources using ComputationalResources: CPU1, CUDALibs @@ -74,3 +76,9 @@ end @conditional_testset "integration" begin include("integration.jl") end + +@conditional_testset "entity embedding" begin + include("entity_embedding.jl") + include("entity_embedding_utils.jl") + include("encoders.jl") +end \ No newline at end of file