diff --git a/spectral/__init__.py b/spectral/__init__.py index a8f5046..1452032 100644 --- a/spectral/__init__.py +++ b/spectral/__init__.py @@ -2,6 +2,8 @@ Basic package setup and global imports. ''' +# flake8: noqa + from __future__ import absolute_import, division, print_function, unicode_literals __version__ = '0.23.1' diff --git a/spectral/algorithms/__init__.py b/spectral/algorithms/__init__.py index f90d58d..796d6dd 100644 --- a/spectral/algorithms/__init__.py +++ b/spectral/algorithms/__init__.py @@ -1,10 +1,12 @@ +# flake8: noqa + from __future__ import absolute_import, division, print_function, unicode_literals from .algorithms import (mean_cov, covariance, principal_components, bdist, - linear_discriminant, create_training_classes, ndvi, - orthogonalize, transform_image, unmix, spectral_angles, - calc_stats, cov_avg, msam, noise_from_diffs, mnf, - GaussianStats, ppi, smacc) + linear_discriminant, create_training_classes, ndvi, + orthogonalize, transform_image, unmix, spectral_angles, + calc_stats, cov_avg, msam, noise_from_diffs, mnf, + GaussianStats, ppi, smacc) from .classifiers import * from .clustering import L1, L2, kmeans from .resampling import BandResampler diff --git a/spectral/algorithms/algorithms.py b/spectral/algorithms/algorithms.py index 567a5c5..1f2852d 100644 --- a/spectral/algorithms/algorithms.py +++ b/spectral/algorithms/algorithms.py @@ -15,6 +15,7 @@ from .spymath import matrix_sqrt from .transforms import LinearTransform + class Iterator: ''' Base class for iterators over pixels (spectra). @@ -50,7 +51,6 @@ def get_num_bands(self): def __iter__(self): (M, N) = self.image.shape[:2] - count = 0 for i in range(M): self.row = i for j in range(N): @@ -86,6 +86,7 @@ def __iter__(self): (self.row, self.col) = (i, j) yield self.image[i, j].astype(self.image.dtype).squeeze() + def iterator(image, mask=None, index=None): ''' Returns an iterator over pixels in the image. @@ -292,6 +293,7 @@ def cov_avg(image, mask, weighted=True): else: return np.mean([c.cov for c in classes], axis=0, dtype=np.float64) + def covariance(*args): ''' Returns the covariance of the set of vectors. @@ -641,7 +643,6 @@ class covariances, mean vector, and a callable transform to convert data to Richards, J.A. & Jia, X. Remote Sensing Digital Image Analysis: An Introduction. (Springer: Berlin, 1999). ''' - C = len(classes) # Number of training sets rank = len(classes) - 1 classes.calc_stats() @@ -677,13 +678,13 @@ class covariances, mean vector, and a callable transform to convert data to return FisherLinearDiscriminant(vals.real, vecs.real, mean, cov_b, cov_w) + # Alias for Linear Discriminant Analysis (LDA) lda = linear_discriminant def log_det(x): - return sum(np.log([eigv for eigv in np.linalg.eigvals(x) - if eigv > 0])) + return sum(np.log([eigv for eigv in np.linalg.eigvals(x) if eigv > 0])) class GaussianStats(object): @@ -933,7 +934,6 @@ def size(self): else: return np.sum(np.not_equal(self.mask, 0).ravel()) - def calc_stats(self): ''' Calculates statistics for the class. @@ -1050,13 +1050,13 @@ def calc_stats(self): def save(self, filename, calc_stats=False): for c in list(self.classes.values()): if c.stats is None: - if calc_stats == False: + if calc_stats is False: msg = 'Class statistics are missing from at least one ' \ 'class and are required to save the training class ' \ 'data. Call the `save` method with keyword ' \ '`calc_stats=True` if you want to compute them and ' \ 'then save the class data.' - raise Exception (msg) + raise Exception(msg) else: c.calc_stats() f = open(filename, 'wb') @@ -1195,6 +1195,7 @@ def bdist(class1, class2): terms = bdist_terms(class1, class2) return terms[0] + terms[1] + bDistance = bdist @@ -1364,6 +1365,7 @@ def spectral_angles(data, members): dots = np.clip(dots / norms[:, :, np.newaxis], -1, 1) return np.arccos(dots) + def msam(data, members): '''Modified SAM scores according to Oshigami, et al [1]. Endmembers are mean-subtracted prior to spectral angle calculation. Results are @@ -1417,7 +1419,7 @@ def msam(data, members): for i in range(M): for j in range(N): - #Fisher z trafo type operation + # Fisher z trafo type operation v = data[i, j] - np.mean(data[i, j]) v /= np.sqrt(v.dot(v)) v = np.clip(v, -1, 1) @@ -1425,9 +1427,10 @@ def msam(data, members): # Calculate Mineral Index according to Oshigami et al. # (Intnl. J. of Remote Sens. 2013) a = np.clip(v.dot(m[k]), -1, 1) - angles[i,j,k]= 1.0 - np.arccos(a) / (math.pi / 2) + angles[i, j, k] = 1.0 - np.arccos(a) / (math.pi / 2) return angles + def noise_from_diffs(X, direction='lowerright'): '''Estimates noise statistcs by taking differences of adjacent pixels. @@ -1469,6 +1472,7 @@ def noise_from_diffs(X, direction='lowerright'): stats.cov /= 2.0 return stats + class MNFResult(object): '''Result object returned by :func:`~spectral.algorithms.algorithms.mnf`. @@ -1504,7 +1508,7 @@ def _num_from_kwargs(self, **kwargs): raise Exception('Keyword not recognized.') num = kwargs.get('num', None) snr = kwargs.get('snr', None) - if num == snr == None: + if num == snr is None: raise Exception('Must specify either `num` or `snr` keyword.') if None not in (num, snr): raise Exception('Can not specify both `num` and `snr` keywords.') @@ -1563,8 +1567,8 @@ def get_denoising_transform(self, **kwargs): V = self.napc.eigenvectors Vr = np.array(V) Vr[:, N:] = 0. - f = LinearTransform(self.noise.sqrt_cov.dot(Vr).dot(V.T) \ - .dot(self.noise.sqrt_inv_cov), + f = LinearTransform(self.noise.sqrt_cov.dot(Vr).dot(V.T) + .dot(self.noise.sqrt_inv_cov), pre=-self.signal.mean, post=self.signal.mean) return f @@ -1626,6 +1630,7 @@ def num_with_snr(self, snr): '''Returns the number of components with SNR >= `snr`.''' return np.sum(self.napc.eigenvalues >= (snr + 1)) + def mnf(signal, noise): '''Computes Minimum Noise Fraction / Noise-Adjusted Principal Components. @@ -1686,6 +1691,7 @@ def mnf(signal, noise): napc = PrincipalComponents(L, V, wstats) return MNFResult(signal, noise, napc) + def ppi(X, niters, threshold=0, centered=False, start=None, display=0, **imshow_kwargs): '''Returns pixel purity indices for an image. @@ -1759,7 +1765,7 @@ def ppi(X, niters, threshold=0, centered=False, start=None, display=0, ''' if display is not None: if not isinstance(display, Integral) or isinstance(display, bool) or \ - display < 0: + display < 0: msg = '`display` argument must be a non-negative integer.' raise ValueError(msg) @@ -1941,7 +1947,7 @@ def smacc(spectra, min_endmembers=None, max_residual_norm=float('Inf')): residual_norms[:] = np.sqrt(np.einsum('ij,ij->i', R, R)) current_max_residual_norm = np.max(residual_norms) print('Found {0} endmembers, current max residual norm is {1:.4f}\r' - .format(len(q), current_max_residual_norm), end='') + .format(len(q), current_max_residual_norm), end='') # Correction as suggested in the SMACC paper. for k, s in enumerate(q): diff --git a/spectral/algorithms/classifiers.py b/spectral/algorithms/classifiers.py index 1e51f7e..f0aa781 100644 --- a/spectral/algorithms/classifiers.py +++ b/spectral/algorithms/classifiers.py @@ -8,8 +8,6 @@ import math import numpy as np -from warnings import warn - import spectral as spy from .algorithms import GaussianStats, ImageIterator from .detectors import RX @@ -18,6 +16,7 @@ __all__ = ('GaussianClassifier', 'MahalanobisDistanceClassifier', 'PerceptronClassifier') + class Classifier(object): ''' Base class for Classifiers. Child classes must implement the @@ -149,9 +148,9 @@ def classify_spectrum(self, x): for (i, cl) in enumerate(self.classes): delta = (x - cl.stats.mean) scores[i] = math.log(cl.class_prob) - 0.5 * cl.stats.log_det_cov \ - - 0.5 * delta.dot(cl.stats.inv_cov).dot(delta) + - 0.5 * delta.dot(cl.stats.inv_cov).dot(delta) return self.classes[np.argmax(scores)].index - + def classify_image(self, image): '''Classifies an entire image, returning a classification map. @@ -301,7 +300,7 @@ class PerceptronClassifier(Perceptron, SupervisedClassifier): >>> classes = create_training_classes(xdata, gt) >>> nfeatures = xdata.shape[-1] >>> nclasses = len(classes) - >>> + >>> >>> p = PerceptronClassifier([nfeatures, 20, 8, nclasses]) >>> p.train(classes, 20, clip=0., accuracy=100., batch=1, >>> momentum=0.3, rate=0.3) @@ -407,8 +406,8 @@ class in `training_data`. If this argument is not provided, if class_data[i].shape[0] > samples_per_class: class_data[i] = class_data[i][:samples_per_class] X = np.vstack(class_data) - y = np.hstack([np.ones(c.shape[0], dtype=np.int16) * i for \ - (i, c) in enumerate(class_data)]) + y = np.hstack([np.ones(c.shape[0], dtype=np.int16) * i for + (i, c) in enumerate(class_data)]) Y = np.eye(np.max(y) + 1, dtype=np.int16)[y] if 'stdout' in kwargs: @@ -441,4 +440,3 @@ def classify_spectrum(self, x): def classify(self, X, **kwargs): return Classifier.classify(self, X, **kwargs) - diff --git a/spectral/algorithms/clustering.py b/spectral/algorithms/clustering.py index a24b88c..0e574d1 100644 --- a/spectral/algorithms/clustering.py +++ b/spectral/algorithms/clustering.py @@ -8,9 +8,9 @@ import numpy as np import spectral as spy -from .classifiers import Classifier from ..utilities.errors import has_nan, NaNValueError + def L1(v1, v2): 'Returns L1 distance between 2 rank-1 arrays.' return np.sum(abs((v1 - v2))) @@ -360,4 +360,3 @@ def kmeans_ndarray(image, nclusters=10, max_iterations=20, **kwargs): logger.info('kmeans terminated with %d clusters after %d iterations.', len(set(old_clusters.ravel())), itnum - 1) return (old_clusters.reshape(nrows, ncols), centers) - diff --git a/spectral/algorithms/continuum.py b/spectral/algorithms/continuum.py index 17e6624..8e9808b 100644 --- a/spectral/algorithms/continuum.py +++ b/spectral/algorithms/continuum.py @@ -24,12 +24,8 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import logging import numpy as np -import spectral as spy -from ..utilities.errors import has_nan, NaNValueError - def _segment_concave_region(spectrum, bands, indices, ind_fill, ibegin, iend): # Here we don't search for local maxima w.r.t. line that connects ends of this region. diff --git a/spectral/algorithms/detectors.py b/spectral/algorithms/detectors.py index 8cc9470..d6e1427 100644 --- a/spectral/algorithms/detectors.py +++ b/spectral/algorithms/detectors.py @@ -12,7 +12,6 @@ from .algorithms import calc_stats from .transforms import LinearTransform from .spatial import map_outer_window_stats -from .spymath import matrix_sqrt class MatchedFilter(LinearTransform): @@ -77,6 +76,7 @@ def whiten(self, X): self._whitening_transform = LinearTransform(A, pre=-self.u_b) return self._whitening_transform(X) + def matched_filter(X, target, background=None, window=None, cov=None): r'''Computes a linear matched filter target detector score. @@ -85,7 +85,7 @@ def matched_filter(X, target, background=None, window=None, cov=None): y = matched_filter(X, target, background) y = matched_filter(X, target, window= [, cov=]) - + Given target/background means and a common covariance matrix, the matched filter response is given by: @@ -144,7 +144,7 @@ def matched_filter(X, target, background=None, window=None, cov=None): window will cover `image[:21, :21]`. For the pixel at (50, 1), the inner window will cover `image[48:53, :4]` and the outer window will cover `image[40:51, :21]`. - + `cov` (ndarray): An optional covariance to use. If this parameter is given, `cov` @@ -192,7 +192,7 @@ class RX(): pattern with unknown spectral distribution," IEEE Trans. Acoust., Speech, Signal Processing, vol. 38, pp. 1760-1770, Oct. 1990. ''' - dim_out=1 + dim_out = 1 def __init__(self, background=None): '''Creates the detector, given optional background/target stats. @@ -272,6 +272,7 @@ def __call__(self, X): # raise Exception('Unexpected number of dimensions.') # + def rx(X, background=None, window=None, cov=None): r'''Computes RX anomaly detector scores. @@ -281,7 +282,7 @@ def rx(X, background=None, window=None, cov=None): y = rx(X, window=(inner, outer) [, cov=C]) - The RX anomaly detector produces a detection statistic equal to the + The RX anomaly detector produces a detection statistic equal to the squared Mahalanobis distance of a spectrum from a background distribution according to @@ -337,7 +338,7 @@ def rx(X, background=None, window=None, cov=None): window will cover `image[:21, :21]`. For the pixel at (50, 1), the inner window will cover `image[48:53, :4]` and the outer window will cover `image[40:51, :21]`. - + `cov` (ndarray): An optional covariance to use. If this parameter is given, `cov` @@ -350,7 +351,7 @@ def rx(X, background=None, window=None, cov=None): The return value will be the RX detector score (squared Mahalanobis distance) for each pixel given. If `X` has shape (R, C, B), the returned ndarray will have shape (R, C).. - + References: Reed, I.S. and Yu, X., "Adaptive multiple-band CFAR detection of an optical @@ -362,14 +363,17 @@ def rx(X, background=None, window=None, cov=None): 'exclusive.') if window is not None: rx = RX() + def rx_wrapper(bg, x): rx.set_background(bg) return rx(x) + return map_outer_window_stats(rx_wrapper, X, window[0], window[1], dim_out=1, cov=cov) else: return RX(background)(X) + class ACE(): r'''Adaptive Coherence/Cosine Estimator (ACE). ''' @@ -412,7 +416,7 @@ def __init__(self, target, background=None, **kwargs): self.vectorize = kwargs.get('vectorize', True) self._target = None self._background = None - + self.set_target(target) if background is not None: self.set_background(background) @@ -446,7 +450,7 @@ def set_background(self, stats): '''Sets background statistics to be used when applying the detector. Arguments: - + `stats` (`GaussianStats`): The Gaussian statistics for the background (e.g., the result @@ -469,7 +473,7 @@ def _update_constants(self): else: self._C = None self._P = None - + def __call__(self, X): '''Compute ACE detector scores for X. @@ -489,7 +493,7 @@ def __call__(self, X): of floats with one less dimension than the input. ''' if not isinstance(X, np.ndarray): - raise TypeError('Expected a numpy.ndarray.') + raise TypeError('Expected a numpy.ndarray.') shape = X.shape @@ -505,7 +509,7 @@ def __call__(self, X): if self.vectorize: # Compute all scores at once - + if self._background.mean is not None: X = X - self._background.mean @@ -532,7 +536,7 @@ def ace(X, target, background=None, window=None, cov=None, **kwargs): y = ace(X, target, background) y = ace(X, target, window= [, cov=]) - + Arguments: `X` (numpy.ndarray): @@ -558,14 +562,14 @@ def ace(X, target, background=None, window=None, cov=None, **kwargs): An ndarray with shape (D, B). In this case, `target` contains `D` length-B targets that define a subspace for the detector. The return value will be an ndarray with shape (R, C). - + A length-D sequence (e.g., list or tuple) of length-B ndarrays. In this case, the detector will be applied seperately to each of the `D` targets. This is equivalent to calling the function sequentially for each target and stacking the results but is much faster. The return value will be an ndarray with shape (R, C, D). - + `background` (`GaussianStats`): The Gaussian statistics for the background (e.g., the result @@ -596,7 +600,7 @@ def ace(X, target, background=None, window=None, cov=None, **kwargs): window will cover `image[:21, :21]`. For the pixel at (50, 1), the inner window will cover `image[48:53, :4]` and the outer window will cover `image[40:51, :21]`. - + `cov` (ndarray): An optional covariance to use. If this parameter is given, `cov` @@ -640,9 +644,11 @@ def ace(X, target, background=None, window=None, cov=None, **kwargs): # Separate score arrays for each target in target list if background is None: detector.set_background(calc_stats(X)) + def apply_to_target(t): detector.set_target(t) return detector(X) + result = np.array([apply_to_target(t) for t in target]) if result.ndim == 3: result = result.transpose(1, 2, 0) @@ -657,12 +663,15 @@ def ace_wrapper(bg, x): dim_out=1, cov=cov) else: # Separate score arrays for each target in target list + def apply_to_target(t, x): detector.set_target(t) return detector(x) + def ace_wrapper(bg, x): detector.set_background(bg) return [apply_to_target(t, x) for t in target] + result = map_outer_window_stats(ace_wrapper, X, window[0], window[1], dim_out=len(target), cov=cov) if result.ndim == 3: @@ -674,4 +683,3 @@ def ace_wrapper(bg, x): return np.clip(result, 0, 1, out=result) else: return np.clip(result, 0, 1) - diff --git a/spectral/algorithms/perceptron.py b/spectral/algorithms/perceptron.py index 2f67800..9b2c8a6 100644 --- a/spectral/algorithms/perceptron.py +++ b/spectral/algorithms/perceptron.py @@ -4,12 +4,12 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import itertools import math import numpy as np import os import sys + class PerceptronLayer: '''A multilayer perceptron layer with sigmoid activation function.''' def __init__(self, shape, k=1.0, weights=None): @@ -129,7 +129,6 @@ def __init__(self, layers, k=1.0): # training (with CTRL-C) self.cache_weights = True - def input(self, x, clip=0.0): '''Sets Perceptron input, activates neurons and sets & returns output. @@ -341,13 +340,13 @@ def _adjust_weights(self, rate, momentum, num_summed, stdout): except KeyboardInterrupt: if self.cache_weights: stdout.write('Interrupt during weight adjustment. Restoring ' \ - 'previous weights.\n') + 'previous weights.\n') for i in range(len(weights)): self.layers[i].weights = weights[i] else: stdout.write('Interrupt during weight adjustment. Weight ' \ - 'cacheing was disabled so current weights may' \ - 'be corrupt.\n') + 'cacheing was disabled so current weights may' \ + 'be corrupt.\n') raise finally: self._reset_corrections() @@ -365,11 +364,11 @@ def _set_scaling(self, X): maxes = x else: mins = np.min([mins, x], axis=0) - maxes = np.max([maxes, x], axis = 0) + maxes = np.max([maxes, x], axis=0) self._offset = mins r = maxes - mins self._scale = 1. / np.where(r < self.min_input_diff, 1, r) - + # Sample data @@ -394,32 +393,38 @@ def _set_scaling(self, X): [[1, 1], [1]], ] + def test_case(XY, shape, *args, **kwargs): (X, Y) = list(zip(*XY)) p = Perceptron(shape) trained = p.train(X, Y, *args, **kwargs) return (trained, p) - + + def test_xor(*args, **kwargs): XY = xor_data shape = [2, 2, 1] return test_case(XY, shape, *args, **kwargs) + def test_xor222(*args, **kwargs): XY = xor_data2 shape = [2, 2, 2] return test_case(XY, shape, *args, **kwargs) + def test_xor231(*args, **kwargs): XY = xor_data shape = [2, 3, 1] return test_case(XY, shape, *args, **kwargs) + def test_and(*args, **kwargs): XY = and_data shape = [2, 1] return test_case(XY, shape, *args, **kwargs) + if __name__ == '__main__': tests = [('AND (2x1)', test_and), ('XOR (2x2x1)', test_xor), @@ -431,7 +436,7 @@ def test_and(*args, **kwargs): print('Training results for 5000 iterations') print('------------------------------------') for (name, result) in nr: - s = [ 'FAILED', 'PASSED'][result] + s = ['FAILED', 'PASSED'][result] print('{0:<20}: {1}'.format(name, s)) if False in results: print('\nNote: XOR convergence for these small network sizes is') diff --git a/spectral/algorithms/resampling.py b/spectral/algorithms/resampling.py index 6ae71e4..9a89815 100644 --- a/spectral/algorithms/resampling.py +++ b/spectral/algorithms/resampling.py @@ -10,6 +10,7 @@ from ..spectral import BandInfo + def erf_local(x): # save the sign of x sign = 1 if x >= 0 else -1 @@ -24,9 +25,10 @@ def erf_local(x): p = 0.3275911 # A&S formula 7.1.26 - t = 1.0/(1.0 + p*x) - y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*math.exp(-x*x) - return sign*y # erf(-x) = -erf(x) + t = 1.0/(1.0 + p * x) + y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * math.exp(-x * x) + return sign * y # erf(-x) = -erf(x) + try: from math import erf @@ -36,19 +38,23 @@ def erf_local(x): except: erf = erf_local + def erfc(z): '''Complement of the error function.''' return 1.0 - erf(z) + def normal_cdf(x): '''CDF of the normal distribution.''' sqrt2 = 1.4142135623730951 return 0.5 * erfc(-x / sqrt2) + def normal_integral(a, b): '''Integral of the normal distribution from a to b.''' return normal_cdf(b) - normal_cdf(a) + def ranges_overlap(R1, R2): '''Returns True if there is overlap between ranges of pairs R1 and R2.''' if (R1[0] < R2[0] and R1[1] < R2[0]) or \ @@ -56,14 +62,17 @@ def ranges_overlap(R1, R2): return False return True + def overlap(R1, R2): '''Returns (min, max) of overlap between the ranges of pairs R1 and R2.''' return (max(R1[0], R2[0]), min(R1[1], R2[1])) + def normal(mean, stdev, x): sqrt_2pi = 2.5066282746310002 return math.exp(-((x - mean) / stdev)**2 / 2.0) / (sqrt_2pi * stdev) + def build_fwhm(centers): '''Returns FWHM list, assuming FWHM is midway between adjacent bands. ''' @@ -74,6 +83,7 @@ def build_fwhm(centers): fwhm[i] = (centers[i + 1] - centers[i - 1]) / 2.0 return fwhm + def create_resampling_matrix(centers1, fwhm1, centers2, fwhm2): ''' Returns a resampling matrix to convert spectra from one band discretization @@ -121,7 +131,7 @@ def create_resampling_matrix(centers1, fwhm1, centers2, fwhm2): # the new schema. if len(matches) == 0: logger.info('No overlap for target band %d (%f / %f)', - i, centers2[i], fwhm2[i]) + i, centers2[i], fwhm2[i]) M[i, 0] = nan continue @@ -134,8 +144,8 @@ def create_resampling_matrix(centers1, fwhm1, centers2, fwhm2): contribs = np.zeros(len(matches)) A = 0. for k in range(len(matches)): - #endNorms = [normal(centers2[i], stdev, x) for x in overlaps[k]] - #dA = (overlaps[k][1] - overlaps[k][0]) * sum(endNorms) / 2.0 + # endNorms = [normal(centers2[i], stdev, x) for x in overlaps[k]] + # dA = (overlaps[k][1] - overlaps[k][0]) * sum(endNorms) / 2.0 (a, b) = [(x - centers2[i]) / stdev for x in overlaps[k]] dA = normal_integral(a, b) contribs[k] = dA @@ -145,6 +155,7 @@ def create_resampling_matrix(centers1, fwhm1, centers2, fwhm2): M[i, matches[k]] = contribs[k] return M + class BandResampler: '''A callable object for resampling spectra between band discretizations. diff --git a/spectral/algorithms/spatial.py b/spectral/algorithms/spatial.py index 6664e54..cba35f4 100644 --- a/spectral/algorithms/spatial.py +++ b/spectral/algorithms/spatial.py @@ -13,6 +13,7 @@ import spectral as spy from .algorithms import GaussianStats, iterator_ij + def get_window_bounds(nrows, ncols, height, width, i, j): '''Returns boundaries of an image window centered on a specified pixel. @@ -47,7 +48,7 @@ def get_window_bounds(nrows, ncols, height, width, i, j): A 4-tuple of ints of the form (row_start, row_stop, col_start, col_stop). - + The dimensions of the window will always be (`height`, `width`). For pixels near the border of the image where there are insufficient pixels between the specified pixel and image border, the window will be flush @@ -80,6 +81,7 @@ def get_window_bounds(nrows, ncols, height, width, i, j): return (rmin, rmax, cmin, cmax) + def get_window_bounds_clipped(nrows, ncols, height, width, i, j): '''Returns boundaries of an image window centered on a specified pixel. @@ -114,7 +116,7 @@ def get_window_bounds_clipped(nrows, ncols, height, width, i, j): A 4-tuple of ints of the form (row_start, row_stop, col_start, col_stop). - + Near the boder of the image where there are insufficient pixels between the specified pixel and the image border, the window will be clipped. @@ -140,10 +142,11 @@ def get_window_bounds_clipped(nrows, ncols, height, width, i, j): return (rmin, rmax, cmin, cmax) + def map_window(func, image, window, rslice=(None,), cslice=(None,), border='shift', dtype=None): '''Applies a function over a rolling spatial window. - + Arguments: `func` (callable): @@ -164,7 +167,7 @@ def map_window(func, image, window, rslice=(None,), cslice=(None,), window. For `window` with even dimensions or for pixels near the image border, this may not correspond to the center pixel in the window. - + `image` (`SpyFile` or np.ndarray): The image on which the apply `func` with the specified window. @@ -174,7 +177,7 @@ def map_window(func, image, window, rslice=(None,), cslice=(None,), The size of the window, in pixels. If this value is an integer, the height and width of the window will both be set to the value. Otherwise, `window` should be a tuple of the form (height, width). - + `rslice` (tuple): Tuple of `slice` parameters specifying at which rows the function @@ -247,10 +250,11 @@ def get_val(i, j): return np.array([[get_val(r, c) for c in cvals] for r in rvals]).astype(dtype) + def map_outer_window_stats(func, image, inner, outer, dim_out=1, cov=None, dtype=None, rslice=(None,), cslice=(None,)): '''Maps a function accepting `GaussianStats` over a rolling spatial window. - + Arguments: `func` (callable): @@ -280,7 +284,7 @@ def map_outer_window_stats(func, image, inner, outer, dim_out=1, cov=None, Otherwise, `inner` should be a tuple of the form (height, width). All pixels within the inner window are excluded from statistics computed for the associated pixel. - + `outer` (int or 2-tuple of ints): The size of the outer window, in pixels. If this value is an integer, @@ -288,7 +292,7 @@ def map_outer_window_stats(func, image, inner, outer, dim_out=1, cov=None, Otherwise, `outer` should be a tuple of the form (height, width). All pixels in the outer window (but not in the inner window) are used to compute statistics for the associated pixel. - + `rslice` (tuple): Tuple of `slice` parameters specifying at which rows the function @@ -325,6 +329,7 @@ def map_outer_window_stats(func, image, inner, outer, dim_out=1, cov=None, dtype) return mapper(image, rslice, cslice) + class WindowedGaussianBackgroundMapper(object): '''A class for procucing window statistics with an inner exclusion window. ''' @@ -491,14 +496,14 @@ def __call__(self, image, rslice=(None,), cslice=(None,)): for j in range(ncols_out): jj = cvals[j] - w // 2 if i_interior_start <= rvals[i] < i_interior_stop and \ - j_interior_start <= cvals[j] < j_interior_stop: - X = image[ii : ii + h, jj : jj + w, :] + j_interior_start <= cvals[j] < j_interior_stop: + X = image[ii: ii + h, jj: jj + w, :] indices = interior_indices else: (inner, (i0, i1, j0, j1), mask) = \ create_mask(rvals[i], cvals[j], True) indices = np.argwhere(mask.ravel() == 0).squeeze() - X = image[i0 : i1, j0 : j1, :] + X = image[i0: i1, j0: j1, :] X = np.take(X.reshape((-1, B)), indices, axis=0) mean = np.mean(X, axis=0) cov = np.cov(X, rowvar=False) @@ -511,6 +516,7 @@ def __call__(self, image, rslice=(None,), cslice=(None,)): status.end_percentage() return x + def inner_outer_window_mask_creator(image_shape, inner, outer): '''Returns a function to give inner/outer windows. @@ -536,7 +542,7 @@ def inner_outer_window_mask_creator(image_shape, inner, outer): `i` (int): Row index of pixel for which to generate the mask - + `j` (int): Row index of pixel for which to generate the mask @@ -567,7 +573,7 @@ def inner_outer_window_mask_creator(image_shape, inner, outer): if wi > wo or hi > ho: raise ValueError('Inner window dimensions must be smaller than outer.') - + (ai, bi) = (hi // 2, wi // 2) (ao, bo) = (ho // 2, wo // 2) @@ -582,7 +588,7 @@ def create_mask(i, j, gen_mask=False): elif inner_imax > R: inner_imax = R inner_imin = R - hi - + inner_jmin = j - bi inner_jmax = inner_jmin + wi if inner_jmin < 0: @@ -591,7 +597,7 @@ def create_mask(i, j, gen_mask=False): elif inner_jmax > C: inner_jmax = C inner_jmin = C - wi - + # Outer window outer_imin = i - ao outer_imax = outer_imin + ho @@ -601,7 +607,7 @@ def create_mask(i, j, gen_mask=False): elif outer_imax > R: outer_imax = R outer_imin = R - ho - + outer_jmin = j - bo outer_jmax = outer_jmin + wo if outer_jmin < 0: @@ -610,17 +616,18 @@ def create_mask(i, j, gen_mask=False): elif outer_jmax > C: outer_jmax = C outer_jmin = C - wo - + inner = (inner_imin, inner_imax, inner_jmin, inner_jmax) outer = (outer_imin, outer_imax, outer_jmin, outer_jmax) if not gen_mask: return (inner, outer) mask = np.zeros((ho, wo), dtype=bool) - mask[inner_imin - outer_imin : inner_imax - outer_imin, - inner_jmin - outer_jmin : inner_jmax - outer_jmin] = True + mask[inner_imin - outer_imin: inner_imax - outer_imin, + inner_jmin - outer_jmin: inner_jmax - outer_jmin] = True return (inner, outer, mask) return create_mask + def map_class_ids(src_class_image, dest_class_image, unlabeled=None): '''Create a mapping between class labels in two classification images. @@ -637,7 +644,7 @@ def map_class_ids(src_class_image, dest_class_image, unlabeled=None): An MxN integer array of class indices. The indices in this array will be mapped to indices in `dest_class_image`. - + `dest_class_image` (ndarray): An MxN integer array of class indices. @@ -710,6 +717,7 @@ def next_id(): matches[:, j] = 0 return cmap + def map_classes(class_image, class_id_map, allow_unmapped=False): '''Modifies class indices according to a class index mapping. @@ -754,6 +762,7 @@ def map_classes(class_image, class_id_map, allow_unmapped=False): mapped[class_image == i] = j return mapped + def expand_binary_mask_for_window(mask, height, width): '''Returns a new mask including window around each pixel in source mask. diff --git a/spectral/algorithms/spymath.py b/spectral/algorithms/spymath.py index bc20c0d..d007fb0 100644 --- a/spectral/algorithms/spymath.py +++ b/spectral/algorithms/spymath.py @@ -38,7 +38,6 @@ def matrix_sqrt(X=None, symmetric=False, inverse=False, eigs=None): (vals, V) = eigs else: (vals, V) = np.linalg.eig(X) - k = len(vals) if inverse is False: SRV = np.diag(np.sqrt(vals)) else: @@ -88,5 +87,3 @@ def get_histogram_cdf_points(data, cdf_vals, ignore=None, mask=None): isort = np.argsort(data) N = len(data) return [data[isort[int(x * (N - 1))]] for x in cdf_vals] - - diff --git a/spectral/config.py b/spectral/config.py index 853887b..8809054 100644 --- a/spectral/config.py +++ b/spectral/config.py @@ -6,16 +6,17 @@ import numpy as np + class SpySettings: '''Run-time settings for the `spectral` module. After importing `spectral`, the settings object is referenced as `spectral.settings`. - + Noteworthy members: `WX_GL_DEPTH_SIZE` (integer, default 24): - + Sets the depth (in number of bits) for the OpenGL depth buffer. If calls to `view_cube` or `view_nd` result in windows with blank canvases, try reducing this value. @@ -27,7 +28,7 @@ class SpySettings: be read with original capitalization retained. `show_progress` (bool, default True): - + Indicates whether long-running algorithms should display progress to sys.stdout. It can be useful to set this value to False when SPy is embedded in another application (e.g., IPython Notebook). @@ -36,7 +37,7 @@ class SpySettings: Width and height (in inches) of windows opened with `imshow`. If this value is `None`, matplotlib's default size is used. - + `imshow_background_color` (3-tuple of integers, default (0,0,0)): Default color to use for masked pixels in `imshow` displays. @@ -54,13 +55,13 @@ class SpySettings: `imshow_stretch_all`: If True, each color channel limits are determined independently. - + `imshow_zoom_figure_width` (int, default `None`): - + Width of zoom windows opened from an imshow window. Since zoom windows are always square, this is also the window height. If this value is `None`, matplotlib's default window size is used. - + `imshow_zoom_pixel_width` (int, default 50): Number of source image pixel rows and columns to display in a @@ -117,7 +118,8 @@ class SpySettings: imshow_class_alpha = 0.5 imshow_enable_rectangle_selector = True imshow_disable_mpl_callbacks = True - + + # Default color table spy_colors = np.array([[0, 0, 0], [255, 0, 0], diff --git a/spectral/database/__init__.py b/spectral/database/__init__.py index 26e0499..ffbb276 100644 --- a/spectral/database/__init__.py +++ b/spectral/database/__init__.py @@ -1,3 +1,5 @@ +# flake8: noqa + from __future__ import absolute_import, division, print_function, unicode_literals from .aster import AsterDatabase diff --git a/spectral/database/aster.py b/spectral/database/aster.py index a33d5d0..2157530 100644 --- a/spectral/database/aster.py +++ b/spectral/database/aster.py @@ -4,16 +4,18 @@ from __future__ import absolute_import, division, print_function, unicode_literals +import numpy as np + from spectral.utilities.python23 import IS_PYTHON3, tobytes, frombytes from .spectral_database import SpectralDatabase if IS_PYTHON3: - readline = lambda fin: fin.readline() - open_file = lambda filename: open(filename, encoding='iso-8859-1') + def readline(fin): return fin.readline() + def open_file(filename): return open(filename, encoding='iso-8859-1') else: - readline = lambda fin: fin.readline().decode('iso-8859-1') - open_file = lambda filename: open(filename) + def readline(fin): return fin.readline().decode('iso-8859-1') + def open_file(filename): return open(filename) table_schemas = [ 'CREATE TABLE Samples (SampleID INTEGER PRIMARY KEY, Name TEXT, Type TEXT, Class TEXT, SubClass TEXT, ' @@ -65,7 +67,6 @@ def read_aster_file(filename): line = readline(fin).strip() if line.find('Collected by:') >= 0: haveCollectedBy = True - collectedByLineNum = i if line.startswith('Description:'): descriptionLineNum = i if line.startswith('Measurement:'): @@ -99,8 +100,8 @@ def read_aster_file(filename): # Try to handle invalid values on signature lines if nItems == 1: -# print 'single item (%s) on signature line, %s' \ -# % (pair[0], filename) + # print('single item (%s) on signature line, %s' \ + # % (pair[0], filename)) continue elif nItems > 2: print('more than 2 values on signature line,', filename) @@ -110,7 +111,7 @@ def read_aster_file(filename): except: print('corrupt signature line,', filename) if x == 0: -# print 'Zero wavelength value', filename + # print('Zero wavelength value', filename) continue elif x < 0: print('Negative wavelength value,', filename) @@ -241,7 +242,6 @@ def read_file(self, filename): def _import_files(self, data_dir, ignore=bad_files): '''Read each file in the ASTER library and convert to AVIRIS bands.''' from glob import glob - import numpy import os if not os.path.isdir(data_dir): @@ -254,8 +254,6 @@ def _import_files(self, data_dir, ignore=bad_files): numFiles = 0 numIgnored = 0 - sigID = 1 - class Sig: pass sigs = [] @@ -298,8 +296,7 @@ class Sig: m['x units'], yUnit, m['first x value'], m['last x value'], sig.x, sig.y) if numFiles == 0: - print('No data files were found in directory "%s".' \ - % data_dir) + print('No data files were found in directory "%s".' % data_dir) else: print('Processed %d files.' % numFiles) if numIgnored > 0: @@ -429,9 +426,8 @@ def create_envi_spectral_library(self, spectrumIDs, bandInfo): ''' from spectral.algorithms.resampling import BandResampler from spectral.io.envi import SpectralLibrary - import numpy import unicodedata - spectra = numpy.empty((len(spectrumIDs), len(bandInfo.centers))) + spectra = np.empty((len(spectrumIDs), len(bandInfo.centers))) names = [] for i in range(len(spectrumIDs)): sig = self.get_signature(spectrumIDs[i]) diff --git a/spectral/database/ecostress.py b/spectral/database/ecostress.py index 6c2c08b..bb041b1 100644 --- a/spectral/database/ecostress.py +++ b/spectral/database/ecostress.py @@ -6,16 +6,19 @@ import itertools import logging +import os + +from glob import glob from spectral.utilities.python23 import IS_PYTHON3 from .aster import AsterDatabase, Signature if IS_PYTHON3: - readline = lambda fin: fin.readline() - open_file = lambda filename: open(filename, encoding='iso-8859-1') + def readline(fin): return fin.readline() + def open_file(filename): return open(filename, encoding='iso-8859-1') else: - readline = lambda fin: fin.readline().decode('iso-8859-1') - open_file = lambda filename: open(filename) + def readline(fin): return fin.readline().decode('iso-8859-1') + def open_file(filename): return open(filename) def read_ecostress_file(filename): @@ -26,7 +29,7 @@ def read_ecostress_file(filename): if not IS_PYTHON3: lines = [line.decode('iso-8859-1') for line in lines] - metaline_to_pair = lambda line: [x.strip() for x in line.split(':', 1)] + def metaline_to_pair(line): return [x.strip() for x in line.split(':', 1)] s = Signature() @@ -60,7 +63,7 @@ def read_ecostress_file(filename): # Try to handle invalid values on signature lines if nItems == 1: logger.info('Skipping single item (%s) on signature line for %s', - pair[0], filename) + pair[0], filename) continue elif nItems > 2: logger.info('Skipping more than 2 values on signature line for %s', @@ -94,6 +97,7 @@ def read_ecostress_file(filename): return s + class EcostressDatabase(AsterDatabase): '''A relational database to manage ECOSTRESS spectral library data.''' @@ -146,10 +150,6 @@ def read_file(self, filename): def _import_files(self, data_dir, ignore=None): '''Import each file from the ECOSTRESS library into the database.''' - from glob import glob - import numpy - import os - logger = logging.getLogger('spectral') if not os.path.isdir(data_dir): raise Exception('Error: Invalid directory name specified.') @@ -161,8 +161,6 @@ def _import_files(self, data_dir, ignore=None): numFiles = 0 numIgnored = 0 - sigID = 1 - class Sig: pass sigs = [] diff --git a/spectral/database/spectral_database.py b/spectral/database/spectral_database.py index 9c646f5..941c992 100644 --- a/spectral/database/spectral_database.py +++ b/spectral/database/spectral_database.py @@ -76,4 +76,3 @@ def print_query(self, sql, args=None): ret = self.query(sql, args) for row in ret: print("|".join([str(x) for x in row])) - diff --git a/spectral/database/usgs.py b/spectral/database/usgs.py index 20ab92b..1249912 100644 --- a/spectral/database/usgs.py +++ b/spectral/database/usgs.py @@ -54,8 +54,14 @@ def array_to_blob(arr): 'ASTER': ['ASTER'], 'AVIRIS': ['AVIRIS', 'aviris'], 'BECK': ['BECK'], - 'CRISM JOINED MTR3': ['CRISM Bandpass(FWHM) JOINED MTR3', 'CRISM Waves JOINED MTR3', 'CRISM Bandpass JOINED MTR3', 'CRISM JOINED MTR3'], - 'CRISM GLOBAL': ['CRISM Bandpass(FWHM) GLOBAL', 'CRISM Wavelengths GLOBAL', 'CRISM Waves GLOBAL', 'CRISM GLOBAL'], + 'CRISM JOINED MTR3': ['CRISM Bandpass(FWHM) JOINED MTR3', + 'CRISM Waves JOINED MTR3', + 'CRISM Bandpass JOINED MTR3', + 'CRISM JOINED MTR3'], + 'CRISM GLOBAL': ['CRISM Bandpass(FWHM) GLOBAL', + 'CRISM Wavelengths GLOBAL', + 'CRISM Waves GLOBAL', + 'CRISM GLOBAL'], 'Hyperion': ['Hyperion'], 'HyMap2': ['HyMap2'], 'Landsat8': ['Landsat8'], @@ -66,6 +72,7 @@ def array_to_blob(arr): 'WorldView3': ['WorldView3'] } + class SpectrometerData: ''' Holds data for spectrometer, from USGS spectral library. @@ -87,8 +94,11 @@ def header(self): Returns: String representation of basic meta data. ''' - return '{0} Record={1}: {2} {3} {4}'.format(self.libname, self.record, - self.measurement, self.description) + return '{0} Record={1}: {2} {3} {4}'.format(self.libname, + self.record, + self.measurement_type, + self.unit, + self.description) @ classmethod def read_from_file(cls, filename): @@ -216,7 +226,7 @@ def header(self): Returns: String representation of basic meta data. ''' - return '{0} Record={1}: {2} {3}{4} {5}'.format(self.libname, self.record, + return '{0} Record={1}: {2} {3} {4} {5}'.format(self.libname, self.record, self.description, self.spectrometer, self.purity, self.measurement_type) @@ -441,7 +451,6 @@ def __init__(self, sqlite_filename=None): def _import_files(self, data_dir): from glob import glob - import numpy import os logger = logging.getLogger('spectral') @@ -524,7 +533,6 @@ def get_spectrum(self, sampleID): Returns a pair of vectors containing the wavelengths and measured values values of a measurment. ''' - import array query = '''SELECT ValuesArray, AssumedWLSpmeterDataID FROM Samples WHERE SampleID = ?''' result = self.cursor.execute(query, (sampleID,)) rows = result.fetchall() @@ -567,7 +575,7 @@ def create_envi_spectral_library(self, spectrumIDs, bandInfo): resampled to the same discretization specified by the bandInfo parameter. See :class:`spectral.BandResampler` for details on the resampling method used. - Note that expected units for bands are micrometers. + Note that expected units for bands are micrometers. ''' from spectral.algorithms.resampling import BandResampler from spectral.io.envi import SpectralLibrary diff --git a/spectral/graphics/__init__.py b/spectral/graphics/__init__.py index b36e31b..2fa5f4d 100644 --- a/spectral/graphics/__init__.py +++ b/spectral/graphics/__init__.py @@ -1,3 +1,5 @@ +# flake8: noqa + from __future__ import absolute_import, division, print_function, unicode_literals from .graphics import (save_rgb, view, view_indexed, view_cube, view_nd, diff --git a/spectral/graphics/colorscale.py b/spectral/graphics/colorscale.py index 5f1e0f5..e26d399 100644 --- a/spectral/graphics/colorscale.py +++ b/spectral/graphics/colorscale.py @@ -6,6 +6,7 @@ import numpy as np + class ColorScale: ''' A color scale class to map scalar values to rgb colors. The class allows @@ -131,4 +132,5 @@ def create_default_color_scale(ntics=0): scale = ColorScale(levels, mycolors, ntics) return scale + default_color_scale = create_default_color_scale() diff --git a/spectral/graphics/graphics.py b/spectral/graphics/graphics.py index 4c82d8b..8da4c6c 100644 --- a/spectral/graphics/graphics.py +++ b/spectral/graphics/graphics.py @@ -4,11 +4,9 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import io from numbers import Number import numpy as np import sys -import time import warnings from ..algorithms.spymath import get_histogram_cdf_points @@ -16,6 +14,7 @@ from ..image import Image from ..spectral import settings + class WindowProxy(object): '''Base class for proxy objects to access data from display windows.''' def __init__(self, window): @@ -276,6 +275,7 @@ def view_indexed(*args, **kwargs): return view(*args, **kwargs) + def imshow(data, bands=None, **kwargs): '''A wrapper around matplotlib's imshow for multi-band images. @@ -323,7 +323,7 @@ def imshow(data, bands=None, **kwargs): for k in ['stretch', 'stretch_all', 'bounds']: if k in kwargs: rgb_kwargs[k] = kwargs.pop(k) - + imshow_kwargs = {'cmap': 'gray'} imshow_kwargs.update(kwargs) @@ -334,12 +334,13 @@ def imshow(data, bands=None, **kwargs): rgb = rgb[:, :, 0] ax = plt.imshow(rgb, **imshow_kwargs) - if show_xaxis == False: + if show_xaxis is False: plt.gca().xaxis.set_visible(False) - if show_yaxis == False: + if show_yaxis is False: plt.gca().yaxis.set_visible(False) return ax + def make_pil_image(*args, **kwargs): '''Creates a PIL Image object. @@ -349,10 +350,9 @@ def make_pil_image(*args, **kwargs): See `get_rgb` for description of arguments. ''' try: - from PIL import Image, ImageDraw + from PIL import Image except ImportError: import Image - import ImageDraw rgb = get_rgb(*args, **kwargs) rgb = (rgb * 255).astype(np.ubyte) @@ -474,7 +474,7 @@ def get_rgb(source, bands=None, **kwargs): If neither `stretch` nor `bounds` are specified, then the default value of `stretch` defined by `spectral.settings.imshow_stretch` will be used. - + `bounds` (tuple): This keyword functions similarly to the `stretch` keyword, except @@ -482,7 +482,7 @@ def get_rgb(source, bands=None, **kwargs): histogram values. The form of this keyword is the same as the first two forms for the `stretch` keyword (i.e., either a 2-tuple of numbers or a 3-tuple of 2-tuples of numbers). - + `stretch_all` (bool): If this keyword is True, each color channel will be scaled @@ -520,12 +520,14 @@ def get_rgb(source, bands=None, **kwargs): ''' return get_rgb_meta(source, bands, **kwargs)[0] + def _fill_mask(arr, mask, fill_value): if mask is None: return arr arr[mask == 0] = np.array(fill_value) / 255. return arr + def get_rgb_meta(source, bands=None, **kwargs): '''Same as get_rgb but also returns some metadata. @@ -560,6 +562,7 @@ def get_rgb_meta(source, bands=None, **kwargs): msg = 'Unable to interpret "default bands" in image ' \ 'metadata. Defaulting to first, middle, & last band.' warnings.warn(msg) + pass elif source.shape[-1] == 1: bands = [0] if len(bands) == 0: @@ -603,7 +606,7 @@ def get_rgb_meta(source, bands=None, **kwargs): meta['mode'] = 'indexed' rgb = rgb.astype(int) pal = kwargs["colors"] - rgb = pal[rgb[:,:,0]] / 255. + rgb = pal[rgb[:, :, 0]] / 255. return (_fill_mask(rgb, mask, bg), meta) elif color_scale is not None: # Colors should be generated from the supplied color scale @@ -629,7 +632,7 @@ def get_rgb_meta(source, bands=None, **kwargs): stretch_all = kwargs.get('stretch_all', settings.imshow_stretch_all) bounds = kwargs.get('bounds', None) - if bounds is not None: + if bounds is not None: # Data limits for the color stretch are set explicitly bounds = np.array(bounds) if bounds.shape not in ((2,), (3, 2)): @@ -668,21 +671,21 @@ def get_rgb_meta(source, bands=None, **kwargs): else: # Use a common lower/upper limit for each band by taking # the lowest lower limit and greatest upper limit. - lims = np.array([get_histogram_cdf_points(rgb[:,:,i], stretch, - ignore=nondata) \ - for i in range(3)]) - minmax = np.array([lims[:,0].min(), lims[:,1].max()]) + lims = np.array([get_histogram_cdf_points(rgb[:, :, i], stretch, + ignore=nondata) + for i in range(3)]) + minmax = np.array([lims[:, 0].min(), lims[:, 1].max()]) rgb_lims = minmax[np.newaxis, :].repeat(3, axis=0) else: if monochrome: # Not sure why anyone would want separate RGB stretches for # a gray-scale image but we'll let them. - rgb_lims = [get_histogram_cdf_points(rgb[:,:,0], stretch[i], - ignore=nondata) \ + rgb_lims = [get_histogram_cdf_points(rgb[:, :, 0], stretch[i], + ignore=nondata) for i in range(3)] elif stretch_all: - rgb_lims = [get_histogram_cdf_points(rgb[:,:,i], stretch[i], - ignore=nondata) \ + rgb_lims = [get_histogram_cdf_points(rgb[:, :, i], stretch[i], + ignore=nondata) for i in range(3)] else: msg = 'Can not use common stretch if different stretch ' \ @@ -701,10 +704,12 @@ def get_rgb_meta(source, bands=None, **kwargs): rgb[:, :, i] = np.clip((rgb[:, :, i] - lower) / span, 0, 1) return (_fill_mask(rgb, mask, bg), meta) + # For checking if valid keywords were supplied _get_rgb_kwargs = ('stretch', 'stretch_all', 'bounds', 'colors', 'color_scale', 'auto_scale', 'ignore', 'mask', 'bg') + def running_ipython(): '''Returns True if ipython is running.''' try: @@ -753,8 +758,7 @@ def check_wx_app(): ''' import spectral import wx - if wx.GetApp() is None and spectral.settings.START_WX_APP == True: + if wx.GetApp() is None and spectral.settings.START_WX_APP is True: warnings.warn('\nThere is no current wx.App object - creating one now.', UserWarning) spectral.app = wx.App() - diff --git a/spectral/graphics/hypercube.py b/spectral/graphics/hypercube.py index a798a70..9593a30 100644 --- a/spectral/graphics/hypercube.py +++ b/spectral/graphics/hypercube.py @@ -44,7 +44,6 @@ raise ImportError("Required dependency wx.glcanvas not present") from .. import settings -from ..image import Image from ..io.spyfile import SpyFile from .colorscale import create_default_color_scale from .graphics import make_pil_image, SpyWindow @@ -72,6 +71,7 @@ def xyz_to_rtp(x, y, z): theta = math.acos(z / r) * 180. / math.pi return [r, theta, phi] + (DOWN, UP) = (1, 0) @@ -102,7 +102,7 @@ def motion(self, event): '''Handles panning & zooming for mouse click+drag events.''' if DOWN not in (self.left, self.right): return - #print 'Mouse movement:', x, y + # print('Mouse movement:', x, y) (w, h) = self.window.size dx = event.X - self.position[0] dy = event.Y - self.position[1] @@ -113,7 +113,6 @@ def motion(self, event): self.window.camera_pos_rtp[0] *= (float(w - dx) / w) elif wx.GetKeyState(wx.WXK_SHIFT): # Mouse movement pans target position in plane of the window - camera_pos = np.array(rtp_to_xyz(*self.window.camera_pos_rtp)) view_vec = -np.array(rtp_to_xyz(*self.window.camera_pos_rtp)) zhat = np.array([0.0, 0.0, 1.0]) right = -np.cross(zhat, view_vec) @@ -134,6 +133,7 @@ def motion(self, event): self.window.Refresh() event.Skip() + class HypercubeWindow(wx.Frame, SpyWindow): """A simple class for using OpenGL with wxPython.""" @@ -164,7 +164,7 @@ def __init__(self, data, parent, id, *args, **kwargs): # These members can be modified before calling the show method. self.clear_color = tuple(kwargs.get('background', (0., 0., 0.))) \ - + (1.,) + + (1.,) self.win_pos = (100, 100) self.fovy = 60. self.znear = 0.1 @@ -239,7 +239,6 @@ def load_textures(self): self.textures = gl.glGenTextures(6) texImages = [] (a, b, c) = data.shape - texSizes = [(b, a), (b, c), (a, c), (b, c), (a, c), (b, a)] for i in range(len(images)): try: # API change for Pillow @@ -247,6 +246,7 @@ def load_textures(self): except: # Fall back to old PIL API img = images[i].tostring("raw", "RGBX", 0, -1) + pass (dim_x, dim_y) = images[i].size texImages.append(img) diff --git a/spectral/graphics/ndwindow.py b/spectral/graphics/ndwindow.py index f9c7a7b..c6b1123 100644 --- a/spectral/graphics/ndwindow.py +++ b/spectral/graphics/ndwindow.py @@ -6,10 +6,8 @@ import math import numpy as np -import os from pprint import pprint import random -import time try: import wx @@ -19,8 +17,7 @@ from .. import settings from ..config import spy_colors -from .colorscale import ColorScale -from .spypylab import ImageView, MplCallback, SpyMplEvent +from .spypylab import ImageView, SpyMplEvent from .graphics import WindowProxy DEFAULT_WIN_SIZE = (500, 500) # Default dimensions of image frame @@ -44,6 +41,7 @@ def xyz_to_rtp(x, y, z): theta = math.acos(z / r) * 180. / math.pi return [r, theta, phi] + (DOWN, UP) = (1, 0) @@ -67,8 +65,8 @@ def left_down(self, event): if wx.GetKeyState(wx.WXK_CONTROL) and wx.GetKeyState(wx.WXK_SHIFT): # Display the row/col and class of the selected pixel. (x, y) = self.position - cmd = lambda: self.window.get_pixel_info( - x, self.window.size[1] - y) + def cmd(): + return self.window.get_pixel_info(x, self.window.size[1] - y) self.window.add_display_command(cmd) self.window.canvas.SetCurrent(self.window.canvas.context) self.window.canvas.Refresh() @@ -106,7 +104,6 @@ def motion(self, event): '''Handles panning & zooming for mouse click+drag events.''' if DOWN not in (self.left, self.right): return - #print 'Mouse movement:', x, y (w, h) = self.window.size dx = event.X - self.position[0] dy = event.Y - self.position[1] @@ -120,7 +117,6 @@ def motion(self, event): self.window.camera_pos_rtp[2] -= xangle elif self.left == DOWN: # Mouse movement pans target position in the plane of window - camera_pos = np.array(rtp_to_xyz(*self.window.camera_pos_rtp)) view_vec = -np.array(rtp_to_xyz(*self.window.camera_pos_rtp)) zhat = np.array([0.0, 0.0, 1.0]) right = -np.cross(zhat, view_vec) @@ -148,15 +144,17 @@ def update_box_coordinates(self): R = self.window.size[1] self.window._selection_box = (xmin, R - ymax, xmax, R - ymin) + class MouseMenu(wx.Menu): '''Right-click menu for reassigning points to different classes.''' ids = [] + def __init__(self, window): super(MouseMenu, self).__init__(title='Assign to class') self.window = window self.id_classes = {} while len(self.ids) < self.window.max_menu_class + 1: - self.ids.append( wx.NewId()) + self.ids.append(wx.NewId()) for i in range(self.window.max_menu_class + 1): id = self.ids[i] self.id_classes[id] = i @@ -169,6 +167,7 @@ def reassign_points(self, event): i = self.id_classes[event.GetId()] self.window.post_reassign_selection(i) + # Multipliers for projecting data into each 3D octant octant_coeffs = np.array([ [1, 1, 1], @@ -287,6 +286,7 @@ def view_class_image(self, *args, **kwargs): ''' return self._window.view_class_image(*args, **kwargs) + class NDWindow(wx.Frame): '''A widow class for displaying N-dimensional data points.''' @@ -356,7 +356,6 @@ def __init__(self, data, parent, id, *args, **kwargs): from matplotlib.cbook import CallbackRegistry self.callbacks = CallbackRegistry() - def on_event_close(self, event=None): pass @@ -395,8 +394,10 @@ def set_data(self, data, **kwargs): from OpenGL.GL import glGetIntegerv except: from OpenGL.GL.glget import glGetIntegerv + pass - classes = kwargs.get('classes', None) +# if 'classes' in kwargs: +# self.classes = kwargs.get('classes', None) features = kwargs.get('features', list(range(6))) if self.data.shape[2] < 6: features = features[:3] @@ -698,7 +699,7 @@ def reassign_selection(self, new_class): # % (i, nreassigned, new_class) print('.', end=' ') i += 1 - print('\n%d points were reasssigned to class %d.' \ + print('\n%d points were reasssigned to class %d.' % (nreassigned_tot, new_class)) self._selection_box = None if nreassigned_tot > 0 and new_class == self.max_menu_class: @@ -829,6 +830,7 @@ def label_axis(x, y, z, label): gl.glRasterPos3f(x, y, z) glut.glutBitmapString(glut.GLUT_BITMAP_HELVETICA_18, str(label)) + def label_axis_for_feature(x, y, z, feature_ind): feature = self.octant_features[feature_ind[0]][feature_ind[1]] label_axis(x, y, z, self.labels[feature]) @@ -961,7 +963,7 @@ def on_char(self, event): self._show_unassigned = not self._show_unassigned print('SHOW UNASSIGNED =', self._show_unassigned) self._refresh_display_lists = True - + self.canvas.Refresh() def update_window_title(self): diff --git a/spectral/graphics/rasterwindow.py b/spectral/graphics/rasterwindow.py index f23ce50..368c4e3 100644 --- a/spectral/graphics/rasterwindow.py +++ b/spectral/graphics/rasterwindow.py @@ -12,6 +12,7 @@ logger = logging.getLogger('spectral') + class RasterWindow(wx.Frame, SpyWindow): ''' RasterWindow is the primary wxWindows object for displaying SPy diff --git a/spectral/graphics/spypylab.py b/spectral/graphics/spypylab.py index 3bdf2c1..a1a6266 100644 --- a/spectral/graphics/spypylab.py +++ b/spectral/graphics/spypylab.py @@ -11,6 +11,7 @@ _mpl_callbacks_checked = False + def check_disable_mpl_callbacks(): '''Disables matplotlib key event handlers, if appropriate.''' import matplotlib as mpl @@ -25,14 +26,17 @@ def check_disable_mpl_callbacks(): mpl.rcParams['keymap.yscale'] = '' mpl.rcParams['keymap.home'] = 'r' + def xy_to_rowcol(x, y): '''Converts image (x, y) coordinate to pixel (row, col).''' return (int(y + 0.5), int(x + 0.5)) + def rowcol_to_xy(r, c): '''Converts pixel (row, col) coordinate to (x, y) of pixel center.''' return (float(c), float(r)) + class MplCallback(object): '''Base class for callbacks using matplotlib's CallbackRegistry. @@ -46,7 +50,7 @@ class MplCallback(object): # associated exception will be rethrown. raise_event_exceptions = False show_events = False - + def __init__(self, registry=None, event=None, callback=None): ''' Arguments: @@ -95,10 +99,10 @@ def set_registry(self, registry=None): self.registry = registry.axes.figure.canvas else: self.registry = registry - + def connect(self, registry=None, event=None, callback=None): '''Binds the callback to the registry and begins receiving events. - + Arguments: registry (ImageView, CallbackRegistry, or FigureCanvas): @@ -129,10 +133,6 @@ def connect(self, registry=None, event=None, callback=None): self.event = event if callback is not None: self.callback = callback - if self.callback is None: - cb = self - else: - cb = self.callback if isinstance(self.registry, CallbackRegistry): self.cid = self.registry.connect(self.event, self) elif isinstance(self.registry, ImageView): @@ -173,12 +173,14 @@ def __call__(self, *args, **kwargs): if self.raise_event_exceptions: raise e + class ImageViewCallback(MplCallback): '''Base class for callbacks that operate on ImageView objects.''' def __init__(self, view, *args, **kwargs): super(ImageViewCallback, self).__init__(*args, **kwargs) self.view = view + class ParentViewPanCallback(ImageViewCallback): '''A callback to pan an image based on a click in another image.''' def __init__(self, child, parent, *args, **kwargs): @@ -214,7 +216,8 @@ def handle_event(self, event): def connect(self): super(ParentViewPanCallback, self).connect(registry=self.view, event='button_press_event') - + + class ImageViewKeyboardHandler(ImageViewCallback): '''Default handler for keyboard events in an ImageView.''' def __init__(self, view, *args, **kwargs): @@ -250,11 +253,11 @@ def handle_event(self, event): print('key = %s' % event.key) kp = KeyParser(event.key) key = kp.key - + #----------------------------------------------------------- # Handling for keyboard input related to class ID assignment #----------------------------------------------------------- - + if key is None and kp.mods_are('shift') and \ self.view.selector is not None: # Rectangle selector is active while shift key is pressed @@ -299,7 +302,7 @@ def handle_event(self, event): if len(self.idstr) > 0: self.idstr = '' print('Cancelled class ID assignment.') - + #----------------------------------------------------------- # General keybinds #----------------------------------------------------------- @@ -354,12 +357,14 @@ def print_help(self): print('See matplotlib imshow documentation for addition key binds.') print() + class KeyParser(object): '''Class to handle ambiguities in matplotlib event key values.''' aliases = {'ctrl': ['ctrl', 'control'], 'alt': ['alt'], 'shift': ['shift'], 'super': ['super']} + def __init__(self, key_str=None): self.reset() if key_str is not None: @@ -368,7 +373,7 @@ def __init__(self, key_str=None): def reset(self): self.key = None self.modifiers = set() - + def parse(self, key_str): '''Extracts the key value and modifiers from a string.''' self.reset() @@ -388,7 +393,7 @@ def parse(self, key_str): self.key = tokens[-1] def has_mod(self, m): - '''Returns True if `m` is one of the modifiers.''' + '''Returns True if `m` is one of the modifiers.''' return m in self.modifiers def mods_are(self, *args): @@ -397,7 +402,7 @@ def mods_are(self, *args): if a not in self.modifiers: return False return True - + def get_token_modifiers(self, token): mods = set() for (modifier, aliases) in list(self.aliases.items()): @@ -405,6 +410,7 @@ def get_token_modifiers(self, token): mods.add(modifier) return mods + class ImageViewMouseHandler(ImageViewCallback): def __init__(self, view, *args, **kwargs): super(ImageViewMouseHandler, self).__init__(view, @@ -448,6 +454,7 @@ class SpyMplEvent(object): def __init__(self, name): self.name = name + class ImageView(object): '''Class to manage events and data associated with image raster views. @@ -459,10 +466,12 @@ class ImageView(object): :meth:`show` method to display the image. The underlying image display functionality is implemented via :func:`matplotlib.pyplot.imshow`. ''' - selector_rectprops = dict(facecolor='red', edgecolor = 'black', + + selector_rectprops = dict(facecolor='red', edgecolor='black', alpha=0.5, fill=True) selector_lineprops = dict(color='black', linestyle='-', - linewidth = 2, alpha=0.5) + linewidth=2, alpha=0.5) + def __init__(self, data=None, bands=None, classes=None, source=None, **kwargs): ''' @@ -523,7 +532,7 @@ def __init__(self, data=None, bands=None, classes=None, source=None, self.selection = None self.interpolation = kwargs.get('interpolation', settings.imshow_interpolation) - + if data is not None: self.set_data(data, bands, **kwargs) if classes is not None: @@ -532,7 +541,7 @@ def __init__(self, data=None, bands=None, classes=None, source=None, self.set_source(source) self.class_colors = spectral.spy_colors - + self.spectrum_plot_fig_id = None self.parent = None self.selector = None @@ -541,7 +550,7 @@ def __init__(self, data=None, bands=None, classes=None, source=None, # Callbacks for events associated specifically with this window. self.callbacks = None - + # A sharable callback registry for related windows. If this # CallbackRegistry is set prior to calling ImageView.show (e.g., by # setting it equal to the `callbacks_common` member of another @@ -553,7 +562,7 @@ def __init__(self, data=None, bands=None, classes=None, source=None, def set_data(self, data, bands=None, **kwargs): '''Sets the data to be shown in the RGB channels. - + Arguments: `data` (ndarray or SpyImage): @@ -615,7 +624,7 @@ def set_rgb_options(self, **kwargs): if self.is_shown: self._update_data_rgb() self.refresh() - + def _update_data_rgb(self): '''Regenerates the RGB values for display.''' from .graphics import get_rgb_meta @@ -686,7 +695,7 @@ def set_source(self, source): The source for spectral data associated with the view. ''' self.source = source - + def show(self, mode=None, fignum=None): '''Renders the image data. @@ -726,7 +735,7 @@ def show(self, mode=None, fignum=None): if settings.imshow_figure_size is not None: kwargs['figsize'] = settings.imshow_figure_size plt.figure(**kwargs) - + if self.data_rgb is not None: self.show_data() if self.classes is not None: @@ -746,7 +755,7 @@ def init_callbacks(self): '''Creates the object's callback registry and default callbacks.''' from spectral import settings from matplotlib.cbook import CallbackRegistry - + self.callbacks = CallbackRegistry() # callbacks_common may have been set to a shared external registry @@ -772,7 +781,6 @@ def updater(*args, **kwargs): callback.connect() self.cb_classes_modified = callback - if settings.imshow_enable_rectangle_selector is False: return try: @@ -782,7 +790,7 @@ def updater(*args, **kwargs): button=1, useblit=True, spancoords='data', - props= \ + props=\ self.selector_rectprops, state_modifier_keys=\ {'square': None, @@ -794,6 +802,7 @@ def updater(*args, **kwargs): msg = 'Failed to create RectangleSelector object. Interactive ' \ 'pixel class labeling will be unavailable.' warnings.warn(msg) + pass def label_region(self, rectangle, class_id): '''Assigns all pixels in the rectangle to the specified class. @@ -844,8 +853,8 @@ def _select_rectangle(self, event1, event2): (c1, c2) = sorted([c1, c2]) if (r2 < 0) or (r1 >= self._image_shape[0]) or \ (c2 < 0) or (c1 >= self._image_shape[1]): - self.selection = None - return + self.selection = None + return r1 = max(r1, 0) r2 = min(r2, self._image_shape[0] - 1) c1 = max(c1, 0) @@ -856,7 +865,7 @@ def _select_rectangle(self, event1, event2): # Make the rectangle display until at least the next event self.selector.set_visible(True) self.selector.update() - + def _guess_mode(self): '''Select an appropriate display mode, based on current data.''' if self.data_rgb is not None: @@ -887,7 +896,6 @@ def show_classes(self): '''Show the class values.''' import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap, NoNorm - from spectral import get_rgb if self.class_axes is not None: msg = 'ImageView.show_classes should only be called once.' @@ -899,7 +907,7 @@ def show_classes(self): cm = ListedColormap(np.array(self.class_colors) / 255.) self._update_class_rgb() kwargs = self.imshow_class_kwargs.copy() - kwargs.update({'cmap': cm, 'norm':NoNorm(), + kwargs.update({'cmap': cm, 'norm': NoNorm(), 'interpolation': self._interpolation}) if self.axes is not None: # A figure has already been created for the view. Make it current. @@ -912,7 +920,7 @@ def show_classes(self): self.class_axes.set_alpha(self._class_alpha) else: self.class_axes.set_alpha(1) - #self.class_axes.axes.set_axis_bgcolor('black') + # self.class_axes.axes.set_axis_bgcolor('black') def refresh(self): '''Updates the displayed data (if it has been shown).''' @@ -932,10 +940,10 @@ def refresh(self): def _update_class_rgb(self): if self.display_mode == 'overlay': - self.class_rgb = np.ma.array(self.classes, mask=(self.classes==0)) + self.class_rgb = np.ma.array(self.classes, mask=(self.classes == 0)) else: self.class_rgb = np.array(self.classes) - + def set_display_mode(self, mode): '''`mode` must be one of ("data", "classes", "overlay").''' if mode not in ('data', 'classes', 'overlay'): @@ -1072,7 +1080,6 @@ def zoom(self, scale): self.axes.set_ylim(y - dy, y + dy) self.refresh() - def format_coord(self, x, y): '''Formats pixel coordinate string displayed in the window.''' (nrows, ncols) = self._image_shape @@ -1092,7 +1099,7 @@ def __str__(self): s = 'ImageView object:\n' if 'bands' in meta: s += ' {0:<20}: {1}\n'.format("Display bands", meta['bands']) - if self.interpolation == None: + if self.interpolation is None: interp = "" else: interp = self.interpolation @@ -1105,7 +1112,7 @@ def __str__(self): def __repr__(self): return str(self) - + def imshow(data=None, bands=None, classes=None, source=None, colors=None, figsize=None, fignum=None, title=None, **kwargs): @@ -1189,7 +1196,7 @@ def imshow(data=None, bands=None, classes=None, source=None, colors=None, >>> classes = open_image('92AV3GT.GIS').read_band(0) >>> cview = imshow(classes=classes) - + Overlay ground truth data on the data display: >>> view.set_classes(classes) @@ -1211,8 +1218,6 @@ def imshow(data=None, bands=None, classes=None, source=None, colors=None, keyboard commands accepted by the display. ''' import matplotlib.pyplot as plt - from spectral import settings - from .graphics import get_rgb set_mpl_interactive() @@ -1234,7 +1239,7 @@ def imshow(data=None, bands=None, classes=None, source=None, colors=None, if title is not None: view.set_title(title) return view - + def plot(data, source=None): ''' @@ -1282,10 +1287,10 @@ def plot(data, source=None): plt.xlabel(xlabel) return p + def set_mpl_interactive(): '''Ensure matplotlib is in interactive mode.''' import matplotlib.pyplot as plt if not plt.isinteractive(): plt.interactive(True) - diff --git a/spectral/graphics/spywxpython.py b/spectral/graphics/spywxpython.py index cb2cb61..e01e5bb 100644 --- a/spectral/graphics/spywxpython.py +++ b/spectral/graphics/spywxpython.py @@ -16,7 +16,6 @@ class SpyWxPythonThreadStarter: def start(self): '''Starts the GUI thread.''' import _thread - import time _thread.start_new_thread(self.run, ()) def run(self): diff --git a/spectral/image.py b/spectral/image.py index be3449d..311332e 100644 --- a/spectral/image.py +++ b/spectral/image.py @@ -9,6 +9,7 @@ from .spectral import BandInfo + class Image(object): '''spectral.Image is the common base class for spectral image objects.''' @@ -17,18 +18,15 @@ def __init__(self, params, metadata=None): self.set_params(params, metadata) def set_params(self, params, metadata): - try: - self.nbands = params.nbands - self.nrows = params.nrows - self.ncols = params.ncols - self.dtype = params.dtype + self.nbands = params.nbands + self.nrows = params.nrows + self.ncols = params.ncols + self.dtype = params.dtype - if not metadata: - self.metadata = {} - else: - self.metadata = metadata - except: - raise + if not metadata: + self.metadata = {} + else: + self.metadata = metadata def params(self): '''Return an object containing the SpyFile parameters.''' @@ -72,7 +70,7 @@ def __init__(self, data, spyfile): Image.__init__(self, params, spyfile.metadata) self.bands = spyfile.bands self.filename = spyfile.filename - self.interleave = 2 # bip + self.interleave = 2 # bip def __repr__(self): lst = np.array2string(np.asarray(self), prefix="ImageArray(") @@ -208,8 +206,7 @@ def __array_wrap__(self, out_arr, context=None): def __getattribute__(self, name): if ((name in np.ndarray.__dict__) and - (name not in ImageArray.__dict__)): + (name not in ImageArray.__dict__)): return getattr(np.asarray(self), name) return super(ImageArray, self).__getattribute__(name) - diff --git a/spectral/io/__init__.py b/spectral/io/__init__.py index 6857054..d632fe2 100644 --- a/spectral/io/__init__.py +++ b/spectral/io/__init__.py @@ -1,3 +1,5 @@ +# flake8: noqa + from __future__ import absolute_import, division, print_function, unicode_literals from .spyfile import SpyFile diff --git a/spectral/io/aviris.py b/spectral/io/aviris.py index e9e4a75..1f1693f 100644 --- a/spectral/io/aviris.py +++ b/spectral/io/aviris.py @@ -5,7 +5,6 @@ from __future__ import absolute_import, division, print_function, unicode_literals import numpy as np -import glob import os import spectral as spy @@ -19,6 +18,7 @@ else: import __builtin__ as builtins + def open(file, band_file=None): ''' Returns a SpyFile object for an AVIRIS image file. @@ -96,4 +96,3 @@ def read_aviris_bands(cal_filename): bands.bandwidth_stdevs = columns[3] bands.band_unit = 'nm' return bands - diff --git a/spectral/io/bilfile.py b/spectral/io/bilfile.py index ec3f641..4c138a8 100644 --- a/spectral/io/bilfile.py +++ b/spectral/io/bilfile.py @@ -11,7 +11,7 @@ import sys import spectral as spy -from ..utilities.python23 import typecode, tobytes, frombytes +from ..utilities.python23 import typecode, tobytes from .spyfile import SpyFile, MemmapFile byte_typecode = typecode('b') @@ -58,7 +58,7 @@ def read_band(self, band, use_memmap=True): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -103,7 +103,7 @@ def read_bands(self, bands, use_memmap=True): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -154,7 +154,7 @@ def read_pixel(self, row, col, use_memmap=True): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -168,7 +168,6 @@ def read_pixel(self, row, col, use_memmap=True): return data vals = array.array(byte_typecode) - delta = self.sample_size * (self.nbands - 1) offset = self.offset + row * self.nbands * self.ncols \ * self.sample_size + col * self.sample_size f = self.fid @@ -211,7 +210,7 @@ def read_subregion(self, row_bounds, col_bounds, bands=None, Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -291,7 +290,7 @@ def read_subimage(self, rows, cols, bands=None, use_memmap=False): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -325,8 +324,6 @@ def read_subimage(self, rows, cols, bands=None, use_memmap=False): bands = list(range(self.nbands)) nSubBands = len(bands) - arr = np.empty((nSubRows, nSubCols, nSubBands), self.dtype) - offset = self.offset vals = array.array(byte_typecode) sample_size = self.sample_size @@ -361,7 +358,7 @@ def read_datum(self, i, j, k, use_memmap=True): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Using this function is not an efficient way to iterate over bands or pixels. For such cases, use readBands or readPixel instead. ''' diff --git a/spectral/io/bipfile.py b/spectral/io/bipfile.py index 9ee92e1..fa4f4b3 100644 --- a/spectral/io/bipfile.py +++ b/spectral/io/bipfile.py @@ -12,10 +12,11 @@ import spectral as spy from .spyfile import SpyFile, MemmapFile -from spectral.utilities.python23 import typecode, tobytes, frombytes +from spectral.utilities.python23 import typecode, tobytes byte_typecode = typecode('b') + class BipFile(SpyFile, MemmapFile): ''' A class to interface image files stored with bands interleaved by pixel. @@ -55,7 +56,7 @@ def read_band(self, band, use_memmap=True): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -104,7 +105,7 @@ def read_bands(self, bands, use_memmap=True): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -159,7 +160,7 @@ def read_pixel(self, row, col, use_memmap=True): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -211,7 +212,7 @@ def read_subregion(self, row_bounds, col_bounds, bands=None, Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -237,7 +238,6 @@ def read_subregion(self, row_bounds, col_bounds, bands=None, d_row = self.sample_size * self.ncols * self.nbands colStartPos = col_bounds[0] * self.sample_size * self.nbands vals = array.array(byte_typecode) - nVals = self.nrows * self.ncols sample_size = self.sample_size # Increments between bands @@ -300,7 +300,7 @@ def read_subimage(self, rows, cols, bands=None, use_memmap=False): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -319,14 +319,12 @@ def read_subimage(self, rows, cols, bands=None, use_memmap=False): return data offset = self.offset - nbands = self.nbands nSubRows = len(rows) # Rows in sub-image nSubCols = len(cols) # Cols in sub-image d_band = self.sample_size d_col = d_band * self.nbands d_row = d_col * self.ncols vals = array.array(byte_typecode) - nVals = self.nrows * self.ncols sample_size = self.sample_size # Increments between bands @@ -375,7 +373,7 @@ def read_datum(self, i, j, k, use_memmap=True): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Using this function is not an efficient way to iterate over bands or pixels. For such cases, use readBands or readPixel instead. ''' diff --git a/spectral/io/bsqfile.py b/spectral/io/bsqfile.py index 8b1a021..dc8ae29 100644 --- a/spectral/io/bsqfile.py +++ b/spectral/io/bsqfile.py @@ -11,7 +11,7 @@ import sys import spectral as spy -from ..utilities.python23 import typecode, tobytes, frombytes +from ..utilities.python23 import typecode, tobytes from .spyfile import SpyFile, MemmapFile byte_typecode = typecode('b') @@ -43,7 +43,6 @@ def _open_memmap(self, mode): else: return None - def read_band(self, band, use_memmap=True): '''Reads a single band from the image. @@ -58,7 +57,7 @@ def read_band(self, band, use_memmap=True): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -102,7 +101,7 @@ def read_bands(self, bands, use_memmap=False): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -152,7 +151,7 @@ def read_pixel(self, row, col, use_memmap=True): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -166,14 +165,10 @@ def read_pixel(self, row, col, use_memmap=True): return data vals = array.array(byte_typecode) - delta = self.sample_size * (self.nbands - 1) - offset = self.offset + row * self.nbands * self.ncols \ - * self.sample_size + col * self.sample_size f = self.fid nPixels = self.nrows * self.ncols - ncols = self.ncols sampleSize = self.sample_size bandSize = sampleSize * nPixels rowSize = sampleSize * self.ncols @@ -216,7 +211,7 @@ def read_subregion(self, row_bounds, col_bounds, bands=None, Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -255,8 +250,6 @@ def read_subregion(self, row_bounds, col_bounds, bands=None, bandSize = nrows * ncols * sampleSize colStartOffset = col_bounds[0] * sampleSize rowSize = ncols * sampleSize - rowStartOffset = row_bounds[0] * rowSize - nSubBands = len(bands) # Pixel format is BSQ for i in bands: @@ -301,7 +294,7 @@ def read_subimage(self, rows, cols, bands=None, use_memmap=False): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Returns: :class:`numpy.ndarray` @@ -322,9 +315,6 @@ def read_subimage(self, rows, cols, bands=None, use_memmap=False): nSubRows = len(rows) # Rows in sub-image nSubCols = len(cols) # Cols in sub-image - d_col = self.sample_size - d_band = d_col * self.ncols - d_row = d_band * self.nbands f = self.fid f.seek(self.offset, 0) @@ -379,7 +369,7 @@ def read_datum(self, i, j, k, use_memmap=True): Specifies whether the file's memmap interface should be used to read the data. Setting this arg to True only has an effect if a memmap is being used (i.e., if `img.using_memmap` is True). - + Using this function is not an efficient way to iterate over bands or pixels. For such cases, use readBands or readPixel instead. ''' diff --git a/spectral/io/envi.py b/spectral/io/envi.py index fc00d2b..4fa4b47 100644 --- a/spectral/io/envi.py +++ b/spectral/io/envi.py @@ -32,7 +32,6 @@ InvalidFileError, SpyFile) - if IS_PYTHON3: import builtins else: @@ -58,10 +57,12 @@ envi_to_dtype = dict((k, np.dtype(v).char) for (k, v) in dtype_map) dtype_to_envi = dict(tuple(reversed(item)) for item in list(envi_to_dtype.items())) + class EnviException(SpyException): '''Base class for ENVI file-related exceptions.''' pass + class EnviDataTypeError(EnviException, TypeError): '''Raised when saving invalid image data type to ENVI format. ''' @@ -71,41 +72,49 @@ def __init__(self, dtype): 'data type names.'.format(np.dtype(dtype).name) super(EnviDataTypeError, self).__init__(msg) + class EnviFeatureNotSupported(EnviException, NotImplementedError): '''A specified ENVI capability is not supported by the spectral module.''' pass + class FileNotAnEnviHeader(EnviException, InvalidFileError): '''Raised when "ENVI" does not appear on the first line of the file.''' def __init__(self, msg): super(FileNotAnEnviHeader, self).__init__(msg) + class MissingEnviHeaderParameter(EnviException): '''Raised when a mandatory header parameter is missing.''' def __init__(self, param): msg = 'Mandatory parameter "%s" missing from header file.' % param super(MissingEnviHeaderParameter, self).__init__(msg) + class EnviHeaderParsingError(EnviException, InvalidFileError): '''Raised upon failure to parse parameter/value pairs from a file.''' def __init__(self): msg = 'Failed to parse ENVI header file.' super(EnviHeaderParsingError, self).__init__(msg) + class EnviDataFileNotFoundError(EnviException, FileNotFoundError): '''Raised when data file associated with a header is not found.''' pass + def _validate_dtype(dtype): '''Raises EnviDataTypeError if dtype can not be written to ENVI file.''' typename = np.dtype(dtype).name if typename not in [np.dtype(t).name for t in list(dtype_to_envi.keys())]: raise EnviDataTypeError(dtype) + def get_supported_dtypes(): '''Returns list of names of image data types supported by ENVI format.''' return [np.dtype(t).name for t in list(dtype_to_envi.keys())] + def read_envi_header(file): ''' USAGE: hdr = read_envi_header(file) @@ -139,8 +148,10 @@ def read_envi_header(file): try: while lines: line = lines.pop(0) - if line.find('=') == -1: continue - if line[0] == ';': continue + if line.find('=') == -1: + continue + if line[0] == ';': + continue (key, sep, val) = line.partition('=') key = key.strip() @@ -153,7 +164,8 @@ def read_envi_header(file): str = val.strip() while str[-1] != '}': line = lines.pop(0) - if line[0] == ';': continue + if line[0] == ';': + continue str += '\n' + line.strip() if key == 'description': @@ -209,6 +221,7 @@ class Params: p.filename = None return p + def _has_frame_offset(params): ''' Returns True if header params indicate non-zero frame offsets. @@ -237,6 +250,7 @@ def _has_frame_offset(params): return True return False + def check_compatibility(header): ''' Verifies that all features of an ENVI header are supported. @@ -254,6 +268,7 @@ def check_compatibility(header): raise EnviFeatureNotSupported( 'ENVI image frame offsets are not supported.') + def open(file, image=None): ''' Opens an image or spectral library with an associated ENVI HDR header file. @@ -341,12 +356,12 @@ def open(file, image=None): try: img.bands.centers = [float(b) for b in h['wavelength']] except: - pass + logger.warning('Unable to parse "wavelength" field from header') if 'fwhm' in h: try: img.bands.bandwidths = [float(f) for f in h['fwhm']] except: - pass + logger.warning('Unable to parse "fwhm" field from header') img.bands.band_unit = h.get('wavelength units', None) if 'bbl' in h: @@ -512,13 +527,13 @@ def save_classification(hdr_file, image, **kwargs): For classification results, specifies the names to assign each integer in the class map being written. If not given, default - class names are created. + class names are created. `class_colors` (array of RGB-tuples): For classification results, specifies colors to assign each integer in the class map being written. If not given, default - colors are automatically generated. + colors are automatically generated. If the source image being saved was already in ENVI format, then the SpyFile object for that image will contain a `metadata` dict that can be @@ -536,14 +551,14 @@ class names are created. # guess the number of classes and create default class names n_classes = int(np.max(data) + 1) metadata['classes'] = str(n_classes) - metadata['class names'] = (['Unclassified'] + + metadata['class names'] = (['Unclassified'] + ['Class ' + str(i) for i in range(1, n_classes)]) # if keyword is given, override whatever is in the metadata dict else: n_classes = int(max(np.max(data) + 1, len(class_names))) metadata['class names'] = class_names metadata['classes'] = str(n_classes) - + # the resulting value for 'class lookup' needs to be a flattened array. colors = [] if class_colors is not None: @@ -562,10 +577,9 @@ class names are created. _write_image(hdr_file, data, metadata, **kwargs) + def _prepared_data_and_metadata(hdr_file, image, **kwargs): - ''' - Return data array and metadata dict representing `image`. - ''' + '''Return data array and metadata dict representing `image`.''' endian_out = str(kwargs.get('byteorder', sys.byteorder)).lower() if endian_out in ('0', 'little'): endian_out = 'little' @@ -579,23 +593,19 @@ def _prepared_data_and_metadata(hdr_file, image, **kwargs): src_interleave = 'bip' if len(data.shape) == 2: data = data[:, :, np.newaxis] - swap = False - metadata = {} + metadata = {} elif isinstance(image, SpyFile): if image.using_memmap is True: data = image._memmap src_interleave = {spy.BSQ: 'bsq', spy.BIL: 'bil', spy.BIP: 'bip'}[image.interleave] - swap = image.swap else: data = image.load(dtype=image.dtype, scale=False) src_interleave = 'bip' - swap = False metadata = image.metadata.copy() else: data = image.load() src_interleave = 'bip' - swap = False if hasattr(image, 'metadata'): metadata = image.metadata.copy() else: @@ -665,7 +675,7 @@ def add_band_info_to_metadata(bands, metadata, overwrite=False): if bands.band_unit is not None and (overwrite is True or 'wavelength units' not in metadata): metadata['wavelength units'] = bands.band_unit - + def _write_image(hdr_file, data, header, **kwargs): ''' @@ -674,7 +684,7 @@ def _write_image(hdr_file, data, header, **kwargs): check_compatibility(header) force = kwargs.get('force', False) img_ext = kwargs.get('ext', '.img') - + (hdr_file, img_file) = check_new_filename(hdr_file, img_ext, force) write_envi_header(hdr_file, header, is_library=False) logger.debug('Saving', img_file) @@ -784,7 +794,7 @@ def create_image(hdr_file, metadata=None, **kwargs): (hdr_file, img_file) = check_new_filename(hdr_file, img_ext, force) default_metadata = {'header offset': 0, 'interleave': 'bip'} - + if metadata is None: metadata = default_metadata else: @@ -825,7 +835,7 @@ def create_image(hdr_file, metadata=None, **kwargs): dt = np.dtype(params.dtype).char _validate_dtype(dt) params.filename = img_file - + is_library = False if metadata.get('file type') == 'ENVI Spectral Library': is_library = True @@ -968,6 +978,7 @@ def save(self, file_basename, description=None): self.spectra.astype('f').tofile(fout) fout.close() + def _write_header_param(fout, paramName, paramVal): if paramName.lower() == 'description': valStr = '{\n%s}' % '\n'.join([' ' + line for line @@ -1000,4 +1011,3 @@ def write_envi_header(fileName, header_dict, is_library=False): if k not in std_params: _write_header_param(fout, k, d[k]) fout.close() - diff --git a/spectral/io/erdas.py b/spectral/io/erdas.py index 423eb87..d95cfb2 100644 --- a/spectral/io/erdas.py +++ b/spectral/io/erdas.py @@ -61,13 +61,11 @@ import array import numpy as np -import sys import spectral as spy from ..utilities.python23 import IS_PYTHON3, typecode from .bilfile import BilFile from .spyfile import find_file_path, InvalidFileError -from .spyfile import InvalidFileError if IS_PYTHON3: import builtins @@ -75,7 +73,6 @@ import __builtin__ as builtins - def open(file): ''' Returns a SpyFile object for an ERDAS/Lan image file. @@ -205,4 +202,3 @@ def read_erdas_lan_header(fileName, byte_order=0): f.close() return h - diff --git a/spectral/io/spyfile.py b/spectral/io/spyfile.py index 8e435de..ab14871 100644 --- a/spectral/io/spyfile.py +++ b/spectral/io/spyfile.py @@ -90,16 +90,18 @@ from .. import SpyException from ..image import Image, ImageArray from ..utilities.errors import has_nan, NaNValueWarning -from ..utilities.python23 import typecode, tobytes, frombytes +from ..utilities.python23 import typecode, tobytes class FileNotFoundError(SpyException): pass + class InvalidFileError(SpyException): '''Raised when file contents are invalid for the exepected file type.''' pass + def find_file_path(filename): ''' Search cwd and SPECTRAL_DATA directories for the given file. @@ -116,7 +118,7 @@ def find_file_path(filename): if not pathname: msg = 'Unable to locate file "%s". If the file exists, ' \ 'use its full path or place its directory in the ' \ - 'SPECTRAL_DATA environment variable.' % filename + 'SPECTRAL_DATA environment variable.' % filename raise FileNotFoundError(msg) return pathname @@ -132,23 +134,19 @@ def __init__(self, params, metadata=None): def set_params(self, params, metadata): Image.set_params(self, params, metadata) - try: - self.filename = params.filename - self.offset = params.offset - self.byte_order = params.byte_order - if spy.byte_order != self.byte_order: - self.swap = 1 - else: - self.swap = 0 - self.sample_size = np.dtype(params.dtype).itemsize - - self.fid = open(find_file_path(self.filename), "rb") + self.filename = params.filename + self.offset = params.offset + self.byte_order = params.byte_order + if spy.byte_order != self.byte_order: + self.swap = 1 + else: + self.swap = 0 + self.sample_size = np.dtype(params.dtype).itemsize - # So that we can use this more like a Numeric array - self.shape = (self.nrows, self.ncols, self.nbands) + self.fid = open(find_file_path(self.filename), "rb") - except: - raise + # So that we can use this more like a Numeric array + self.shape = (self.nrows, self.ncols, self.nbands) def transform(self, xform): '''Returns a SpyFile image with the linear transform applied.''' @@ -224,7 +222,7 @@ def load(self, **kwargs): imarray = ImageArray(npArray, self) if has_nan(imarray): warnings.warn('Image data contains NaN values.', NaNValueWarning) - return imarray + return imarray def __getitem__(self, args): '''Subscripting operator that provides a numpy-like interface. @@ -310,7 +308,7 @@ def __getitem__(self, args): elif atypes[2] == slice: (zstart, zstop, zstep) = (args[2].start, args[2].stop, args[2].step) - if zstart == zstop == zstep == None: + if zstart == zstop == zstep is None: bands = None else: if zstart is None: @@ -572,6 +570,7 @@ def tile_image(im, nrows, ncols): tiles.append(row) return tiles + def transform_image(transform, img): '''Applies a linear transform to an image. @@ -746,6 +745,7 @@ def read_bands(self, bands): data[i, j] = self.read_pixel(i, j)[bands] return data + class MemmapFile(object): '''Interface class for SpyFile subclasses using `numpy.memmap` objects.''' @@ -788,7 +788,7 @@ def open_memmap(self, **kwargs): If `writable` is True, modifying values in the returned memmap will result in corresponding modification to the image data file. - ''' + ''' src_inter = {spy.BIL: 'bil', spy.BIP: 'bip', spy.BSQ: 'bsq'}[self.interleave] @@ -826,6 +826,7 @@ def asarray(self, writable=False): ''' return self.open_memmap(writable=writable) + def interleave_transpose(int1, int2): '''Returns the 3-tuple of indices to transpose between interleaves. diff --git a/spectral/spectral.py b/spectral/spectral.py index cf82cd8..bf9eb5d 100644 --- a/spectral/spectral.py +++ b/spectral/spectral.py @@ -5,18 +5,10 @@ from __future__ import absolute_import, division, print_function, unicode_literals import logging -import numbers -import numpy as np -import pickle -import os from warnings import warn -#from .algorithms.algorithms import TrainingClassSet -#from . import io -#from .io import aviris, envi, erdas, spyfile -#from .io.spyfile import find_file_path, SpyFile +from . import settings # noqa: F401 -from . import settings def _init(): '''Basic configuration of the spectral package.''' @@ -36,6 +28,7 @@ def _init(): spectral = __import__(__name__.split('.')[0]) spectral._status = status.StatusDisplay() + def _setup_logger(): logger = logging.getLogger('spectral') logger.setLevel(logging.INFO) @@ -44,6 +37,7 @@ def _setup_logger(): ch.setFormatter(formatter) logger.addHandler(ch) + class BandInfo: '''A BandInfo object characterizes the spectral bands associated with an image. All BandInfo member variables are optional. For *N* bands, all @@ -129,4 +123,3 @@ def load_training_sets(file, image=None): ts = TrainingClassSet() ts.load(file, image) return ts - diff --git a/spectral/tests/__init__.py b/spectral/tests/__init__.py index 94d2a41..dc79b2b 100644 --- a/spectral/tests/__init__.py +++ b/spectral/tests/__init__.py @@ -6,6 +6,8 @@ # python -m spectral.tests.run ''' +# flake8: noqa + from __future__ import absolute_import, division, print_function, unicode_literals # If abort_on_fail is True, an AssertionError will be raised when a unit test diff --git a/spectral/tests/classifiers.py b/spectral/tests/classifiers.py index d71f54c..09640cc 100644 --- a/spectral/tests/classifiers.py +++ b/spectral/tests/classifiers.py @@ -12,9 +12,10 @@ import numpy as np import spectral as spy from numpy.testing import assert_allclose -from .spytest import SpyTest, test_method +from .spytest import SpyTest from spectral.tests import testdir + class ClassifierTest(SpyTest): '''Tests various classfication functions.''' @@ -42,7 +43,7 @@ def test_load_training_sets(self): for id in ids: s1 = ts[id] s2 = ts2[id] - assert(s1.index == s2.index) + assert (s1.index == s2.index) np.testing.assert_almost_equal(s1.class_prob, s2.class_prob) assert_allclose(s1.stats.mean, s2.stats.mean) assert_allclose(s1.stats.cov, s2.stats.cov) @@ -52,33 +53,33 @@ def test_gmlc_spectrum_image_equal(self): '''Tests that classification of spectrum is same as from image.''' gmlc = spy.GaussianClassifier(self.ts, min_samples=600) data = self.data[20: 30, 30: 40, :] - assert(gmlc.classify_spectrum(data[2, 2]) == \ - gmlc.classify_image(data)[2, 2]) + assert (gmlc.classify_spectrum(data[2, 2]) == + gmlc.classify_image(data)[2, 2]) def test_gmlc_classify_spyfile_runs(self): '''Tests that GaussianClassifier classifies a SpyFile object.''' gmlc = spy.GaussianClassifier(self.ts, min_samples=600) - ret = gmlc.classify_image(self.image) + gmlc.classify_image(self.image) def test_gmlc_classify_transformedimage_runs(self): '''Tests that GaussianClassifier classifies a TransformedImage object.''' pc = spy.principal_components(self.data).reduce(num=3) ximg = pc.transform(self.image) ts = spy.create_training_classes(pc.transform(self.data), self.gt, - calc_stats=True) + calc_stats=True) gmlc = spy.GaussianClassifier(ts) - ret = gmlc.classify_image(ximg) + gmlc.classify_image(ximg) def test_gmlc_classify_ndarray_transformedimage_equal(self): '''Gaussian classification of an ndarray and TransformedImage are equal''' pc = spy.principal_components(self.data).reduce(num=3) ximg = pc.transform(self.image) ts = spy.create_training_classes(pc.transform(self.data), self.gt, - calc_stats=True) + calc_stats=True) gmlc = spy.GaussianClassifier(ts) cl_ximg = gmlc.classify_image(ximg) cl_ndarray = gmlc.classify_image(pc.transform(self.data)) - assert(np.all(cl_ximg == cl_ndarray)) + assert (np.all(cl_ximg == cl_ndarray)) def test_mahalanobis_class_mean(self): '''Test that a class's mean spectrum is classified as that class. @@ -86,38 +87,38 @@ def test_mahalanobis_class_mean(self): ''' mdc = spy.MahalanobisDistanceClassifier(self.ts) cl = mdc.classes[0] - assert(mdc.classify(cl.stats.mean) == cl.index) + assert (mdc.classify(cl.stats.mean) == cl.index) def test_mahalanobis_classify_spyfile_runs(self): '''Mahalanobis classifier works with a SpyFile object.''' mdc = spy.MahalanobisDistanceClassifier(self.ts) - ret = mdc.classify_image(self.image) + mdc.classify_image(self.image) def test_mahalanobis_classify_transformedimage_runs(self): '''Mahalanobis classifier works with a TransformedImage object.''' pc = spy.principal_components(self.data).reduce(num=3) ximg = pc.transform(self.image) ts = spy.create_training_classes(pc.transform(self.data), self.gt, - calc_stats=True) + calc_stats=True) gmlc = spy.MahalanobisDistanceClassifier(ts) - ret = gmlc.classify_image(ximg) + gmlc.classify_image(ximg) def test_mahalanobis_classify_ndarray_transformedimage_equal(self): '''Mahalanobis classification of ndarray and TransformedImage are equal''' pc = spy.principal_components(self.data).reduce(num=3) ximg = pc.transform(self.image) ts = spy.create_training_classes(pc.transform(self.data), self.gt, - calc_stats=True) + calc_stats=True) mdc = spy.GaussianClassifier(ts) cl_ximg = mdc.classify_image(ximg) cl_ndarray = mdc.classify_image(pc.transform(self.data)) - assert(np.all(cl_ximg == cl_ndarray)) + assert (np.all(cl_ximg == cl_ndarray)) def test_perceptron_learns_and(self): '''Test that 2x1 network can learn the logical AND function.''' from spectral.algorithms.perceptron import test_and (success, p) = test_and(stdout=None) - assert(success) + assert (success) def test_perceptron_learns_xor(self): '''Test that 2x2x1 network can learn the logical XOR function.''' @@ -127,7 +128,7 @@ def test_perceptron_learns_xor(self): (success, p) = test_xor231(3000, stdout=None) if success is True: return - assert(False) + assert (False) def test_perceptron_learns_xor_222(self): '''Test that 2x2x2 network can learn the logical XOR function.''' @@ -137,7 +138,7 @@ def test_perceptron_learns_xor_222(self): (success, p) = test_xor222(3000, stdout=None) if success is True: return - assert(False) + assert (False) def test_perceptron_learns_image_classes(self): '''Test that perceptron can learn image class means.''' @@ -152,14 +153,14 @@ def test_perceptron_learns_image_classes(self): rate=0.3) if success is True: return - assert(False) + assert (False) def test_mahalanobis_spectrum_image_equal(self): '''Tests that classification of spectrum is same as from image.''' mdc = spy.MahalanobisDistanceClassifier(self.ts) data = self.data[20: 30, 30: 40, :] - assert(mdc.classify_spectrum(data[2, 2]) == \ - mdc.classify_image(data)[2, 2]) + assert (mdc.classify_spectrum(data[2, 2]) == + mdc.classify_image(data)[2, 2]) def run(): @@ -169,6 +170,7 @@ def run(): test = ClassifierTest() test.run() + if __name__ == '__main__': from spectral.tests.run import parse_args, reset_stats, print_summary parse_args() diff --git a/spectral/tests/continuum.py b/spectral/tests/continuum.py index af93a92..2c70922 100644 --- a/spectral/tests/continuum.py +++ b/spectral/tests/continuum.py @@ -12,7 +12,6 @@ from numpy.testing import assert_allclose import spectral as spy -from spectral.algorithms.spymath import matrix_sqrt from spectral.algorithms.continuum import spectral_continuum, remove_continuum, continuum_points from spectral.tests.spytest import SpyTest @@ -152,26 +151,26 @@ def test_2d_array(self): cnt = spectral_continuum(part, self.bands) # Check some values to make sure results are sane. assert(cnt[0, 200] == 1422) - assert(cnt[1, 200] == 1421) - assert(cnt[2, 200] == 1469) - assert(cnt[3, 200] == 1491) + assert (cnt[1, 200] == 1421) + assert (cnt[2, 200] == 1469) + assert (cnt[3, 200] == 1491) def test_3d_array(self): part = self.image[20:22, 20:22] cnt = spectral_continuum(part, self.bands) # Check some values to make sure results are sane. - assert(cnt[0, 0, 200] == 1422) - assert(cnt[0, 1, 200] == 1421) - assert(cnt[1, 0, 200] == 1469) - assert(cnt[1, 1, 200] == 1491) + assert (cnt[0, 0, 200] == 1422) + assert (cnt[0, 1, 200] == 1421) + assert (cnt[1, 0, 200] == 1469) + assert (cnt[1, 1, 200] == 1491) def test_out_parameter(self): part = self.image[20:22, 20:22] out = np.empty_like(part) cnt = spectral_continuum(part, self.bands, out=out) - assert(cnt is out) + assert (cnt is out) # And just do a quick check if result is sane. - assert(out[1, 1, 200] == 1491) + assert (out[1, 1, 200] == 1491) class FindContinuumPointsTest(ContinuumTest): @@ -179,8 +178,8 @@ class FindContinuumPointsTest(ContinuumTest): def test_points_of_real_spectrum(self): points = continuum_points(self.image[20, 20], self.bands) - assert(np.array_equal(points[0], self.bands[[0, 1, 2, 5, 6, 41, 219]])) - assert(np.array_equal(points[1], np.array( + assert (np.array_equal(points[0], self.bands[[0, 1, 2, 5, 6, 41, 219]])) + assert (np.array_equal(points[1], np.array( [3505, 4141, 4516, 4924, 5002, 4712, 1019], dtype=np.int16))) def test_points_of_real_spectrum_segmented(self): @@ -199,8 +198,8 @@ def test_points_of_real_spectrum_segmented(self): 1333, 1317, 1312, 1219, 1202, 1162, 1126, 1122, 1100, 1068, 1019], dtype=np.int16) ) - assert(np.array_equal(points[0], expected_result[0])) - assert(np.array_equal(points[1], expected_result[1])) + assert (np.array_equal(points[0], expected_result[0])) + assert (np.array_equal(points[1], expected_result[1])) class RemoveContinuumTest(ContinuumTest): @@ -232,8 +231,8 @@ def test_in_and_out_same(self): part = self.image[20:22, 20:22].astype(np.float64) res = remove_continuum(part, self.bands, out=part) # Make sure results are sane. - assert(res[1, 1, 200] == 0.8372113957762342) - assert(res is part) + assert (res[1, 1, 200] == 0.8372113957762342) + assert (res is part) def run(): diff --git a/spectral/tests/database.py b/spectral/tests/database.py index 4a27a39..3ed4a53 100644 --- a/spectral/tests/database.py +++ b/spectral/tests/database.py @@ -11,7 +11,6 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import numpy as np import os from numpy.testing import assert_almost_equal @@ -45,7 +44,7 @@ def test_create_database(self): '''Test creating new database from ECOSTRESS data files.''' db = spy.EcostressDatabase.create(ECOSTRESS_DB, ECOSTRESS_DATA_DIR) - assert(list(db.query('SELECT COUNT() FROM Spectra'))[0][0] == 3) + assert (list(db.query('SELECT COUNT() FROM Spectra'))[0][0] == 3) class ECOSTRESSDatabaseTest(SpyTest): @@ -59,7 +58,7 @@ def setup(self): def test_read_signatures(self): '''Can get spectra from the opened database.''' - assert(list(self.db.query('SELECT COUNT() FROM Spectra'))[0][0] == 3) + assert (list(self.db.query('SELECT COUNT() FROM Spectra'))[0][0] == 3) def test_create_envi_lib(self): '''Can resample spectra and create an ENVI spectral library.''' @@ -69,7 +68,7 @@ def test_create_envi_lib(self): bands.centers = [x / 1000. for x in bands.centers] bands.bandwidths = [x / 1000. for x in bands.bandwidths] slib = self.db.create_envi_spectral_library(ids, bands) - assert(slib.spectra.shape == (3, 220)) + assert (slib.spectra.shape == (3, 220)) class USGSDatabaseCreationTest(SpyTest): @@ -87,8 +86,8 @@ def setup(self): def test_create_database(self): '''Test creating new database from USGS data files.''' db = spy.USGSDatabase.create(USGS_DB, USGS_DATA_DIR) - assert(list(db.query('SELECT COUNT() FROM Samples'))[0][0] == 8) - assert(list(db.query('SELECT COUNT() FROM SpectrometerData')) + assert (list(db.query('SELECT COUNT() FROM Samples'))[0][0] == 8) + assert (list(db.query('SELECT COUNT() FROM SpectrometerData')) [0][0] == 13) @@ -103,8 +102,8 @@ def setup(self): def test_read_signatures(self): '''Can get spectra from the opened database.''' - assert(list(self.db.query('SELECT COUNT() FROM Samples'))[0][0] == 8) - assert(list(self.db.query('SELECT COUNT() FROM SpectrometerData')) + assert (list(self.db.query('SELECT COUNT() FROM Samples'))[0][0] == 8) + assert (list(self.db.query('SELECT COUNT() FROM SpectrometerData')) [0][0] == 13) some_sample = list(self.db.query('''SELECT Chapter, FileName, @@ -115,9 +114,9 @@ def test_read_signatures(self): Description='Material a b0 0 ASDFRa AREF' AND Spectrometer='ASDFR' AND Purity='a' AND MeasurementType='AREF' '''))[0] - assert(some_sample[0] == 'ChapterB_b0') - assert(some_sample[1] == 'liba_Material_a_b0_0_ASDFRa_AREF.txt') - assert(some_sample[3] == 24) + assert (some_sample[0] == 'ChapterB_b0') + assert (some_sample[1] == 'liba_Material_a_b0_0_ASDFRa_AREF.txt') + assert (some_sample[3] == 24) assert_almost_equal(some_sample[4], 0.33387077) assert_almost_equal(some_sample[5], 0.51682192) @@ -126,15 +125,15 @@ def test_read_signatures(self): FROM SpectrometerData WHERE SpectrometerDataID=? ''', (some_sample[2],)))[0] - assert(some_spectrometer_data[0] == 'liba') - assert(some_spectrometer_data[1] == 13) - assert(some_spectrometer_data[2] == 'Wavelengths') - assert(some_spectrometer_data[3] == 'micrometer') - assert(some_spectrometer_data[4] == 'ASD') - assert(some_spectrometer_data[5] == 'Wavelengths ASD 0.35-2.5 um') - assert(some_spectrometer_data[6] == + assert (some_spectrometer_data[0] == 'liba') + assert (some_spectrometer_data[1] == 13) + assert (some_spectrometer_data[2] == 'Wavelengths') + assert (some_spectrometer_data[3] == 'micrometer') + assert (some_spectrometer_data[4] == 'ASD') + assert (some_spectrometer_data[5] == 'Wavelengths ASD 0.35-2.5 um') + assert (some_spectrometer_data[6] == 'liba_Wavelengths_ASD_0.35-2.5_um.txt') - assert(some_spectrometer_data[7] == 24) + assert (some_spectrometer_data[7] == 24) assert_almost_equal(some_spectrometer_data[8], 0.35) assert_almost_equal(some_spectrometer_data[9], 2.5) @@ -144,8 +143,8 @@ def test_get_spectrum(self): WHERE LibName='libc' AND Description='Material D 2 AVIRISb RTGC' '''))[0][0] (x, y) = self.db.get_spectrum(some_sample_id) - assert(len(x) == len(y)) - assert(len(y) == 7) + assert (len(x) == len(y)) + assert (len(y) == 7) assert_almost_equal(y[0], 0.010381651) assert_almost_equal(x[-1], 2.2020326) @@ -157,7 +156,7 @@ def test_create_envi_lib(self): bands.centers = [x / 1000. for x in bands.centers] bands.bandwidths = [x / 1000. for x in bands.bandwidths] slib = self.db.create_envi_spectral_library(ids, bands) - assert(slib.spectra.shape == (8, 220)) + assert (slib.spectra.shape == (8, 220)) def run(): diff --git a/spectral/tests/detectors.py b/spectral/tests/detectors.py index 26bdb21..c981cdf 100644 --- a/spectral/tests/detectors.py +++ b/spectral/tests/detectors.py @@ -9,7 +9,6 @@ from __future__ import absolute_import, division, print_function, unicode_literals import numpy as np -from numpy.testing import assert_allclose import spectral as spy from spectral.tests.spytest import SpyTest @@ -29,7 +28,7 @@ def test_mf_bg_eq_zero(self): '''Matched Filter response of background should be zero.''' (i, j) = self.target_ij np.testing.assert_approx_equal(self.mf(self.background.mean), 0) - + def test_mf_target_eq_one(self): '''Matched Filter response of target should be one.''' from spectral.algorithms.detectors import matched_filter @@ -55,7 +54,7 @@ def test_mf_windowed_target_eq_one(self): '''Windowed Matched Filter response of target pixel should be one.''' X = self.data[:10, :10, :] ij = (3, 3) - y = spy.matched_filter(X, X[ij], window=(3,7), cov=self.background.cov) + y = spy.matched_filter(X, X[ij], window=(3, 7), cov=self.background.cov) np.allclose(1, y[ij]) @@ -65,11 +64,10 @@ def setup(self): self.background = spy.calc_stats(self.data) def test_rx_bg_eq_zero(self): - from spectral.algorithms.detectors import rx, RX - d = rx(self.data) + from spectral.algorithms.detectors import rx stats = spy.calc_stats(self.data) np.testing.assert_approx_equal(rx(stats.mean, background=stats), 0) - + class ACETest(SpyTest): def setup(self): @@ -81,38 +79,38 @@ def test_ace_bg_eq_zero(self): '''ACE score of background mean should be zero.''' ij = (10, 10) y = spy.ace(self.bg.mean, self.X[ij], background=self.bg) - assert(np.allclose(0, y)) - + assert (np.allclose(0, y)) + def test_ace_pixel_target_eq_one(self): '''ACE score of target should be one for single pixel arg.''' ij = (10, 10) y = spy.ace(self.X[ij], self.X[ij], background=self.bg) - assert(np.allclose(1, y)) + assert (np.allclose(1, y)) def test_ace_novec_pixel_target_eq_one(self): '''ACE score of target should be one for single pixel arg.''' ij = (10, 10) y = spy.ace(self.X[ij], self.X[ij], background=self.bg, vectorize=False) - assert(np.allclose(1, y)) + assert (np.allclose(1, y)) def test_ace_target_eq_one(self): '''ACE score of target should be one.''' ij = (10, 10) y = spy.ace(self.X, self.X[ij], background=self.bg) - assert(np.allclose(1, y[ij])) + assert (np.allclose(1, y[ij])) def test_ace_novec_target_eq_one(self): '''ACE score (without vectorization) of target should be one.''' ij = (10, 10) y = spy.ace(self.X, self.X[ij], background=self.bg, vectorize=False) - assert(np.allclose(1, y[ij])) + assert (np.allclose(1, y[ij])) def test_ace_multi_targets_eq_one(self): '''ACE score of multiple targets should each be one.''' ij1 = (10, 10) ij2 = (3, 12) y = spy.ace(self.X, [self.X[ij1], self.X[ij2]], background=self.bg) - assert(np.allclose(1, [y[ij1][0], y[ij2][1]])) + assert (np.allclose(1, [y[ij1][0], y[ij2][1]])) def test_ace_novec_multi_targets_eq_one(self): '''ACE score of multiple targets should each be one.''' @@ -120,7 +118,7 @@ def test_ace_novec_multi_targets_eq_one(self): ij2 = (3, 12) y = spy.ace(self.X, [self.X[ij1], self.X[ij2]], background=self.bg, vectorize=False) - assert(np.allclose(1, [y[ij1][0], y[ij2][1]])) + assert (np.allclose(1, [y[ij1][0], y[ij2][1]])) def test_ace_multi_targets_bg_eq_zero(self): '''ACE score of background for multiple targets should be one.''' @@ -128,7 +126,7 @@ def test_ace_multi_targets_bg_eq_zero(self): ij2 = (3, 12) y = spy.ace(self.bg.mean, [self.X[ij1], self.X[ij2]], background=self.bg) - assert(np.allclose(0, y)) + assert (np.allclose(0, y)) def test_ace_subspace_targets_eq_one(self): '''ACE score of targets defining target subspace should each be one.''' @@ -136,7 +134,7 @@ def test_ace_subspace_targets_eq_one(self): ij2 = (3, 12) y = spy.ace(self.X, np.array([self.X[ij1], self.X[ij2]]), background=self.bg) - assert(np.allclose(1, [y[ij1], y[ij2]])) + assert (np.allclose(1, [y[ij1], y[ij2]])) def test_ace_novec_subspace_targets_eq_one(self): '''ACE score of targets defining target subspace should each be one.''' @@ -144,7 +142,7 @@ def test_ace_novec_subspace_targets_eq_one(self): ij2 = (3, 12) y = spy.ace(self.X, np.array([self.X[ij1], self.X[ij2]]), background=self.bg, vectorize=False) - assert(np.allclose(1, [y[ij1], y[ij2]])) + assert (np.allclose(1, [y[ij1], y[ij2]])) def test_ace_subspace_bg_eq_zero(self): '''ACE score of background for target subspace should be zero.''' @@ -152,13 +150,13 @@ def test_ace_subspace_bg_eq_zero(self): ij2 = (3, 12) y = spy.ace(self.bg.mean, np.array([self.X[ij1], self.X[ij2]]), background=self.bg) - assert(np.allclose(0, y)) + assert (np.allclose(0, y)) def test_ace_windowed_target_eq_one(self): '''ACE score of target for windowed background should be one.''' ij = (10, 10) - y = spy.ace(self.X, self.X[ij], window=(3,7), cov=self.bg.cov) - assert(np.allclose(1, y[ij])) + y = spy.ace(self.X, self.X[ij], window=(3, 7), cov=self.bg.cov) + assert (np.allclose(1, y[ij])) def run(): @@ -168,6 +166,7 @@ def run(): for T in [MatchedFilterTest, RXTest, ACETest]: T().run() + if __name__ == '__main__': from spectral.tests.run import parse_args, reset_stats, print_summary parse_args() diff --git a/spectral/tests/dimensionality.py b/spectral/tests/dimensionality.py index 9ebb49e..5d86b7d 100644 --- a/spectral/tests/dimensionality.py +++ b/spectral/tests/dimensionality.py @@ -11,7 +11,7 @@ import numpy as np import spectral as spy -from spectral.tests.spytest import SpyTest, test_method +from spectral.tests.spytest import SpyTest class DimensionalityTest(SpyTest): @@ -27,17 +27,17 @@ def test_mnf_all_equals_data(self): noise = spy.noise_from_diffs(data[117: 137, 85: 122, :]) mnfr = spy.mnf(signal, noise) denoised = mnfr.denoise(data, num=data.shape[-1]) - assert(np.allclose(denoised, data)) + assert (np.allclose(denoised, data)) def test_ppi(self): '''Tests that ppi function runs''' data = self.data - p = spy.ppi(data, 4) + spy.ppi(data, 4) def test_ppi_threshold(self): '''Tests that ppi function runs with threshold arg''' data = self.data - p = spy.ppi(data, 4, 10) + spy.ppi(data, 4, 10) def test_ppi_continues(self): '''Tests that running ppi with initial indices works as expected.''' @@ -47,18 +47,18 @@ def test_ppi_continues(self): np.random.set_state(s) p2 = spy.ppi(data, 2) p2 = spy.ppi(data, 2, start=p2) - assert(np.all(p == p2)) + assert (np.all(p == p2)) def test_ppi_centered(self): '''Tests that ppi with mean-subtracted data works as expected.''' data = self.data s = np.random.get_state() p = spy.ppi(data, 4) - + np.random.set_state(s) data_centered = data - spy.calc_stats(data).mean p2 = spy.ppi(data_centered, 4) - assert(np.all(p == p2)) + assert (np.all(p == p2)) def test_smacc_minimal(self): '''Tests smacc correctness on minimal example.''' @@ -69,8 +69,8 @@ def test_smacc_minimal(self): [0.0, 1.0, 1.0] ]) S, F, R = spy.smacc(H) - assert(np.allclose(np.matmul(F, S) + R, H)) - assert(np.min(F) == 0.0) + assert (np.allclose(np.matmul(F, S) + R, H)) + assert (np.min(F) == 0.0) expected_S = np.array([ # First two longer ones. [1., 1., 0.], @@ -78,7 +78,7 @@ def test_smacc_minimal(self): # First of the two shorted ones. Other can be expressed other 3. [1., 0., 0.], ]) - assert(np.array_equal(S, expected_S)) + assert (np.array_equal(S, expected_S)) def test_smacc_runs(self): '''Tests that smacc runs without additional arguments.''' @@ -87,9 +87,9 @@ def test_smacc_runs(self): S, F, R = spy.smacc(scaled_data) data_shape = scaled_data.shape H = scaled_data.reshape(data_shape[0] * data_shape[1], data_shape[2]) - assert(np.allclose(np.matmul(F, S) + R, H)) - assert(np.min(F) == 0.0) - assert(len(S.shape) == 2 and S.shape[0] == 9 and S.shape[1] == 220) + assert (np.allclose(np.matmul(F, S) + R, H)) + assert (np.min(F) == 0.0) + assert (len(S.shape) == 2 and S.shape[0] == 9 and S.shape[1] == 220) def test_smacc_min_endmembers(self): '''Tests that smacc runs with min_endmember argument.''' @@ -98,9 +98,9 @@ def test_smacc_min_endmembers(self): S, F, R = spy.smacc(scaled_data, 10) data_shape = scaled_data.shape H = scaled_data.reshape(data_shape[0] * data_shape[1], data_shape[2]) - assert(np.allclose(np.matmul(F, S) + R, H)) - assert(np.min(F) == 0.0) - assert(len(S.shape) == 2 and S.shape[0] == 10 and S.shape[1] == 220) + assert (np.allclose(np.matmul(F, S) + R, H)) + assert (np.min(F) == 0.0) + assert (len(S.shape) == 2 and S.shape[0] == 10 and S.shape[1] == 220) def test_smacc_max_residual_norm(self): '''Tests that smacc runs with max_residual_norm argument.''' @@ -109,21 +109,21 @@ def test_smacc_max_residual_norm(self): S, F, R = spy.smacc(scaled_data, 9, 0.8) data_shape = scaled_data.shape H = scaled_data.reshape(data_shape[0] * data_shape[1], data_shape[2]) - assert(np.allclose(np.matmul(F, S) + R, H)) - assert(np.min(F) == 0.0) + assert (np.allclose(np.matmul(F, S) + R, H)) + assert (np.min(F) == 0.0) residual_norms = np.einsum('ij,ij->i', R, R) - assert(np.max(residual_norms) <= 0.8) + assert (np.max(residual_norms) <= 0.8) def test_pca_runs(self): '''Should be able to compute PCs and transform data.''' data = self.data - xdata = spy.principal_components(data).transform(data) + spy.principal_components(data).transform(data) def test_pca_runs_from_stats(self): '''Should be able to pass image stats to PCA function.''' data = self.data stats = spy.calc_stats(data) - xdata = spy.principal_components(stats).transform(data) + spy.principal_components(stats).transform(data) def test_orthogonalize(self): '''Can correctly create an orthogonal basis from vectors.''' @@ -136,8 +136,8 @@ def test_orthogonalize(self): X = np.array([50 * sin_h, 75 * cos_h]) Y = spy.orthogonalize(X) - assert(np.allclose(Y.dot(Y.T), np.array([[1, 0], [0, 1]]))) - assert(np.allclose(X.dot(Y.T), np.array([[50, 0], [0, 75]]))) + assert (np.allclose(Y.dot(Y.T), np.array([[1, 0], [0, 1]]))) + assert (np.allclose(X.dot(Y.T), np.array([[50, 0], [0, 75]]))) def test_orthogonalize_subset(self): '''Can correctly create an orthogonal basis from vector subset.''' @@ -151,8 +151,8 @@ def test_orthogonalize_subset(self): # First vector in X will already be a unit vector X = np.array([sin_h, 75 * cos_h]) Y = spy.orthogonalize(X, start=1) - assert(np.allclose(Y.dot(Y.T), np.array([[1, 0], [0, 1]]))) - assert(np.allclose(X.dot(Y.T), np.array([[1, 0], [0, 75]]))) + assert (np.allclose(Y.dot(Y.T), np.array([[1, 0], [0, 1]]))) + assert (np.allclose(X.dot(Y.T), np.array([[1, 0], [0, 75]]))) def run(): @@ -162,6 +162,7 @@ def run(): test = DimensionalityTest() test.run() + if __name__ == '__main__': from spectral.tests.run import parse_args, reset_stats, print_summary parse_args() diff --git a/spectral/tests/envi.py b/spectral/tests/envi.py index 815390b..430af13 100644 --- a/spectral/tests/envi.py +++ b/spectral/tests/envi.py @@ -29,6 +29,7 @@ some Param = 0 ''' + class ENVIWriteTest(SpyTest): '''Tests that SpyFile memmap interfaces read and write properly.''' def __init__(self): @@ -37,7 +38,7 @@ def __init__(self): def setup(self): if not os.path.isdir(testdir): os.makedirs(testdir) - + def test_save_image_ndarray(self): '''Test saving an ENVI formated image from a numpy.ndarray.''' (R, B, C) = (10, 20, 30) @@ -57,7 +58,7 @@ def test_save_image_ndarray_no_ext(self): hdr_file = base + '.hdr' spy.envi.save_image(hdr_file, data, ext='') rdata = spy.open_image(hdr_file).load() - assert(np.all(data==rdata)) + assert (np.all(data == rdata)) def test_save_image_ndarray_alt_ext(self): '''Test saving an ENVI formated image with alternate extension.''' @@ -68,7 +69,7 @@ def test_save_image_ndarray_alt_ext(self): img_file = base + ext spy.envi.save_image(hdr_file, data, ext=ext) rdata = spy.envi.open(hdr_file, img_file).load() - assert(np.all(data==rdata)) + assert (np.all(data == rdata)) def test_save_image_spyfile(self): '''Test saving an ENVI formatted image from a SpyFile object.''' @@ -101,7 +102,7 @@ def test_create_image_metadata(self): img = spy.open_image(fname) img._disable_memmap() assert_almost_equal(img[r, b, c], datum) - assert(img.offset == offset) + assert (img.offset == offset) for key in md: assert key.lower() in img.metadata assert str(md[key]) == img.metadata[key.lower()] @@ -113,7 +114,7 @@ def test_create_image_keywords(self): offset = 1024 datum = 33 fname = os.path.join(testdir, 'test_create_image_keywords.hdr') - img = spy.envi.create_image(fname, shape=(R,B,C), + img = spy.envi.create_image(fname, shape=(R, B, C), interleave='bsq', dtype=np.uint16, offset=offset) @@ -124,27 +125,26 @@ def test_create_image_keywords(self): img = spy.open_image(fname) img._disable_memmap() assert_almost_equal(img[r, b, c], datum) - assert(img.offset == offset) + assert (img.offset == offset) def test_save_invalid_dtype_fails(self): - '''Should not be able to write unsupported data type to file.''' + '''Should not be able to write unsupported data type to file.''' from spectral.io.envi import EnviDataTypeError a = np.random.randint(0, 200, 900).reshape((30, 30)).astype(np.int8) - fname = os.path.join(testdir, 'test_save_invalid_dtype_fails.hdr') try: spy.envi.save_image('invalid.hdr', a) - except EnviDataTypeError as e: + except EnviDataTypeError: pass else: raise Exception('Expected EnviDataTypeError to be raised.') - + def test_save_load_classes(self): '''Verify that `envi.save_classification` saves data correctly.''' fname = os.path.join(testdir, 'test_save_load_classes.hdr') gt = spy.open_image('92AV3GT.GIS').read_band(0) spy.envi.save_classification(fname, gt, dtype=np.uint8) gt2 = spy.open_image(fname).read_band(0) - assert(np.all(gt == gt2)) + assert (np.all(gt == gt2)) def test_open_nonzero_frame_offset_fails(self): '''Opening files with nonzero frame offsets should fail.''' @@ -155,7 +155,7 @@ def test_open_nonzero_frame_offset_fails(self): fout.write('major frame offsets = 128\n') fout.close() try: - img2 = spy.envi.open(fname) + spy.envi.open(fname) except spy.envi.EnviFeatureNotSupported: pass else: @@ -170,13 +170,13 @@ def test_open_zero_frame_offset_passes(self): fout.write('major frame offsets = 0\n') fout.write('minor frame offsets = {0, 0}\n') fout.close() - img2 = spy.envi.open(fname) + spy.envi.open(fname) def test_save_nonzero_frame_offset_fails(self): '''Opening files with nonzero frame offsets should fail.''' img = spy.open_image('92AV3C.lan') fname = os.path.join(testdir, 'test_save_nonzero_frame_offset_fails.hdr') - meta = {'major frame offsets' : [128, 0]} + meta = {'major frame offsets': [128, 0]} try: spy.envi.save_image(fname, img, metadata=meta) except spy.envi.EnviFeatureNotSupported: @@ -188,7 +188,7 @@ def test_save_zero_frame_offset_passes(self): '''Opening files with nonzero frame offsets should fail.''' img = spy.open_image('92AV3C.lan') fname = os.path.join(testdir, 'test_save_zero_frame_offset_passes.hdr') - meta = {'major frame offsets' : 0} + meta = {'major frame offsets': 0} spy.envi.save_image(fname, img, metadata=meta) def test_catch_parse_error(self): @@ -200,7 +200,7 @@ def test_catch_parse_error(self): fout.write('foo = {{\n') fout.close() try: - img2 = spy.envi.open(fname) + spy.envi.open(fname) except spy.envi.EnviHeaderParsingError: pass else: @@ -211,14 +211,14 @@ def test_header_missing_mandatory_parameter_fails(self): img = spy.open_image('92AV3C.lan') fname = os.path.join(testdir, 'test_missing_param_fails.hdr') spy.envi.save_image(fname, img) - lines = [line for line in open(fname).readlines() \ + lines = [line for line in open(fname).readlines() if 'bands' not in line] fout = open(fname, 'w') for line in lines: fout.write(line) fout.close() try: - img2 = spy.envi.open(fname) + spy.envi.open(fname) except spy.envi.MissingEnviHeaderParameter: pass else: @@ -229,7 +229,7 @@ def test_param_name_converted_to_lower_case(self): header = 'mixed_case_header.hdr' open(header, 'w').write(MIXED_CASE_HEADER) h = spy.envi.read_envi_header(header) - assert('some param' in h) + assert ('some param' in h) def test_support_nonlowercase_params(self): '''By default, parameter names are converted to lower case.''' @@ -241,7 +241,7 @@ def test_support_nonlowercase_params(self): h = spy.envi.read_envi_header(header) finally: spy.settings.envi_support_nonlowercase_params = orig - assert('some Param' in h) + assert ('some Param' in h) def test_missing_ENVI_in_header_fails(self): '''FileNotAnEnviHeader should be raised if "ENVI" not on first line.''' @@ -254,7 +254,7 @@ def test_missing_ENVI_in_header_fails(self): fout.write(line) fout.close() try: - img2 = spy.envi.open(fname) + spy.envi.open(fname) except spy.envi.FileNotAnEnviHeader: pass else: @@ -267,7 +267,7 @@ def test_open_missing_data_raises_envidatafilenotfounderror(self): spy.envi.save_image(fname, img, ext='.img') os.unlink(os.path.splitext(fname)[0] + '.img') try: - img2 = spy.envi.open(fname) + spy.envi.open(fname) except spy.envi.EnviDataFileNotFoundError: pass else: @@ -282,7 +282,7 @@ def test_create_spectral_lib_with_header(self): basename = os.path.join(testdir, 'slib') slib.save(basename) slib = spy.envi.open(basename + '.hdr') - assert(slib.spectra.shape == (20, nbands)) + assert (slib.spectra.shape == (20, nbands)) def test_create_spectral_lib_without_header(self): '''Can create ENVI spectral library from numpy array without bands.''' @@ -292,7 +292,8 @@ def test_create_spectral_lib_without_header(self): basename = os.path.join(testdir, 'slib') slib.save(basename) slib = spy.envi.open(basename + '.hdr') - assert(slib.spectra.shape == (20, nbands)) + assert (slib.spectra.shape == (20, nbands)) + def run(): print('\n' + '-' * 72) @@ -301,6 +302,7 @@ def run(): write_test = ENVIWriteTest() write_test.run() + if __name__ == '__main__': from spectral.tests.run import parse_args, reset_stats, print_summary parse_args() diff --git a/spectral/tests/iterators.py b/spectral/tests/iterators.py index 1da443e..2912386 100644 --- a/spectral/tests/iterators.py +++ b/spectral/tests/iterators.py @@ -15,6 +15,7 @@ from spectral.algorithms.algorithms import iterator, iterator_ij from spectral.tests.spytest import SpyTest + class IteratorTest(SpyTest): '''Tests various math functions.''' @@ -78,7 +79,7 @@ def test_iterator_spyfile(self): pixels = data.reshape((-1, data.shape[-1])) sum = np.sum(pixels[classes == cls], 0) itsum = np.sum(np.array([x for x in iterator(self.image, self.gt, cls)]), - 0) + 0) assert_allclose(sum, itsum) def test_iterator_spyfile_nomemmap(self): @@ -100,6 +101,7 @@ def run(): test = IteratorTest() test.run() + if __name__ == '__main__': from spectral.tests.run import parse_args, reset_stats, print_summary parse_args() diff --git a/spectral/tests/memmap.py b/spectral/tests/memmap.py index ae9da53..ade4def 100644 --- a/spectral/tests/memmap.py +++ b/spectral/tests/memmap.py @@ -13,14 +13,13 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import numpy as np from numpy.testing import assert_almost_equal import spectral as spy -from spectral.io.spyfile import find_file_path, FileNotFoundError from spectral.tests import testdir from spectral.tests.spytest import SpyTest + class SpyFileMemmapTest(SpyTest): '''Tests that SpyFile memmap interfaces read and write properly.''' def __init__(self, file, datum, value, src_inter): @@ -58,18 +57,17 @@ def setup(self): def create_test_image_file(self): import os - import spectral img = spy.open_image(self.file) fname = os.path.join(testdir, 'memmap_test_%s.hdr' % self.src_inter) spy.envi.save_image(fname, img, - dtype = img.dtype, - interleave = self.src_inter, + dtype=img.dtype, + interleave=self.src_inter, force=True) self.image = spy.open_image(fname) - + def test_spyfile_has_memmap(self): - assert(self.image.using_memmap == True) + assert (self.image.using_memmap is True) def test_bip_memmap_read(self): self.create_test_image_file() @@ -90,7 +88,6 @@ def test_bsq_memmap_read(self): assert_almost_equal(mm[k, i, j], self.value) def test_bip_memmap_write(self): - from spectral import open_image self.create_test_image_file() (i, j, k) = self.datum mm = self.image.open_memmap(interleave='bip', writable=True) @@ -99,7 +96,6 @@ def test_bip_memmap_write(self): assert_almost_equal(self.image.open_memmap()[i, j, k], 2 * self.value) def test_bil_memmap_write(self): - from spectral import open_image self.create_test_image_file() (i, j, k) = self.datum mm = self.image.open_memmap(interleave='bil', writable=True) @@ -108,7 +104,6 @@ def test_bil_memmap_write(self): assert_almost_equal(self.image.open_memmap()[i, j, k], 3 * self.value) def test_bsq_memmap_write(self): - from spectral import open_image self.create_test_image_file() (i, j, k) = self.datum mm = self.image.open_memmap(interleave='bsq', writable=True) @@ -144,8 +139,6 @@ def __init__(self, filename, datum, value): def run(self): import os - import itertools - import spectral print('\n' + '-' * 72) print('Running memmap tests.') print('-' * 72) @@ -163,6 +156,7 @@ def run(): suite = SpyFileMemmapTestSuite('92AV3C.lan', (30, 40, 50), 5420.0) suite.run() + if __name__ == '__main__': from spectral.tests.run import parse_args, reset_stats, print_summary parse_args() diff --git a/spectral/tests/run.py b/spectral/tests/run.py index 38f1448..f5b2fa6 100644 --- a/spectral/tests/run.py +++ b/spectral/tests/run.py @@ -12,6 +12,7 @@ import spectral.tests + def parse_args(): parser = OptionParser() parser.add_option('-c', '--continue', dest='continue_tests', @@ -21,20 +22,23 @@ def parse_args(): (options, args) = parser.parse_args() spectral.tests.abort_on_fail = not options.continue_tests + def reset_stats(): spectral.tests._num_tests_run = 0 spectral.tests._num_tests_failed = 0 + def print_summary(): if spectral.tests._num_tests_failed > 0: - msg = '%d of %d tests FAILED.' % (spectral.tests._num_tests_failed, - spectral.tests._num_tests_run) + msg = '%d of %d tests FAILED.' % (spectral.tests._num_tests_failed, + spectral.tests._num_tests_run) else: - msg = 'All %d tests PASSED!' % spectral.tests._num_tests_run + msg = 'All %d tests PASSED!' % spectral.tests._num_tests_run print('\n' + '-' * 72) print(msg) print('-' * 72) + if __name__ == '__main__': logging.getLogger('spectral').setLevel(logging.ERROR) parse_args() diff --git a/spectral/tests/spatial.py b/spectral/tests/spatial.py index 8e83baf..882c136 100644 --- a/spectral/tests/spatial.py +++ b/spectral/tests/spatial.py @@ -10,7 +10,6 @@ import numpy as np from numpy.testing import assert_allclose -from warnings import warn import spectral as spy from spectral.algorithms.spatial import (get_window_bounds, @@ -27,21 +26,21 @@ def setup(self): self.data = spy.open_image('92AV3C.lan').load() def test_get_window_bounds(self): - assert(get_window_bounds(90, 90, 3, 7, 30, 40) == (29, 32, 37, 44)) + assert (get_window_bounds(90, 90, 3, 7, 30, 40) == (29, 32, 37, 44)) def test_get_window_bounds_border(self): - assert(get_window_bounds(90, 90, 3, 7, 0, 2) == (0, 3, 0, 7)) + assert (get_window_bounds(90, 90, 3, 7, 0, 2) == (0, 3, 0, 7)) def test_get_window_bounds_clipped(self): - assert(get_window_bounds_clipped(90, 90, 3, 7, 30, 40) \ - == (29, 32, 37, 44)) + assert (get_window_bounds_clipped(90, 90, 3, 7, 30, 40) + == (29, 32, 37, 44)) def test_get_window_bounds_clipped_border(self): - assert(get_window_bounds_clipped(90, 90, 3, 7, 0, 2) == (0, 2, 0, 6)) + assert (get_window_bounds_clipped(90, 90, 3, 7, 0, 2) == (0, 2, 0, 6)) def test_map_window(self): '''Test computing spectra average over local window.''' - f = lambda X, ij: np.mean(X.reshape((-1, X.shape[-1])), axis=0) + def f(X, ij): return np.mean(X.reshape((-1, X.shape[-1])), axis=0) X = self.data y = map_window(f, X, (3, 5), (10, 50), (20, 40)) t = np.mean(X[9:12, 18:23].reshape((-1, X.shape[-1])), axis=0) @@ -49,7 +48,7 @@ def test_map_window(self): def test_map_window_clipped(self): '''Test spatial averaging near border with clipped window.''' - f = lambda X, ij: np.mean(X.reshape((-1, X.shape[-1])), axis=0) + def f(X, ij): return np.mean(X.reshape((-1, X.shape[-1])), axis=0) X = self.data y = map_window(f, X, (3, 5), (100, None), (100, None), border='clip') t = np.mean(X[-2:, -3:].reshape((-1, X.shape[-1])), axis=0) @@ -57,7 +56,7 @@ def test_map_window_clipped(self): def test_map_window_shifted(self): '''Test spatial averaging near border with shifted window.''' - f = lambda X, ij: np.mean(X.reshape((-1, X.shape[-1])), axis=0) + def f(X, ij): return np.mean(X.reshape((-1, X.shape[-1])), axis=0) X = self.data y = map_window(f, X, (3, 5), (100, None), (100, None), border='shift') t = np.mean(X[-3:, -5:].reshape((-1, X.shape[-1])), axis=0) @@ -65,12 +64,13 @@ def test_map_window_shifted(self): def test_map_window_stepped(self): '''Test spatial averaging with non-unity row/column step sizes.''' - f = lambda X, ij: np.mean(X.reshape((-1, X.shape[-1])), axis=0) + def f(X, ij): return np.mean(X.reshape((-1, X.shape[-1])), axis=0) X = self.data y = map_window(f, X, (3, 5), (30, 60, 3), (70, 100, 4), border='shift') t = np.mean(X[32:35, 72:77].reshape((-1, X.shape[-1])), axis=0) assert_allclose(y[1, 1], t) + class MapClassesTest(SpyTest): '''Test mapping of class indices between classification images.''' @@ -82,27 +82,27 @@ def test_map_class_ids_identity(self): gt = np.array(self.gt) d = map_class_ids(gt, gt) for i in set(gt.ravel()): - assert(i in d) + assert (i in d) for (i, j) in d.items(): - assert(j == i) + assert (j == i) def test_map_class_ids_identity_unlabeled(self): '''Mapping a class image back to itself with an unlabeled class.''' gt = np.array(self.gt) d = map_class_ids(gt, gt, unlabeled=0) for i in set(gt.ravel()): - assert(i in d) + assert (i in d) for (i, j) in d.items(): - assert(j == i) + assert (j == i) def test_map_class_ids_identity_multiple_unlabeled(self): '''Mapping a class image back to itself with unlabeled classes.''' gt = np.array(self.gt) d = map_class_ids(gt, gt, unlabeled=[2, 4]) for i in set(gt.ravel()): - assert(i in d) + assert (i in d) for (i, j) in d.items(): - assert(j == i) + assert (j == i) def test_map_class_ids_isomorphic(self): '''Test map_class_ids with isomorphic classes.''' @@ -110,49 +110,47 @@ def test_map_class_ids_isomorphic(self): gt2 = gt + 1 d = map_class_ids(gt, gt2) for (i, j) in d.items(): - assert(j == i + 1) + assert (j == i + 1) def test_map_class_ids_isomorphic_background(self): '''Test map_class_ids with isomorphic classes and background arg.''' gt = np.array(self.gt) gt2 = gt + 1 d = map_class_ids(gt, gt2, unlabeled=0) - assert(d[0] == 0) + assert (d[0] == 0) d.pop(0) for (i, j) in d.items(): - assert(j == i + 1) + assert (j == i + 1) def test_map_class_ids_src_gt_dest(self): '''Test map_class_ids with more classes in source image.''' gt = np.array(self.gt) (i, j) = (100, 30) - old_label = gt[i, j] new_label = max(set(gt.ravel())) + 10 gt2 = np.array(gt) gt2[i, j] = new_label - + d = map_class_ids(gt2, gt) # There are enough pixels for each class that a new single-pixel class # should not be mapped to one of the existing classes. - assert(d[new_label] not in gt) + assert (d[new_label] not in gt) d.pop(new_label) for (i, j) in d.items(): - assert(j == i) + assert (j == i) def test_map_class_ids_dest_gt_src(self): '''Test map_class_ids with more classes in dest image.''' gt = np.array(self.gt) (i, j) = (100, 30) - old_label = gt[i, j] new_label = max(set(gt.ravel())) + 10 gt2 = np.array(gt) gt2[i, j] = new_label - + d = map_class_ids(gt, gt2) for (i, j) in d.items(): - assert(j == i) + assert (j == i) def test_map_classes_isomorphic(self): '''map_classes should map isomorphic class image back to original.''' @@ -160,8 +158,8 @@ def test_map_classes_isomorphic(self): gt2 = gt + 1 d = map_class_ids(gt2, gt) result = map_classes(gt2, d) - assert(np.alltrue(result == gt)) - + assert (np.alltrue(result == gt)) + def test_map_fails_allow_unmapped_false(self): '''map_classes should raise ValueError if image has unmapped value.''' gt = np.array(self.gt) @@ -169,12 +167,12 @@ def test_map_fails_allow_unmapped_false(self): d = map_class_ids(gt2, gt) d.pop(1) try: - result = map_classes(gt2, d) + map_classes(gt2, d) except ValueError: pass else: - assert(False) - + assert (False) + def test_map_allow_unmapped_true(self): '''map_classes should raise ValueError if image has unmapped value.''' gt = np.array(self.gt) @@ -182,8 +180,9 @@ def test_map_allow_unmapped_true(self): d = map_class_ids(gt2, gt) d.pop(1) result = map_classes(gt2, d, allow_unmapped=True) - assert(np.alltrue(result[gt2 == 1] == 1)) - + assert (np.alltrue(result[gt2 == 1] == 1)) + + def run(): print('\n' + '-' * 72) print('Running spatial tests.') @@ -191,6 +190,7 @@ def run(): for T in [SpatialWindowTest, MapClassesTest]: T().run() + if __name__ == '__main__': from spectral.tests.run import parse_args, reset_stats, print_summary parse_args() diff --git a/spectral/tests/spyfile.py b/spectral/tests/spyfile.py index 54f7889..51a842c 100644 --- a/spectral/tests/spyfile.py +++ b/spectral/tests/spyfile.py @@ -28,9 +28,11 @@ assert_almost_equal = np.testing.assert_allclose -def assert_allclose (a, b, **kwargs): + +def assert_allclose(a, b, **kwargs): np.testing.assert_allclose(np.array(a), np.array(b), **kwargs) + class SpyFileTest(SpyTest): '''Tests that SpyFile methods read data correctly from files.''' def __init__(self, file, datum, value): @@ -106,7 +108,7 @@ def test_read_bands_nonascending(self): def test_read_bands_duplicates(self): (i, j, k) = self.datum - bands = (k - 5, k - 5, k, k -5) + bands = (k - 5, k - 5, k, k - 5) assert_almost_equal(self.image.read_bands(bands, use_memmap=True)[i, j, 2], self.value) @@ -177,7 +179,7 @@ def test_load(self): spyf.read_subimage([0, 2, 4], [6, 3])) load_assert(data.read_subimage([0, 2], [6, 3], [0, 1]), spyf.read_subimage([0, 2], [6, 3], [0, 1])) - load_assert(data.read_datum(1,2,8), spyf.read_datum(1,2,8)) + load_assert(data.read_datum(1, 2, 8), spyf.read_datum(1, 2, 8)) ufunc_result = data + 1 assert isinstance(ufunc_result, np.ndarray) @@ -202,6 +204,7 @@ def test_getitem_islice_jslice(self): (i, j, k) = self.datum assert_almost_equal(self.image[i-3:i+3, j-3:j+3][3, 3, k], self.value) + def assert_same_shape_almost_equal(obj1, obj2, decimal=7, err_msg='', verbose=True): """ @@ -220,6 +223,7 @@ def assert_same_shape_almost_equal(obj1, obj2, decimal=7, err_msg='', assert_almost_equal(obj1, obj2, decimal=decimal, err_msg=err_msg, verbose=verbose) + class SpyFileTestSuite(object): '''Tests reading by byte orders, data types, and interleaves. For a specified image file name, the test suite will verify proper reading of @@ -300,6 +304,7 @@ def run(self): test = SpyFileTest(testimg, self.datum, self.value) test.run() + def create_complex_test_files(dtypes): '''Create test files with complex data''' if not os.path.isdir(testdir): @@ -315,11 +320,12 @@ def create_complex_test_files(dtypes): tests.append((fname, datum, X[datum])) return tests + def run(): tests = [('92AV3C.lan', (99, 99, 99), 2057.0)] for (fname, datum, value) in tests: try: - check = find_file_path(fname) + find_file_path(fname) suite = SpyFileTestSuite(fname, datum, value, dtypes=('i2', 'i4', 'f4', 'f8', 'c8', 'c16')) suite.run() @@ -342,12 +348,13 @@ def run(): tests = create_complex_test_files(dtypes) for (dtype, (fname, datum, value)) in zip(dtypes, tests): try: - check = find_file_path(fname) + find_file_path(fname) suite = SpyFileTestSuite(fname, datum, value, dtypes=(dtype,)) suite.run() except FileNotFoundError: print('File "%s" not found. Skipping.' % fname) + if __name__ == '__main__': run() diff --git a/spectral/tests/spymath.py b/spectral/tests/spymath.py index f5fb8f7..3be5618 100644 --- a/spectral/tests/spymath.py +++ b/spectral/tests/spymath.py @@ -22,7 +22,7 @@ class SpyMathTest(SpyTest): def setup(self): self.data = spy.open_image('92AV3C.lan').open_memmap() self.C = spy.calc_stats(self.data).cov - self.X = np.array([[2., 1.],[1., 2.]]) + self.X = np.array([[2., 1.], [1., 2.]]) def test_matrix_sqrt(self): S = matrix_sqrt(self.X) @@ -73,12 +73,12 @@ def setup(self): def test_evals_sorted(self): '''Eigenvalues should be sorted in descending order.''' - assert(np.alltrue(np.diff(self.pc.eigenvalues) <= 0)) + assert (np.alltrue(np.diff(self.pc.eigenvalues) <= 0)) def test_evecs_orthonormal(self): '''Eigenvectors should be orthonormal.''' evecs = self.pc.eigenvectors - assert(np.allclose(evecs.T.dot(evecs), np.eye(evecs.shape[0]))) + assert (np.allclose(evecs.T.dot(evecs), np.eye(evecs.shape[0]))) class LDATest(SpyTest): @@ -92,7 +92,6 @@ def test_lda_covw_whitened(self): '''cov_w should be whitened in the transformed space.''' classes = spy.create_training_classes(self.data, self.classes) fld = spy.linear_discriminant(classes) - xdata = fld.transform(self.data) classes.transform(fld.transform) fld2 = spy.linear_discriminant(classes) assert_allclose(np.eye(fld2.cov_w.shape[0]), fld2.cov_w, atol=1e-8) @@ -105,6 +104,7 @@ def run(): for T in [SpyMathTest, PCATest, LDATest]: T().run() + if __name__ == '__main__': from spectral.tests.run import parse_args, reset_stats, print_summary parse_args() diff --git a/spectral/tests/spytest.py b/spectral/tests/spytest.py index 5ae61c6..33ffdcc 100644 --- a/spectral/tests/spytest.py +++ b/spectral/tests/spytest.py @@ -10,6 +10,7 @@ from collections import Callable import sys + class SpyTest(object): '''Base class for test cases. @@ -32,14 +33,17 @@ def run(self): is called. ''' import spectral.tests as tests - from spectral.tests import abort_on_fail self.setup() + class NullStdOut(object): + def write(*args, **kwargs): pass + def flush(self): pass + null = NullStdOut() methods = [getattr(self, s) for s in sorted(dir(self)) if s.startswith('test_')] methods = [m for m in methods if isinstance(m, Callable)] @@ -61,8 +65,10 @@ def flush(self): sys.stdout = stdout self.finish() + # The following test method is now deprecated and should no longer be used. + def test_method(method): '''Decorator function for unit tests.''' import spectral.tests as tests diff --git a/spectral/tests/transforms.py b/spectral/tests/transforms.py index 45a5c69..71dc1db 100644 --- a/spectral/tests/transforms.py +++ b/spectral/tests/transforms.py @@ -145,6 +145,7 @@ def run(): test = LinearTransformTest(image, datum, value / 10000.0) test.run() + if __name__ == '__main__': from spectral.tests.run import parse_args, reset_stats, print_summary parse_args() diff --git a/spectral/utilities/__init__.py b/spectral/utilities/__init__.py index 9c59938..fe56104 100644 --- a/spectral/utilities/__init__.py +++ b/spectral/utilities/__init__.py @@ -1,2 +1 @@ from __future__ import absolute_import, division, print_function, unicode_literals - diff --git a/spectral/utilities/errors.py b/spectral/utilities/errors.py index 2835762..f912ef0 100644 --- a/spectral/utilities/errors.py +++ b/spectral/utilities/errors.py @@ -2,16 +2,20 @@ import numpy as np + class SpyException(Exception): '''Base class for spectral module-specific exceptions.''' pass + class NaNValueWarning(UserWarning): pass + class NaNValueError(ValueError): pass + def has_nan(X): '''returns True if ndarray `X` contains a NaN value.''' return bool(np.isnan(np.min(X))) diff --git a/spectral/utilities/python23.py b/spectral/utilities/python23.py index 5a28ca0..bb56fb9 100644 --- a/spectral/utilities/python23.py +++ b/spectral/utilities/python23.py @@ -8,6 +8,7 @@ IS_PYTHON3 = sys.version_info >= (3,) + def typecode(t): '''Typecode handling for array module. @@ -28,17 +29,19 @@ def typecode(t): else: return chr(ord(t)) + if IS_PYTHON3: def is_string(s): return isinstance(s, (str, bytes)) else: def is_string(s): - return isinstance(s, basestring) + return isinstance(s, basestring) # noqa: F821 + # array.tostring is deprecated in python3 if IS_PYTHON3: - tobytes = lambda array: array.tobytes() - frombytes = lambda array, src: array.frombytes(src) + def tobytes(array): return array.tobytes() + def frombytes(array, src): return array.frombytes(src) else: - tobytes = lambda array: array.tostring() - frombytes = lambda array, src: array.fromstring(src) + def tobytes(array): return array.tostring() + def frombytes(array, src): return array.fromstring(src) diff --git a/spectral/utilities/status.py b/spectral/utilities/status.py index e2ad749..941a0f5 100644 --- a/spectral/utilities/status.py +++ b/spectral/utilities/status.py @@ -7,6 +7,7 @@ import sys from .. import settings + class StatusDisplay: ''' A class to sequentially display percentage completion of an iterative