From 45836bab600ab0e28e2d55007a09427e12566f57 Mon Sep 17 00:00:00 2001 From: Logan Bishop-Van Horn Date: Thu, 21 Sep 2023 07:47:44 -0700 Subject: [PATCH] Update Parameter.__eq__ --- tdgl/parameter.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tdgl/parameter.py b/tdgl/parameter.py index db65daa..ee2061c 100644 --- a/tdgl/parameter.py +++ b/tdgl/parameter.py @@ -241,7 +241,25 @@ def __eq__(self, other) -> bool: if self.func.__code__ != other.func.__code__: return False - return self.kwargs == other.kwargs + if set(self.kwargs) != set(other.kwargs): + return False + + def array_safe_equals(a, b) -> bool: + """Check if a and b are equal, even if they are numpy arrays.""" + if a is b: + return True + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return a.shape == b.shape and np.allclose(a, b) + try: + return a == b + except TypeError: + return NotImplemented + + for key in self.kwargs: + if not array_safe_equals(self.kwargs[key], other.kwargs[key]): + return False + + return True class CompositeParameter(Parameter):