Skip to content

Commit

Permalink
Update Parameter.__eq__
Browse files Browse the repository at this point in the history
  • Loading branch information
loganbvh committed Sep 21, 2023
1 parent 996112d commit 45836ba
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion tdgl/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,25 @@ def __eq__(self, other) -> bool:
if self.func.__code__ != other.func.__code__:
return False

return self.kwargs == other.kwargs
if set(self.kwargs) != set(other.kwargs):
return False

def array_safe_equals(a, b) -> bool:
"""Check if a and b are equal, even if they are numpy arrays."""
if a is b:
return True
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
return a.shape == b.shape and np.allclose(a, b)
try:
return a == b
except TypeError:
return NotImplemented

for key in self.kwargs:
if not array_safe_equals(self.kwargs[key], other.kwargs[key]):
return False

return True


class CompositeParameter(Parameter):
Expand Down

0 comments on commit 45836ba

Please sign in to comment.