From d7e999b685c9318c432d092303d669dcc2d3f449 Mon Sep 17 00:00:00 2001 From: Jonathan King Date: Fri, 15 Nov 2024 14:15:46 -0700 Subject: [PATCH] fix nodata type checking following numpy 2+ --- pysheds/sview.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/pysheds/sview.py b/pysheds/sview.py index 2b410c4..c6f51ec 100644 --- a/pysheds/sview.py +++ b/pysheds/sview.py @@ -81,10 +81,8 @@ def __new__(cls, input_array, viewfinder=None, metadata={}): assert not np.issubdtype(obj.dtype, np.flexible) except: raise TypeError('`object` and `flexible` dtypes not allowed.') - try: - assert np.can_cast(viewfinder.nodata, obj.dtype, casting='safe') - except: - raise TypeError('`nodata` value not representable in dtype of array.') + cls._validate_nodata(viewfinder.nodata, obj.dtype) + # Don't allow original viewfinder and metadata to be modified viewfinder = viewfinder.copy() metadata = metadata.copy() @@ -117,6 +115,16 @@ def _handle_raster_input(cls, input_array, viewfinder, metadata): new_metadata=metadata) return input_array, viewfinder, metadata + @staticmethod + def _validate_nodata(nodata, dtype): + "Checks the NoData value is preserved when cast to the raster dtype" + nodata = np.array(nodata) + casted = nodata.astype(dtype, casting='unsafe') + try: + assert (nodata == casted) or np.can_cast(nodata, dtype, casting='safe') + except: + raise TypeError('`nodata` value not representable in dtype of array.') + @property def viewfinder(self): return self._viewfinder @@ -287,10 +295,7 @@ def __new__(cls, input_array, viewfinder=None, metadata={}): assert not np.issubdtype(obj.dtype, np.flexible) except: raise TypeError('`object` and `flexible` dtypes not allowed.') - try: - assert np.can_cast(viewfinder.nodata, obj.dtype, casting='safe') - except: - raise TypeError('`nodata` value not representable in dtype of array.') + cls._validate_nodata(viewfinder.nodata, obj.dtype) # Don't allow original viewfinder and metadata to be modified viewfinder = viewfinder.copy() metadata = metadata.copy()