Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add __array_ufunc__ and __array__function__ to ArrayIndex #96

Merged
10 changes: 10 additions & 0 deletions ndindex/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def _typecheck(self, idx, shape=None, _copy=True):
return (a,)
raise TypeError(f"{self.__class__.__name__} must be created with an array with dtype {self.dtype.__name__}")

# These will allow array == ArrayIndex to give True or False instead of
# returning an array.
__array_ufunc__ = None
def __array_function__(self, func, types, args, kwargs):
return NotImplemented

def __array__(self):
raise TypeError(f"Cannot convert {self.__class__.__name__} to an array. Use .array instead.")


@property
def raw(self):
return self.args[0]
Expand Down
6 changes: 6 additions & 0 deletions ndindex/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ def test_attributes():
assert idx.shape == a.shape == (2, 2)
assert idx.size == a.size == 4

def test_cast_raises():
with raises(TypeError):
a = array([[0, 1], [1, 0]])
idx = IntegerArray(a)
assert array(idx) == idx

def test_copy():
idx = IntegerArray([1, 2])
idx2 = IntegerArray(idx.raw)
Expand Down
42 changes: 27 additions & 15 deletions ndindex/tests/test_ndindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,28 @@
def test_eq(idx):
index = ndindex(idx)
new = type(index)(*index.args)
assert (new == index) is True
try:
if isinstance(new.raw, np.ndarray):
raise ValueError
assert (new.raw == index.raw) is True
assert (index.raw == index) is True
except ValueError:
np.testing.assert_equal(new.raw, index.raw)
# Sadly, there is now way to bypass array.__eq__ from producing an
# array.

if isinstance(new.raw, np.ndarray):
# trying to get a single value out of comparing two arrays requires all sorts of special handling, just let numpy do it
assert np.array_equal(new.raw, index.raw)
else:
(new.raw == index.raw)

assert (new == index)
assert (new.raw == index)
assert (new == index.raw)
assert (index == new)
assert (index.raw == new)
assert (index == new.raw)

assert (index.raw == index)
assert hash(new) == hash(index)
assert (index == index.raw)
assert not (index == 'a')
assert not ('a' == index)
assert (index != 'a')
assert ('a' != index)

try:
h = hash(idx)
except TypeError:
Expand All @@ -42,11 +53,12 @@ def test_eq(idx):
else:
assert hash(index) == h

assert (index == index.raw) is True
assert (index == 'a') is False
assert ('a' == index) is False
assert (index != 'a') is True
assert ('a' != index) is True
def test_eq_array_raises():
index = ndindex([1, 2, 3])
with raises(TypeError):
np.equal(index.raw, index)
with raises(TypeError):
np.array_equal(index.raw, index)

def test_eq_explicit():
assert Integer(0) != False
Expand Down