ANIMA  4.0
animaPyramidalDenseMCMSVFMatchingBridge.hxx
Go to the documentation of this file.
1 #pragma once
3 
5 #include <animaMCMFileWriter.h>
6 
10 
11 #include <itkResampleImageFilter.h>
12 
13 #include <animaVelocityUtils.h>
15 #include <animaMCMConstants.h>
16 
17 namespace anima
18 {
19 
20 template <unsigned int ImageDimension>
22 {
23  m_ReferenceImage = NULL;
24  m_FloatingImage = NULL;
25 
26  m_OutputTransform = BaseTransformType::New();
27  m_OutputTransform->SetIdentity();
28 
29  m_outputTransformFile = "";
30 
31  m_OutputImage = NULL;
32 
33  m_BlockSize = 5;
34  m_BlockSpacing = 2;
35  m_StDevThreshold = 5;
36 
37  m_SymmetryType = Asymmetric;
38  m_MetricOrientation = FiniteStrain;
39  m_FiniteStrainImageReorientation = true;
40  m_Transform = Translation;
41  m_Metric = MCMOneToOneBasicMeanSquares;
42  m_Optimizer = Bobyqa;
43 
44  m_SmallDelta = anima::DiffusionSmallDelta;
45  m_BigDelta = anima::DiffusionBigDelta;
46 
47  m_MaximumIterations = 10;
48  m_MinimalTransformError = 0.01;
49  m_OptimizerMaximumIterations = 100;
50  m_SearchRadius = 2;
51  m_SearchAngleRadius = 5;
52  m_SearchScaleRadius = 0.1;
53  m_FinalRadius = 0.001;
54  m_StepSize = 1;
55  m_TranslateUpperBound = 50;
56  m_AngleUpperBound = 180;
57  m_ScaleUpperBound = 3;
58  m_Agregator = Baloo;
59  m_ExtrapolationSigma = 3;
60  m_ElasticSigma = 3;
61  m_OutlierSigma = 3;
62  m_MEstimateConvergenceThreshold = 0.01;
63  m_NeighborhoodApproximation = 2.5;
64  m_BCHCompositionOrder = 1;
65  m_ExponentiationOrder = 1;
66  m_NumberOfPyramidLevels = 3;
67  m_LastPyramidLevel = 0;
68  m_PercentageKept = 0.8;
69  this->SetNumberOfWorkUnits(itk::MultiThreaderBase::GetGlobalDefaultNumberOfThreads());
70 }
71 
72 template <unsigned int ImageDimension>
74 {
75 }
76 
77 template <unsigned int ImageDimension>
81 {
82  typename InterpolatorType::Pointer interpolator = InterpolatorType::New();
83  interpolator->SetReferenceOutputModel(image->GetDescriptionModel());
84 
85  interpolator->Register();
86  return interpolator;
87 }
88 
89 template <unsigned int ImageDimension>
93 {
94  BlockMatcherType *matcher = new BlockMatcherType;
95  return matcher;
96 }
97 
98 template <unsigned int ImageDimension>
99 void
101 {
102  this->SetupPyramids();
103 
104  // Iterate over pyramid levels
105  for (unsigned int i = 0;i < m_ReferencePyramid->GetNumberOfLevels();++i)
106  {
107  if (i + m_LastPyramidLevel >= m_ReferencePyramid->GetNumberOfLevels())
108  continue;
109 
110  typename InputImageType::Pointer refImage = m_ReferencePyramid->GetOutput(i);
111  refImage->DisconnectPipeline();
112 
113  typename InputImageType::Pointer floImage = m_FloatingPyramid->GetOutput(i);
114  floImage->DisconnectPipeline();
115 
116  // Update field to match the current resolution
117  if (m_OutputTransform->GetParametersAsVectorField() != NULL)
118  {
119  typedef itk::ResampleImageFilter<VelocityFieldType,VelocityFieldType> VectorResampleFilterType;
120  typedef typename VectorResampleFilterType::Pointer VectorResampleFilterPointer;
121 
122  AffineTransformPointer tmpIdentity = AffineTransformType::New();
123  tmpIdentity->SetIdentity();
124 
125  VectorResampleFilterPointer tmpResample = VectorResampleFilterType::New();
126  tmpResample->SetTransform(tmpIdentity);
127  tmpResample->SetInput(m_OutputTransform->GetParametersAsVectorField());
128 
129  tmpResample->SetSize(refImage->GetLargestPossibleRegion().GetSize());
130  tmpResample->SetOutputOrigin(refImage->GetOrigin());
131  tmpResample->SetOutputSpacing(refImage->GetSpacing());
132  tmpResample->SetOutputDirection(refImage->GetDirection());
133 
134  tmpResample->Update();
135 
136  VelocityFieldType *tmpOut = tmpResample->GetOutput();
137  m_OutputTransform->SetParametersAsVectorField(tmpOut);
138  tmpOut->DisconnectPipeline();
139  }
140 
141  std::cout << "Processing pyramid level " << i << std::endl;
142  std::cout << "Image size: " << refImage->GetLargestPossibleRegion().GetSize() << std::endl;
143 
144  double meanSpacing = 0;
145  for (unsigned int j = 0;j < ImageDimension;++j)
146  meanSpacing += refImage->GetSpacing()[j];
147  meanSpacing /= ImageDimension;
148 
149  // Init agregator mean shift parameters
150  BaseAgregatorType* agregPtr = NULL;
151 
152  if (m_Agregator == MSmoother)
153  {
155  agreg->SetExtrapolationSigma(m_ExtrapolationSigma * meanSpacing);
156  agreg->SetOutlierRejectionSigma(m_OutlierSigma);
157  agreg->SetOutputTransformType(BaseAgregatorType::SVF);
158 
159  if (this->GetNumberOfWorkUnits() != 0)
160  agreg->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
161 
162  agreg->SetGeometryInformation(refImage.GetPointer());
163 
164  agreg->SetNeighborhoodHalfSize((unsigned int)floor(m_ExtrapolationSigma * m_NeighborhoodApproximation));
165  agreg->SetDistanceBoundary(m_ExtrapolationSigma * meanSpacing * m_NeighborhoodApproximation);
166  agreg->SetMEstimateConvergenceThreshold(m_MEstimateConvergenceThreshold);
167 
168  agregPtr = agreg;
169  }
170  else
171  {
173  agreg->SetExtrapolationSigma(m_ExtrapolationSigma * meanSpacing);
174  agreg->SetOutlierRejectionSigma(m_OutlierSigma);
175  agreg->SetOutputTransformType(BaseAgregatorType::SVF);
176 
177  if (this->GetNumberOfWorkUnits() != 0)
178  agreg->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
179 
180  agreg->SetGeometryInformation(refImage.GetPointer());
181 
182  agregPtr = agreg;
183  }
184 
185  // Init matcher
186  BlockMatcherType *mainMatcher = this->CreateBlockMatcher();
187 
188  BlockMatcherType *reverseMatcher = 0;
189  mainMatcher->SetBlockPercentageKept(GetPercentageKept());
190  mainMatcher->SetBlockSize(GetBlockSize());
191  mainMatcher->SetBlockSpacing(GetBlockSpacing());
192  mainMatcher->SetBlockVarianceThreshold(GetStDevThreshold() * GetStDevThreshold());
193  mainMatcher->SetGradientDirections(m_GradientDirections);
194  mainMatcher->SetSmallDelta(m_SmallDelta);
195  mainMatcher->SetBigDelta(m_BigDelta);
196  mainMatcher->SetGradientStrengths(m_GradientStrengths);
197 
198  switch (m_SymmetryType)
199  {
200  case Asymmetric:
201  {
202  typedef typename anima::AsymmetricBMRegistrationMethod <InputImageType> BlockMatchRegistrationType;
203  m_bmreg = BlockMatchRegistrationType::New();
204  break;
205  }
206 
207  case Symmetric:
208  {
209  typedef typename anima::SymmetricBMRegistrationMethod <InputImageType> BlockMatchRegistrationType;
210  typename BlockMatchRegistrationType::Pointer tmpReg = BlockMatchRegistrationType::New();
211  reverseMatcher = this->CreateBlockMatcher();
212  reverseMatcher->SetBlockPercentageKept(GetPercentageKept());
213  reverseMatcher->SetBlockSize(GetBlockSize());
214  reverseMatcher->SetBlockSpacing(GetBlockSpacing());
215  reverseMatcher->SetBlockVarianceThreshold(GetStDevThreshold() * GetStDevThreshold());
216  reverseMatcher->SetGradientDirections(m_GradientDirections);
217  reverseMatcher->SetSmallDelta(m_SmallDelta);
218  reverseMatcher->SetBigDelta(m_BigDelta);
219  reverseMatcher->SetGradientStrengths(m_GradientStrengths);
220 
221  tmpReg->SetReverseBlockMatcher(reverseMatcher);
222  m_bmreg = tmpReg;
223  break;
224  }
225 
226  case Kissing:
227  {
228  typedef typename anima::KissingSymmetricBMRegistrationMethod <InputImageType> BlockMatchRegistrationType;
229  m_bmreg = BlockMatchRegistrationType::New();
230  break;
231  }
232  }
233 
234  m_bmreg->SetBlockMatcher(mainMatcher);
235  m_bmreg->SetAgregator(agregPtr);
236  m_bmreg->SetBCHCompositionOrder(m_BCHCompositionOrder);
237  m_bmreg->SetExponentiationOrder(m_ExponentiationOrder);
238 
239  if (this->GetNumberOfWorkUnits() != 0)
240  m_bmreg->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
241 
242  m_bmreg->SetFixedImage(refImage);
243  m_bmreg->SetMovingImage(floImage);
244 
245  m_bmreg->SetSVFElasticRegSigma(m_ElasticSigma * meanSpacing);
246 
248 
249  typename InterpolatorType::Pointer interpolator = this->CreateInterpolator(floImage);
250 
251  typename ResampleFilterType::Pointer refResampler = ResampleFilterType::New();
252  refResampler->SetOutputLargestPossibleRegion(floImage->GetLargestPossibleRegion());
253  refResampler->SetOutputOrigin(floImage->GetOrigin());
254  refResampler->SetOutputSpacing(floImage->GetSpacing());
255  refResampler->SetOutputDirection(floImage->GetDirection());
256  refResampler->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
257  refResampler->SetReferenceOutputModel(floImage->GetDescriptionModel());
258  refResampler->SetFiniteStrainReorientation(this->GetFiniteStrainImageReorientation());
259  refResampler->SetInterpolator(interpolator.GetPointer());
260  m_bmreg->SetReferenceImageResampler(refResampler);
261 
262  interpolator = this->CreateInterpolator(refImage);
263 
264  typename ResampleFilterType::Pointer movingResampler = ResampleFilterType::New();
265  movingResampler->SetOutputLargestPossibleRegion(refImage->GetLargestPossibleRegion());
266  movingResampler->SetOutputOrigin(refImage->GetOrigin());
267  movingResampler->SetOutputSpacing(refImage->GetSpacing());
268  movingResampler->SetOutputDirection(refImage->GetDirection());
269  movingResampler->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
270  movingResampler->SetReferenceOutputModel(refImage->GetDescriptionModel());
271  movingResampler->SetFiniteStrainReorientation(this->GetFiniteStrainImageReorientation());
272  movingResampler->SetInterpolator(interpolator.GetPointer());
273  m_bmreg->SetMovingImageResampler(movingResampler);
274 
275  switch (GetTransform())
276  {
277  case Translation:
279  if (reverseMatcher)
281  break;
282  case Rigid:
284  if (reverseMatcher)
286  break;
287  case Affine:
288  default:
290  if (reverseMatcher)
292  break;
293  }
294 
295  switch (GetOptimizer())
296  {
297  case Exhaustive:
299  if (reverseMatcher)
301  break;
302 
303  case Bobyqa:
304  default:
306  if (reverseMatcher)
308  break;
309  }
310 
311  switch (m_Metric)
312  {
313  case MCMBasicMeanSquares:
315  if (reverseMatcher)
317  break;
320  if (reverseMatcher)
322  break;
323  case MCMCorrelation:
325  if (reverseMatcher)
327  break;
328  case MTCorrelation:
330  if (reverseMatcher)
332  break;
333  case MCMMeanSquares:
334  default:
336  if (reverseMatcher)
338  break;
339  }
340 
341  switch (m_MetricOrientation)
342  {
343  case None:
345  if (reverseMatcher)
347  break;
348  case PPD:
350  if (reverseMatcher)
352  break;
353  case FiniteStrain:
354  default:
356  if (reverseMatcher)
358  break;
359  }
360 
361  m_bmreg->SetMaximumIterations(m_MaximumIterations);
362  m_bmreg->SetMinimalTransformError(m_MinimalTransformError);
363  m_bmreg->SetInitialTransform(m_OutputTransform.GetPointer());
364 
365  mainMatcher->SetOptimizerMaximumIterations(m_OptimizerMaximumIterations);
366 
367  double sr = m_SearchRadius;
368  mainMatcher->SetSearchRadius(sr);
369 
370  double sar = m_SearchAngleRadius;
371  mainMatcher->SetSearchAngleRadius(sar);
372 
373  double scr = m_SearchScaleRadius;
374  mainMatcher->SetSearchScaleRadius(scr);
375 
376  double fr = m_FinalRadius;
377  mainMatcher->SetFinalRadius(fr);
378 
379  double ss = m_StepSize;
380  mainMatcher->SetStepSize(ss);
381 
382  double tub = m_TranslateUpperBound;
383  mainMatcher->SetTranslateMax(tub);
384 
385  double aub = m_AngleUpperBound;
386  mainMatcher->SetAngleMax(aub);
387 
388  double scub = m_ScaleUpperBound;
389  mainMatcher->SetScaleMax(scub);
390 
391  if (reverseMatcher)
392  {
393  reverseMatcher->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
394  reverseMatcher->SetOptimizerMaximumIterations(GetOptimizerMaximumIterations());
395 
396  reverseMatcher->SetSearchRadius(sr);
397  reverseMatcher->SetSearchAngleRadius(sar);
398  reverseMatcher->SetSearchScaleRadius(scr);
399  reverseMatcher->SetFinalRadius(fr);
400  reverseMatcher->SetStepSize(ss);
401  reverseMatcher->SetTranslateMax(tub);
402  reverseMatcher->SetAngleMax(aub);
403  reverseMatcher->SetScaleMax(scub);
404  }
405 
406  try
407  {
408  m_bmreg->Update();
409  }
410  catch (itk::ExceptionObject & err)
411  {
412  std::cout << "Exception: " << err << std::endl;
413  exit(-1);
414  }
415 
416  const BaseTransformType *resTrsf = dynamic_cast <const BaseTransformType *> (m_bmreg->GetOutput()->Get());
417  m_OutputTransform->SetParametersAsVectorField(resTrsf->GetParametersAsVectorField());
418 
419  delete mainMatcher;
420  if (reverseMatcher)
421  delete reverseMatcher;
422  if (agregPtr)
423  delete agregPtr;
424  }
425 
426  if (m_SymmetryType == Kissing)
427  {
428  VelocityFieldType *finalTrsfField = const_cast <VelocityFieldType *> (m_OutputTransform->GetParametersAsVectorField());
429  typedef itk::MultiplyImageFilter <VelocityFieldType,itk::Image <double,ImageDimension>, VelocityFieldType> MultiplyFilterType;
430 
431  typename MultiplyFilterType::Pointer fieldMultiplier = MultiplyFilterType::New();
432  fieldMultiplier->SetInput(finalTrsfField);
433  fieldMultiplier->SetConstant(2.0);
434  fieldMultiplier->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
435  fieldMultiplier->InPlaceOn();
436 
437  fieldMultiplier->Update();
438 
439  VelocityFieldType *outputField = fieldMultiplier->GetOutput();
440  m_OutputTransform->SetParametersAsVectorField(fieldMultiplier->GetOutput());
441  outputField->DisconnectPipeline();
442  }
443 
444  DisplacementFieldTransformPointer outputDispTrsf = DisplacementFieldTransformType::New();
445  anima::GetSVFExponential(m_OutputTransform.GetPointer(), outputDispTrsf.GetPointer(), m_ExponentiationOrder, GetNumberOfWorkUnits(), false);
446 
448  typename ResampleFilterType::Pointer tmpResample = ResampleFilterType::New();
449 
450  typename InterpolatorType::Pointer interpolator = this->CreateInterpolator(m_ReferenceImage);
451 
452  typedef itk::Transform<typename BaseAgregatorType::ScalarType,ImageDimension,ImageDimension> BaseTransformType;
453  typename BaseTransformType::Pointer baseTrsf = outputDispTrsf.GetPointer();
454  tmpResample->SetTransform(baseTrsf);
455  tmpResample->SetFiniteStrainReorientation(this->GetFiniteStrainImageReorientation());
456  tmpResample->SetInput(m_FloatingImage);
457 
458  if (this->GetNumberOfWorkUnits() != 0)
459  tmpResample->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
460 
461  tmpResample->SetOutputLargestPossibleRegion(m_ReferenceImage->GetLargestPossibleRegion());
462  tmpResample->SetOutputOrigin(m_ReferenceImage->GetOrigin());
463  tmpResample->SetOutputSpacing(m_ReferenceImage->GetSpacing());
464  tmpResample->SetOutputDirection(m_ReferenceImage->GetDirection());
465  tmpResample->SetReferenceOutputModel(m_ReferenceImage->GetDescriptionModel());
466  tmpResample->SetInterpolator(interpolator.GetPointer());
467  tmpResample->Update();
468 
469  m_OutputImage = tmpResample->GetOutput();
470  m_OutputImage->DisconnectPipeline();
471 }
472 
473 template <unsigned int ImageDimension>
474 void
476 {
477  std::cout << "Writing output image to: " << m_resultFile << std::endl;
479  mcmWriter.SetInputImage(m_OutputImage);
480  mcmWriter.SetFileName(m_resultFile);
481 
482  mcmWriter.Update();
483 
484  if (m_outputTransformFile != "")
485  {
486  std::cout << "Writing output SVF to: " << m_outputTransformFile << std::endl;
487  anima::writeImage <VelocityFieldType> (m_outputTransformFile,
488  const_cast <VelocityFieldType *> (m_OutputTransform->GetParametersAsVectorField()));
489  }
490 }
491 
492 template <unsigned int ImageDimension>
493 void
495 {
496  // Create pyramid here, check images actually are of the same size.
497  m_ReferencePyramid = PyramidType::New();
498 
499  m_ReferencePyramid->SetInput(m_ReferenceImage);
500  m_ReferencePyramid->SetNumberOfLevels(m_NumberOfPyramidLevels);
501 
502  if (this->GetNumberOfWorkUnits() != 0)
503  m_ReferencePyramid->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
504 
506  typename InterpolatorType::Pointer interpolator = this->CreateInterpolator(m_ReferenceImage);
507 
508  typename ResampleFilterType::Pointer refResampler = ResampleFilterType::New();
509  refResampler->SetReferenceOutputModel(m_ReferenceImage->GetDescriptionModel());
510  refResampler->SetFiniteStrainReorientation(this->GetFiniteStrainImageReorientation());
511  refResampler->SetInterpolator(interpolator.GetPointer());
512  m_ReferencePyramid->SetImageResampler(refResampler);
513 
514  m_ReferencePyramid->Update();
515 
516  // Create pyramid for Floating image
517  m_FloatingPyramid = PyramidType::New();
518 
519  m_FloatingPyramid->SetInput(m_FloatingImage);
520  m_FloatingPyramid->SetNumberOfLevels(m_NumberOfPyramidLevels);
521 
522  if (this->GetNumberOfWorkUnits() != 0)
523  m_FloatingPyramid->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
524 
525  typename ResampleFilterType::Pointer floResampler = ResampleFilterType::New();
526  interpolator = this->CreateInterpolator(m_FloatingImage);
527 
528  floResampler->SetReferenceOutputModel(m_FloatingImage->GetDescriptionModel());
529  floResampler->SetFiniteStrainReorientation(this->GetFiniteStrainImageReorientation());
530  floResampler->SetInterpolator(interpolator.GetPointer());
531  m_FloatingPyramid->SetImageResampler(floResampler);
532 
533  m_FloatingPyramid->Update();
534 }
535 
536 } // end of namespace anima
MEstimateAgregatorType::BaseOutputTransformType BaseTransformType
MCMPointer & GetDescriptionModel()
void SetGradientDirections(std::vector< vnl_vector_fixed< double, 3 > > &grads)
void SetBigDelta(double val)
void SetBlockSize(unsigned int val)
void SetSmallDelta(double val)
void SetNumberOfWorkUnits(unsigned int val)
void SetOptimizerType(OptimizerDefinition val)
void SetInputImage(InputImageType *input)
const double DiffusionBigDelta
Default big delta value (classical values)
DisplacementFieldTransformType::Pointer DisplacementFieldTransformPointer
void SetGeometryInformation(const TInputImageType *geomImage)
void SetGradientStrengths(std::vector< double > &val)
const double DiffusionSmallDelta
Default small delta value (classical values)
void SetSimilarityType(SimilarityDefinition val)
virtual InterpolatorType * CreateInterpolator(InputImageType *image)
void SetGeometryInformation(const TInputImageType *geomImage)
void SetOutputTransformType(TRANSFORM_TYPE name)
void SetBlockSpacing(unsigned int val)
void SetFileName(std::string fileName)
void SetBlockPercentageKept(double val)
void SetOptimizerMaximumIterations(unsigned int val)
void GetSVFExponential(itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *baseTrsf, rpi::DisplacementFieldTransform< ScalarType, NDimensions > *resultTransform, unsigned int exponentiationOrder, unsigned int numThreads, bool invert)
void SetBlockTransformType(TransformDefinition val)
void SetModelRotationType(ModelRotationType val)
void SetBlockVarianceThreshold(double val)
itk::SmartPointer< Self > Pointer
Definition: animaMCMImage.h:15
void SetSearchRadius(double val)