Skip to content

Commit

Permalink
Rename Measure as MeasureCreator
Browse files Browse the repository at this point in the history
  • Loading branch information
onurulgen committed Feb 23, 2024
1 parent e6d09fc commit cee3df5
Show file tree
Hide file tree
Showing 20 changed files with 80 additions and 80 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
417
418
2 changes: 1 addition & 1 deletion reg-lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ add_library(_reg_compute ${NIFTYREG_LIBRARY_TYPE}
F3dContent.cpp
Optimiser.cpp
Platform.cpp
Measure.cpp
MeasureCreator.cpp
)
target_link_libraries(_reg_compute _reg_measure)
install(TARGETS _reg_compute
Expand Down
6 changes: 3 additions & 3 deletions reg-lib/Measure.cpp → reg-lib/MeasureCreator.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "Measure.h"
#include "MeasureCreator.hpp"
#include "_reg_nmi.h"
#include "_reg_ssd.h"
#include "_reg_dti.h"
Expand All @@ -7,7 +7,7 @@
#include "_reg_mind.h"

/* *************************************************************** */
reg_measure* Measure::Create(const MeasureType measureType) {
reg_measure* MeasureCreator::Create(const MeasureType measureType) {
switch (measureType) {
case MeasureType::Nmi:
return new reg_nmi();
Expand All @@ -29,7 +29,7 @@ reg_measure* Measure::Create(const MeasureType measureType) {
}
}
/* *************************************************************** */
void Measure::Initialise(reg_measure& measure, DefContent& con, DefContent *conBw) {
void MeasureCreator::Initialise(reg_measure& measure, DefContent& con, DefContent *conBw) {
measure.InitialiseMeasure(con.GetReference(),
con.GetFloating(),
con.GetReferenceMask(),
Expand Down
2 changes: 1 addition & 1 deletion reg-lib/Measure.h → reg-lib/MeasureCreator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

enum class MeasureType { Nmi, Ssd, Dti, Lncc, Kld, Mind, MindSsc };

class Measure {
class MeasureCreator {
public:
virtual reg_measure* Create(const MeasureType measureType);
virtual void Initialise(reg_measure& measure, DefContent& con, DefContent *conBw = nullptr);
Expand Down
8 changes: 8 additions & 0 deletions reg-lib/MeasureCreatorFactory.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pragma once

#include "MeasureCreator.hpp"

class MeasureCreatorFactory {
public:
virtual MeasureCreator* Produce() { return new MeasureCreator(); }
};
8 changes: 0 additions & 8 deletions reg-lib/MeasureFactory.h

This file was deleted.

12 changes: 6 additions & 6 deletions reg-lib/Platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "CudaComputeFactory.h"
#include "CudaContentCreatorFactory.h"
#include "CudaKernelFactory.h"
#include "CudaMeasureFactory.h"
#include "CudaMeasureCreatorFactory.hpp"
#include "CudaOptimiser.hpp"
#endif
#ifdef USE_OPENCL
Expand All @@ -24,7 +24,7 @@ Platform::Platform(const PlatformType platformTypeIn) {
computeFactory = new ComputeFactory();
contentCreatorFactory = new ContentCreatorFactory();
kernelFactory = new CpuKernelFactory();
measureFactory = new MeasureFactory();
measureCreatorFactory = new MeasureCreatorFactory();
}
#ifdef USE_CUDA
else if (platformType == PlatformType::Cuda) {
Expand All @@ -33,7 +33,7 @@ Platform::Platform(const PlatformType platformTypeIn) {
computeFactory = new CudaComputeFactory();
contentCreatorFactory = new CudaContentCreatorFactory();
kernelFactory = new CudaKernelFactory();
measureFactory = new CudaMeasureFactory();
measureCreatorFactory = new CudaMeasureCreatorFactory();
}
#endif
#ifdef USE_OPENCL
Expand All @@ -52,7 +52,7 @@ Platform::~Platform() {
delete computeFactory;
delete contentCreatorFactory;
delete kernelFactory;
delete measureFactory;
delete measureCreatorFactory;
}
/* *************************************************************** */
std::string Platform::GetName() const {
Expand Down Expand Up @@ -104,8 +104,8 @@ Kernel* Platform::CreateKernel(const std::string& name, Content *con) const {
return kernelFactory->Produce(name, con);
}
/* *************************************************************** */
Measure* Platform::CreateMeasure() const {
return measureFactory->Produce();
MeasureCreator* Platform::CreateMeasureCreator() const {
return measureCreatorFactory->Produce();
}
/* *************************************************************** */
template<typename Type>
Expand Down
8 changes: 4 additions & 4 deletions reg-lib/Platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "ComputeFactory.h"
#include "ContentCreatorFactory.h"
#include "KernelFactory.h"
#include "MeasureFactory.h"
#include "MeasureCreatorFactory.hpp"
#include "Optimiser.hpp"

enum class PlatformType { Cpu, Cuda, OpenCl };
Expand Down Expand Up @@ -34,7 +34,7 @@ class Platform {
Compute* CreateCompute(Content& con) const;
ContentCreator* CreateContentCreator(const ContentType conType = ContentType::Base) const;
Kernel* CreateKernel(const std::string& name, Content *con) const;
Measure* CreateMeasure() const;
MeasureCreator* CreateMeasureCreator() const;
template<typename Type>
Optimiser<Type>* CreateOptimiser(F3dContent& con,
InterfaceOptimiser& opt,
Expand Down Expand Up @@ -62,8 +62,8 @@ class Platform {
ComputeFactory *computeFactory = nullptr;
ContentCreatorFactory *contentCreatorFactory = nullptr;
KernelFactory *kernelFactory = nullptr;
MeasureFactory *measureFactory = nullptr;
MeasureCreatorFactory *measureCreatorFactory = nullptr;
std::string platformName;
PlatformType platformType;
unsigned gpuIdx;
unsigned gpuIdx = 0;
};
36 changes: 18 additions & 18 deletions reg-lib/_reg_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ void reg_base<T>::CheckParameters() {
// Set the default similarity measure if none has been set
if (!measure_nmi && !measure_ssd && !measure_dti && !measure_lncc &&
!measure_kld && !measure_mind && !measure_mindssc) {
measure_nmi.reset(dynamic_cast<reg_nmi*>(measure->Create(MeasureType::Nmi)));
measure_nmi.reset(dynamic_cast<reg_nmi*>(measureCreator->Create(MeasureType::Nmi)));
for (int i = 0; i < inputReference->nt; ++i)
measure_nmi->SetTimePointWeight(i, 1.0);
}
Expand Down Expand Up @@ -360,25 +360,25 @@ void reg_base<T>::InitialiseSimilarity() {
DefContent& con = dynamic_cast<DefContent&>(*this->con);

if (measure_nmi)
measure->Initialise(*measure_nmi, con);
measureCreator->Initialise(*measure_nmi, con);

if (measure_ssd)
measure->Initialise(*measure_ssd, con);
measureCreator->Initialise(*measure_ssd, con);

if (measure_kld)
measure->Initialise(*measure_kld, con);
measureCreator->Initialise(*measure_kld, con);

if (measure_lncc)
measure->Initialise(*measure_lncc, con);
measureCreator->Initialise(*measure_lncc, con);

if (measure_dti)
measure->Initialise(*measure_dti, con);
measureCreator->Initialise(*measure_dti, con);

if (measure_mind)
measure->Initialise(*measure_mind, con);
measureCreator->Initialise(*measure_mind, con);

if (measure_mindssc)
measure->Initialise(*measure_mindssc, con);
measureCreator->Initialise(*measure_mindssc, con);

NR_FUNC_CALLED();
}
Expand Down Expand Up @@ -551,22 +551,22 @@ void reg_base<T>::GetVoxelBasedGradient() {
//void reg_base<T>::ApproximateParzenWindow()
//{
// if(!measure_nmi)
// measure_nmi.reset(dynamic_cast<reg_nmi*>(measure->Create(MeasureType::Nmi)));
// measure_nmi.reset(dynamic_cast<reg_nmi*>(measureCreator->Create(MeasureType::Nmi)));
// measure_nmi=approxParzenWindow = true;
//}
///* *************************************************************** */
//template<class T>
//void reg_base<T>::DoNotApproximateParzenWindow()
//{
// if(!measure_nmi)
// measure_nmi.reset(dynamic_cast<reg_nmi*>(measure->Create(MeasureType::Nmi)));
// measure_nmi.reset(dynamic_cast<reg_nmi*>(measureCreator->Create(MeasureType::Nmi)));
// measure_nmi=approxParzenWindow = false;
//}
/* *************************************************************** */
template<class T>
void reg_base<T>::UseNMISetReferenceBinNumber(int timePoint, int refBinNumber) {
if (!measure_nmi)
measure_nmi.reset(dynamic_cast<reg_nmi*>(measure->Create(MeasureType::Nmi)));
measure_nmi.reset(dynamic_cast<reg_nmi*>(measureCreator->Create(MeasureType::Nmi)));
measure_nmi->SetTimePointWeight(timePoint, 1.0);//weight initially set to default value of 1.0
// I am here adding 4 to the specified bin number to accommodate for
// the spline support
Expand All @@ -577,7 +577,7 @@ void reg_base<T>::UseNMISetReferenceBinNumber(int timePoint, int refBinNumber) {
template<class T>
void reg_base<T>::UseNMISetFloatingBinNumber(int timePoint, int floBinNumber) {
if (!measure_nmi)
measure_nmi.reset(dynamic_cast<reg_nmi*>(measure->Create(MeasureType::Nmi)));
measure_nmi.reset(dynamic_cast<reg_nmi*>(measureCreator->Create(MeasureType::Nmi)));
measure_nmi->SetTimePointWeight(timePoint, 1.0);//weight initially set to default value of 1.0
// I am here adding 4 to the specified bin number to accommodate for
// the spline support
Expand All @@ -588,7 +588,7 @@ void reg_base<T>::UseNMISetFloatingBinNumber(int timePoint, int floBinNumber) {
template<class T>
void reg_base<T>::UseSSD(int timePoint, bool normalise) {
if (!measure_ssd)
measure_ssd.reset(dynamic_cast<reg_ssd*>(measure->Create(MeasureType::Ssd)));
measure_ssd.reset(dynamic_cast<reg_ssd*>(measureCreator->Create(MeasureType::Ssd)));
measure_ssd->SetTimePointWeight(timePoint, 1.0);//weight initially set to default value of 1.0
measure_ssd->SetNormaliseTimePoint(timePoint, normalise);
NR_FUNC_CALLED();
Expand All @@ -597,7 +597,7 @@ void reg_base<T>::UseSSD(int timePoint, bool normalise) {
template<class T>
void reg_base<T>::UseMIND(int timePoint, int offset) {
if (!measure_mind)
measure_mind.reset(dynamic_cast<reg_mind*>(measure->Create(MeasureType::Mind)));
measure_mind.reset(dynamic_cast<reg_mind*>(measureCreator->Create(MeasureType::Mind)));
measure_mind->SetTimePointWeight(timePoint, 1.0);//weight set to 1.0 to indicate time point is active
measure_mind->SetDescriptorOffset(offset);
NR_FUNC_CALLED();
Expand All @@ -606,7 +606,7 @@ void reg_base<T>::UseMIND(int timePoint, int offset) {
template<class T>
void reg_base<T>::UseMINDSSC(int timePoint, int offset) {
if (!measure_mindssc)
measure_mindssc.reset(dynamic_cast<reg_mindssc*>(measure->Create(MeasureType::MindSsc)));
measure_mindssc.reset(dynamic_cast<reg_mindssc*>(measureCreator->Create(MeasureType::MindSsc)));
measure_mindssc->SetTimePointWeight(timePoint, 1.0);//weight set to 1.0 to indicate time point is active
measure_mindssc->SetDescriptorOffset(offset);
NR_FUNC_CALLED();
Expand All @@ -615,15 +615,15 @@ void reg_base<T>::UseMINDSSC(int timePoint, int offset) {
template<class T>
void reg_base<T>::UseKLDivergence(int timePoint) {
if (!measure_kld)
measure_kld.reset(dynamic_cast<reg_kld*>(measure->Create(MeasureType::Kld)));
measure_kld.reset(dynamic_cast<reg_kld*>(measureCreator->Create(MeasureType::Kld)));
measure_kld->SetTimePointWeight(timePoint, 1.0);//weight initially set to default value of 1.0
NR_FUNC_CALLED();
}
/* *************************************************************** */
template<class T>
void reg_base<T>::UseLNCC(int timePoint, float stddev) {
if (!measure_lncc)
measure_lncc.reset(dynamic_cast<reg_lncc*>(measure->Create(MeasureType::Lncc)));
measure_lncc.reset(dynamic_cast<reg_lncc*>(measureCreator->Create(MeasureType::Lncc)));
measure_lncc->SetKernelStandardDeviation(timePoint, stddev);
measure_lncc->SetTimePointWeight(timePoint, 1.0); // weight initially set to default value of 1.0
NR_FUNC_CALLED();
Expand All @@ -642,7 +642,7 @@ void reg_base<T>::UseDTI(bool *timePoint) {
NR_FATAL_ERROR("The use of DTI has been deactivated as it requires some refactoring");

if (!measure_dti)
measure_dti.reset(dynamic_cast<reg_dti*>(measure->Create(MeasureType::Dti)));
measure_dti.reset(dynamic_cast<reg_dti*>(measureCreator->Create(MeasureType::Dti)));
for (int i = 0; i < inputReference->nt; ++i) {
if (timePoint[i])
measure_dti->SetTimePointWeight(i, 1.0); // weight set to 1.0 to indicate time point is active
Expand Down
4 changes: 2 additions & 2 deletions reg-lib/_reg_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class reg_base: public InterfaceOptimiser {
unique_ptr<Compute> compute;

// Measure
unique_ptr<Measure> measure;
unique_ptr<MeasureCreator> measureCreator;

// Optimiser-related variables
unique_ptr<Optimiser<T>> optimiser;
Expand Down Expand Up @@ -143,7 +143,7 @@ class reg_base: public InterfaceOptimiser {
// Platform
virtual void SetPlatformType(const PlatformType platformType) {
platform.reset(new Platform(platformType));
measure.reset(platform->CreateMeasure());
measureCreator.reset(platform->CreateMeasureCreator());
}
virtual void SetGpuIdx(const unsigned gpuIdx) { platform->SetGpuIdx(gpuIdx); }

Expand Down
14 changes: 7 additions & 7 deletions reg-lib/_reg_f3d2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,25 +489,25 @@ void reg_f3d2<T>::InitialiseSimilarity() {
F3dContent& con = dynamic_cast<F3dContent&>(*this->con);

if (this->measure_nmi)
this->measure->Initialise(*this->measure_nmi, con, conBw.get());
this->measureCreator->Initialise(*this->measure_nmi, con, conBw.get());

if (this->measure_ssd)
this->measure->Initialise(*this->measure_ssd, con, conBw.get());
this->measureCreator->Initialise(*this->measure_ssd, con, conBw.get());

if (this->measure_kld)
this->measure->Initialise(*this->measure_kld, con, conBw.get());
this->measureCreator->Initialise(*this->measure_kld, con, conBw.get());

if (this->measure_lncc)
this->measure->Initialise(*this->measure_lncc, con, conBw.get());
this->measureCreator->Initialise(*this->measure_lncc, con, conBw.get());

if (this->measure_dti)
this->measure->Initialise(*this->measure_dti, con, conBw.get());
this->measureCreator->Initialise(*this->measure_dti, con, conBw.get());

if (this->measure_mind)
this->measure->Initialise(*this->measure_mind, con, conBw.get());
this->measureCreator->Initialise(*this->measure_mind, con, conBw.get());

if (this->measure_mindssc)
this->measure->Initialise(*this->measure_mindssc, con, conBw.get());
this->measureCreator->Initialise(*this->measure_mindssc, con, conBw.get());

NR_FUNC_CALLED();
}
Expand Down
2 changes: 1 addition & 1 deletion reg-lib/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ add_library(${NAME} ${NIFTYREG_LIBRARY_TYPE}
CudaKernelFactory.cpp
CudaLocalTransformation.cu
CudaLtsKernel.cpp
CudaMeasure.cpp
CudaMeasureCreator.cpp
CudaNormaliseGradient.cu
CudaOptimiser.cu
CudaResampleImageKernel.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "CudaMeasure.h"
#include "CudaMeasureCreator.hpp"
#include "CudaDefContent.h"
#include "_reg_nmi_gpu.h"
#include "_reg_ssd_gpu.h"

/* *************************************************************** */
reg_measure* CudaMeasure::Create(const MeasureType measureType) {
reg_measure* CudaMeasureCreator::Create(const MeasureType measureType) {
switch (measureType) {
case MeasureType::Nmi:
return new reg_nmi_gpu();
Expand All @@ -26,7 +26,7 @@ reg_measure* CudaMeasure::Create(const MeasureType measureType) {
}
}
/* *************************************************************** */
void CudaMeasure::Initialise(reg_measure& measure, DefContent& con, DefContent *conBw) {
void CudaMeasureCreator::Initialise(reg_measure& measure, DefContent& con, DefContent *conBw) {
reg_measure_gpu& measureGpu = dynamic_cast<reg_measure_gpu&>(measure);
CudaDefContent& cudaCon = dynamic_cast<CudaDefContent&>(con);
CudaDefContent *cudaConBw = dynamic_cast<CudaDefContent*>(conBw);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include "Measure.h"
#include "MeasureCreator.hpp"

class CudaMeasure: public Measure {
class CudaMeasureCreator: public MeasureCreator {
public:
virtual reg_measure* Create(const MeasureType measureType) override;
virtual void Initialise(reg_measure& measure, DefContent& con, DefContent *conBw = nullptr) override;
Expand Down
8 changes: 8 additions & 0 deletions reg-lib/cuda/CudaMeasureCreatorFactory.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pragma once

#include "CudaMeasureCreator.hpp"

class CudaMeasureCreatorFactory: public MeasureCreatorFactory {
public:
virtual MeasureCreator* Produce() override { return new CudaMeasureCreator(); }
};
8 changes: 0 additions & 8 deletions reg-lib/cuda/CudaMeasureFactory.h

This file was deleted.

8 changes: 4 additions & 4 deletions reg-test/reg_test_lncc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@ class LnccTest {
unique_ptr<Compute> compute{ platform->CreateCompute(*content) };
compute->ResampleImage(0, 0);
content->SetWarped(floating.disown());
// Create the measure
unique_ptr<Measure> measure{ platform->CreateMeasure() };
// Create the measure creator
unique_ptr<MeasureCreator> measureCreator{ platform->CreateMeasureCreator() };
// Use LNCC as a measure
unique_ptr<reg_lncc> measure_lncc{ dynamic_cast<reg_lncc*>(measure->Create(MeasureType::Lncc)) };
unique_ptr<reg_lncc> measure_lncc{ dynamic_cast<reg_lncc*>(measureCreator->Create(MeasureType::Lncc)) };
measure_lncc->SetKernelStandardDeviation(0, sigma);
measure_lncc->SetTimePointWeight(0, 1.0); // weight initially set to default value of 1.0
measure->Initialise(*measure_lncc, *content);
measureCreator->Initialise(*measure_lncc, *content);
const double lncc = measure_lncc->GetSimilarityMeasureValue();
// Save for testing
testCases.push_back({ testName, lncc, expLncc });
Expand Down
Loading

0 comments on commit cee3df5

Please sign in to comment.