ANIMA  4.0
animaPyramidalSymmetryBridge.hxx
Go to the documentation of this file.
1 #pragma once
3 
5 #include <itkTransformFileWriter.h>
6 #include <itkMultiResolutionPyramidImageFilter.h>
7 #include <itkImageRegistrationMethod.h>
8 #include <itkLinearInterpolateImageFunction.h>
9 #include <animaNLOPTOptimizers.h>
10 
11 #include <itkMeanSquaresImageToImageMetric.h>
12 #include <itkMutualInformationHistogramImageToImageMetric.h>
13 #include <itkImageMomentsCalculator.h>
14 #include <itkProgressReporter.h>
15 #include <itkMinimumMaximumImageFilter.h>
16 
17 #include <animaVectorOperations.h>
19 
20 namespace anima
21 {
22 
23 template <class PixelType, typename ScalarType>
25 {
26  typedef typename itk::ImageRegistrationMethod<OutputImageType, OutputImageType> RegistrationType;
27 
28  //progress management
29  itk::ProgressReporter progress(this, 0, GetNumberOfPyramidLevels());
30 
31  if(m_progressCallback)
32  {
33  this->AddObserver ( itk::ProgressEvent(), m_progressCallback );
34  }
35 
36  this->SetupPyramids();
37 
38  typename InputImageType::PointType centralPoint;
39 
40  typedef typename itk::ImageMomentsCalculator <InputImageType> ImageMomentsType;
41 
42  typename ImageMomentsType::Pointer momentsCalculator = ImageMomentsType::New();
43 
44  momentsCalculator->SetImage( m_ReferenceImage );
45  momentsCalculator->Compute();
46 
47  itk::Vector <double,InputImageType::ImageDimension> centralVector = momentsCalculator->GetCenterOfGravity();
48 
49  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
50  centralPoint[i] = centralVector[i];
51 
52  typename TransformType::ParametersType initialParams(TransformType::ParametersDimension);
53 
54  // Here should come test on orientation matrix -> find the true left/right axis
55  for (unsigned int i = 0;i < TransformType::ParametersDimension;++i)
56  initialParams[i] = 0;
57 
58  m_OutputTransform->SetParameters(initialParams);
59  m_OutputTransform->SetRotationCenter(centralPoint);
60 
61  unsigned int dimension = m_OutputTransform->GetNumberOfParameters();
62  itk::Array<double> lowerBounds(dimension);
63  itk::Array<double> upperBounds(dimension);
64 
65  for (unsigned int i = 0;i < 2;++i)
66  {
67  lowerBounds[i] = - GetUpperBoundAngle();
68  upperBounds[i] = GetUpperBoundAngle();
69  }
70 
71  // Iterate over pyramid levels
72  for (int i = 0;i < GetNumberOfPyramidLevels();++i)
73  {
74  std::cout << "Processing pyramid level " << i << std::endl;
75  std::cout << "Image size: " << m_ReferencePyramid->GetOutput(i)->GetLargestPossibleRegion().GetSize() << std::endl;
76 
77  // Init matcher
78  typename RegistrationType::Pointer reg = RegistrationType::New();
79 
80  reg->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
81 
82  typedef anima::NLOPTOptimizers OptimizerType;
83  typename OptimizerType::Pointer optimizer = OptimizerType::New();
84 
85  optimizer->SetAlgorithm(NLOPT_LN_BOBYQA);
86  optimizer->SetXTolRel(1.0e-4);
87  optimizer->SetFTolRel(1.0e-6);
88  optimizer->SetMaxEval(GetOptimizerMaxIterations());
89  optimizer->SetVectorStorageSize(2000);
90  optimizer->SetMaximize(GetMetric() != MeanSquares);
91 
92  double meanSpacing = 0;
93  for (unsigned int j = 0;j < InputImageType::ImageDimension;++j)
94  meanSpacing += m_ReferencePyramid->GetOutput(i)->GetSpacing()[j];
95 
96  lowerBounds[2] = m_OutputTransform->GetParameters()[2] - meanSpacing * GetUpperBoundDistance();
97  upperBounds[2] = m_OutputTransform->GetParameters()[2] + meanSpacing * GetUpperBoundDistance();
98 
99  optimizer->SetLowerBoundParameters(lowerBounds);
100  optimizer->SetUpperBoundParameters(upperBounds);
101 
102  reg->SetOptimizer(optimizer);
103  reg->SetTransform(m_OutputTransform);
104 
105  typedef itk::LinearInterpolateImageFunction <OutputImageType, double> InterpolatorType;
106  typename InterpolatorType::Pointer interpolator = InterpolatorType::New();
107 
108  reg->SetInterpolator(interpolator);
109 
110  switch (GetMetric())
111  {
112  case MutualInformation:
113  {
114  typedef itk::MutualInformationHistogramImageToImageMetric < OutputImageType,OutputImageType > MetricType;
115  typename MetricType::Pointer tmpMetric = MetricType::New();
116 
117  typename MetricType::HistogramType::SizeType histogramSize;
118  histogramSize.SetSize(2);
119 
120  histogramSize[0] = GetHistogramSize();
121  histogramSize[1] = GetHistogramSize();
122  tmpMetric->SetHistogramSize( histogramSize );
123 
124  reg->SetMetric(tmpMetric);
125  break;
126  }
127  case MeanSquares:
128  default:
129  {
130  typedef itk::MeanSquaresImageToImageMetric < OutputImageType,OutputImageType > MetricType;
131  typename MetricType::Pointer tmpMetric = MetricType::New();
132  reg->SetMetric(tmpMetric);
133  break;
134  }
135  }
136 
137  reg->SetFixedImage(m_ReferencePyramid->GetOutput(i));
138  reg->SetMovingImage(m_FloatingPyramid->GetOutput(i));
139 
140  reg->SetFixedImageRegion(m_ReferencePyramid->GetOutput(i)->GetLargestPossibleRegion());
141  reg->SetInitialTransformParameters(m_OutputTransform->GetParameters());
142 
143  try
144  {
145  reg->Update();
146  }
147  catch( itk::ExceptionObject & err )
148  {
149  std::cout << "ExceptionObject caught ! " << err << std::endl;
150  throw err;
151  }
152 
153  progress.CompletedPixel();
154  m_OutputTransform->SetParameters(reg->GetLastTransformParameters());
155  }
156 
157  // Now compute the transform to bring the image back onto its symmetry plane
158 
159  itk::ContinuousIndex <ScalarType,InputImageType::ImageDimension> imageCenter;
160 
161  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
162  imageCenter[i] = (m_ReferenceImage->GetLargestPossibleRegion().GetSize()[i] - 1.0) / 2.0;
163 
164  typename InputImageType::PointType centerReal;
165  m_ReferenceImage->TransformContinuousIndexToPhysicalPoint(imageCenter,centerReal);
166 
167  std::vector <double> directionVox(InputImageType::ImageDimension, 0);
168  std::vector <double> directionReal(InputImageType::ImageDimension, 0);
169  std::vector <double> directionSpherical(InputImageType::ImageDimension, 0);
170 
171  //First use real direction to find the real X direction
172  unsigned int indexAbsMax = 0;
173  typename InputImageType::DirectionType dirMatrix = m_ReferenceImage->GetDirection();
174  double valMax = std::abs(dirMatrix(0,0));
175  for (unsigned int i = 1;i < InputImageType::ImageDimension;++i)
176  {
177  if (std::abs(dirMatrix(0,i)) > valMax)
178  {
179  valMax = std::abs(dirMatrix(0,i));
180  indexAbsMax = i;
181  }
182  }
183 
184  // Now redo it with the real X-direction
185  directionVox[indexAbsMax] = 1;
186 
187  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
188  directionReal[i] = dirMatrix(i,indexAbsMax);
189 
190  anima::TransformCartesianToSphericalCoordinates(directionReal,directionSpherical);
191 
192  initialParams.Fill(0);
193  initialParams[0] = M_PI / 2.0 - directionSpherical[0];
194  initialParams[1] = directionSpherical[1];
195  initialParams[2] = 0;
196 
197  this->ComputeRealignTransform(centralVector,centerReal,initialParams);
198 
199  // Compute output image
200  typedef typename anima::ResampleImageFilter<InputImageType, OutputImageType> ResampleFilterType;
201  typename ResampleFilterType::Pointer tmpResample = ResampleFilterType::New();
202  tmpResample->SetTransform(m_OutputRealignTransform);
203  tmpResample->SetInput(m_FloatingImage);
204 
205  tmpResample->SetSize(m_ReferenceImage->GetLargestPossibleRegion().GetSize());
206  tmpResample->SetOutputOrigin(m_ReferenceImage->GetOrigin());
207  tmpResample->SetOutputSpacing(m_ReferenceImage->GetSpacing());
208  tmpResample->SetOutputDirection(m_ReferenceImage->GetDirection());
209 
210  using MinMaxFilterType = itk::MinimumMaximumImageFilter <InputImageType>;
211  typename MinMaxFilterType::Pointer minMaxFilter = MinMaxFilterType::New();
212  minMaxFilter->SetInput(m_ReferenceImage);
213  if (this->GetNumberOfWorkUnits() != 0)
214  minMaxFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
215  minMaxFilter->Update();
216 
217  double minValue = minMaxFilter->GetMinimum();
218 
219  if (minValue < 0.0)
220  tmpResample->SetDefaultPixelValue(-1024.0);
221  else
222  tmpResample->SetDefaultPixelValue(0.0);
223 
224  tmpResample->Update();
225 
226  m_OutputImage = tmpResample->GetOutput();
227  m_OutputImage->DisconnectPipeline();
228 }
229 
230 template <class PixelType, typename ScalarType>
231 void PyramidalSymmetryBridge<PixelType,ScalarType>::ComputeRealignTransform(typename itk::Vector <double,InputImageType::ImageDimension> centralPoint,
232  typename InputImageType::PointType &centerReal, ParametersType &imageParams)
233 {
234  TransformPointer imageMidPlaneTrsf = TransformType::New();
235  imageMidPlaneTrsf->SetParameters(imageParams);
236  imageMidPlaneTrsf->SetRotationCenter(centerReal);
237 
238  typedef itk::Matrix <ScalarType,InputImageType::ImageDimension+1,InputImageType::ImageDimension+1> HomogeneousMatrixType;
239 
240  HomogeneousMatrixType midPlaneMatrix;
241  midPlaneMatrix.SetIdentity();
242 
243  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
244  {
245  midPlaneMatrix(i,InputImageType::ImageDimension) = imageMidPlaneTrsf->GetOffset()[i];
246 
247  for (unsigned int j = 0;j < InputImageType::ImageDimension;++j)
248  midPlaneMatrix(i,j) = imageMidPlaneTrsf->GetMatrix()(i,j);
249  }
250 
251  HomogeneousMatrixType outputTrsfMatrix;
252  outputTrsfMatrix.SetIdentity();
253 
254  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
255  {
256  outputTrsfMatrix(i,InputImageType::ImageDimension) = m_OutputTransform->GetOffset()[i];
257 
258  for (unsigned int j = 0;j < InputImageType::ImageDimension;++j)
259  outputTrsfMatrix(i,j) = m_OutputTransform->GetMatrix()(i,j);
260  }
261 
262  HomogeneousMatrixType matrixComposition = midPlaneMatrix * outputTrsfMatrix;
263 
264  double theta_x, theta_y, theta_z;
265  double sinus, cosinus, phi, f;
266 
267  // note : le vecteur rotation obtenu est toujours dans le plan
268  // au milieu de l'image, donc une des 3 composantes est forcement nulle
269  // ce qui donne l'idee d'une autre parametrisation du plan de symetrie
270 
271  theta_x = - matrixComposition[1][2] + matrixComposition[2][1]; //X = R(2,3) - R(3,2)
272  theta_y = - matrixComposition[2][0] + matrixComposition[0][2]; //Y = R(3,1) - R(1,3)
273  theta_z = - matrixComposition[0][1] + matrixComposition[1][0]; //Z = R(1,2) - R(2,1)
274 
275  sinus = std::sqrt(theta_x*theta_x + theta_y*theta_y + theta_z*theta_z);
276 
277  if (std::abs(sinus) > 1e-9)
278  {
279  cosinus = matrixComposition[0][0] + matrixComposition[1][1] + matrixComposition[2][2] - 1;
280  phi = std::atan(sinus / cosinus);
281  f = phi / sinus;
282  theta_x = theta_x * f;
283  theta_y = theta_y * f;
284  theta_z = theta_z * f;
285  }
286  else
287  {
288  phi = 0;
289 
290  theta_x = 0;
291  theta_y = 0;
292  theta_z = 0;
293  }
294 
295  // le vecteur rotation de r est celui de R divisé par 2
296 
297  phi = phi / 2.0;
298  theta_x = theta_x / 2.0;
299  theta_y = theta_y / 2.0;
300  theta_z = theta_z / 2.0;
301 
302  // formules : http://www.iau-sofa.rl.ac.uk/2003_0429/sofa/rv2m.for
303 
304  // calcul du vecteur unitaire ayant la même direction que le vecteur rotation
305 
306  double u_x, u_y, u_z;
307 
308  sinus = std::sqrt(theta_x*theta_x + theta_y*theta_y + theta_z*theta_z);
309 
310  if (std::abs(sinus) > 1e-9)
311  {
312  u_x = theta_x / sinus;
313  u_y = theta_y / sinus;
314  u_z = theta_z / sinus;
315  }
316 
317  else {
318  u_x = 0;
319  u_y = 0;
320  u_z = 0;
321  }
322 
323  // calcul de r
324 
325  double c, s, t;
326 
327  c = std::cos(phi);
328  s = std::sin(phi);
329  t = 1 - c;
330 
331  HomogeneousMatrixType sqrtMatrix;
332  sqrtMatrix.SetIdentity();
333 
334  sqrtMatrix[0][0] = t*u_x*u_x + c;
335  sqrtMatrix[0][1] = t*u_x*u_y - s*u_z;
336  sqrtMatrix[0][2] = t*u_x*u_z + s*u_y;
337  sqrtMatrix[1][0] = t*u_x*u_y + s*u_z;
338  sqrtMatrix[1][1] = t*u_y*u_y + c;
339  sqrtMatrix[1][2] = t*u_y*u_z - s*u_x;
340  sqrtMatrix[2][0] = t*u_x*u_z - s*u_y;
341  sqrtMatrix[2][1] = t*u_y*u_z + s*u_x;
342  sqrtMatrix[2][2] = t*u_z*u_z + c;
343 
344  // calcul de transfo5 = r+I
345 
346  MatrixType tmpMatrix;
347 
348  for (unsigned int i=0; i<3; i++)
349  for (unsigned int j=0; j<3; j++)
350  tmpMatrix[i][j] = sqrtMatrix[i][j];
351 
352  for (unsigned int i=0; i<3; i++)
353  tmpMatrix[i][i]++;
354 
355  // calcul de transfo6 = (r+I)^{-1}
356 
357  tmpMatrix = tmpMatrix.GetInverse();
358 
359  // calcul de t = (r+I)^{-1}.T
360  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
361  for (unsigned int j = 0;j < InputImageType::ImageDimension;++j)
362  sqrtMatrix[i][3] += tmpMatrix[i][j] * matrixComposition[j][3];
363 
364  itk::Vector <double,InputImageType::ImageDimension+1> centralPointHom, centralPointTr;
365  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
366  centralPointHom[i] = centralPoint[i];
367  centralPointHom[InputImageType::ImageDimension] = 1;
368 
369  centralPointTr = sqrtMatrix * centralPointHom;
370 
371  sqrtMatrix = sqrtMatrix.GetInverse();
372 
373  HomogeneousMatrixType trToCenterMatrix;
374  trToCenterMatrix.SetIdentity();
375 
376  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
377  trToCenterMatrix(i,3) = centralPointTr[i] - centerReal[i];
378 
379  sqrtMatrix = sqrtMatrix * trToCenterMatrix;
380 
381  MatrixType trsfMatrix;
382  OffsetType offsetVector;
383  for (unsigned int i = 0;i < InputImageType::ImageDimension;++i)
384  {
385  offsetVector[i] = sqrtMatrix(i,3);
386  for (unsigned int j=0; j<3; j++)
387  trsfMatrix(i,j) = sqrtMatrix(i,j);
388  }
389 
390  if (m_OutputRealignTransform.IsNull())
391  m_OutputRealignTransform = BaseTransformType::New();
392 
393  m_OutputRealignTransform->SetMatrix(trsfMatrix);
394  m_OutputRealignTransform->SetOffset(offsetVector);
395 }
396 
397 template <class PixelType, typename ScalarType>
399 {
400  SaveResultFile();
401 
402  SaveRealignTransformFile();
403 
404  SaveTransformFile();
405 }
406 
407 
408 template <class PixelType, typename ScalarType>
410 {
411  if (GetResultfile() != "")
412  {
413  std::cout << "Writing output image to: " << GetResultfile() << std::endl;
414  anima::writeImage <InputImageType> (GetResultfile(),m_OutputImage);
415  }
416 }
417 
418 
419 template <class PixelType, typename ScalarType>
421 {
422  if (GetOutputTransformFile() != "")
423  {
424  std::cout << "Writing output transform to: " << GetOutputTransformFile() << std::endl;
425  itk::TransformFileWriter::Pointer writer = itk::TransformFileWriter::New();
426 
427  // SymmetryPlaneTransforms should not be written as is, this loses information
428  typename BaseTransformType::Pointer tmpTrsf = BaseTransformType::New();
429  tmpTrsf->SetMatrix(m_OutputTransform->GetMatrix());
430  tmpTrsf->SetOffset(m_OutputTransform->GetOffset());
431 
432  writer->SetInput(tmpTrsf);
433  writer->SetFileName(GetOutputTransformFile());
434  writer->Update();
435  }
436 }
437 
438 template <class PixelType, typename ScalarType>
440 {
441  if (GetOutputRealignTransformFile() != "")
442  {
443  std::cout << "Writing output realign transform to: " << GetOutputRealignTransformFile() << std::endl;
444  itk::TransformFileWriter::Pointer writer = itk::TransformFileWriter::New();
445  writer->SetInput(m_OutputRealignTransform);
446  writer->SetFileName(GetOutputRealignTransformFile());
447  writer->Update();
448  }
449 }
450 
451 template <class PixelType, typename ScalarType>
453 {
454  m_ReferencePyramid = PyramidType::New();
455 
456  m_ReferencePyramid->SetInput(m_ReferenceImage);
457  m_ReferencePyramid->SetNumberOfLevels(GetNumberOfPyramidLevels());
458 
459  m_ReferencePyramid->Update();
460 
461  // Create pyramid for Floating image
462  m_FloatingPyramid = PyramidType::New();
463 
464  m_FloatingPyramid->SetInput(m_FloatingImage);
465  m_FloatingPyramid->SetNumberOfLevels(GetNumberOfPyramidLevels());
466  m_FloatingPyramid->Update();
467 }
468 
469 } // end of namespace anima
TransformType::ParametersType ParametersType
Implements an ITK wrapper for the NLOPT library.
void ComputeRealignTransform(itk::Vector< double, InputImageType::ImageDimension > centralPoint, typename InputImageType::PointType &centerReal, ParametersType &imageParams)
Superclass::ParametersType ParametersType
void TransformCartesianToSphericalCoordinates(const VectorType &v, VectorType &resVec)