Skip to content

Commit

Permalink
dwidenoise: Optimal shrinkage
Browse files Browse the repository at this point in the history
- Default behaviour is now to use optimal shrinkage based on minimisation of the Frobenius norm.
- Prior behaviour can be accessed using "-filter truncate".
Closes #3022.
  • Loading branch information
Lestropie committed Nov 8, 2024
1 parent 4165276 commit aec1d06
Showing 1 changed file with 110 additions and 28 deletions.
138 changes: 110 additions & 28 deletions cmd/dwidenoise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ const std::vector<std::string> shapes = {"cuboid", "sphere"};
enum class shape_type { CUBOID, SPHERE };
constexpr default_type sphere_multiplier_default = 1.0 / 0.85;

const std::vector<std::string> filters = {"truncate", "frobenius"};
enum class filter_type { TRUNCATE, FROBENIUS };

// clang-format off
void usage() {

Expand Down Expand Up @@ -76,11 +79,18 @@ void usage() {
"the command will select the smallest isotropic patch size "
"that exceeds the number of DW images in the input data; "
"e.g., 5x5x5 for data with <= 125 DWI volumes, "
"7x7x7 for data with <= 343 DWI volumes, etc.";
"7x7x7 for data with <= 343 DWI volumes, etc."

+ "By default, optimal value shrinkage based on minimisation of the Frobenius norm "
"will be used to attenuate eigenvectors based on the estimated noise level. "
"Hard truncation of sub-threshold components"
"---which was the behaviour of the dwidenoise command in version 3.0.x---"
"can be activated using -filter truncate.";

AUTHOR = "Daan Christiaens ([email protected])"
" and Jelle Veraart ([email protected])"
" and J-Donald Tournier ([email protected])";
" and J-Donald Tournier ([email protected])"
" and Robert E. Smith ([email protected])";

REFERENCES
+ "Veraart, J.; Novikov, D.S.; Christiaens, D.; Ades-aron, B.; Sijbers, J. & Fieremans, E. " // Internal
Expand Down Expand Up @@ -117,6 +127,10 @@ void usage() {
"* Exp1: the original estimator used in Veraart et al. (2016), or \n"
"* Exp2: the improved estimator introduced in Cordero-Grande et al. (2019).")
+ Argument("Exp1/Exp2").type_choice(estimators)
+ Option("filter",
"Modulate how components are filtered based on their eigenvalues; "
"options are: " + join(filters, ",") + "; default: frobenius")
+ Argument("choice").type_choice(filters)

+ OptionGroup("Options for exporting additional data regarding PCA behaviour")
+ Option("noise",
Expand All @@ -125,10 +139,13 @@ void usage() {
"Note that on complex input data,"
" this will be the total noise level across real and imaginary channels,"
" so a scale factor sqrt(2) applies.")
+ Argument("level").type_image_out()
+ Argument("image").type_image_out()
+ Option("rank",
"The selected signal rank of the output denoised image.")
+ Argument("cutoff").type_image_out()
+ Argument("image").type_image_out()
+ Option("sumweights",
"the sum of eigenvector weights contributed to the output image")
+ Argument("image").type_image_out()
+ Option("max_dist",
"The maximum distance between a voxel and another voxel that was included in the local denoising patch")
+ Argument("image").type_image_out()
Expand All @@ -146,10 +163,11 @@ void usage() {
+ Argument("value").type_float(0.0)
+ Option("radius_ratio",
"Set the spherical kernel size as a ratio of number of voxels to number of input volumes "
"(default: ~1.18)")
"(default: 1.0/0.85 ~= 1.18)")
+ Argument("value").type_float(0.0)
// TODO Command-line option that allows user to specify minimum absolute number of voxels in kernel
+ Option("extent",
"Set the patch size of the cuboid filter; "
"Set the patch size of the cuboid kernel; "
"can be either a single odd integer or a comma-separated triplet of odd integers")
+ Argument("window").type_sequence_int();

Expand Down Expand Up @@ -184,6 +202,7 @@ void usage() {

using real_type = float;
using voxel_type = Eigen::Array<int, 3, 1>;
using vector_type = Eigen::VectorXd;

class KernelVoxel {
public:
Expand Down Expand Up @@ -415,26 +434,31 @@ template <typename F> class DenoisingFunctor {

public:
using MatrixType = Eigen::Matrix<F, Eigen::Dynamic, Eigen::Dynamic>;
using SValsType = Eigen::VectorXd;

DenoisingFunctor(int ndwi,
std::shared_ptr<KernelBase> kernel,
filter_type filter,
Image<bool> &mask,
Image<real_type> &noise,
Image<uint16_t> &rank,
Image<float> &sum_weights,
Image<float> &max_dist,
Image<uint16_t> &voxels,
estimator_type estimator)
: kernel(kernel),
filter(filter),
m(ndwi),
estimator(estimator),
X(ndwi, kernel->estimated_size()),
XtX(std::min(m, kernel->estimated_size()), std::min(m, kernel->estimated_size())),
eig(std::min(m, kernel->estimated_size())),
s(std::min(m, kernel->estimated_size())),
clam(std::min(m, kernel->estimated_size())),
w(std::min(m, kernel->estimated_size())),
mask(mask),
noise(noise),
rankmap(rank),
sumweightsmap(sum_weights),
maxdistmap(max_dist),
voxelsmap(voxels) {}

Expand All @@ -461,6 +485,8 @@ template <typename F> class DenoisingFunctor {
DEBUG("Expanding decomposition matrix storage from " + str(X.rows()) + " to " + str(r));
XtX.resize(r, r);
s.resize(r);
clam.resize(r);
w.resize(r);
}

// Fill matrices with NaN when in debug mode;
Expand All @@ -471,6 +497,8 @@ template <typename F> class DenoisingFunctor {
X.fill(std::numeric_limits<F>::signaling_NaN());
XtX.fill(std::numeric_limits<F>::signaling_NaN());
s.fill(std::numeric_limits<default_type>::signaling_NaN());
clam.fill(std::numeric_limits<default_type>::signaling_NaN());
w.fill(std::numeric_limits<default_type>::signaling_NaN());
#endif

load_data(dwi, neighbourhood.voxels);
Expand All @@ -486,13 +514,12 @@ template <typename F> class DenoisingFunctor {

// Marchenko-Pastur optimal threshold
const double lam_r = std::max(s[0], 0.0) / q;
double clam = 0.0;
double sigma2 = 0.0;
ssize_t cutoff_p = 0;
for (ssize_t p = 0; p < r; ++p) // p+1 is the number of noise components
{ // (as opposed to the paper where p is defined as the number of signal components)
const double lam = std::max(s[p], 0.0) / q;
clam += lam;
clam[p] = (p == 0 ? 0.0 : clam[p - 1]) + lam;
double denominator = std::numeric_limits<double>::signaling_NaN();
switch (estimator) {
case estimator_type::EXP1:
Expand All @@ -505,7 +532,7 @@ template <typename F> class DenoisingFunctor {
assert(false);
}
const double gam = double(p + 1) / denominator;
const double sigsq1 = clam / double(p + 1);
const double sigsq1 = clam[p] / double(p + 1);
const double sigsq2 = (lam - lam_r) / (4.0 * std::sqrt(gam));
// sigsq2 > sigsq1 if signal else noise
if (sigsq2 < sigsq1) {
Expand All @@ -514,20 +541,38 @@ template <typename F> class DenoisingFunctor {
}
}

if (cutoff_p > 0) {
// recombine data using only eigenvectors above threshold:
s.head(cutoff_p).setZero();
s.segment(cutoff_p, r - cutoff_p).setOnes();
if (m <= n)
X.col(neighbourhood.centre_index) =
eig.eigenvectors() *
(s.head(r).cast<F>().asDiagonal() * (eig.eigenvectors().adjoint() * X.col(neighbourhood.centre_index)));
else
X.col(neighbourhood.centre_index) =
X.leftCols(n) * (eig.eigenvectors() * (s.head(r).cast<F>().asDiagonal() *
eig.eigenvectors().adjoint().col(neighbourhood.centre_index)));
// Generate weights vector
double sum_weights = 0.0;
switch (filter) {
case filter_type::TRUNCATE:
w.head(cutoff_p).setZero();
w.segment(cutoff_p, r - cutoff_p).setOnes();
sum_weights = r - cutoff_p;
break;
case filter_type::FROBENIUS: {
const double beta = r / q;
const double threshold = 1.0 + std::sqrt(beta);
for (ssize_t i = 0; i != r; ++i) {
const double y = clam[i] / (sigma2 * (i + 1));
const double nu = y > threshold ? std::sqrt(Math::pow2(Math::pow2(y) - beta - 1.0) - (4.0 * beta)) / y : 0.0;
w[i] = nu / y;
sum_weights += w[i];
}
} break;
default:
assert(false);
}

// recombine data using only eigenvectors above threshold:
if (m <= n)
X.col(neighbourhood.centre_index) =
eig.eigenvectors() *
(w.head(r).cast<F>().asDiagonal() * (eig.eigenvectors().adjoint() * X.col(neighbourhood.centre_index)));
else
X.col(neighbourhood.centre_index) =
X.leftCols(n) * (eig.eigenvectors() * (w.head(r).cast<F>().asDiagonal() *
eig.eigenvectors().adjoint().col(neighbourhood.centre_index)));

// Store output
assign_pos_of(dwi).to(out);
out.row(3) = X.col(neighbourhood.centre_index);
Expand All @@ -541,6 +586,10 @@ template <typename F> class DenoisingFunctor {
assign_pos_of(dwi, 0, 3).to(rankmap);
rankmap.value() = uint16_t(r - cutoff_p);
}
if (sumweightsmap.valid()) {
assign_pos_of(dwi, 0, 3).to(sumweightsmap);
sumweightsmap.value() = sum_weights;
}
if (maxdistmap.valid()) {
assign_pos_of(dwi, 0, 3).to(maxdistmap);
maxdistmap.value() = float(neighbourhood.max_distance);
Expand All @@ -552,16 +601,27 @@ template <typename F> class DenoisingFunctor {
}

private:
// Denoising configuration
std::shared_ptr<KernelBase> kernel;
filter_type filter;
const ssize_t m;
const estimator_type estimator;

// Reusable memory
MatrixType X;
MatrixType XtX;
Eigen::SelfAdjointEigenSolver<MatrixType> eig;
SValsType s;
vector_type s;
vector_type clam;
vector_type w;

Image<bool> mask;

// Export images
// TODO Group these into a class?
Image<real_type> noise;
Image<uint16_t> rankmap;
Image<float> sumweightsmap;
Image<float> maxdistmap;
Image<uint16_t> voxelsmap;

Expand All @@ -580,18 +640,20 @@ void run(Header &data,
Image<bool> &mask,
Image<real_type> &noise,
Image<uint16_t> &rank,
Image<float> &sum_weights,
Image<float> &max_dist,
Image<uint16_t> &voxels,
const std::string &output_name,
std::shared_ptr<KernelBase> kernel,
filter_type filter,
estimator_type estimator) {
auto input = data.get_image<T>().with_direct_io(3);
// create output
Header header(data);
header.datatype() = DataType::from<T>();
auto output = Image<T>::create(output_name, header);
// run
DenoisingFunctor<T> func(data.size(3), kernel, mask, noise, rank, max_dist, voxels, estimator);
DenoisingFunctor<T> func(data.size(3), kernel, filter, mask, noise, rank, sum_weights, max_dist, voxels, estimator);
ThreadedLoop("running MP-PCA denoising", data, 0, 3).run(func, input, output);
}

Expand All @@ -613,6 +675,11 @@ void run() {
if (!opt.empty())
estimator = estimator_type(int(opt[0][0]));

filter_type filter = filter_type::FROBENIUS;
opt = get_options("filter");
if (!opt.empty())
filter = filter_type(int(opt[0][0]));

Image<real_type> noise;
opt = get_options("noise");
if (!opt.empty()) {
Expand All @@ -632,6 +699,21 @@ void run() {
rank = Image<uint16_t>::create(opt[0][0], header);
}

Image<float> sum_weights;
opt = get_options("sumweights");
if (!opt.empty()) {
Header header(dwi);
header.ndim() = 3;
header.datatype() = DataType::Float32;
header.datatype().set_byte_order_native();
header.reset_intensity_scaling();
sum_weights = Image<float>::create(opt[0][0], header);
if (filter == filter_type::TRUNCATE) {
WARN("Note that with a truncation filter, "
"output image from -sumweights option will be equivalent to rank");
}
}

Image<float> max_dist;
opt = get_options("max_dist");
if (!opt.empty()) {
Expand Down Expand Up @@ -714,19 +796,19 @@ void run() {
switch (prec) {
case 0:
INFO("select real float32 for processing");
run<float>(dwi, mask, noise, rank, max_dist, voxels, argument[1], kernel, estimator);
run<float>(dwi, mask, noise, rank, sum_weights, max_dist, voxels, argument[1], kernel, filter, estimator);
break;
case 1:
INFO("select real float64 for processing");
run<double>(dwi, mask, noise, rank, max_dist, voxels, argument[1], kernel, estimator);
run<double>(dwi, mask, noise, rank, sum_weights, max_dist, voxels, argument[1], kernel, filter, estimator);
break;
case 2:
INFO("select complex float32 for processing");
run<cfloat>(dwi, mask, noise, rank, max_dist, voxels, argument[1], kernel, estimator);
run<cfloat>(dwi, mask, noise, rank, sum_weights, max_dist, voxels, argument[1], kernel, filter, estimator);
break;
case 3:
INFO("select complex float64 for processing");
run<cdouble>(dwi, mask, noise, rank, max_dist, voxels, argument[1], kernel, estimator);
run<cdouble>(dwi, mask, noise, rank, sum_weights, max_dist, voxels, argument[1], kernel, filter, estimator);
break;
}
}

0 comments on commit aec1d06

Please sign in to comment.