Skip to content

Commit

Permalink
Restrict the inputs of ops.pad() (#20348)
Browse files Browse the repository at this point in the history
  • Loading branch information
Grvzard authored Oct 13, 2024
1 parent 95ca6a6 commit e0533f8
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4348,6 +4348,13 @@ def _process_pad_width(self, pad_width):
return pad_width

def call(self, x, constant_values=None):
if len(self.pad_width) > 1 and len(self.pad_width) != len(x.shape):
raise ValueError(
"`pad_width` must have the same length as `x.shape`. "
f"Received: pad_width={self.pad_width} "
f"(of length {len(self.pad_width)}) and x.shape={x.shape} "
f"(of length {len(x.shape)})"
)
return backend.numpy.pad(
x,
pad_width=self.pad_width,
Expand Down

0 comments on commit e0533f8

Please sign in to comment.