diff --git a/README.md b/README.md index 9047b49..a3c5f75 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,11 @@ GPU-accelerated stain normalization tools for histopathological images. Compatible with PyTorch, TensorFlow, and Numpy. Normalization algorithms currently implemented: -- Macenko et al. [\[1\]](#reference) (ported from [numpy implementation](https://github.com/schaugf/HEnorm_python)) -- Reinhard et al. [\[2\]](#reference) (only numpy & TensorFlow backend support) +| Algorithm | numpy | torch | tensorflow | +|-|-|-|-| +| Macenko [\[1\]](#references) | ✓ | ✓ | ✓ | +| Reinhard [\[2\]](#references)| ✓ | ✗ | ✓ | +| Vahadane [\[3\]](#references) | ✓ | ✗ | ✗ | ## Installation @@ -44,16 +47,9 @@ norm, H, E = normalizer.normalize(I=t_to_transform, stains=True) ![alt text](data/result.png) -## Implemented algorithms - -| Algorithm | numpy | torch | tensorflow | -|-|-|-|-| -| Macenko | ✓ | ✓ | ✓ | -| Reinhard | ✓ | ✗ | ✓ | - ## Backend comparison -Results with 10 runs per size on a Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz +Runtime results using Macenko with 10 runs per size on a Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz. | size | numpy avg. time | torch avg. time | tf avg. time | |--------|-------------------|-------------------|------------------| @@ -66,10 +62,11 @@ Results with 10 runs per size on a Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz | 1568 | 1.1935s ± 0.0739 | 0.2590s ± 0.0088 | 0.2531s ± 0.0031 | | 1792 | 1.4523s ± 0.0207 | 0.3402s ± 0.0114 | 0.3080s ± 0.0188 | -## Reference +## References -- [1] Macenko, Marc, et al. "A method for normalizing histology slides for quantitative analysis." 2009 IEEE International Symposium on Biomedical Imaging: From Nano to Macro. IEEE, 2009. -- [2] Reinhard, Erik, et al. "Color transfer between images." IEEE Computer Graphics and Applications. IEEE, 2001. +- [1] Macenko, Marc et al. "A method for normalizing histology slides for quantitative analysis." 2009 IEEE International Symposium on Biomedical Imaging: From Nano to Macro. IEEE, 2009. +- [2] Reinhard, Erik et al. "Color transfer between images." IEEE Computer Graphics and Applications. 2015 IEEE 12th International Symposium on Biomedical Imaging (ISBI), IEEE, 2001. +- [3] Vahadane, Abhishek et al. "Structure-preserved color normalization for histopathological images". IEEE, 2015 ## Citing diff --git a/torchstain/base/normalizers/__init__.py b/torchstain/base/normalizers/__init__.py index a8f8bb7..82f59f1 100644 --- a/torchstain/base/normalizers/__init__.py +++ b/torchstain/base/normalizers/__init__.py @@ -1,3 +1,4 @@ from .he_normalizer import HENormalizer from .macenko import MacenkoNormalizer from .reinhard import ReinhardNormalizer +from .vahadane import VahadaneNormalizer diff --git a/torchstain/base/normalizers/vahadane.py b/torchstain/base/normalizers/vahadane.py new file mode 100644 index 0000000..4defc66 --- /dev/null +++ b/torchstain/base/normalizers/vahadane.py @@ -0,0 +1,10 @@ +def VahadaneNormalizer(backend='numpy'): + if backend == 'numpy': + from torchstain.numpy.normalizers import NumpyVahadaneNormalizer + return NumpyVahadaneNormalizer() + elif backend == "torch": + raise NotImplementedError + elif backend == "tensorflow": + raise NotImplementedError + else: + raise Exception(f'Unknown backend {backend}') diff --git a/torchstain/numpy/normalizers/__init__.py b/torchstain/numpy/normalizers/__init__.py index d453cf1..9c9e61c 100644 --- a/torchstain/numpy/normalizers/__init__.py +++ b/torchstain/numpy/normalizers/__init__.py @@ -1,2 +1,3 @@ from .macenko import NumpyMacenkoNormalizer -from .reinhard import NumpyReinhardNormalizer \ No newline at end of file +from .reinhard import NumpyReinhardNormalizer +from .vahadane import NumpyVahadaneNormalizer diff --git a/torchstain/numpy/normalizers/vahadane.py b/torchstain/numpy/normalizers/vahadane.py new file mode 100644 index 0000000..0813674 --- /dev/null +++ b/torchstain/numpy/normalizers/vahadane.py @@ -0,0 +1,33 @@ +import numpy as np +from torchstain.base.normalizers import HENormalizer +from torchstain.numpy.utils.lasso import lasso +from torchstain.numpy.utils.stats import standardize_brightness +from torchstain.numpy.utils.extract import get_stain_matrix, get_concentrations +from torchstain.numpy.utils.od2rgb import od2rgb + +""" +Source code adapted from: +https://github.com/wanghao14/Stain_Normalization/blob/master/stainNorm_Vahadane.py +https://github.com/Peter554/StainTools/blob/master/staintools/stain_normalizer.py +""" +class NumpyVahadaneNormalizer(HENormalizer): + def __init__(self): + super().__init__() + self.stain_matrix_target = None + self.maxC_target = None + + def fit(self, target): + # target = target.astype("float32") + self.stain_matrix_target = get_stain_matrix(target) + concentration_target = get_concentrations(target, self.stain_matrix_target) + self.maxC_target = np.percentile(concentration_target, 99, axis=0).reshape((1, 2)) + + def normalize(self, I): + # I = I.astype("float32") + # I = standardize_brightness(I) + stain_matrix = get_stain_matrix(I) + concentrations = get_concentrations(I, stain_matrix) + maxC = np.percentile(concentrations, 99, axis=0).reshape((1, 2)) + concentrations *= (self.maxC_target / maxC) + out = 255 * np.exp(-1 * np.dot(concentrations, self.stain_matrix_target)) + return out.reshape(I.shape).astype("uint8") diff --git a/torchstain/numpy/utils/__init__.py b/torchstain/numpy/utils/__init__.py index f440077..56f2aea 100644 --- a/torchstain/numpy/utils/__init__.py +++ b/torchstain/numpy/utils/__init__.py @@ -2,3 +2,6 @@ from torchstain.numpy.utils.lab2rgb import * from torchstain.numpy.utils.split import * from torchstain.numpy.utils.stats import * +from torchstain.numpy.utils.lasso import * +from torchstain.numpy.utils.od2rgb import * +from torchstain.numpy.utils.rgb2od import * diff --git a/torchstain/numpy/utils/extract.py b/torchstain/numpy/utils/extract.py new file mode 100644 index 0000000..794c562 --- /dev/null +++ b/torchstain/numpy/utils/extract.py @@ -0,0 +1,53 @@ +import numpy as np +from torchstain.numpy.utils.rgb2od import rgb2od +from torchstain.numpy.utils.rgb2lab import rgb2lab +from torchstain.numpy.utils.lasso import lasso +import spams +from sklearn.linear_model import LassoLars +from sklearn.decomposition import DictionaryLearning + +def extract_tissue(I, th=0.8): + LAB = rgb2lab(I / 255) + L = LAB[:, :, 0] / 255.0 + return L < th + +def get_stain_matrix(I, th=0.8, alpha=0.1): + # convert RGB -> OD and flatten channel-wise + OD = rgb2od(I).reshape((-1, 3)) + + # detect glass and remove it from OD image + mask = extract_tissue(I, th).reshape((-1,)) + OD = OD[mask] + + # perform dictionary learning @TODO: Implement DL train method + #model = DictionaryLearning( + # fit_algorithm="lars", + # transform_algorithm="lasso_lars", n_components=2, transform_n_nonzero_coefs=0, + # transform_alpha=alpha, positive_dict=True, verbose=False, split_sign=True, + # # positive_code=True, + #) + #dictionary1 = model.fit_transform(OD) + #print(dictionary1) + dictionary = spams.trainDL(OD.T, K=2, lambda1=alpha, mode=2, modeD=0, + posAlpha=True, posD=True, verbose=False) + + #print(dictionary) + #exit() + dictionary = dictionary.T + if dictionary[0, 0] < dictionary[1, 0]: + dictionary = dictionary[[1, 0], :] + + # normalize rows and return result + return dictionary / np.linalg.norm(dictionary, axis=1)[:, None] + +def get_concentrations(I, stain_matrix, alpha=0.01): + # convert RGB -> OD and flatten channel-wise + OD = rgb2od(I).reshape((-1, 3)) + + # perform LASSO regression + #model = LassoLars(alpha=alpha, positive=True, fit_intercept=False) + #model.fit(X=OD.T, y=stain_matrix.T) + #print(OD.T) + #return model.predict(OD.T).T + return spams.lasso(OD.T, D=stain_matrix.T, mode=2, lambda1=alpha, pos=True).toarray().T + #return lasso(OD.T, y=stain_matrix.T).T # @TODO: Implement LARS-LASSO diff --git a/torchstain/numpy/utils/lasso.py b/torchstain/numpy/utils/lasso.py new file mode 100644 index 0000000..ad35c89 --- /dev/null +++ b/torchstain/numpy/utils/lasso.py @@ -0,0 +1,52 @@ +import numpy as np + + +""" +LASSO implementation was adapted from: +https://www.kaggle.com/code/mcweng24/lasso-regression-using-numpy +""" +def predicted_values(X, w): + return np.matmul(X, w) + + +def rho_compute(y, X, w, j): + X_k = np.delete(X, j, 1) + w_k = np.delete(w, j) + predict_k = predicted_values(X_k, w_k) + residual = y - predict_k + return np.sum(X[:, j] * residual) + + +def z_compute(X): + z_vector = np.sum(X * X, axis=0) + return np.sum(X * X, axis = 0) + + +def coordinate_descent(y, X, w, alpha, z, tol): + max_step = 100 + iteration = 0 + while max_step > tol: + iteration += 1 + old_weights = np.copy(w) + for j in range(len(w)): + rho_j = rho_compute(y, X, w, j) + if j == 0: + w[j] = rho_j / z[j] + elif rho_j < -alpha * len(y): + w[j] = (rho_j + (alpha * len(y))) / z[j] + elif rho_j > -alpha * len(y) and rho_j < alpha * len(y): + w[j] = 0. + elif rho_j > alpha * len(y): + w[j] = (rho_j - (alpha * len(y))) / z[j] + else: + w[j] = np.NaN + step_sizes = np.abs(old_weights - w) + max_step = step_sizes.max() + return w, iteration, max_step + + +def lasso(x, y, alpha=0.1, tol=0.0001): + w = np.zeros(x.shape[1], dtype="float32") + z = z_compute(x) + w_opt, iterations, max_step = coordinate_descent(y, x, w, alpha, z, tol) + return w_opt diff --git a/torchstain/numpy/utils/od2rgb.py b/torchstain/numpy/utils/od2rgb.py new file mode 100644 index 0000000..904a025 --- /dev/null +++ b/torchstain/numpy/utils/od2rgb.py @@ -0,0 +1,6 @@ +import numpy as np + +# https://github.com/Peter554/StainTools/blob/2089900d11173ee5ea7de95d34532932afd3181a/staintools/utils/optical_density_conversion.py#L18 +def od2rgb(OD): + OD = np.maximum(OD, 1e-6) + return (255 * np.exp(-1 * OD)).astype(np.uint8) diff --git a/torchstain/numpy/utils/rgb2od.py b/torchstain/numpy/utils/rgb2od.py new file mode 100644 index 0000000..bc34847 --- /dev/null +++ b/torchstain/numpy/utils/rgb2od.py @@ -0,0 +1,9 @@ +import numpy as np + +# https://github.com/Peter554/StainTools/blob/2089900d11173ee5ea7de95d34532932afd3181a/staintools/utils/optical_density_conversion.py#L4 +def rgb2od(I): + # remove zeros + I[I == 0] = 1 + + # convert to OD and return + return np.maximum(-1 * np.log(I / 255), 1e-6) diff --git a/torchstain/numpy/utils/stats.py b/torchstain/numpy/utils/stats.py index b43ebe1..fcf24ed 100644 --- a/torchstain/numpy/utils/stats.py +++ b/torchstain/numpy/utils/stats.py @@ -5,3 +5,7 @@ def get_mean_std(I): def standardize(x, mu, std): return (x - mu) / std + +def standardize_brightness(x, alpha=99): + p = np.percentile(x, alpha) + return np.clip(x * 255 / p, 0, 255).astype("uint8")