ANIMA  4.0
animaPyramidalDenseTensorSVFMatchingBridge.hxx
Go to the documentation of this file.
1 #pragma once
3 
8 
9 #include <itkResampleImageFilter.h>
11 
14 
15 #include <animaVelocityUtils.h>
17 
18 namespace anima
19 {
20 
21 template <unsigned int ImageDimension>
23 {
24  m_ReferenceImage = NULL;
25  m_FloatingImage = NULL;
26 
27  m_OutputTransform = BaseTransformType::New();
28  m_OutputTransform->SetIdentity();
29 
30  m_outputTransformFile = "";
31 
32  m_OutputImage = NULL;
33 
34  m_BlockSize = 5;
35  m_BlockSpacing = 2;
36  m_StDevThreshold = 5;
37 
38  m_SymmetryType = Asymmetric;
39  m_MetricOrientation = FiniteStrain;
40  m_FiniteStrainImageReorientation = true;
41  m_Transform = Translation;
43  m_Optimizer = Bobyqa;
44 
45  m_MaximumIterations = 10;
46  m_MinimalTransformError = 0.01;
47  m_OptimizerMaximumIterations = 100;
48  m_SearchRadius = 2;
49  m_SearchAngleRadius = 5;
50  m_SearchScaleRadius = 0.1;
51  m_FinalRadius = 0.001;
52  m_StepSize = 1;
53  m_TranslateUpperBound = 50;
54  m_AngleUpperBound = 180;
55  m_ScaleUpperBound = 3;
56  m_Agregator = Baloo;
57  m_ExtrapolationSigma = 3;
58  m_ElasticSigma = 3;
59  m_OutlierSigma = 3;
60  m_MEstimateConvergenceThreshold = 0.01;
61  m_NeighborhoodApproximation = 2.5;
62  m_BCHCompositionOrder = 1;
63  m_ExponentiationOrder = 1;
64  m_NumberOfPyramidLevels = 3;
65  m_LastPyramidLevel = 0;
66  m_PercentageKept = 0.8;
67  this->SetNumberOfWorkUnits(itk::MultiThreaderBase::GetGlobalDefaultNumberOfThreads());
68 }
69 
70 template <unsigned int ImageDimension>
72 {
73 }
74 
75 template <unsigned int ImageDimension>
76 void
78 {
79  this->SetupPyramids();
80 
81  // Iterate over pyramid levels
82  for (unsigned int i = 0;i < m_ReferencePyramid->GetNumberOfLevels();++i)
83  {
84  if (i + m_LastPyramidLevel >= m_ReferencePyramid->GetNumberOfLevels())
85  continue;
86 
87  typename InputImageType::Pointer refImage = m_ReferencePyramid->GetOutput(i);
88  refImage->DisconnectPipeline();
89 
90  typename InputImageType::Pointer floImage = m_FloatingPyramid->GetOutput(i);
91  floImage->DisconnectPipeline();
92 
93  typename MaskImageType::Pointer maskGenerationImage = ITK_NULLPTR;
94  if (m_BlockGenerationPyramid)
95  {
96  maskGenerationImage = m_BlockGenerationPyramid->GetOutput(i);
97  maskGenerationImage->DisconnectPipeline();
98  }
99 
100  // Update field to match the current resolution
101  if (m_OutputTransform->GetParametersAsVectorField() != NULL)
102  {
103  typedef itk::ResampleImageFilter<VelocityFieldType,VelocityFieldType> VectorResampleFilterType;
104  typedef typename VectorResampleFilterType::Pointer VectorResampleFilterPointer;
105 
106  AffineTransformPointer tmpIdentity = AffineTransformType::New();
107  tmpIdentity->SetIdentity();
108 
109  VectorResampleFilterPointer tmpResample = VectorResampleFilterType::New();
110  tmpResample->SetTransform(tmpIdentity);
111  tmpResample->SetInput(m_OutputTransform->GetParametersAsVectorField());
112 
113  tmpResample->SetSize(refImage->GetLargestPossibleRegion().GetSize());
114  tmpResample->SetOutputOrigin(refImage->GetOrigin());
115  tmpResample->SetOutputSpacing(refImage->GetSpacing());
116  tmpResample->SetOutputDirection(refImage->GetDirection());
117 
118  tmpResample->Update();
119 
120  VelocityFieldType *tmpOut = tmpResample->GetOutput();
121  m_OutputTransform->SetParametersAsVectorField(tmpOut);
122  tmpOut->DisconnectPipeline();
123  }
124 
125  std::cout << "Processing pyramid level " << i << std::endl;
126  std::cout << "Image size: " << refImage->GetLargestPossibleRegion().GetSize() << std::endl;
127 
128  double meanSpacing = 0;
129  for (unsigned int j = 0;j < ImageDimension;++j)
130  meanSpacing += refImage->GetSpacing()[j];
131  meanSpacing /= ImageDimension;
132 
133  // Init agregator mean shift parameters
134  BaseAgregatorType* agregPtr = NULL;
135 
136  if (m_Agregator == MSmoother)
137  {
139  agreg->SetExtrapolationSigma(m_ExtrapolationSigma * meanSpacing);
140  agreg->SetOutlierRejectionSigma(m_OutlierSigma);
141  agreg->SetOutputTransformType(BaseAgregatorType::SVF);
142 
143  if (this->GetNumberOfWorkUnits() != 0)
144  agreg->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
145 
146  agreg->SetGeometryInformation(refImage.GetPointer());
147 
148  agreg->SetNeighborhoodHalfSize((unsigned int)floor(m_ExtrapolationSigma * m_NeighborhoodApproximation));
149  agreg->SetDistanceBoundary(m_ExtrapolationSigma * meanSpacing * m_NeighborhoodApproximation);
150  agreg->SetMEstimateConvergenceThreshold(m_MEstimateConvergenceThreshold);
151 
152  agregPtr = agreg;
153  }
154  else
155  {
157  agreg->SetExtrapolationSigma(m_ExtrapolationSigma * meanSpacing);
158  agreg->SetOutlierRejectionSigma(m_OutlierSigma);
159  agreg->SetOutputTransformType(BaseAgregatorType::SVF);
160 
161  if (this->GetNumberOfWorkUnits() != 0)
162  agreg->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
163 
164  agreg->SetGeometryInformation(refImage.GetPointer());
165 
166  agregPtr = agreg;
167  }
168 
169  // Init matcher
170  typedef anima::TensorBlockMatcher <InputImageType> BlockMatcherType;
171 
172  BlockMatcherType *mainMatcher = new BlockMatcherType;
173  BlockMatcherType *reverseMatcher = 0;
174  mainMatcher->SetBlockPercentageKept(GetPercentageKept());
175  mainMatcher->SetBlockSize(GetBlockSize());
176  mainMatcher->SetBlockSpacing(GetBlockSpacing());
177  mainMatcher->SetBlockVarianceThreshold(GetStDevThreshold() * GetStDevThreshold());
178  mainMatcher->SetBlockGenerationMask(maskGenerationImage);
179 
180  switch (m_SymmetryType)
181  {
182  case Asymmetric:
183  {
184  typedef typename anima::AsymmetricBMRegistrationMethod <InputImageType> BlockMatchRegistrationType;
185  m_bmreg = BlockMatchRegistrationType::New();
186  break;
187  }
188 
189  case Symmetric:
190  {
191  typedef typename anima::SymmetricBMRegistrationMethod <InputImageType> BlockMatchRegistrationType;
192  typename BlockMatchRegistrationType::Pointer tmpReg = BlockMatchRegistrationType::New();
193  reverseMatcher = new BlockMatcherType;
194  reverseMatcher->SetBlockPercentageKept(GetPercentageKept());
195  reverseMatcher->SetBlockSize(GetBlockSize());
196  reverseMatcher->SetBlockSpacing(GetBlockSpacing());
197  reverseMatcher->SetBlockVarianceThreshold(GetStDevThreshold() * GetStDevThreshold());
198  reverseMatcher->SetBlockGenerationMask(maskGenerationImage);
199 
200  tmpReg->SetReverseBlockMatcher(reverseMatcher);
201  m_bmreg = tmpReg;
202  break;
203  }
204 
205  case Kissing:
206  {
207  typedef typename anima::KissingSymmetricBMRegistrationMethod <InputImageType> BlockMatchRegistrationType;
208  m_bmreg = BlockMatchRegistrationType::New();
209  break;
210  }
211  }
212 
213  m_bmreg->SetBlockMatcher(mainMatcher);
214  m_bmreg->SetAgregator(agregPtr);
215  m_bmreg->SetBCHCompositionOrder(m_BCHCompositionOrder);
216  m_bmreg->SetExponentiationOrder(m_ExponentiationOrder);
217 
218  if (this->GetNumberOfWorkUnits() != 0)
219  m_bmreg->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
220 
221  m_bmreg->SetFixedImage(refImage);
222  m_bmreg->SetMovingImage(floImage);
223 
224  m_bmreg->SetSVFElasticRegSigma(m_ElasticSigma * meanSpacing);
225 
227 
228  typename ResampleFilterType::Pointer refResampler = ResampleFilterType::New();
229  refResampler->SetOutputLargestPossibleRegion(floImage->GetLargestPossibleRegion());
230  refResampler->SetOutputOrigin(floImage->GetOrigin());
231  refResampler->SetOutputSpacing(floImage->GetSpacing());
232  refResampler->SetOutputDirection(floImage->GetDirection());
233  refResampler->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
234  refResampler->SetFiniteStrainReorientation(this->GetFiniteStrainImageReorientation());
235  m_bmreg->SetReferenceImageResampler(refResampler);
236 
237  typename ResampleFilterType::Pointer movingResampler = ResampleFilterType::New();
238  movingResampler->SetOutputLargestPossibleRegion(refImage->GetLargestPossibleRegion());
239  movingResampler->SetOutputOrigin(refImage->GetOrigin());
240  movingResampler->SetOutputSpacing(refImage->GetSpacing());
241  movingResampler->SetOutputDirection(refImage->GetDirection());
242  movingResampler->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
243  movingResampler->SetFiniteStrainReorientation(this->GetFiniteStrainImageReorientation());
244  m_bmreg->SetMovingImageResampler(movingResampler);
245 
246  switch (GetTransform())
247  {
248  case Translation:
249  mainMatcher->SetBlockTransformType(BlockMatcherType::Superclass::Translation);
250  if (reverseMatcher)
251  reverseMatcher->SetBlockTransformType(BlockMatcherType::Superclass::Translation);
252  break;
253  case Rigid:
254  mainMatcher->SetBlockTransformType(BlockMatcherType::Superclass::Rigid);
255  if (reverseMatcher)
256  reverseMatcher->SetBlockTransformType(BlockMatcherType::Superclass::Rigid);
257  break;
258  case Affine:
259  default:
260  mainMatcher->SetBlockTransformType(BlockMatcherType::Superclass::Affine);
261  if (reverseMatcher)
262  reverseMatcher->SetBlockTransformType(BlockMatcherType::Superclass::Affine);
263  break;
264  }
265 
266  switch (GetOptimizer())
267  {
268  case Exhaustive:
269  mainMatcher->SetOptimizerType(BlockMatcherType::Exhaustive);
270  if (reverseMatcher)
271  reverseMatcher->SetOptimizerType(BlockMatcherType::Exhaustive);
272  break;
273 
274  case Bobyqa:
275  default:
276  mainMatcher->SetOptimizerType(BlockMatcherType::Bobyqa);
277  if (reverseMatcher)
278  reverseMatcher->SetOptimizerType(BlockMatcherType::Bobyqa);
279  break;
280  }
281 
282  switch (m_Metric)
283  {
284  case TensorCorrelation:
285  mainMatcher->SetSimilarityType(BlockMatcherType::TensorCorrelation);
286  if (reverseMatcher)
287  reverseMatcher->SetSimilarityType(BlockMatcherType::TensorCorrelation);
288  break;
290  mainMatcher->SetSimilarityType(BlockMatcherType::TensorGeneralizedCorrelation);
291  if (reverseMatcher)
292  reverseMatcher->SetSimilarityType(BlockMatcherType::TensorGeneralizedCorrelation);
293  break;
295  mainMatcher->SetSimilarityType(BlockMatcherType::TensorOrientedGeneralizedCorrelation);
296  if (reverseMatcher)
297  reverseMatcher->SetSimilarityType(BlockMatcherType::TensorOrientedGeneralizedCorrelation);
298  break;
299  case TensorMeanSquares:
300  default:
301  mainMatcher->SetSimilarityType(BlockMatcherType::TensorMeanSquares);
302  if (reverseMatcher)
303  reverseMatcher->SetSimilarityType(BlockMatcherType::TensorMeanSquares);
304  break;
305  }
306 
307  switch (m_MetricOrientation)
308  {
309  case None:
310  mainMatcher->SetModelRotationType(BlockMatcherType::None);
311  if (reverseMatcher)
312  reverseMatcher->SetModelRotationType(BlockMatcherType::None);
313  break;
314  case PPD:
315  mainMatcher->SetModelRotationType(BlockMatcherType::PPD);
316  if (reverseMatcher)
317  reverseMatcher->SetModelRotationType(BlockMatcherType::PPD);
318  break;
319  case FiniteStrain:
320  default:
321  mainMatcher->SetModelRotationType(BlockMatcherType::FiniteStrain);
322  if (reverseMatcher)
323  reverseMatcher->SetModelRotationType(BlockMatcherType::FiniteStrain);
324  break;
325  }
326 
327  m_bmreg->SetMaximumIterations(m_MaximumIterations);
328  m_bmreg->SetMinimalTransformError(m_MinimalTransformError);
329  m_bmreg->SetInitialTransform(m_OutputTransform.GetPointer());
330 
331  mainMatcher->SetOptimizerMaximumIterations(m_OptimizerMaximumIterations);
332 
333  double sr = m_SearchRadius;
334  mainMatcher->SetSearchRadius(sr);
335 
336  double sar = m_SearchAngleRadius;
337  mainMatcher->SetSearchAngleRadius(sar);
338 
339  double scr = m_SearchScaleRadius;
340  mainMatcher->SetSearchScaleRadius(scr);
341 
342  double fr = m_FinalRadius;
343  mainMatcher->SetFinalRadius(fr);
344 
345  double ss = m_StepSize;
346  mainMatcher->SetStepSize(ss);
347 
348  double tub = m_TranslateUpperBound;
349  mainMatcher->SetTranslateMax(tub);
350 
351  double aub = m_AngleUpperBound;
352  mainMatcher->SetAngleMax(aub);
353 
354  double scub = m_ScaleUpperBound;
355  mainMatcher->SetScaleMax(scub);
356 
357  if (reverseMatcher)
358  {
359  reverseMatcher->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
360  reverseMatcher->SetOptimizerMaximumIterations(GetOptimizerMaximumIterations());
361 
362  reverseMatcher->SetSearchRadius(sr);
363  reverseMatcher->SetSearchAngleRadius(sar);
364  reverseMatcher->SetSearchScaleRadius(scr);
365  reverseMatcher->SetFinalRadius(fr);
366  reverseMatcher->SetStepSize(ss);
367  reverseMatcher->SetTranslateMax(tub);
368  reverseMatcher->SetAngleMax(aub);
369  reverseMatcher->SetScaleMax(scub);
370  }
371 
372  try
373  {
374  m_bmreg->Update();
375  }
376  catch( itk::ExceptionObject & err )
377  {
378  std::cout << "ExceptionObject caught !" << err << std::endl;
379  exit(-1);
380  }
381 
382  const BaseTransformType *resTrsf = dynamic_cast <const BaseTransformType *> (m_bmreg->GetOutput()->Get());
383  m_OutputTransform->SetParametersAsVectorField(resTrsf->GetParametersAsVectorField());
384 
385  delete mainMatcher;
386  if (reverseMatcher)
387  delete reverseMatcher;
388  if (agregPtr)
389  delete agregPtr;
390  }
391 
392  if (m_SymmetryType == Kissing)
393  {
394  VelocityFieldType *finalTrsfField = const_cast <VelocityFieldType *> (m_OutputTransform->GetParametersAsVectorField());
395  typedef itk::MultiplyImageFilter <VelocityFieldType,itk::Image <double,ImageDimension>, VelocityFieldType> MultiplyFilterType;
396 
397  typename MultiplyFilterType::Pointer fieldMultiplier = MultiplyFilterType::New();
398  fieldMultiplier->SetInput(finalTrsfField);
399  fieldMultiplier->SetConstant(2.0);
400  fieldMultiplier->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
401  fieldMultiplier->InPlaceOn();
402 
403  fieldMultiplier->Update();
404 
405  VelocityFieldType *outputField = fieldMultiplier->GetOutput();
406  m_OutputTransform->SetParametersAsVectorField(fieldMultiplier->GetOutput());
407  outputField->DisconnectPipeline();
408  }
409 
410  DisplacementFieldTransformPointer outputDispTrsf = DisplacementFieldTransformType::New();
411  anima::GetSVFExponential(m_OutputTransform.GetPointer(), outputDispTrsf.GetPointer(), m_ExponentiationOrder, GetNumberOfWorkUnits(), false);
412 
414  typename ResampleFilterType::Pointer tmpResample = ResampleFilterType::New();
415 
416  typedef itk::Transform<typename BaseAgregatorType::ScalarType,ImageDimension,ImageDimension> BaseTransformType;
417  typename BaseTransformType::Pointer baseTrsf = outputDispTrsf.GetPointer();
418  tmpResample->SetTransform(baseTrsf);
419  tmpResample->SetFiniteStrainReorientation(this->GetFiniteStrainImageReorientation());
420  tmpResample->SetInput(m_FloatingImage);
421 
422  if (this->GetNumberOfWorkUnits() != 0)
423  tmpResample->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
424 
425  tmpResample->SetOutputLargestPossibleRegion(m_ReferenceImage->GetLargestPossibleRegion());
426  tmpResample->SetOutputOrigin(m_ReferenceImage->GetOrigin());
427  tmpResample->SetOutputSpacing(m_ReferenceImage->GetSpacing());
428  tmpResample->SetOutputDirection(m_ReferenceImage->GetDirection());
429  tmpResample->Update();
430 
432  typename ExpTensorFilterType::Pointer tensorExper = ExpTensorFilterType::New();
433 
434  typename InputImageType::Pointer tmpImage = tmpResample->GetOutput();
435  tmpImage->DisconnectPipeline();
436 
437  tensorExper->SetInput(tmpImage);
438  tensorExper->SetScaleNonDiagonal(true);
439 
440  if (this->GetNumberOfWorkUnits() != 0)
441  tensorExper->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
442 
443  tensorExper->Update();
444 
445  m_OutputImage = tensorExper->GetOutput();
446  m_OutputImage->DisconnectPipeline();
447 }
448 
449 template <unsigned int ImageDimension>
450 void
452 {
453  std::cout << "Writing output image to: " << m_resultFile << std::endl;
454  anima::writeImage <InputImageType> (m_resultFile,m_OutputImage);
455 
456  if (m_outputTransformFile != "")
457  {
458  std::cout << "Writing output SVF to: " << m_outputTransformFile << std::endl;
459  anima::writeImage <VelocityFieldType> (m_outputTransformFile,
460  const_cast <VelocityFieldType *> (m_OutputTransform->GetParametersAsVectorField()));
461  }
462 }
463 
464 template <unsigned int ImageDimension>
465 void
467 {
468  // Create pyramid here, check images actually are of the same size.
469  m_ReferencePyramid = PyramidType::New();
470 
471  m_ReferencePyramid->SetInput(m_ReferenceImage);
472  m_ReferencePyramid->SetNumberOfLevels(m_NumberOfPyramidLevels);
473 
474  if (this->GetNumberOfWorkUnits() != 0)
475  m_ReferencePyramid->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
476 
478 
479  typename ResampleFilterType::Pointer refResampler = ResampleFilterType::New();
480  refResampler->SetFiniteStrainReorientation(this->GetFiniteStrainImageReorientation());
481  m_ReferencePyramid->SetImageResampler(refResampler);
482 
483  m_ReferencePyramid->Update();
484 
485  // Create pyramid for Floating image
486  m_FloatingPyramid = PyramidType::New();
487 
488  m_FloatingPyramid->SetInput(m_FloatingImage);
489  m_FloatingPyramid->SetNumberOfLevels(m_NumberOfPyramidLevels);
490 
491  if (this->GetNumberOfWorkUnits() != 0)
492  m_FloatingPyramid->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
493 
494  typename ResampleFilterType::Pointer floResampler = ResampleFilterType::New();
495  floResampler->SetFiniteStrainReorientation(this->GetFiniteStrainImageReorientation());
496  m_FloatingPyramid->SetImageResampler(floResampler);
497 
498  m_FloatingPyramid->Update();
499 
500  m_BlockGenerationPyramid = 0;
501  if (m_BlockGenerationMask)
502  {
503  typedef anima::ResampleImageFilter<MaskImageType, MaskImageType,
504  typename BaseAgregatorType::ScalarType> MaskResampleFilterType;
505 
506  typename MaskResampleFilterType::Pointer maskResampler = MaskResampleFilterType::New();
507 
508  m_BlockGenerationPyramid = MaskPyramidType::New();
509  m_BlockGenerationPyramid->SetImageResampler(maskResampler);
510  m_BlockGenerationPyramid->SetInput(m_BlockGenerationMask);
511  m_BlockGenerationPyramid->SetNumberOfLevels(GetNumberOfPyramidLevels());
512  m_BlockGenerationPyramid->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
513  m_BlockGenerationPyramid->Update();
514  }
515 }
516 
517 } // end of namespace anima
DisplacementFieldTransformType::Pointer DisplacementFieldTransformPointer
void SetGeometryInformation(const TInputImageType *geomImage)
MEstimateAgregatorType::BaseOutputTransformType BaseTransformType
void SetGeometryInformation(const TInputImageType *geomImage)
void SetOutputTransformType(TRANSFORM_TYPE name)
void SetBlockPercentageKept(double val)
void GetSVFExponential(itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *baseTrsf, rpi::DisplacementFieldTransform< ScalarType, NDimensions > *resultTransform, unsigned int exponentiationOrder, unsigned int numThreads, bool invert)