ANIMA  4.0
animaPyramidalDistortionCorrectionBlockMatchingBridge.hxx
Go to the documentation of this file.
1 #pragma once
3 #include <itkImageMomentsCalculator.h>
4 
8 
9 #include <itkResampleImageFilter.h>
10 
11 #include <animaVelocityUtils.h>
13 
14 namespace anima
15 {
16 
17 template <unsigned int ImageDimension>
19 {
20  m_BackwardImage = NULL;
21  m_ForwardImage = NULL;
22 
23  m_InitialTransform = NULL;
24 
25  m_OutputTransform = DisplacementFieldTransformType::New();
26  m_OutputTransform->SetIdentity();
27 
28  m_outputTransformFile = "";
29 
30  m_OutputImage = NULL;
31 
32  m_TransformDirection = 0;
33  m_BlockSize = 5;
34  m_BlockSpacing = 2;
35  m_StDevThreshold = 5;
36  m_MaximumIterations = 10;
37  m_OptimizerMaximumIterations = 100;
38  m_SearchRadius = 2;
39  m_SearchScaleRadius = 0.1;
40  m_SearchSkewRadius = 0.1;
41  m_FinalRadius = 0.001;
42  m_TranlateUpperBound = 50;
43  m_ScaleUpperBound = std::log(5.0);
44  m_SkewUpperBound = std::tan(M_PI / 3.0);
45  m_Agregator = Baloo;
46  m_TransformKind = DirectionScaleSkew;
47  m_Metric = SquaredCorrelation;
48  m_WeightedAgregation = false;
49  m_ExtrapolationSigma = 3;
50  m_ElasticSigma = 3;
51  m_OutlierSigma = 3;
52  m_MEstimateConvergenceThreshold = 0.01;
53  m_NeighborhoodApproximation = 2.5;
54  m_ExponentiationOrder = 1;
55  m_NumberOfPyramidLevels = 3;
56  m_LastPyramidLevel = 0;
57  m_PercentageKept = 0.8;
58  this->SetNumberOfWorkUnits(itk::MultiThreaderBase::GetGlobalDefaultNumberOfThreads());
59 }
60 
61 template <unsigned int ImageDimension>
62 void
65 {
66  if (!m_InitialTransform)
67  m_InitialTransform = DisplacementFieldTransformType::New();
68 
69  m_InitialTransform->SetParametersAsVectorField(field);
70 }
71 
72 template <unsigned int ImageDimension>
74 {
75 }
76 
77 template <unsigned int ImageDimension>
78 void
80 {
81  typedef typename anima::DistortionCorrectionBMRegistrationMethod<InputImageType> BlockMatchRegistrationType;
82 
83  this->SetupPyramids();
84 
85  // Iterate over pyramid levels
86  for (unsigned int i = 0;i < m_BackwardPyramid->GetNumberOfLevels();++i)
87  {
88  if (i + m_LastPyramidLevel >= m_BackwardPyramid->GetNumberOfLevels())
89  continue;
90 
91  typename InputImageType::Pointer backwardImage = m_BackwardPyramid->GetOutput(i);
92  backwardImage->DisconnectPipeline();
93 
94  typename InputImageType::Pointer forwardImage = m_ForwardPyramid->GetOutput(i);
95  forwardImage->DisconnectPipeline();
96 
97  // Update fields to match the current resolution
98  if (m_OutputTransform->GetParametersAsVectorField() != NULL)
99  {
100  typedef itk::ResampleImageFilter<VectorFieldType,VectorFieldType> VectorResampleFilterType;
101  typedef typename VectorResampleFilterType::Pointer VectorResampleFilterPointer;
102 
103  AffineTransformPointer tmpIdentity = AffineTransformType::New();
104  tmpIdentity->SetIdentity();
105 
106  VectorResampleFilterPointer tmpResample = VectorResampleFilterType::New();
107  tmpResample->SetTransform(tmpIdentity);
108  tmpResample->SetInput(m_OutputTransform->GetParametersAsVectorField());
109 
110  tmpResample->SetSize(backwardImage->GetLargestPossibleRegion().GetSize());
111  tmpResample->SetOutputOrigin(backwardImage->GetOrigin());
112  tmpResample->SetOutputSpacing(backwardImage->GetSpacing());
113  tmpResample->SetOutputDirection(backwardImage->GetDirection());
114 
115  tmpResample->Update();
116 
117  VectorFieldType *tmpOut = tmpResample->GetOutput();
118  m_OutputTransform->SetParametersAsVectorField(tmpOut);
119  tmpOut->DisconnectPipeline();
120  }
121 
122  DisplacementFieldTransformPointer initialTransform;
123 
124  if (m_InitialTransform)
125  {
126  typedef itk::ResampleImageFilter<VectorFieldType,VectorFieldType> VectorResampleFilterType;
127  typedef typename VectorResampleFilterType::Pointer VectorResampleFilterPointer;
128 
129  AffineTransformPointer tmpIdentity = AffineTransformType::New();
130  tmpIdentity->SetIdentity();
131 
132  VectorResampleFilterPointer tmpResample = VectorResampleFilterType::New();
133  tmpResample->SetTransform(tmpIdentity);
134  tmpResample->SetInput(m_InitialTransform->GetParametersAsVectorField());
135 
136  tmpResample->SetSize(backwardImage->GetLargestPossibleRegion().GetSize());
137  tmpResample->SetOutputOrigin(backwardImage->GetOrigin());
138  tmpResample->SetOutputSpacing(backwardImage->GetSpacing());
139  tmpResample->SetOutputDirection(backwardImage->GetDirection());
140 
141  tmpResample->Update();
142 
143  VectorFieldType *tmpOut = tmpResample->GetOutput();
144  initialTransform = DisplacementFieldTransformType::New();
145  initialTransform->SetParametersAsVectorField(tmpOut);
146  tmpOut->DisconnectPipeline();
147  }
148 
149  std::cout << "Processing pyramid level " << i << std::endl;
150  std::cout << "Image size: " << backwardImage->GetLargestPossibleRegion().GetSize() << std::endl;
151 
152  double meanSpacing = 0;
153  for (unsigned int j = 0;j < ImageDimension;++j)
154  meanSpacing += backwardImage->GetSpacing()[j];
155  meanSpacing /= ImageDimension;
156 
157  // Init matcher
158  typename BlockMatchRegistrationType::Pointer bmreg = BlockMatchRegistrationType::New();
159 
160  bmreg->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
161 
162  typedef anima::ResampleImageFilter<InputImageType, InputImageType,
163  typename BaseAgregatorType::ScalarType> ResampleFilterType;
164 
165  typename ResampleFilterType::Pointer refResampler = ResampleFilterType::New();
166  refResampler->SetSize(forwardImage->GetLargestPossibleRegion().GetSize());
167  refResampler->SetOutputOrigin(forwardImage->GetOrigin());
168  refResampler->SetOutputSpacing(forwardImage->GetSpacing());
169  refResampler->SetOutputDirection(forwardImage->GetDirection());
170  refResampler->SetDefaultPixelValue(0);
171  refResampler->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
172  refResampler->SetScaleIntensitiesWithJacobian(true);
173  bmreg->SetReferenceImageResampler(refResampler);
174 
175  typename ResampleFilterType::Pointer movingResampler = ResampleFilterType::New();
176  movingResampler->SetSize(backwardImage->GetLargestPossibleRegion().GetSize());
177  movingResampler->SetOutputOrigin(backwardImage->GetOrigin());
178  movingResampler->SetOutputSpacing(backwardImage->GetSpacing());
179  movingResampler->SetOutputDirection(backwardImage->GetDirection());
180  movingResampler->SetDefaultPixelValue(0);
181  movingResampler->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
182  movingResampler->SetScaleIntensitiesWithJacobian(true);
183  bmreg->SetMovingImageResampler(movingResampler);
184 
185  // Init matcher
187 
188  BlockMatcherType *mainMatcher = new BlockMatcherType;
189  mainMatcher->SetBlockPercentageKept(GetPercentageKept());
190  mainMatcher->SetBlockSize(GetBlockSize());
191  mainMatcher->SetBlockSpacing(GetBlockSpacing());
192  mainMatcher->SetBlockVarianceThreshold(GetStDevThreshold() * GetStDevThreshold());
193 
194  bmreg->SetBlockMatcher(mainMatcher);
195 
196  bmreg->SetFixedImage(backwardImage);
197  bmreg->SetMovingImage(forwardImage);
198 
199  // Init agregator mean shift parameters
200  BaseAgregatorType* agregPtr = NULL;
201 
202  if (m_Agregator == MSmoother)
203  {
205  agreg->SetExtrapolationSigma(m_ExtrapolationSigma * meanSpacing);
206  agreg->SetOutlierRejectionSigma(m_OutlierSigma);
207  agreg->SetOutputTransformType(BaseAgregatorType::SVF);
208 
209  agreg->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
210  agreg->SetGeometryInformation(backwardImage.GetPointer());
211 
212  agreg->SetNeighborhoodHalfSize((unsigned int)floor(m_ExtrapolationSigma * m_NeighborhoodApproximation));
213  agreg->SetDistanceBoundary(m_ExtrapolationSigma * meanSpacing * m_NeighborhoodApproximation);
214  agreg->SetMEstimateConvergenceThreshold(m_MEstimateConvergenceThreshold);
215 
216  agregPtr = agreg;
217  }
218  else
219  {
221  agreg->SetExtrapolationSigma(m_ExtrapolationSigma * meanSpacing);
222  agreg->SetOutlierRejectionSigma(m_OutlierSigma);
223  agreg->SetOutputTransformType(BaseAgregatorType::SVF);
224 
225  agreg->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
226  agreg->SetGeometryInformation(backwardImage.GetPointer());
227 
228  agregPtr = agreg;
229  }
230 
231  bmreg->SetAgregator(agregPtr);
232  bmreg->SetExponentiationOrder(m_ExponentiationOrder);
233 
234  mainMatcher->SetBlockTransformType((typename BlockMatcherType::TransformDefinition) m_TransformKind);
235  mainMatcher->SetSimilarityType((typename BlockMatcherType::SimilarityDefinition) m_Metric);
236  mainMatcher->SetOptimizerType(BlockMatcherType::Bobyqa);
237 
238  bmreg->SetSVFElasticRegSigma(m_ElasticSigma * meanSpacing);
239 
240  bmreg->SetMaximumIterations(m_MaximumIterations);
241  mainMatcher->SetOptimizerMaximumIterations(m_OptimizerMaximumIterations);
242  mainMatcher->SetTransformDirection(m_TransformDirection);
243 
244  if (initialTransform)
245  bmreg->SetInitialTransform(initialTransform.GetPointer());
246 
247  bmreg->SetCurrentTransform(m_OutputTransform.GetPointer());
248 
249  mainMatcher->SetSearchRadius(m_SearchRadius);
250  mainMatcher->SetSearchScaleRadius(m_SearchScaleRadius);
251  mainMatcher->SetSearchSkewRadius(m_SearchSkewRadius);
252  mainMatcher->SetFinalRadius(m_FinalRadius);
253  mainMatcher->SetTranslateMax(m_TranlateUpperBound);
254  mainMatcher->SetScaleMax(m_ScaleUpperBound);
255  mainMatcher->SetSkewMax(m_SkewUpperBound);
256  mainMatcher->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
257 
258  bmreg->Update();
259 
260  const DisplacementFieldTransformType *resTrsf = dynamic_cast <const DisplacementFieldTransformType *> (bmreg->GetOutput()->Get());
261  m_OutputTransform->SetParametersAsVectorField(resTrsf->GetParametersAsVectorField());
262 
263  if (agregPtr)
264  delete agregPtr;
265  if (mainMatcher)
266  delete mainMatcher;
267  }
268 
269  if (m_LastPyramidLevel != 0)
270  {
271  // Resample output transform to go back to full resolution
272  typedef itk::ResampleImageFilter<VectorFieldType,VectorFieldType> VectorResampleFilterType;
273  typedef typename VectorResampleFilterType::Pointer VectorResampleFilterPointer;
274 
275  AffineTransformPointer tmpIdentity = AffineTransformType::New();
276  tmpIdentity->SetIdentity();
277 
278  VectorResampleFilterPointer tmpResample = VectorResampleFilterType::New();
279  tmpResample->SetTransform(tmpIdentity);
280  tmpResample->SetInput(m_OutputTransform->GetParametersAsVectorField());
281 
282  tmpResample->SetSize(m_BackwardImage->GetLargestPossibleRegion().GetSize());
283  tmpResample->SetOutputOrigin(m_BackwardImage->GetOrigin());
284  tmpResample->SetOutputSpacing(m_BackwardImage->GetSpacing());
285  tmpResample->SetOutputDirection(m_BackwardImage->GetDirection());
286 
287  tmpResample->Update();
288 
289  VectorFieldType *tmpOut = tmpResample->GetOutput();
290  m_OutputTransform->SetParametersAsVectorField(tmpOut);
291  tmpOut->DisconnectPipeline();
292  }
293 
294  typedef typename anima::ResampleImageFilter<InputImageType, InputImageType,
295  typename BaseAgregatorType::ScalarType> ResampleFilterType;
296 
297  DisplacementFieldTransformPointer oppositeTransform = DisplacementFieldTransformType::New();
298 
299  typedef itk::MultiplyImageFilter <VectorFieldType,itk::Image <double, ImageDimension>, VectorFieldType> MultiplyFilterType;
300  typename MultiplyFilterType::Pointer multiplyFilter = MultiplyFilterType::New();
301  multiplyFilter->SetInput(m_OutputTransform->GetParametersAsVectorField());
302  multiplyFilter->SetConstant(-1.0);
303  multiplyFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
304 
305  multiplyFilter->Update();
306 
307  oppositeTransform->SetParametersAsVectorField(multiplyFilter->GetOutput());
308 
309  if (m_InitialTransform)
310  anima::composeDistortionCorrections<typename BaseAgregatorType::ScalarType,ImageDimension>
311  (m_InitialTransform,m_OutputTransform,oppositeTransform,this->GetNumberOfWorkUnits());
312 
313  typename ResampleFilterType::Pointer tmpResampleFloating = ResampleFilterType::New();
314  tmpResampleFloating->SetTransform(m_OutputTransform);
315  tmpResampleFloating->SetInput(m_ForwardImage);
316 
317  tmpResampleFloating->SetSize(m_BackwardImage->GetLargestPossibleRegion().GetSize());
318  tmpResampleFloating->SetOutputOrigin(m_BackwardImage->GetOrigin());
319  tmpResampleFloating->SetOutputSpacing(m_BackwardImage->GetSpacing());
320  tmpResampleFloating->SetOutputDirection(m_BackwardImage->GetDirection());
321  tmpResampleFloating->SetDefaultPixelValue(0);
322  tmpResampleFloating->SetScaleIntensitiesWithJacobian(true);
323  tmpResampleFloating->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
324  tmpResampleFloating->Update();
325 
326  typename ResampleFilterType::Pointer tmpResampleReference = ResampleFilterType::New();
327  tmpResampleReference->SetTransform(oppositeTransform);
328  tmpResampleReference->SetInput(m_BackwardImage);
329 
330  tmpResampleReference->SetSize(m_BackwardImage->GetLargestPossibleRegion().GetSize());
331  tmpResampleReference->SetOutputOrigin(m_BackwardImage->GetOrigin());
332  tmpResampleReference->SetOutputSpacing(m_BackwardImage->GetSpacing());
333  tmpResampleReference->SetOutputDirection(m_BackwardImage->GetDirection());
334  tmpResampleReference->SetDefaultPixelValue(0);
335  tmpResampleReference->SetScaleIntensitiesWithJacobian(true);
336  tmpResampleReference->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
337 
338  tmpResampleReference->Update();
339 
340  typedef itk::AddImageFilter <InputImageType,InputImageType,InputImageType> AddFilterType;
341  typename AddFilterType::Pointer addFilter = AddFilterType::New();
342  addFilter->SetInput1(tmpResampleFloating->GetOutput());
343  addFilter->SetInput2(tmpResampleReference->GetOutput());
344  addFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
345 
346  addFilter->Update();
347 
348  typedef itk::MultiplyImageFilter <InputImageType,itk::Image <double, ImageDimension>,InputImageType> MultiplyScalarFilterType;
349  typename MultiplyScalarFilterType::Pointer multiplyScalarFilter = MultiplyScalarFilterType::New();
350  multiplyScalarFilter->SetInput(addFilter->GetOutput());
351  multiplyScalarFilter->SetConstant(0.5);
352  multiplyScalarFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
353  multiplyScalarFilter->InPlaceOn();
354 
355  multiplyScalarFilter->Update();
356 
357  m_OutputImage = multiplyScalarFilter->GetOutput();
358  m_OutputImage->DisconnectPipeline();
359 }
360 
361 template <unsigned int ImageDimension>
362 void
364 {
365  std::cout << "Writing output image to: " << m_resultFile << std::endl;
366  anima::writeImage<InputImageType>(m_resultFile,m_OutputImage);
367 
368  if (m_outputTransformFile != "")
369  {
370  std::cout << "Writing output transform to: " << m_outputTransformFile << std::endl;
371  anima::writeImage<VectorFieldType>(m_outputTransformFile, const_cast <VectorFieldType *> (m_OutputTransform->GetParametersAsVectorField()));
372  }
373 }
374 
375 template <unsigned int ImageDimension>
376 void
378 {
379  // Create pyramid here, check images actually are of the same size.
380  m_BackwardPyramid = PyramidType::New();
381 
382  m_BackwardPyramid->SetInput(m_BackwardImage);
383  m_BackwardPyramid->SetNumberOfLevels(m_NumberOfPyramidLevels);
384  m_BackwardPyramid->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
385 
386  typedef typename anima::ResampleImageFilter<InputImageType, InputImageType,
387  typename BaseAgregatorType::ScalarType> ResampleFilterType;
388 
389  typename ResampleFilterType::Pointer backwardResampler = ResampleFilterType::New();
390  m_BackwardPyramid->SetImageResampler(backwardResampler);
391 
392  m_BackwardPyramid->Update();
393 
394  // Create pyramid for Floating image
395  m_ForwardPyramid = PyramidType::New();
396 
397  m_ForwardPyramid->SetInput(m_ForwardImage);
398  m_ForwardPyramid->SetNumberOfLevels(m_NumberOfPyramidLevels);
399  m_ForwardPyramid->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
400 
401  typename ResampleFilterType::Pointer forwardResampler = ResampleFilterType::New();
402  m_ForwardPyramid->SetImageResampler(forwardResampler);
403 
404  m_ForwardPyramid->Update();
405 }
406 
407 } // end namespace anima
void SetGeometryInformation(const TInputImageType *geomImage)
void SetGeometryInformation(const TInputImageType *geomImage)
void SetOutputTransformType(TRANSFORM_TYPE name)
void SetBlockPercentageKept(double val)
rpi::DisplacementFieldTransform< typename BaseAgregatorType::ScalarType, ImageDimension > DisplacementFieldTransformType