Skip to content

Commit

Permalink
Fix shape of input and mask in the Mask layer, and adjust test.
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Feb 20, 2024
1 parent b337bf7 commit c2e4935
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
27 changes: 18 additions & 9 deletions n3fit/src/n3fit/layers/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class Mask(MetaLayer):
"""
This layers applies a boolean mask to a rank-1 input tensor.
This layers applies a boolean mask to an input tensor.
The mask admit a multiplier for all outputs which will be internally
saved as a weight so it can be updated during trainig.
Expand All @@ -16,7 +16,7 @@ class Mask(MetaLayer):
Parameters
----------
bool_mask: np.array
bool_mask: np.array of shape (n_replicas, n_features)
numpy array with the boolean mask to be applied
c: float
constant multiplier for every output
Expand All @@ -28,10 +28,7 @@ def __init__(self, bool_mask=None, c=None, **kwargs):
self.last_dim = -1
else:
self.mask = op.numpy_to_tensor(bool_mask, dtype=bool)
if len(bool_mask.shape) == 1:
self.last_dim = count_nonzero(bool_mask)
else:
self.last_dim = count_nonzero(bool_mask[0, ...])
self.last_dim = count_nonzero(bool_mask[0, ...])
self.c = c
self.masked_output_shape = None
super().__init__(**kwargs)
Expand All @@ -42,12 +39,24 @@ def build(self, input_shape):
self.kernel = self.builder_helper("mask", (1,), initializer, trainable=False)
# Make sure reshape will succeed: set the last dimension to the unmasked data length and before-last to
# the number of replicas
self.masked_output_shape = [-1 if d is None else d for d in input_shape]
self.masked_output_shape[-1] = self.last_dim
self.masked_output_shape[-2] = self.mask.shape[-2]
if self.mask is not None:
self.masked_output_shape = [-1 if d is None else d for d in input_shape]
self.masked_output_shape[-1] = self.last_dim
self.masked_output_shape[-2] = self.mask.shape[-2]
super(Mask, self).build(input_shape)

def call(self, ret):
"""
Apply the mask to the input tensor, and multiply by the constant if present.
Parameters
----------
ret: Tensor of shape (batch_size, n_replicas, n_features)
Returns
-------
Tensor of shape (batch_size, n_replicas, n_features)
"""
if self.mask is not None:
flat_res = op.boolean_mask(ret, self.mask, axis=1)
ret = op.reshape(flat_res, shape=self.masked_output_shape)
Expand Down
9 changes: 5 additions & 4 deletions n3fit/src/n3fit/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,19 +240,20 @@ def test_rotation_evol():

def test_mask():
"""Test the mask layer"""
SIZE = 100
fi = np.random.rand(SIZE)
batch_size, replicas, points = 1, 1, 100
shape = (batch_size, replicas, points)
fi = np.random.rand(*shape)
# Check that the multiplier works
vals = [0.0, 2.0, np.random.rand()]
for val in vals:
masker = layers.Mask(c=val)
ret = masker(fi)
np.testing.assert_allclose(ret, val * fi, rtol=1e-5)
# Check that the boolean works
np_mask = np.random.randint(0, 2, size=SIZE, dtype=bool)
np_mask = np.random.randint(0, 2, size=shape[1:], dtype=bool)
masker = layers.Mask(bool_mask=np_mask)
ret = masker(fi)
masked_fi = fi[np_mask]
masked_fi = fi[np.newaxis, :, np_mask]
np.testing.assert_allclose(ret, masked_fi, rtol=1e-5)
# Check that the combination works!
rn_val = vals[-1]
Expand Down

0 comments on commit c2e4935

Please sign in to comment.