ANIMA  4.0
animaBaseBMRegistrationMethod.hxx
Go to the documentation of this file.
1 #pragma once
3 
7 
8 #include <animaVelocityUtils.h>
9 #include <itkImageRegionIterator.h>
10 
11 namespace anima
12 {
13 
14 template <typename TInputImageType>
15 BaseBMRegistrationMethod <TInputImageType>
17 {
18  m_Abort = false;
19  m_FixedImage = 0;
20  m_MovingImage = 0;
21 
22  m_SVFElasticRegSigma = 0;
23  m_BCHCompositionOrder = 1;
24  m_ExponentiationOrder = 0;
25 
26  m_MaximumIterations = 10;
27  m_MinimalTransformError = 0.0001;
28 
29  m_ReferenceImageResampler = 0;
30  m_MovingImageResampler = 0;
31 
32  m_VerboseProgression = true;
33 
34  this->SetNumberOfWorkUnits(this->GetMultiThreader()->GetNumberOfWorkUnits());
35 
36  m_InitialTransform = 0;
37  this->SetNumberOfRequiredOutputs(1);
38  TransformOutputPointer transformDecorator = static_cast <TransformOutputType *> (this->MakeOutput(0).GetPointer());
39  this->itk::ProcessObject::SetNthOutput(0, transformDecorator.GetPointer());
40 }
41 
45 template <typename TInputImageType>
48 ::GetOutput()
49 {
50  return static_cast <TransformOutputType *> (this->ProcessObject::GetOutput(0));
51 }
52 
53 template <typename TInputImageType>
54 itk::DataObject::Pointer
56 ::MakeOutput(DataObjectPointerArraySizeType output)
57 {
58  switch (output)
59  {
60  case 0:
61  return static_cast <itk::DataObject*> (TransformOutputType::New().GetPointer());
62  break;
63  default:
64  itkExceptionMacro("MakeOutput request for an output number larger than the expected number of outputs");
65  return ITK_NULLPTR;
66  }
67 }
68 
72 template <typename TInputImageType>
73 void
75 ::GenerateData()
76 {
77  m_Abort = false;
78  this->StartOptimization();
79 }
80 
84 template <typename TInputImageType>
85 void
87 ::StartOptimization()
88 {
89  m_Agregator->SetInputTransformType(m_BlockMatcher->GetAgregatorInputTransformType());
90 
91  TransformPointer computedTransform = ITK_NULLPTR;
92  this->SetupTransform(computedTransform);
93 
94  //progress management
95  itk::ProgressReporter progress(this, 0, m_MaximumIterations);
96 
97  // Real work goes here
98  InputImagePointer fixedResampled, movingResampled;
99  for (unsigned int iterations = 0; iterations < m_MaximumIterations && !m_Abort; ++iterations)
100  {
101  // Resample fixed and moving image here
102  this->ResampleImages(computedTransform, fixedResampled, movingResampled);
103 
104  // Perform one iteration of registration between the images
105  // Calls pure virtual method that can use the block matching class available here
106  this->GetAgregator()->SetCurrentLinearTransform(computedTransform);
107  TransformPointer addOn;
108  this->PerformOneIteration(fixedResampled, movingResampled, addOn);
109 
110  bool continueLoop = this->ComposeAddOnWithTransform(computedTransform,addOn);
111 
112  if (m_VerboseProgression)
113  std::cout << "Iteration " << iterations << " done..." << std::endl;
114 
115  if (iterations != m_MaximumIterations - 1)
116  progress.CompletedPixel();
117 
118  if (!continueLoop)
119  break;
120  }
121 
122  TransformOutputPointer transformDecorator = TransformOutputType::New();
123  transformDecorator->Set(computedTransform.GetPointer());
124 
125  this->itk::ProcessObject::SetNthOutput(0, transformDecorator.GetPointer());
126 }
127 
128 template <typename TInputImageType>
129 void
131 ::SetupTransform(TransformPointer &optimizedTransform)
132 {
133  if (m_Agregator->GetOutputTransformType() != AgregatorType::SVF)
134  {
135  if (m_InitialTransform)
136  {
137  optimizedTransform = AffineTransformType::New();
138  optimizedTransform->SetParameters(m_InitialTransform->GetParameters());
139  }
140  else
141  {
142  typename AffineTransformType::Pointer tmpTrsf = AffineTransformType::New();
143  tmpTrsf->SetIdentity();
144  optimizedTransform = tmpTrsf;
145  }
146  }
147  else
148  {
149  if (m_InitialTransform)
150  optimizedTransform = m_InitialTransform;
151  else
152  {
153  SVFTransformPointer tmpTrsf = SVFTransformType::New();
154  tmpTrsf->SetIdentity();
155  optimizedTransform = tmpTrsf;
156  }
157  }
158 }
159 
160 template <typename TInputImageType>
161 void
163 ::ResampleImages(TransformType *currentTransform, InputImagePointer &refImage, InputImagePointer &movingImage)
164 {
165  // Moving image resampling
166  if (m_MovingImage->GetNumberOfComponentsPerPixel() > 1)
167  {
168  // Model resampling
170  InternalFilterType *resampleFilter = dynamic_cast <InternalFilterType *> (m_MovingImageResampler.GetPointer());
171 
172  if (m_Agregator->GetOutputTransformType() == AgregatorType::SVF)
173  {
174  // Compute temporary field and set it to resampler
175  DisplacementFieldTransformPointer dispTrsf = DisplacementFieldTransformType::New();
176  SVFTransformType *svfCast = dynamic_cast<SVFTransformType *> (currentTransform);
177 
178  anima::GetSVFExponential(svfCast,dispTrsf.GetPointer(),m_ExponentiationOrder,this->GetNumberOfWorkUnits(),false);
179 
180  resampleFilter->SetTransform(dispTrsf);
181  }
182  else
183  resampleFilter->SetTransform(currentTransform);
184  }
185  else
186  {
187  // Anatomical resampling
188  typedef itk::Image <ImageScalarType, TInputImageType::ImageDimension> InternalScalarImageType;
190  InternalFilterType *resampleFilter = dynamic_cast <InternalFilterType *> (m_MovingImageResampler.GetPointer());
191 
192  if (m_Agregator->GetOutputTransformType() == AgregatorType::SVF)
193  {
194  // Compute temporary field and set it to resampler
195  DisplacementFieldTransformPointer dispTrsf = DisplacementFieldTransformType::New();
196  SVFTransformType *svfCast = dynamic_cast<SVFTransformType *> (currentTransform);
197 
198  anima::GetSVFExponential(svfCast,dispTrsf.GetPointer(),m_ExponentiationOrder,this->GetNumberOfWorkUnits(),false);
199 
200  resampleFilter->SetTransform(dispTrsf);
201  }
202  else
203  resampleFilter->SetTransform(currentTransform);
204  }
205 
206  m_MovingImageResampler->SetInput(m_MovingImage);
207  m_MovingImageResampler->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
208  m_MovingImageResampler->Update();
209 
210  movingImage = m_MovingImageResampler->GetOutput();
211  movingImage->DisconnectPipeline();
212 
213  // Fixed image resampling
214  if (m_FixedImage->GetNumberOfComponentsPerPixel() > 1)
215  {
216  // Model resampling
218  InternalFilterType *resampleFilter = dynamic_cast <InternalFilterType *> (m_ReferenceImageResampler.GetPointer());
219 
220  if (m_Agregator->GetOutputTransformType() == AgregatorType::SVF)
221  {
222  // Compute temporary field and set it to resampler
223  DisplacementFieldTransformPointer dispTrsf = DisplacementFieldTransformType::New();
224  SVFTransformType *svfCast = dynamic_cast<SVFTransformType *> (currentTransform);
225 
226  anima::GetSVFExponential(svfCast,dispTrsf.GetPointer(),m_ExponentiationOrder,this->GetNumberOfWorkUnits(),true);
227 
228  resampleFilter->SetTransform(dispTrsf);
229  }
230  else
231  {
232  AffineTransformType *affCast = dynamic_cast<AffineTransformType *> (currentTransform);
233  TransformPointer reverseTrsf = affCast->GetInverseTransform();
234  resampleFilter->SetTransform(reverseTrsf);
235  }
236  }
237  else
238  {
239  // Anatomical resampling
240  typedef itk::Image <ImageScalarType, TInputImageType::ImageDimension> InternalScalarImageType;
242  InternalFilterType *resampleFilter = dynamic_cast <InternalFilterType *> (m_ReferenceImageResampler.GetPointer());
243 
244  if (m_Agregator->GetOutputTransformType() == AgregatorType::SVF)
245  {
246  // Compute temporary field and set it to resampler
247  DisplacementFieldTransformPointer dispTrsf = DisplacementFieldTransformType::New();
248  SVFTransformType *svfCast = dynamic_cast<SVFTransformType *> (currentTransform);
249 
250  anima::GetSVFExponential(svfCast,dispTrsf.GetPointer(),m_ExponentiationOrder,this->GetNumberOfWorkUnits(),true);
251 
252  resampleFilter->SetTransform(dispTrsf);
253  }
254  else
255  {
256  AffineTransformType *affCast = dynamic_cast<AffineTransformType *> (currentTransform);
257  TransformPointer reverseTrsf = affCast->GetInverseTransform();
258  resampleFilter->SetTransform(reverseTrsf);
259  }
260  }
261 
262  m_ReferenceImageResampler->SetInput(m_FixedImage);
263  m_ReferenceImageResampler->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
264  m_ReferenceImageResampler->Update();
265 
266  refImage = m_ReferenceImageResampler->GetOutput();
267  refImage->DisconnectPipeline();
268 }
269 
270 template <typename TInputImageType>
271 bool
273 ::ComposeAddOnWithTransform(TransformPointer &computedTransform, TransformType *addOn)
274 {
275  if (m_Agregator->GetOutputTransformType() != AgregatorType::SVF)
276  {
277  typename TransformType::ParametersType oldPars = computedTransform->GetParameters();
278 
279  if (m_Agregator->GetOutputTransformType() != AgregatorType::ANISOTROPIC_SIM)
280  {
281  AffineTransformType *tmpTrsf = dynamic_cast<AffineTransformType *>(computedTransform.GetPointer());
282  AffineTransformType *tmpAddOn = dynamic_cast<AffineTransformType *>(addOn);
283  tmpTrsf->Compose(tmpAddOn, true);
284  }
285  else
286  {
287  // no composition since we compute global transform only
288  computedTransform->SetParameters(addOn->GetParameters());
289  }
290 
291  typename TransformType::ParametersType newPars = computedTransform->GetParameters();
292 
293  // Compute the distance between consecutive solutions, until a certain threshold
294  double err = 0;
295  for (unsigned int i = 0; i < newPars.Size(); ++i)
296  err += std::pow(newPars[i] - oldPars[i], 2.);
297 
298  if (err <= m_MinimalTransformError)
299  return false;
300  }
301  else
302  {
303  // Add update to current velocity field (cf. Vercauteren et al, 2008)
304  SVFTransformType *tmpTrsf = dynamic_cast<SVFTransformType *>(computedTransform.GetPointer());
305  SVFTransformType *tmpAddOn = dynamic_cast<SVFTransformType *>(addOn);
306 
307  anima::composeSVF(tmpTrsf,tmpAddOn,this->GetNumberOfWorkUnits(),m_BCHCompositionOrder);
308 
309  typedef typename SVFTransformType::VectorFieldType VectorFieldType;
310  typedef itk::ImageRegionConstIterator <VectorFieldType> IteratorType;
311  IteratorType diffItr(tmpAddOn->GetParametersAsVectorField(),
312  tmpAddOn->GetParametersAsVectorField()->GetLargestPossibleRegion());
313 
314  bool smallEnoughTransform = true;
315  while (!diffItr.IsAtEnd())
316  {
317  double err = 0;
318  for (unsigned int i = 0;i < VectorFieldType::ImageDimension;++i)
319  err += diffItr.Get()[i] * diffItr.Get()[i];
320 
321  if (err > m_MinimalTransformError)
322  {
323  smallEnoughTransform = false;
324  break;
325  }
326 
327  ++diffItr;
328  }
329 
330  if (smallEnoughTransform)
331  return false;
332 
333  if (m_SVFElasticRegSigma > 0)
334  {
335  typedef typename SVFTransformType::VectorFieldType VelocityFieldType;
336 
338  typename SmoothingFilterType::Pointer smootherPtr = SmoothingFilterType::New();
339 
340  smootherPtr->SetInput(tmpTrsf->GetParametersAsVectorField());
341  smootherPtr->SetSigma(m_SVFElasticRegSigma);
342  smootherPtr->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
343 
344  smootherPtr->Update();
345 
346  typename VelocityFieldType::Pointer tmpSmoothed = smootherPtr->GetOutput();
347  tmpSmoothed->DisconnectPipeline();
348 
349  tmpTrsf->SetParametersAsVectorField(tmpSmoothed);
350  }
351  }
352 
353  return true;
354 }
355 
359 template <typename TInputImageType>
360 void
362 ::PrintSelf(std::ostream& os, itk::Indent indent) const
363 {
364  Superclass::PrintSelf( os, indent );
365 
366  os << indent << "Fixed Image: " << m_FixedImage.GetPointer() << std::endl;
367  os << indent << "Moving Image: " << m_MovingImage.GetPointer() << std::endl;
368 
369  os << indent << "Maximum Iterations: " << m_MaximumIterations << std::endl;
370 }
371 
372 } // end namespace anima
TransformOutputType::Pointer TransformOutputPointer
DisplacementFieldTransformType::Pointer DisplacementFieldTransformPointer
AgregatorType::BaseOutputTransformType TransformType
itk::AffineTransform< typename AgregatorType::ScalarType, TInputImageType::ImageDimension > AffineTransformType
itk::DataObjectDecorator< TransformType > TransformOutputType
void GetSVFExponential(itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *baseTrsf, rpi::DisplacementFieldTransform< ScalarType, NDimensions > *resultTransform, unsigned int exponentiationOrder, unsigned int numThreads, bool invert)
void composeSVF(itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *baseTrsf, itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *addonTrsf, unsigned int numThreads, unsigned int bchOrder)
itk::StationaryVelocityFieldTransform< AgregatorScalarType, TInputImageType::ImageDimension > SVFTransformType
itk::ProcessObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType