5 #include <itkTransformFileWriter.h> 6 #include <itkMultiResolutionPyramidImageFilter.h> 7 #include <itkImageRegistrationMethod.h> 8 #include <itkLinearInterpolateImageFunction.h> 11 #include <itkMeanSquaresImageToImageMetric.h> 12 #include <itkMutualInformationHistogramImageToImageMetric.h> 13 #include <itkImageMomentsCalculator.h> 14 #include <itkProgressReporter.h> 15 #include <itkMinimumMaximumImageFilter.h> 23 template <
class PixelType,
typename ScalarType>
26 typedef typename itk::ImageRegistrationMethod<OutputImageType, OutputImageType> RegistrationType;
29 itk::ProgressReporter progress(
this, 0, GetNumberOfPyramidLevels());
31 if(m_progressCallback)
33 this->AddObserver ( itk::ProgressEvent(), m_progressCallback );
36 this->SetupPyramids();
38 typename InputImageType::PointType centralPoint;
40 typedef typename itk::ImageMomentsCalculator <InputImageType> ImageMomentsType;
42 typename ImageMomentsType::Pointer momentsCalculator = ImageMomentsType::New();
44 momentsCalculator->SetImage( m_ReferenceImage );
45 momentsCalculator->Compute();
47 itk::Vector <double,InputImageType::ImageDimension> centralVector = momentsCalculator->GetCenterOfGravity();
49 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
50 centralPoint[i] = centralVector[i];
55 for (
unsigned int i = 0;i < TransformType::ParametersDimension;++i)
58 m_OutputTransform->SetParameters(initialParams);
59 m_OutputTransform->SetRotationCenter(centralPoint);
61 unsigned int dimension = m_OutputTransform->GetNumberOfParameters();
62 itk::Array<double> lowerBounds(dimension);
63 itk::Array<double> upperBounds(dimension);
65 for (
unsigned int i = 0;i < 2;++i)
67 lowerBounds[i] = - GetUpperBoundAngle();
68 upperBounds[i] = GetUpperBoundAngle();
72 for (
int i = 0;i < GetNumberOfPyramidLevels();++i)
74 std::cout <<
"Processing pyramid level " << i << std::endl;
75 std::cout <<
"Image size: " << m_ReferencePyramid->GetOutput(i)->GetLargestPossibleRegion().GetSize() << std::endl;
78 typename RegistrationType::Pointer reg = RegistrationType::New();
80 reg->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
83 typename OptimizerType::Pointer optimizer = OptimizerType::New();
85 optimizer->SetAlgorithm(NLOPT_LN_BOBYQA);
86 optimizer->SetXTolRel(1.0e-4);
87 optimizer->SetFTolRel(1.0e-6);
88 optimizer->SetMaxEval(GetOptimizerMaxIterations());
89 optimizer->SetVectorStorageSize(2000);
92 double meanSpacing = 0;
93 for (
unsigned int j = 0;j < InputImageType::ImageDimension;++j)
94 meanSpacing += m_ReferencePyramid->GetOutput(i)->GetSpacing()[j];
96 lowerBounds[2] = m_OutputTransform->GetParameters()[2] - meanSpacing * GetUpperBoundDistance();
97 upperBounds[2] = m_OutputTransform->GetParameters()[2] + meanSpacing * GetUpperBoundDistance();
99 optimizer->SetLowerBoundParameters(lowerBounds);
100 optimizer->SetUpperBoundParameters(upperBounds);
102 reg->SetOptimizer(optimizer);
103 reg->SetTransform(m_OutputTransform);
105 typedef itk::LinearInterpolateImageFunction <OutputImageType, double> InterpolatorType;
106 typename InterpolatorType::Pointer interpolator = InterpolatorType::New();
108 reg->SetInterpolator(interpolator);
114 typedef itk::MutualInformationHistogramImageToImageMetric < OutputImageType,OutputImageType > MetricType;
115 typename MetricType::Pointer tmpMetric = MetricType::New();
117 typename MetricType::HistogramType::SizeType histogramSize;
118 histogramSize.SetSize(2);
120 histogramSize[0] = GetHistogramSize();
121 histogramSize[1] = GetHistogramSize();
122 tmpMetric->SetHistogramSize( histogramSize );
124 reg->SetMetric(tmpMetric);
130 typedef itk::MeanSquaresImageToImageMetric < OutputImageType,OutputImageType > MetricType;
131 typename MetricType::Pointer tmpMetric = MetricType::New();
132 reg->SetMetric(tmpMetric);
137 reg->SetFixedImage(m_ReferencePyramid->GetOutput(i));
138 reg->SetMovingImage(m_FloatingPyramid->GetOutput(i));
140 reg->SetFixedImageRegion(m_ReferencePyramid->GetOutput(i)->GetLargestPossibleRegion());
141 reg->SetInitialTransformParameters(m_OutputTransform->GetParameters());
147 catch( itk::ExceptionObject & err )
149 std::cout <<
"ExceptionObject caught ! " << err << std::endl;
153 progress.CompletedPixel();
154 m_OutputTransform->SetParameters(reg->GetLastTransformParameters());
159 itk::ContinuousIndex <ScalarType,InputImageType::ImageDimension> imageCenter;
161 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
162 imageCenter[i] = (m_ReferenceImage->GetLargestPossibleRegion().GetSize()[i] - 1.0) / 2.0;
164 typename InputImageType::PointType centerReal;
165 m_ReferenceImage->TransformContinuousIndexToPhysicalPoint(imageCenter,centerReal);
167 std::vector <double> directionVox(InputImageType::ImageDimension, 0);
168 std::vector <double> directionReal(InputImageType::ImageDimension, 0);
169 std::vector <double> directionSpherical(InputImageType::ImageDimension, 0);
172 unsigned int indexAbsMax = 0;
173 typename InputImageType::DirectionType dirMatrix = m_ReferenceImage->GetDirection();
174 double valMax = std::abs(dirMatrix(0,0));
175 for (
unsigned int i = 1;i < InputImageType::ImageDimension;++i)
177 if (std::abs(dirMatrix(0,i)) > valMax)
179 valMax = std::abs(dirMatrix(0,i));
185 directionVox[indexAbsMax] = 1;
187 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
188 directionReal[i] = dirMatrix(i,indexAbsMax);
192 initialParams.Fill(0);
193 initialParams[0] = M_PI / 2.0 - directionSpherical[0];
194 initialParams[1] = directionSpherical[1];
195 initialParams[2] = 0;
197 this->ComputeRealignTransform(centralVector,centerReal,initialParams);
201 typename ResampleFilterType::Pointer tmpResample = ResampleFilterType::New();
202 tmpResample->SetTransform(m_OutputRealignTransform);
203 tmpResample->SetInput(m_FloatingImage);
205 tmpResample->SetSize(m_ReferenceImage->GetLargestPossibleRegion().GetSize());
206 tmpResample->SetOutputOrigin(m_ReferenceImage->GetOrigin());
207 tmpResample->SetOutputSpacing(m_ReferenceImage->GetSpacing());
208 tmpResample->SetOutputDirection(m_ReferenceImage->GetDirection());
210 using MinMaxFilterType = itk::MinimumMaximumImageFilter <InputImageType>;
211 typename MinMaxFilterType::Pointer minMaxFilter = MinMaxFilterType::New();
212 minMaxFilter->SetInput(m_ReferenceImage);
213 if (this->GetNumberOfWorkUnits() != 0)
214 minMaxFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
215 minMaxFilter->Update();
217 double minValue = minMaxFilter->GetMinimum();
220 tmpResample->SetDefaultPixelValue(-1024.0);
222 tmpResample->SetDefaultPixelValue(0.0);
224 tmpResample->Update();
226 m_OutputImage = tmpResample->GetOutput();
227 m_OutputImage->DisconnectPipeline();
230 template <
class PixelType,
typename ScalarType>
232 typename InputImageType::PointType ¢erReal,
ParametersType &imageParams)
235 imageMidPlaneTrsf->SetParameters(imageParams);
236 imageMidPlaneTrsf->SetRotationCenter(centerReal);
238 typedef itk::Matrix <ScalarType,InputImageType::ImageDimension+1,InputImageType::ImageDimension+1> HomogeneousMatrixType;
240 HomogeneousMatrixType midPlaneMatrix;
241 midPlaneMatrix.SetIdentity();
243 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
245 midPlaneMatrix(i,InputImageType::ImageDimension) = imageMidPlaneTrsf->GetOffset()[i];
247 for (
unsigned int j = 0;j < InputImageType::ImageDimension;++j)
248 midPlaneMatrix(i,j) = imageMidPlaneTrsf->GetMatrix()(i,j);
251 HomogeneousMatrixType outputTrsfMatrix;
252 outputTrsfMatrix.SetIdentity();
254 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
256 outputTrsfMatrix(i,InputImageType::ImageDimension) = m_OutputTransform->GetOffset()[i];
258 for (
unsigned int j = 0;j < InputImageType::ImageDimension;++j)
259 outputTrsfMatrix(i,j) = m_OutputTransform->GetMatrix()(i,j);
262 HomogeneousMatrixType matrixComposition = midPlaneMatrix * outputTrsfMatrix;
264 double theta_x, theta_y, theta_z;
265 double sinus, cosinus, phi, f;
271 theta_x = - matrixComposition[1][2] + matrixComposition[2][1];
272 theta_y = - matrixComposition[2][0] + matrixComposition[0][2];
273 theta_z = - matrixComposition[0][1] + matrixComposition[1][0];
275 sinus = std::sqrt(theta_x*theta_x + theta_y*theta_y + theta_z*theta_z);
277 if (std::abs(sinus) > 1e-9)
279 cosinus = matrixComposition[0][0] + matrixComposition[1][1] + matrixComposition[2][2] - 1;
280 phi = std::atan(sinus / cosinus);
282 theta_x = theta_x * f;
283 theta_y = theta_y * f;
284 theta_z = theta_z * f;
298 theta_x = theta_x / 2.0;
299 theta_y = theta_y / 2.0;
300 theta_z = theta_z / 2.0;
306 double u_x, u_y, u_z;
308 sinus = std::sqrt(theta_x*theta_x + theta_y*theta_y + theta_z*theta_z);
310 if (std::abs(sinus) > 1e-9)
312 u_x = theta_x / sinus;
313 u_y = theta_y / sinus;
314 u_z = theta_z / sinus;
331 HomogeneousMatrixType sqrtMatrix;
332 sqrtMatrix.SetIdentity();
334 sqrtMatrix[0][0] = t*u_x*u_x + c;
335 sqrtMatrix[0][1] = t*u_x*u_y - s*u_z;
336 sqrtMatrix[0][2] = t*u_x*u_z + s*u_y;
337 sqrtMatrix[1][0] = t*u_x*u_y + s*u_z;
338 sqrtMatrix[1][1] = t*u_y*u_y + c;
339 sqrtMatrix[1][2] = t*u_y*u_z - s*u_x;
340 sqrtMatrix[2][0] = t*u_x*u_z - s*u_y;
341 sqrtMatrix[2][1] = t*u_y*u_z + s*u_x;
342 sqrtMatrix[2][2] = t*u_z*u_z + c;
348 for (
unsigned int i=0; i<3; i++)
349 for (
unsigned int j=0; j<3; j++)
350 tmpMatrix[i][j] = sqrtMatrix[i][j];
352 for (
unsigned int i=0; i<3; i++)
357 tmpMatrix = tmpMatrix.GetInverse();
360 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
361 for (
unsigned int j = 0;j < InputImageType::ImageDimension;++j)
362 sqrtMatrix[i][3] += tmpMatrix[i][j] * matrixComposition[j][3];
364 itk::Vector <double,InputImageType::ImageDimension+1> centralPointHom, centralPointTr;
365 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
366 centralPointHom[i] = centralPoint[i];
367 centralPointHom[InputImageType::ImageDimension] = 1;
369 centralPointTr = sqrtMatrix * centralPointHom;
371 sqrtMatrix = sqrtMatrix.GetInverse();
373 HomogeneousMatrixType trToCenterMatrix;
374 trToCenterMatrix.SetIdentity();
376 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
377 trToCenterMatrix(i,3) = centralPointTr[i] - centerReal[i];
379 sqrtMatrix = sqrtMatrix * trToCenterMatrix;
383 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
385 offsetVector[i] = sqrtMatrix(i,3);
386 for (
unsigned int j=0; j<3; j++)
387 trsfMatrix(i,j) = sqrtMatrix(i,j);
390 if (m_OutputRealignTransform.IsNull())
391 m_OutputRealignTransform = BaseTransformType::New();
393 m_OutputRealignTransform->SetMatrix(trsfMatrix);
394 m_OutputRealignTransform->SetOffset(offsetVector);
397 template <
class PixelType,
typename ScalarType>
402 SaveRealignTransformFile();
408 template <
class PixelType,
typename ScalarType>
411 if (GetResultfile() !=
"")
413 std::cout <<
"Writing output image to: " << GetResultfile() << std::endl;
414 anima::writeImage <InputImageType> (GetResultfile(),m_OutputImage);
419 template <
class PixelType,
typename ScalarType>
422 if (GetOutputTransformFile() !=
"")
424 std::cout <<
"Writing output transform to: " << GetOutputTransformFile() << std::endl;
425 itk::TransformFileWriter::Pointer writer = itk::TransformFileWriter::New();
428 typename BaseTransformType::Pointer tmpTrsf = BaseTransformType::New();
429 tmpTrsf->SetMatrix(m_OutputTransform->GetMatrix());
430 tmpTrsf->SetOffset(m_OutputTransform->GetOffset());
432 writer->SetInput(tmpTrsf);
433 writer->SetFileName(GetOutputTransformFile());
438 template <
class PixelType,
typename ScalarType>
441 if (GetOutputRealignTransformFile() !=
"")
443 std::cout <<
"Writing output realign transform to: " << GetOutputRealignTransformFile() << std::endl;
444 itk::TransformFileWriter::Pointer writer = itk::TransformFileWriter::New();
445 writer->SetInput(m_OutputRealignTransform);
446 writer->SetFileName(GetOutputRealignTransformFile());
451 template <
class PixelType,
typename ScalarType>
454 m_ReferencePyramid = PyramidType::New();
456 m_ReferencePyramid->SetInput(m_ReferenceImage);
457 m_ReferencePyramid->SetNumberOfLevels(GetNumberOfPyramidLevels());
459 m_ReferencePyramid->Update();
462 m_FloatingPyramid = PyramidType::New();
464 m_FloatingPyramid->SetInput(m_FloatingImage);
465 m_FloatingPyramid->SetNumberOfLevels(GetNumberOfPyramidLevels());
466 m_FloatingPyramid->Update();
TransformType::ParametersType ParametersType
Implements an ITK wrapper for the NLOPT library.
void ComputeRealignTransform(itk::Vector< double, InputImageType::ImageDimension > centralPoint, typename InputImageType::PointType ¢erReal, ParametersType &imageParams)
void SaveResultFile(void)
void SaveTransformFile(void)
TransformType::OffsetType OffsetType
void Update() ITK_OVERRIDE
TransformType::MatrixType MatrixType
void TransformCartesianToSphericalCoordinates(const VectorType &v, VectorType &resVec)
void SaveRealignTransformFile(void)
TransformType::Pointer TransformPointer