ANIMA  4.0
animaPyramidalSymmetryConstrainedRegistrationBridge.hxx
Go to the documentation of this file.
1 #pragma once
3 
4 #include <itkTransformFactoryBase.h>
5 #include <itkTransformFileWriter.h>
6 #include <itkMultiResolutionPyramidImageFilter.h>
7 #include <itkImageRegistrationMethod.h>
8 #include <itkLinearInterpolateImageFunction.h>
9 
11 #include <animaNLOPTOptimizers.h>
12 #include <animaMatrixOperations.h>
14 
15 #include <itkMeanSquaresImageToImageMetric.h>
16 #include <itkMutualInformationHistogramImageToImageMetric.h>
17 #include <itkNormalizedMutualInformationHistogramImageToImageMetric.h>
18 #include <itkCenteredTransformInitializer.h>
19 #include <itkMinimumMaximumImageFilter.h>
20 
21 namespace anima
22 {
23 
24 template <typename ScalarType>
26 {
27  m_ReferenceImage = NULL;
28  m_FloatingImage = NULL;
29 
30  m_OutputTransform = TransformType::New();
31  m_OutputTransform->SetIdentity();
32 
33  m_outputTransformFile = "";
34 
35  m_InitialTransform = BaseTransformType::New();
36  m_InitialTransform->SetIdentity();
37 
38  m_OutputImage = NULL;
39 
40  m_Metric = MutualInformation;
41  m_OptimizerMaximumIterations = 100;
42 
43  m_ReferenceMinimalValue = 0.0;
44  m_FloatingMinimalValue = 0.0;
45 
46  m_UpperBoundAngle = M_PI;
47  m_TranslateUpperBound = 10;
48  m_HistogramSize = 128;
49 
50  m_NumberOfPyramidLevels = 3;
51  m_FastRegistration = false;
52  this->SetNumberOfWorkUnits(itk::MultiThreaderBase::GetGlobalDefaultNumberOfThreads());
53 }
54 
55 template <typename ScalarType>
57 {
58 }
59 
60 template <typename ScalarType>
62 {
63  typedef typename itk::ImageRegistrationMethod<InputImageType, InputImageType> RegistrationType;
64 
65  // Compute minimal value of reference and Floating images
66  using MinMaxFilterType = itk::MinimumMaximumImageFilter <InputImageType>;
67  typename MinMaxFilterType::Pointer minMaxFilter = MinMaxFilterType::New();
68  minMaxFilter->SetInput(m_ReferenceImage);
69  if (this->GetNumberOfWorkUnits() != 0)
70  minMaxFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
71  minMaxFilter->Update();
72 
73  m_ReferenceMinimalValue = minMaxFilter->GetMinimum();
74 
75  minMaxFilter = MinMaxFilterType::New();
76  minMaxFilter->SetInput(m_FloatingImage);
77  if (this->GetNumberOfWorkUnits() != 0)
78  minMaxFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
79  minMaxFilter->Update();
80 
81  m_FloatingMinimalValue = minMaxFilter->GetMinimum();
82 
83  // Only CT images are below zero, little hack to set minimal values to either -1024 or 0
84  if (m_ReferenceMinimalValue < 0.0)
85  m_ReferenceMinimalValue = -1024;
86  else
87  m_ReferenceMinimalValue = 0.0;
88 
89  if (m_FloatingMinimalValue < 0.0)
90  m_FloatingMinimalValue = -1024;
91  else
92  m_FloatingMinimalValue = 0.0;
93 
94  this->SetupPyramids();
95 
96  typename TransformType::ParametersType initialParams(TransformType::ParametersDimension);
97  for (unsigned int i = 0;i < TransformType::ParametersDimension;++i)
98  initialParams[i] = 0;
99 
100  unsigned int indexAbsRefMax = 0;
101  typename InputImageType::DirectionType dirRefMatrix = m_ReferenceImage->GetDirection();
102  double valRefMax = std::abs(dirRefMatrix(0,0));
103  for (unsigned int i = 1;i < InputImageType::ImageDimension;++i)
104  {
105  if (std::abs(dirRefMatrix(0,i)) > valRefMax)
106  {
107  valRefMax = std::abs(dirRefMatrix(0,i));
108  indexAbsRefMax = i;
109  }
110  }
111 
112  OffsetType directionRefReal;
113  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
114  directionRefReal[i] = dirRefMatrix(i,indexAbsRefMax);
115 
116  m_OutputTransform->SetReferencePlaneNormal(directionRefReal);
117  m_OutputTransform->SetParameters(initialParams);
118 
119  InputImageType::PointType centralPoint;
120  itk::ContinuousIndex <ScalarType,InputImageType::ImageDimension> centralVoxIndex;
121 
122  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
123  centralVoxIndex[i] = m_ReferenceImage->GetLargestPossibleRegion().GetSize()[i] / 2.0;
124 
125  m_ReferenceImage->TransformContinuousIndexToPhysicalPoint(centralVoxIndex,centralPoint);
126 
127  unsigned int dimension = m_OutputTransform->GetNumberOfParameters();
128  itk::Array<double> lowerBounds(dimension);
129  itk::Array<double> upperBounds(dimension);
130 
131  lowerBounds[0] = - m_UpperBoundAngle;
132  upperBounds[0] = m_UpperBoundAngle;
133 
134  // Iterate over pyramid levels
135  for (unsigned int i = 0;i < this->GetNumberOfPyramidLevels();++i)
136  {
137  std::cout << "Processing pyramid level " << i << std::endl;
138  std::cout << "Image size: " << m_ReferencePyramid->GetOutput(i)->GetLargestPossibleRegion().GetSize() << std::endl;
139 
140  // Init matcher
141  typename RegistrationType::Pointer reg = RegistrationType::New();
142 
143  reg->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
144 
145  typedef anima::NLOPTOptimizers OptimizerType;
146 
147  typename OptimizerType::Pointer optimizer = OptimizerType::New();
148 
149  optimizer->SetAlgorithm(NLOPT_LN_BOBYQA);
150  optimizer->SetXTolRel(1.0e-4);
151  optimizer->SetFTolRel(1.0e-6);
152  optimizer->SetMaxEval(m_OptimizerMaximumIterations);
153  optimizer->SetVectorStorageSize(2000);
154  optimizer->SetMaximize(m_Metric != MeanSquares);
155 
156  double meanSpacing = 0;
157  for (unsigned int j = 0;j < InputImageType::ImageDimension;++j)
158  meanSpacing += m_ReferencePyramid->GetOutput(i)->GetSpacing()[j];
159 
160  for (unsigned int j = 0;j < 2;++j)
161  {
162  lowerBounds[j + 1] = m_OutputTransform->GetParameters()[j + 1] - m_TranslateUpperBound * meanSpacing;
163  upperBounds[j + 1] = m_OutputTransform->GetParameters()[j + 1] + m_TranslateUpperBound * meanSpacing;
164  }
165 
166  optimizer->SetLowerBoundParameters(lowerBounds);
167  optimizer->SetUpperBoundParameters(upperBounds);
168 
169  reg->SetOptimizer(optimizer);
170  reg->SetTransform(m_OutputTransform);
171 
172  typedef itk::LinearInterpolateImageFunction <InputImageType,double> InterpolatorType;
173  InterpolatorType::Pointer interpolator = InterpolatorType::New();
174 
175  reg->SetInterpolator(interpolator);
176 
177  switch (m_Metric)
178  {
179  case MutualInformation:
180  {
181  typedef itk::MutualInformationHistogramImageToImageMetric < InputImageType,InputImageType > MetricType;
182  typename MetricType::Pointer tmpMetric = MetricType::New();
183 
184  MetricType::HistogramType::SizeType histogramSize;
185  histogramSize.SetSize(2);
186 
187  histogramSize[0] = m_HistogramSize;
188  histogramSize[1] = m_HistogramSize;
189  tmpMetric->SetHistogramSize( histogramSize );
190 
191  reg->SetMetric(tmpMetric);
192  break;
193  }
194 
196  {
197  typedef itk::NormalizedMutualInformationHistogramImageToImageMetric < InputImageType,InputImageType > MetricType;
198  typename MetricType::Pointer tmpMetric = MetricType::New();
199 
200  MetricType::HistogramType::SizeType histogramSize;
201  histogramSize.SetSize(2);
202 
203  histogramSize[0] = m_HistogramSize;
204  histogramSize[1] = m_HistogramSize;
205  tmpMetric->SetHistogramSize( histogramSize );
206 
207  reg->SetMetric(tmpMetric);
208  break;
209  }
210 
211  case MeanSquares:
212  default:
213  {
214  typedef itk::MeanSquaresImageToImageMetric < InputImageType,InputImageType > MetricType;
215  typename MetricType::Pointer tmpMetric = MetricType::New();
216  reg->SetMetric(tmpMetric);
217  break;
218  }
219  }
220 
221  reg->SetFixedImage(m_ReferencePyramid->GetOutput(i));
222  reg->SetMovingImage(m_FloatingPyramid->GetOutput(i));
223 
224  if (m_FastRegistration)
225  {
226  // We can work in 2D since the rest should be perfectly ok
227  InputImageRegionType workRegion = m_ReferencePyramid->GetOutput(i)->GetLargestPossibleRegion();
228  InputImageRegionType::IndexType centralIndex;
229  m_ReferencePyramid->GetOutput(i)->TransformPhysicalPointToIndex(centralPoint,centralIndex);
230 
231  unsigned int baseIndex = centralIndex[indexAbsRefMax];
232  workRegion.SetIndex(indexAbsRefMax,baseIndex);
233  workRegion.SetSize(indexAbsRefMax,1);
234 
235  reg->SetFixedImageRegion(workRegion);
236  }
237  else
238  reg->SetFixedImageRegion(m_ReferencePyramid->GetOutput(i)->GetLargestPossibleRegion());
239 
240  reg->SetInitialTransformParameters(m_OutputTransform->GetParameters());
241 
242  try
243  {
244  reg->Update();
245  }
246  catch( itk::ExceptionObject & err )
247  {
248  std::cout << "ExceptionObject caught ! " << err << std::endl;
249  throw err;
250  }
251 
252  m_OutputTransform->SetParameters(reg->GetLastTransformParameters());
253  }
254 
255  // Now compute the final transform
256  typedef itk::Matrix <ScalarType,InputImageType::ImageDimension+1,InputImageType::ImageDimension+1> TransformMatrixType;
257  TransformMatrixType refSymPlaneMatrix, initialMatrix, outputMatrix;
258  refSymPlaneMatrix.SetIdentity();
259  initialMatrix.SetIdentity();
260  outputMatrix.SetIdentity();
261 
262  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
263  {
264  for (unsigned int j = 0;j < InputImageType::ImageDimension;++j)
265  {
266  refSymPlaneMatrix(i,j) = m_RefSymmetryTransform->GetMatrix()(i,j);
267  outputMatrix(i,j) = m_OutputTransform->GetMatrix()(i,j);
268  initialMatrix(i,j) = m_InitialTransform->GetMatrix()(i,j);
269  }
270 
271  refSymPlaneMatrix(i,3) = m_RefSymmetryTransform->GetOffset()[i];
272  outputMatrix(i,3) = m_OutputTransform->GetOffset()[i];
273  initialMatrix(i,3) = m_InitialTransform->GetOffset()[i];
274  }
275 
276  refSymPlaneMatrix = refSymPlaneMatrix.GetInverse();
277  MatrixType tmpOutMatrix;
278  OffsetType tmpOffset;
279 
280  outputMatrix = initialMatrix * outputMatrix * refSymPlaneMatrix;
281 
282  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
283  {
284  for (unsigned int j = 0;j < InputImageType::ImageDimension;++j)
285  tmpOutMatrix(i,j) = outputMatrix(i,j);
286 
287  tmpOffset[i] = outputMatrix(i,3);
288  }
289 
290  if (m_OutputRealignTransform.IsNull())
291  m_OutputRealignTransform = BaseTransformType::New();
292 
293  m_OutputRealignTransform->SetMatrix(tmpOutMatrix);
294  m_OutputRealignTransform->SetOffset(tmpOffset);
295 
296  typedef typename anima::ResampleImageFilter<InputImageType, InputImageType> ResampleFilterType;
297  typename ResampleFilterType::Pointer tmpResample = ResampleFilterType::New();
298  tmpResample->SetTransform(m_OutputRealignTransform);
299  tmpResample->SetInput(m_FloatingImage);
300 
301  tmpResample->SetSize(m_ReferenceImage->GetLargestPossibleRegion().GetSize());
302  tmpResample->SetOutputOrigin(m_ReferenceImage->GetOrigin());
303  tmpResample->SetOutputSpacing(m_ReferenceImage->GetSpacing());
304  tmpResample->SetOutputDirection(m_ReferenceImage->GetDirection());
305  tmpResample->SetDefaultPixelValue(m_FloatingMinimalValue);
306  tmpResample->Update();
307 
308  m_OutputImage = tmpResample->GetOutput();
309 }
310 
311 template <typename ScalarType>
313 {
314  std::cout << "Writing output image to: " << m_resultFile << std::endl;
315 
316  anima::writeImage <InputImageType> (m_resultFile,m_OutputImage);
317 
318  if (m_outputTransformFile != "")
319  {
320  std::cout << "Writing output transform to: " << m_outputTransformFile << std::endl;
321  itk::TransformFileWriter::Pointer writer = itk::TransformFileWriter::New();
322  writer->SetInput(m_OutputRealignTransform);
323  writer->SetFileName(m_outputTransformFile);
324  writer->Update();
325  }
326 }
327 
328 template <typename ScalarType>
330 {
331  m_InitialTransform->SetIdentity();
332 
333  typedef typename itk::CenteredTransformInitializer<BaseTransformType, InputImageType, InputImageType> TransformInitializerType;
334  typedef typename anima::ResampleImageFilter<InputImageType, InputImageType> ResampleFilterType;
335 
336  InputImagePointer initialReferenceImage;
337  InputImagePointer initialFloatingImage;
338 
339  typename ResampleFilterType::Pointer tmpResample = ResampleFilterType::New();
340  tmpResample->SetTransform(m_RefSymmetryTransform);
341  tmpResample->SetInput(m_ReferenceImage);
342 
343  tmpResample->SetSize(m_ReferenceImage->GetLargestPossibleRegion().GetSize());
344  tmpResample->SetOutputOrigin(m_ReferenceImage->GetOrigin());
345  tmpResample->SetOutputSpacing(m_ReferenceImage->GetSpacing());
346  tmpResample->SetOutputDirection(m_ReferenceImage->GetDirection());
347  tmpResample->SetDefaultPixelValue(m_ReferenceMinimalValue);
348  tmpResample->Update();
349 
350  initialReferenceImage = tmpResample->GetOutput();
351  initialReferenceImage->DisconnectPipeline();
352 
353  // Tricky part: align the central planes of the two input images
354  OffsetType directionRefReal, directionFloReal;
355 
356  //First use real direction to find the real X direction
357  unsigned int indexAbsRefMax = 0;
358  unsigned int indexAbsFloMax = 0;
359  typename InputImageType::DirectionType dirRefMatrix = m_ReferenceImage->GetDirection();
360  typename InputImageType::DirectionType dirFloMatrix = m_FloatingImage->GetDirection();
361  double valRefMax = std::abs(dirRefMatrix(0,0));
362  double valFloMax = std::abs(dirFloMatrix(0,0));
363  for (unsigned int i = 1;i < InputImageType::ImageDimension;++i)
364  {
365  if (std::abs(dirRefMatrix(0,i)) > valRefMax)
366  {
367  valRefMax = std::abs(dirRefMatrix(0,i));
368  indexAbsRefMax = i;
369  }
370 
371  if (std::abs(dirFloMatrix(0,i)) > valFloMax)
372  {
373  valFloMax = std::abs(dirFloMatrix(0,i));
374  indexAbsFloMax = i;
375  }
376  }
377 
378  // Now redo it with the real X-direction
379  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
380  {
381  directionRefReal[i] = dirRefMatrix(i,indexAbsRefMax);
382  directionFloReal[i] = dirFloMatrix(i,indexAbsFloMax);
383  }
384 
385  typedef itk::Matrix <ScalarType,InputImageType::ImageDimension+1,InputImageType::ImageDimension+1> TransformMatrixType;
386  MatrixType tmpMatrix = anima::GetRotationMatrixFromVectors(directionRefReal,directionFloReal);
387  TransformMatrixType floRefMatrix;
388  floRefMatrix.SetIdentity();
389 
390  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
391  {
392  for (unsigned int j = 0;j < InputImageType::ImageDimension;++j)
393  floRefMatrix(i,j) = tmpMatrix(i,j);
394  }
395 
396  itk::ContinuousIndex <ScalarType,InputImageType::ImageDimension> refImageCenter, floImageCenter;
397 
398  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
399  refImageCenter[i] = m_ReferenceImage->GetLargestPossibleRegion().GetSize()[i] / 2.0;
400  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
401  floImageCenter[i] = m_FloatingImage->GetLargestPossibleRegion().GetSize()[i] / 2.0;
402 
403  typename InputImageType::PointType refCenter, floCenter;
404  m_ReferenceImage->TransformContinuousIndexToPhysicalPoint(refImageCenter,refCenter);
405  m_FloatingImage->TransformContinuousIndexToPhysicalPoint(floImageCenter,floCenter);
406 
407  TransformMatrixType refTranslationMatrix, floTranslationMatrix;
408  refTranslationMatrix.SetIdentity();
409  floTranslationMatrix.SetIdentity();
410 
411  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
412  {
413  refTranslationMatrix(i,3) = - refCenter[i];
414  floTranslationMatrix(i,3) = floCenter[i];
415  }
416 
417  floRefMatrix = floTranslationMatrix * floRefMatrix * refTranslationMatrix;
418 
419  TransformMatrixType FloatingSymmetry;
420  FloatingSymmetry.SetIdentity();
421 
422  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
423  {
424  for (unsigned int j = 0;j < InputImageType::ImageDimension;++j)
425  FloatingSymmetry(i,j) = m_FloSymmetryTransform->GetMatrix()(i,j);
426 
427  FloatingSymmetry(i,3) = m_FloSymmetryTransform->GetOffset()[i];
428  }
429 
430  floRefMatrix = FloatingSymmetry * floRefMatrix;
431  m_OutputTransform->SetCenter(refCenter);
432 
433  MatrixType FloatingMatrix;
434  OffsetType FloatingOffset;
435 
436  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
437  {
438  for (unsigned int j = 0;j < InputImageType::ImageDimension;++j)
439  FloatingMatrix(i,j) = floRefMatrix(i,j);
440 
441  FloatingOffset[i] = floRefMatrix(i,3);
442  }
443 
444  m_InitialTransform->SetMatrix(FloatingMatrix);
445  m_InitialTransform->SetOffset(FloatingOffset);
446 
447  // Now, create pyramid
448  m_ReferencePyramid = PyramidType::New();
449 
450  m_ReferencePyramid->SetInput(initialReferenceImage);
451  m_ReferencePyramid->SetNumberOfLevels(m_NumberOfPyramidLevels);
452  typename ResampleFilterType::Pointer refResampler = ResampleFilterType::New();
453  refResampler->SetDefaultPixelValue(m_ReferenceMinimalValue);
454  m_ReferencePyramid->SetImageResampler(refResampler);
455  m_ReferencePyramid->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
456  m_ReferencePyramid->Update();
457 
458  tmpResample = ResampleFilterType::New();
459  tmpResample->SetTransform(m_InitialTransform);
460  tmpResample->SetInput(m_FloatingImage);
461 
462  tmpResample->SetSize(m_ReferenceImage->GetLargestPossibleRegion().GetSize());
463  tmpResample->SetOutputOrigin(m_ReferenceImage->GetOrigin());
464  tmpResample->SetOutputSpacing(m_ReferenceImage->GetSpacing());
465  tmpResample->SetOutputDirection(m_ReferenceImage->GetDirection());
466  tmpResample->SetDefaultPixelValue(m_FloatingMinimalValue);
467  tmpResample->Update();
468 
469  initialFloatingImage = tmpResample->GetOutput();
470  initialFloatingImage->DisconnectPipeline();
471 
472  // Create pyramid for Floating image
473  m_FloatingPyramid = PyramidType::New();
474 
475  m_FloatingPyramid->SetInput(initialFloatingImage);
476  m_FloatingPyramid->SetNumberOfLevels(m_NumberOfPyramidLevels);
477  m_FloatingPyramid->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
478 
479  typename ResampleFilterType::Pointer floResampler = ResampleFilterType::New();
480  floResampler->SetDefaultPixelValue(m_FloatingMinimalValue);
481  m_FloatingPyramid->SetImageResampler(floResampler);
482  m_FloatingPyramid->Update();
483 }
484 
485 } // end of namespace
Superclass::ParametersType ParametersType
Implements an ITK wrapper for the NLOPT library.
itk::Matrix< double, 3, 3 > GetRotationMatrixFromVectors(const VectorType &first_direction, const VectorType &second_direction, const unsigned int dimension)