7 #include <itkImageRegionIterator.h> 8 #include <itkImageRegionIteratorWithIndex.h> 9 #include <itkImageRegionConstIterator.h> 10 #include <itkImageRegionConstIteratorWithIndex.h> 15 template <
class TInputImage>
17 NonLocalMeansImageFilter <TInputImage>
18 ::BeforeThreadedGenerateData()
20 Superclass::BeforeThreadedGenerateData();
22 this->computeAverageLocalVariance();
23 this->computeMeanAndVarImages();
24 m_maxAbsDisp = std::floor((
double)(m_SearchNeighborhood / m_SearchStepSize)) * m_SearchStepSize;
27 template <
class TInputImage>
30 ::computeMeanAndVarImages()
33 MeanAndVarianceImagesFilterType;
34 typename MeanAndVarianceImagesFilterType::Pointer filter = MeanAndVarianceImagesFilterType::New();
35 filter->SetInput(this->GetInput());
36 typename InputImageType::SizeType radius;
38 for (
unsigned int j = 0;j < InputImageDimension;++j)
40 radius[j] = m_PatchHalfSize;
43 filter->SetRadius(radius);
44 filter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
46 m_meanImage = filter->GetMeanImage();
47 m_varImage = filter->GetVarImage();
50 template <
class TInputImage>
53 ::computeAverageLocalVariance()
55 typedef itk::ImageRegionConstIteratorWithIndex< InputImageType > InIteratorType;
57 InIteratorType dataIterator (this->GetInput(), largestRegion);
59 double averageLocalSignal, diffSignal;
61 typename InputImageRegionType::IndexType baseIndex;
63 double averageCovariance = 0;
64 unsigned int numEstimations = 0;
65 unsigned int numLocalPixels = 2 * InputImageDimension;
67 while (!dataIterator.IsAtEnd())
69 baseSignal =
static_cast<double>(dataIterator.Get());
70 baseIndex = dataIterator.GetIndex();
71 averageLocalSignal = 0;
73 typename InputImageRegionType::IndexType valueIndex;
74 for (
unsigned int d = 0; d < InputImageDimension; ++d)
76 valueIndex = baseIndex;
77 int tmpIndex = baseIndex[d] - m_localNeighborhood;
78 valueIndex[d] = std::max(tmpIndex,0);
79 averageLocalSignal +=
static_cast<double> (this->GetInput()->GetPixel(valueIndex));
81 valueIndex = baseIndex;
82 tmpIndex = baseIndex[d] + m_localNeighborhood;
83 int maxIndex = largestRegion.GetSize()[d] - 1;
84 valueIndex[d] = std::min(tmpIndex, maxIndex);
85 averageLocalSignal +=
static_cast<double> (this->GetInput()->GetPixel(valueIndex));
88 averageLocalSignal /= numLocalPixels;
89 diffSignal = sqrt(numLocalPixels / (numLocalPixels + 1.0)) * (baseSignal - averageLocalSignal);
91 averageCovariance += diffSignal * diffSignal;
98 m_noiseCovariance = averageCovariance / numEstimations;
101 template <
class TInputImage>
107 typename OutputImageType::Pointer output = this->GetOutput();
108 typename InputImageType::Pointer input =
const_cast<InputImageType *
> (this->GetInput());
110 typedef itk::ImageRegionConstIterator< InputImageType > InIteratorType;
111 typedef itk::ImageRegionIteratorWithIndex< OutputImageType > OutRegionIteratorType;
113 InIteratorType inputIterator(input, outputRegionForThread);
114 OutRegionIteratorType outputIterator(output, outputRegionForThread);
116 std::vector <InputPixelType> databaseSamples;
117 std::vector <double> databaseWeights;
121 PatchSearcherType patchSearcher;
123 patchSearcher.SetSearchStepSize(m_SearchStepSize);
124 patchSearcher.SetMaxAbsDisp(m_maxAbsDisp);
125 patchSearcher.SetInputImage(input);
126 patchSearcher.SetBetaParameter(m_BetaParameter);
127 patchSearcher.SetNoiseCovariance(m_noiseCovariance);
128 patchSearcher.SetWeightThreshold(m_WeightThreshold);
129 patchSearcher.SetMeanImage(m_meanImage);
130 patchSearcher.SetVarImage(m_varImage);
131 patchSearcher.SetMeanMinThreshold(m_MeanMinThreshold);
132 patchSearcher.SetVarMinThreshold(m_VarMinThreshold);
134 while (!outputIterator.IsAtEnd())
136 patchSearcher.UpdateAtPosition(outputIterator.GetIndex());
138 databaseSamples = patchSearcher.GetDatabaseSamples();
139 databaseWeights = patchSearcher.GetDatabaseWeights();
142 double average = 0, sum = 0, w_max = 0;
144 switch (m_WeightMethod)
147 for (
unsigned int d = 0;d < databaseSamples.size();++d)
149 average += databaseSamples[d] * databaseWeights[d];
150 sum += databaseWeights[d];
152 if (w_max < databaseWeights[d])
153 w_max = databaseWeights[d];
157 outputIterator.Set((average + w_max * inputIterator.Get()) / (sum + w_max));
159 outputIterator.Set(inputIterator.Get());
164 for (
unsigned int d=0; d < databaseSamples.size(); d++)
166 average += databaseWeights[d] * (databaseSamples[d] * databaseSamples[d]);
167 sum += databaseWeights[d];
168 if (w_max < databaseWeights[d])
169 w_max = databaseWeights[d];
174 double t = ((average + (inputIterator.Get() * inputIterator.Get())
175 * w_max) / (sum + w_max)) - (2.0 * m_noiseCovariance);
180 outputIterator.Set(std::sqrt(t));
183 outputIterator.Set(inputIterator.Get());
188 this->IncrementNumberOfProcessedPoints();
Applies an variance filter to an image.
InputImageType::RegionType InputImageRegionType
void SetPatchHalfSize(unsigned int arg)
TInputImage InputImageType
OutputImageType::RegionType OutputImageRegionType