-
Notifications
You must be signed in to change notification settings - Fork 0
/
embeddingLayer.m
60 lines (51 loc) · 2.11 KB
/
embeddingLayer.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
classdef embeddingLayer < nnet.layer.Layer & ...
nnet.layer.Formattable
properties (Learnable)
% Layer learnable parameters.
Weights
end
methods
function layer = embeddingLayer(embeddingDimension, inputDimension, NameValueArgs)
% layer = embeddingLayer(embeddingDimension,inputDimension)
% creates a embedAndReshapeLayer object that embeds and
% reshapes the input to the specified output size using an
% embedding of the specified size and input dimension.
%
% layer = embeddingLayer(embeddingDimension,inputDimension,Name=name)
% also specifies the layer name.
% Parse input arguments.
arguments
embeddingDimension
inputDimension
NameValueArgs.Name = "";
end
name = NameValueArgs.Name;
% Set layer name.
layer.Name = name;
% Set layer description.
layer.Description = "Embedding layer with dimension " + embeddingDimension;
% Initialize embedding weights.
layer.Weights = randn(embeddingDimension,inputDimension);
sz = [embeddingDimension inputDimension];
mu = 0;
sigma = 0.01;
layer.Weights = initializeGaussian(sz,mu,sigma);
end
function Z = predict(layer, X)
% Forward input data through the layer at prediction time and
% output the result.
%
% Inputs:
% layer - Layer to forward propagate through
% X - Numeric indices, specified as a formatted
% dlarray with a "C" and optionally a "B"
% dimension.
% Outputs:
% Z - Output of layer forward function returned as
% an dlarray with format "CB".
% Embedding.
weights = layer.Weights;
Z = embed(X,weights);
end
end
end