Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vahadane numpy backend #28

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
23 changes: 10 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
GPU-accelerated stain normalization tools for histopathological images. Compatible with PyTorch, TensorFlow, and Numpy.
Normalization algorithms currently implemented:

- Macenko et al. [\[1\]](#reference) (ported from [numpy implementation](https://github.com/schaugf/HEnorm_python))
- Reinhard et al. [\[2\]](#reference) (only numpy & TensorFlow backend support)
| Algorithm | numpy | torch | tensorflow |
|-|-|-|-|
| Macenko [\[1\]](#references) | ✓ | ✓ | ✓ |
| Reinhard [\[2\]](#references)| ✓ | ✗ | ✓ |
| Vahadane [\[3\]](#references) | ✓ | ✗ | ✗ |

## Installation

Expand Down Expand Up @@ -44,16 +47,9 @@ norm, H, E = normalizer.normalize(I=t_to_transform, stains=True)

![alt text](data/result.png)

## Implemented algorithms

| Algorithm | numpy | torch | tensorflow |
|-|-|-|-|
| Macenko | ✓ | ✓ | ✓ |
| Reinhard | ✓ | ✗ | ✓ |

## Backend comparison

Results with 10 runs per size on a Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz
Runtime results using Macenko with 10 runs per size on a Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz.

| size | numpy avg. time | torch avg. time | tf avg. time |
|--------|-------------------|-------------------|------------------|
Expand All @@ -66,10 +62,11 @@ Results with 10 runs per size on a Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz
| 1568 | 1.1935s ± 0.0739 | 0.2590s ± 0.0088 | 0.2531s ± 0.0031 |
| 1792 | 1.4523s ± 0.0207 | 0.3402s ± 0.0114 | 0.3080s ± 0.0188 |

## Reference
## References

- [1] Macenko, Marc, et al. "A method for normalizing histology slides for quantitative analysis." 2009 IEEE International Symposium on Biomedical Imaging: From Nano to Macro. IEEE, 2009.
- [2] Reinhard, Erik, et al. "Color transfer between images." IEEE Computer Graphics and Applications. IEEE, 2001.
- [1] Macenko, Marc et al. "A method for normalizing histology slides for quantitative analysis." 2009 IEEE International Symposium on Biomedical Imaging: From Nano to Macro. IEEE, 2009.
- [2] Reinhard, Erik et al. "Color transfer between images." IEEE Computer Graphics and Applications. 2015 IEEE 12th International Symposium on Biomedical Imaging (ISBI), IEEE, 2001.
- [3] Vahadane, Abhishek et al. "Structure-preserved color normalization for histopathological images". IEEE, 2015

## Citing

Expand Down
1 change: 1 addition & 0 deletions torchstain/base/normalizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .he_normalizer import HENormalizer
from .macenko import MacenkoNormalizer
from .reinhard import ReinhardNormalizer
from .vahadane import VahadaneNormalizer
10 changes: 10 additions & 0 deletions torchstain/base/normalizers/vahadane.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def VahadaneNormalizer(backend='numpy'):
if backend == 'numpy':
from torchstain.numpy.normalizers import NumpyVahadaneNormalizer
return NumpyVahadaneNormalizer()
elif backend == "torch":
raise NotImplementedError
elif backend == "tensorflow":
raise NotImplementedError
else:
raise Exception(f'Unknown backend {backend}')
3 changes: 2 additions & 1 deletion torchstain/numpy/normalizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .macenko import NumpyMacenkoNormalizer
from .reinhard import NumpyReinhardNormalizer
from .reinhard import NumpyReinhardNormalizer
from .vahadane import NumpyVahadaneNormalizer
33 changes: 33 additions & 0 deletions torchstain/numpy/normalizers/vahadane.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
from torchstain.base.normalizers import HENormalizer
from torchstain.numpy.utils.lasso import lasso
from torchstain.numpy.utils.stats import standardize_brightness
from torchstain.numpy.utils.extract import get_stain_matrix, get_concentrations
from torchstain.numpy.utils.od2rgb import od2rgb

"""
Source code adapted from:
https://github.com/wanghao14/Stain_Normalization/blob/master/stainNorm_Vahadane.py
https://github.com/Peter554/StainTools/blob/master/staintools/stain_normalizer.py
"""
class NumpyVahadaneNormalizer(HENormalizer):
def __init__(self):
super().__init__()
self.stain_matrix_target = None
self.maxC_target = None

def fit(self, target):
# target = target.astype("float32")
self.stain_matrix_target = get_stain_matrix(target)
concentration_target = get_concentrations(target, self.stain_matrix_target)
self.maxC_target = np.percentile(concentration_target, 99, axis=0).reshape((1, 2))

def normalize(self, I):
# I = I.astype("float32")
# I = standardize_brightness(I)
stain_matrix = get_stain_matrix(I)
concentrations = get_concentrations(I, stain_matrix)
maxC = np.percentile(concentrations, 99, axis=0).reshape((1, 2))
concentrations *= (self.maxC_target / maxC)
out = 255 * np.exp(-1 * np.dot(concentrations, self.stain_matrix_target))
return out.reshape(I.shape).astype("uint8")
3 changes: 3 additions & 0 deletions torchstain/numpy/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
from torchstain.numpy.utils.lab2rgb import *
from torchstain.numpy.utils.split import *
from torchstain.numpy.utils.stats import *
from torchstain.numpy.utils.lasso import *
from torchstain.numpy.utils.od2rgb import *
from torchstain.numpy.utils.rgb2od import *
53 changes: 53 additions & 0 deletions torchstain/numpy/utils/extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
from torchstain.numpy.utils.rgb2od import rgb2od
from torchstain.numpy.utils.rgb2lab import rgb2lab
from torchstain.numpy.utils.lasso import lasso
import spams
from sklearn.linear_model import LassoLars
from sklearn.decomposition import DictionaryLearning

def extract_tissue(I, th=0.8):
LAB = rgb2lab(I / 255)
L = LAB[:, :, 0] / 255.0
return L < th

def get_stain_matrix(I, th=0.8, alpha=0.1):
# convert RGB -> OD and flatten channel-wise
OD = rgb2od(I).reshape((-1, 3))

# detect glass and remove it from OD image
mask = extract_tissue(I, th).reshape((-1,))
OD = OD[mask]

# perform dictionary learning @TODO: Implement DL train method
#model = DictionaryLearning(
# fit_algorithm="lars",
# transform_algorithm="lasso_lars", n_components=2, transform_n_nonzero_coefs=0,
# transform_alpha=alpha, positive_dict=True, verbose=False, split_sign=True,
# # positive_code=True,
#)
#dictionary1 = model.fit_transform(OD)
#print(dictionary1)
dictionary = spams.trainDL(OD.T, K=2, lambda1=alpha, mode=2, modeD=0,
posAlpha=True, posD=True, verbose=False)

#print(dictionary)
#exit()
dictionary = dictionary.T
if dictionary[0, 0] < dictionary[1, 0]:
dictionary = dictionary[[1, 0], :]

# normalize rows and return result
return dictionary / np.linalg.norm(dictionary, axis=1)[:, None]

def get_concentrations(I, stain_matrix, alpha=0.01):
# convert RGB -> OD and flatten channel-wise
OD = rgb2od(I).reshape((-1, 3))

# perform LASSO regression
#model = LassoLars(alpha=alpha, positive=True, fit_intercept=False)
#model.fit(X=OD.T, y=stain_matrix.T)
#print(OD.T)
#return model.predict(OD.T).T
return spams.lasso(OD.T, D=stain_matrix.T, mode=2, lambda1=alpha, pos=True).toarray().T
#return lasso(OD.T, y=stain_matrix.T).T # @TODO: Implement LARS-LASSO
52 changes: 52 additions & 0 deletions torchstain/numpy/utils/lasso.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np


"""
LASSO implementation was adapted from:
https://www.kaggle.com/code/mcweng24/lasso-regression-using-numpy
"""
def predicted_values(X, w):
return np.matmul(X, w)


def rho_compute(y, X, w, j):
X_k = np.delete(X, j, 1)
w_k = np.delete(w, j)
predict_k = predicted_values(X_k, w_k)
residual = y - predict_k
return np.sum(X[:, j] * residual)


def z_compute(X):
z_vector = np.sum(X * X, axis=0)
return np.sum(X * X, axis = 0)


def coordinate_descent(y, X, w, alpha, z, tol):
max_step = 100
iteration = 0
while max_step > tol:
iteration += 1
old_weights = np.copy(w)
for j in range(len(w)):
rho_j = rho_compute(y, X, w, j)
if j == 0:
w[j] = rho_j / z[j]
elif rho_j < -alpha * len(y):
w[j] = (rho_j + (alpha * len(y))) / z[j]
elif rho_j > -alpha * len(y) and rho_j < alpha * len(y):
w[j] = 0.
elif rho_j > alpha * len(y):
w[j] = (rho_j - (alpha * len(y))) / z[j]
else:
w[j] = np.NaN
step_sizes = np.abs(old_weights - w)
max_step = step_sizes.max()
return w, iteration, max_step


def lasso(x, y, alpha=0.1, tol=0.0001):
w = np.zeros(x.shape[1], dtype="float32")
z = z_compute(x)
w_opt, iterations, max_step = coordinate_descent(y, x, w, alpha, z, tol)
return w_opt
6 changes: 6 additions & 0 deletions torchstain/numpy/utils/od2rgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import numpy as np

# https://github.com/Peter554/StainTools/blob/2089900d11173ee5ea7de95d34532932afd3181a/staintools/utils/optical_density_conversion.py#L18
def od2rgb(OD):
OD = np.maximum(OD, 1e-6)
return (255 * np.exp(-1 * OD)).astype(np.uint8)
9 changes: 9 additions & 0 deletions torchstain/numpy/utils/rgb2od.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import numpy as np

# https://github.com/Peter554/StainTools/blob/2089900d11173ee5ea7de95d34532932afd3181a/staintools/utils/optical_density_conversion.py#L4
def rgb2od(I):
# remove zeros
I[I == 0] = 1

# convert to OD and return
return np.maximum(-1 * np.log(I / 255), 1e-6)
4 changes: 4 additions & 0 deletions torchstain/numpy/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ def get_mean_std(I):

def standardize(x, mu, std):
return (x - mu) / std

def standardize_brightness(x, alpha=99):
p = np.percentile(x, alpha)
return np.clip(x * 255 / p, 0, 255).astype("uint8")