ANIMA  4.0
animaDistortionCorrectionBlockMatcher.hxx
Go to the documentation of this file.
1 #pragma once
3 
4 /* Similarity measures */
7 
8 /* Transforms */
10 
11 #include <itkLinearInterpolateImageFunction.h>
12 
13 #include <itkImageRegionConstIterator.h>
14 
15 namespace anima
16 {
17 
18 template <typename TInputImageType>
21 {
22  m_SimilarityType = SquaredCorrelation;
23  m_BlockTransformType = DirectionScaleSkew;
24 
25  m_TranslateMax = 10;
26  m_SkewMax = M_PI / 4.0;
27  m_ScaleMax = 3;
28 
29  m_SearchSkewRadius = 5;
30  m_SearchScaleRadius = 0.1;
31 
32  m_TransformDirection = 1;
33 }
34 
35 template <typename TInputImageType>
36 bool
39 {
40  if (m_SimilarityType == MeanSquares)
41  return false;
42 
43  return true;
44 }
45 
46 template <typename TInputImageType>
50 {
51  return AgregatorType::AFFINE;
52 }
53 
54 template <typename TInputImageType>
58 {
59  MetricPointer metric;
60 
61  switch(m_SimilarityType)
62  {
63  case Correlation:
64  case SquaredCorrelation:
65  {
67 
68  typename LocalMetricType::Pointer tmpMetric = LocalMetricType::New();
69  tmpMetric->SetSquaredCorrelation(m_SimilarityType == SquaredCorrelation);
70  tmpMetric->SetScaleIntensities(true);
71 
72  metric = tmpMetric;
73  break;
74  }
75 
76  case MeanSquares:
77  default:
78  {
80 
81  typename LocalMetricType::Pointer tmpMetric = LocalMetricType::New();
82  tmpMetric->SetScaleIntensities(true);
83 
84  metric = tmpMetric;
85  break;
86  }
87  }
88 
89  typedef itk::ImageToImageMetric <InputImageType,InputImageType> BaseMetricType;
90  BaseMetricType *baseMetric = dynamic_cast <BaseMetricType *> (metric.GetPointer());
91 
92  typedef itk::LinearInterpolateImageFunction<InputImageType,double> LocalInterpolatorType;
93  typename LocalInterpolatorType::Pointer interpolator = LocalInterpolatorType::New();
94 
95  baseMetric->SetInterpolator(interpolator);
96  baseMetric->ComputeGradientOff();
97 
98  baseMetric->SetFixedImage(this->GetReferenceImage());
99  baseMetric->SetMovingImage(this->GetMovingImage());
100  interpolator->SetInputImage(this->GetMovingImage());
101 
102  return metric;
103 }
104 
105 template <typename TInputImageType>
109 {
110  BaseInputTransformPointer outputValue;
111 
113  typename BaseTransformType::Pointer tmpTr;
114 
115  switch(m_BlockTransformType)
116  {
117  case Direction:
118  {
120  break;
121  }
122 
123  case DirectionScale:
124  {
126  break;
127  }
128 
129  case DirectionScaleSkew:
130  default:
131  {
132  tmpTr = BaseTransformType::New();
133  break;
134  }
135  }
136 
137  typename BaseTransformType::HomogeneousMatrixType geometry;
138 
139  geometry.SetIdentity();
140  for (unsigned int i = 0;i < 3;++i)
141  for (unsigned int j = 0;j < 3;++j)
142  geometry(i,j) = this->GetReferenceImage()->GetDirection()(i,j) * this->GetReferenceImage()->GetSpacing()[j];
143 
144  tmpTr->SetIdentity();
145  for (unsigned int j = 0;j < 3;++j)
146  geometry(j,InputImageType::ImageDimension) = blockCenter[j];
147 
148  tmpTr->SetUniqueDirection(m_TransformDirection);
149  tmpTr->SetGeometry(geometry, false);
150 
151  outputValue = tmpTr;
152 
153  return outputValue;
154 }
155 
156 template <typename TInputImageType>
157 double
159 ::ComputeBlockWeight(double val, unsigned int block)
160 {
161  double similarityWeight = 0;
162 
163  switch (m_SimilarityType)
164  {
165  case MeanSquares:
166  similarityWeight = 1;
167 
168  case Correlation:
169  similarityWeight = (val + 1) / 2.0;
170 
171  case SquaredCorrelation:
172  default:
173  similarityWeight = val;
174  }
175 
176  // Structure weight
177  std::vector <double> localGradient(InputImageType::ImageDimension,0);
178  itk::ImageRegionConstIterator <InputImageType> fixedItr(this->GetReferenceImage(),this->GetBlockRegion(block));
179  typedef typename InputImageType::RegionType ImageRegionType;
180  typename ImageRegionType::IndexType currentIndex, modifiedIndex;
181 
182  typename InputImageType::DirectionType orientationMatrix = this->GetReferenceImage()->GetDirection();
183  typename InputImageType::SpacingType imageSpacing = this->GetReferenceImage()->GetSpacing();
184  typename InputImageType::SizeType imageSize = this->GetReferenceImage()->GetLargestPossibleRegion().GetSize();
185 
186  std::vector <double> correctionDirection(InputImageType::ImageDimension);
187  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
188  correctionDirection[i] = this->GetReferenceImage()->GetDirection()(i,m_TransformDirection);
189 
190  vnl_matrix <double> meanStructureTensor(InputImageType::ImageDimension,InputImageType::ImageDimension);
191  meanStructureTensor.fill(0);
192 
193  while (!fixedItr.IsAtEnd())
194  {
195  currentIndex = fixedItr.GetIndex();
196  std::fill(localGradient.begin(),localGradient.end(),0);
197 
198  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
199  {
200  modifiedIndex = currentIndex;
201  modifiedIndex[i] = std::max(0,(int)(currentIndex[i] - 1));
202  double previousValue = this->GetReferenceImage()->GetPixel(modifiedIndex);
203  modifiedIndex[i] = std::min((int)(imageSize[i] - 1),(int)(currentIndex[i] + 1));
204  double postValue = this->GetReferenceImage()->GetPixel(modifiedIndex);
205 
206  for (unsigned int j = 0;j < InputImageType::ImageDimension;++j)
207  localGradient[j] += (postValue - previousValue) * orientationMatrix(j,i) / (2.0 * imageSpacing[i]);
208  }
209 
210  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
211  for (unsigned int j = i;j < InputImageType::ImageDimension;++j)
212  {
213  meanStructureTensor(i,j) += localGradient[i] * localGradient[j];
214  if (j != i)
215  meanStructureTensor(j,i) = meanStructureTensor(i,j);
216  }
217 
218  ++fixedItr;
219  }
220 
221  meanStructureTensor /= this->GetBlockRegion(block).GetNumberOfPixels();
222 
223  itk::SymmetricEigenAnalysis < vnl_matrix <double>, vnl_diag_matrix<double>, vnl_matrix <double> > eigenComputer(InputImageType::ImageDimension);
224  vnl_matrix <double> eVec(InputImageType::ImageDimension,InputImageType::ImageDimension);
225  vnl_diag_matrix <double> eVals(InputImageType::ImageDimension);
226 
227  eigenComputer.ComputeEigenValuesAndVectors(meanStructureTensor, eVals, eVec);
228  double linearCoef = (eVals[InputImageType::ImageDimension - 1] - eVals[InputImageType::ImageDimension - 2]) / eVals[InputImageType::ImageDimension - 1];
229 
230  double scalarProduct = 0;
231  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
232  scalarProduct += eVec[InputImageType::ImageDimension - 1][i] * correctionDirection[i];
233 
234  double structureWeight = linearCoef * std::abs(scalarProduct);
235 
236  return std::sqrt(structureWeight * similarityWeight);
237 }
238 
239 template <typename TInputImageType>
240 void
242 ::BlockMatchingSetup(MetricPointer &metric, unsigned int block)
243 {
245  BaseTransformType *tr = dynamic_cast <BaseTransformType *> (this->GetBlockTransformPointer(block).GetPointer());
246  tr->SetIdentity();
247 
248  // Metric specific init
249  typedef itk::ImageToImageMetric <InputImageType, InputImageType> InternalMetricType;
250  InternalMetricType *tmpMetric = dynamic_cast <InternalMetricType *> (metric.GetPointer());
251  tmpMetric->SetFixedImageRegion(this->GetBlockRegion(block));
252  tmpMetric->SetTransform(this->GetBlockTransformPointer(block));
253  tmpMetric->Initialize();
254 
255  if (m_SimilarityType != MeanSquares)
256  ((anima::FastCorrelationImageToImageMetric<InputImageType, InputImageType> *)metric.GetPointer())->PreComputeFixedValues();
257  else
258  ((anima::FastMeanSquaresImageToImageMetric<InputImageType, InputImageType> *)metric.GetPointer())->PreComputeFixedValues();
259 }
260 
261 template <typename TInputImageType>
262 void
265 {
266  if (this->GetOptimizerType() == Superclass::Exhaustive)
267  throw itk::ExceptionObject(__FILE__, __LINE__,"Exhaustive optimizer not supported in distortion correction",ITK_LOCATION);
268 
269  typedef anima::BobyqaOptimizer LocalOptimizerType;
270  LocalOptimizerType::ScalesType tmpScales(this->GetBlockTransformPointer(0)->GetNumberOfParameters());
271  LocalOptimizerType::ScalesType lowerBounds(this->GetBlockTransformPointer(0)->GetNumberOfParameters());
272  LocalOptimizerType::ScalesType upperBounds(this->GetBlockTransformPointer(0)->GetNumberOfParameters());
273  typename InputImageType::SpacingType fixedSpacing = this->GetReferenceImage()->GetSpacing();
274 
275  // Scale factor to ensure that max translations and skew can be reached
276  // Based on the fact that non diagonal terms log is a = x * log(y) / (exp(y) - 1)
277  // where y is the diagonal scaling factor, x the desired term
278  double scaleFactor = 1.0;
279  if ((m_BlockTransformType != Direction)&&(m_ScaleMax > 0))
280  scaleFactor = - m_ScaleMax / (std::exp(- m_ScaleMax) - 1.0);
281 
282  switch (m_BlockTransformType)
283  {
284  case DirectionScaleSkew:
285  {
286  tmpScales[0] = this->GetSearchRadius() / m_SearchScaleRadius;
287  tmpScales[1] = this->GetSearchRadius() / m_SearchSkewRadius;
288  tmpScales[2] = this->GetSearchRadius() / m_SearchSkewRadius;
289  tmpScales[3] = 1.0;
290 
291  lowerBounds[0] = - m_ScaleMax;
292  upperBounds[0] = m_ScaleMax;
293  lowerBounds[1] = - m_SkewMax * scaleFactor;
294  upperBounds[1] = m_SkewMax * scaleFactor;
295  lowerBounds[2] = - m_SkewMax * scaleFactor;
296  upperBounds[2] = m_SkewMax * scaleFactor;
297  lowerBounds[3] = - m_TranslateMax * scaleFactor;
298  upperBounds[3] = m_TranslateMax * scaleFactor;
299 
300  break;
301  }
302 
303  case DirectionScale:
304  {
305  tmpScales[0] = this->GetSearchRadius() / m_SearchScaleRadius;
306  tmpScales[1] = 1.0;
307 
308  lowerBounds[0] = - m_ScaleMax;
309  upperBounds[0] = m_ScaleMax;
310  lowerBounds[1] = - m_TranslateMax * scaleFactor;
311  upperBounds[1] = m_TranslateMax * scaleFactor;
312 
313  break;
314  }
315 
316 
317  case Direction:
318  default:
319  {
320  tmpScales[0] = 1.0;
321  lowerBounds[0] = - m_TranslateMax;
322  upperBounds[0] = m_TranslateMax;
323 
324  break;
325  }
326  }
327 
328  LocalOptimizerType * tmpOpt = dynamic_cast <LocalOptimizerType *> (optimizer.GetPointer());
329  tmpOpt->SetScales(tmpScales);
330  tmpOpt->SetLowerBounds(lowerBounds);
331  tmpOpt->SetUpperBounds(upperBounds);
332 }
333 
334 } // end namespace anima
MetricType::Pointer MetricPointer
Superclass::BaseInputTransformPointer BaseInputTransformPointer
virtual void BlockMatchingSetup(MetricPointer &metric, unsigned int block)
BaseInputTransformType::Pointer BaseInputTransformPointer
virtual void TransformDependantOptimizerSetup(OptimizerPointer &optimizer)
InputImageType::PointType PointType
OptimizerType::Pointer OptimizerPointer
InputImageType::RegionType ImageRegionType
virtual BaseInputTransformPointer GetNewBlockTransform(PointType &blockCenter)
virtual double ComputeBlockWeight(double val, unsigned int block)