4 #include <itkTransformFactoryBase.h> 5 #include <itkTransformFileWriter.h> 6 #include <itkMultiResolutionPyramidImageFilter.h> 7 #include <itkImageRegistrationMethod.h> 8 #include <itkLinearInterpolateImageFunction.h> 15 #include <itkMeanSquaresImageToImageMetric.h> 16 #include <itkMutualInformationHistogramImageToImageMetric.h> 17 #include <itkNormalizedMutualInformationHistogramImageToImageMetric.h> 18 #include <itkCenteredTransformInitializer.h> 19 #include <itkMinimumMaximumImageFilter.h> 24 template <
typename ScalarType>
27 m_ReferenceImage = NULL;
28 m_FloatingImage = NULL;
30 m_OutputTransform = TransformType::New();
31 m_OutputTransform->SetIdentity();
33 m_outputTransformFile =
"";
35 m_InitialTransform = BaseTransformType::New();
36 m_InitialTransform->SetIdentity();
41 m_OptimizerMaximumIterations = 100;
43 m_ReferenceMinimalValue = 0.0;
44 m_FloatingMinimalValue = 0.0;
46 m_UpperBoundAngle = M_PI;
47 m_TranslateUpperBound = 10;
48 m_HistogramSize = 128;
50 m_NumberOfPyramidLevels = 3;
51 m_FastRegistration =
false;
52 this->SetNumberOfWorkUnits(itk::MultiThreaderBase::GetGlobalDefaultNumberOfThreads());
55 template <
typename ScalarType>
60 template <
typename ScalarType>
63 typedef typename itk::ImageRegistrationMethod<InputImageType, InputImageType> RegistrationType;
66 using MinMaxFilterType = itk::MinimumMaximumImageFilter <InputImageType>;
67 typename MinMaxFilterType::Pointer minMaxFilter = MinMaxFilterType::New();
68 minMaxFilter->SetInput(m_ReferenceImage);
69 if (this->GetNumberOfWorkUnits() != 0)
70 minMaxFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
71 minMaxFilter->Update();
73 m_ReferenceMinimalValue = minMaxFilter->GetMinimum();
75 minMaxFilter = MinMaxFilterType::New();
76 minMaxFilter->SetInput(m_FloatingImage);
77 if (this->GetNumberOfWorkUnits() != 0)
78 minMaxFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
79 minMaxFilter->Update();
81 m_FloatingMinimalValue = minMaxFilter->GetMinimum();
84 if (m_ReferenceMinimalValue < 0.0)
85 m_ReferenceMinimalValue = -1024;
87 m_ReferenceMinimalValue = 0.0;
89 if (m_FloatingMinimalValue < 0.0)
90 m_FloatingMinimalValue = -1024;
92 m_FloatingMinimalValue = 0.0;
94 this->SetupPyramids();
97 for (
unsigned int i = 0;i < TransformType::ParametersDimension;++i)
100 unsigned int indexAbsRefMax = 0;
101 typename InputImageType::DirectionType dirRefMatrix = m_ReferenceImage->GetDirection();
102 double valRefMax = std::abs(dirRefMatrix(0,0));
103 for (
unsigned int i = 1;i < InputImageType::ImageDimension;++i)
105 if (std::abs(dirRefMatrix(0,i)) > valRefMax)
107 valRefMax = std::abs(dirRefMatrix(0,i));
113 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
114 directionRefReal[i] = dirRefMatrix(i,indexAbsRefMax);
116 m_OutputTransform->SetReferencePlaneNormal(directionRefReal);
117 m_OutputTransform->SetParameters(initialParams);
119 InputImageType::PointType centralPoint;
120 itk::ContinuousIndex <ScalarType,InputImageType::ImageDimension> centralVoxIndex;
122 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
123 centralVoxIndex[i] = m_ReferenceImage->GetLargestPossibleRegion().GetSize()[i] / 2.0;
125 m_ReferenceImage->TransformContinuousIndexToPhysicalPoint(centralVoxIndex,centralPoint);
127 unsigned int dimension = m_OutputTransform->GetNumberOfParameters();
128 itk::Array<double> lowerBounds(dimension);
129 itk::Array<double> upperBounds(dimension);
131 lowerBounds[0] = - m_UpperBoundAngle;
132 upperBounds[0] = m_UpperBoundAngle;
135 for (
unsigned int i = 0;i < this->GetNumberOfPyramidLevels();++i)
137 std::cout <<
"Processing pyramid level " << i << std::endl;
138 std::cout <<
"Image size: " << m_ReferencePyramid->GetOutput(i)->GetLargestPossibleRegion().GetSize() << std::endl;
141 typename RegistrationType::Pointer reg = RegistrationType::New();
143 reg->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
147 typename OptimizerType::Pointer optimizer = OptimizerType::New();
149 optimizer->SetAlgorithm(NLOPT_LN_BOBYQA);
150 optimizer->SetXTolRel(1.0e-4);
151 optimizer->SetFTolRel(1.0e-6);
152 optimizer->SetMaxEval(m_OptimizerMaximumIterations);
153 optimizer->SetVectorStorageSize(2000);
156 double meanSpacing = 0;
157 for (
unsigned int j = 0;j < InputImageType::ImageDimension;++j)
158 meanSpacing += m_ReferencePyramid->GetOutput(i)->GetSpacing()[j];
160 for (
unsigned int j = 0;j < 2;++j)
162 lowerBounds[j + 1] = m_OutputTransform->GetParameters()[j + 1] - m_TranslateUpperBound * meanSpacing;
163 upperBounds[j + 1] = m_OutputTransform->GetParameters()[j + 1] + m_TranslateUpperBound * meanSpacing;
166 optimizer->SetLowerBoundParameters(lowerBounds);
167 optimizer->SetUpperBoundParameters(upperBounds);
169 reg->SetOptimizer(optimizer);
170 reg->SetTransform(m_OutputTransform);
172 typedef itk::LinearInterpolateImageFunction <InputImageType,double> InterpolatorType;
173 InterpolatorType::Pointer interpolator = InterpolatorType::New();
175 reg->SetInterpolator(interpolator);
181 typedef itk::MutualInformationHistogramImageToImageMetric < InputImageType,InputImageType > MetricType;
182 typename MetricType::Pointer tmpMetric = MetricType::New();
184 MetricType::HistogramType::SizeType histogramSize;
185 histogramSize.SetSize(2);
187 histogramSize[0] = m_HistogramSize;
188 histogramSize[1] = m_HistogramSize;
189 tmpMetric->SetHistogramSize( histogramSize );
191 reg->SetMetric(tmpMetric);
197 typedef itk::NormalizedMutualInformationHistogramImageToImageMetric < InputImageType,InputImageType > MetricType;
198 typename MetricType::Pointer tmpMetric = MetricType::New();
200 MetricType::HistogramType::SizeType histogramSize;
201 histogramSize.SetSize(2);
203 histogramSize[0] = m_HistogramSize;
204 histogramSize[1] = m_HistogramSize;
205 tmpMetric->SetHistogramSize( histogramSize );
207 reg->SetMetric(tmpMetric);
214 typedef itk::MeanSquaresImageToImageMetric < InputImageType,InputImageType > MetricType;
215 typename MetricType::Pointer tmpMetric = MetricType::New();
216 reg->SetMetric(tmpMetric);
221 reg->SetFixedImage(m_ReferencePyramid->GetOutput(i));
222 reg->SetMovingImage(m_FloatingPyramid->GetOutput(i));
224 if (m_FastRegistration)
228 InputImageRegionType::IndexType centralIndex;
229 m_ReferencePyramid->GetOutput(i)->TransformPhysicalPointToIndex(centralPoint,centralIndex);
231 unsigned int baseIndex = centralIndex[indexAbsRefMax];
232 workRegion.SetIndex(indexAbsRefMax,baseIndex);
233 workRegion.SetSize(indexAbsRefMax,1);
235 reg->SetFixedImageRegion(workRegion);
238 reg->SetFixedImageRegion(m_ReferencePyramid->GetOutput(i)->GetLargestPossibleRegion());
240 reg->SetInitialTransformParameters(m_OutputTransform->GetParameters());
246 catch( itk::ExceptionObject & err )
248 std::cout <<
"ExceptionObject caught ! " << err << std::endl;
252 m_OutputTransform->SetParameters(reg->GetLastTransformParameters());
256 typedef itk::Matrix <ScalarType,InputImageType::ImageDimension+1,InputImageType::ImageDimension+1> TransformMatrixType;
257 TransformMatrixType refSymPlaneMatrix, initialMatrix, outputMatrix;
258 refSymPlaneMatrix.SetIdentity();
259 initialMatrix.SetIdentity();
260 outputMatrix.SetIdentity();
262 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
264 for (
unsigned int j = 0;j < InputImageType::ImageDimension;++j)
266 refSymPlaneMatrix(i,j) = m_RefSymmetryTransform->GetMatrix()(i,j);
267 outputMatrix(i,j) = m_OutputTransform->GetMatrix()(i,j);
268 initialMatrix(i,j) = m_InitialTransform->GetMatrix()(i,j);
271 refSymPlaneMatrix(i,3) = m_RefSymmetryTransform->GetOffset()[i];
272 outputMatrix(i,3) = m_OutputTransform->GetOffset()[i];
273 initialMatrix(i,3) = m_InitialTransform->GetOffset()[i];
276 refSymPlaneMatrix = refSymPlaneMatrix.GetInverse();
280 outputMatrix = initialMatrix * outputMatrix * refSymPlaneMatrix;
282 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
284 for (
unsigned int j = 0;j < InputImageType::ImageDimension;++j)
285 tmpOutMatrix(i,j) = outputMatrix(i,j);
287 tmpOffset[i] = outputMatrix(i,3);
290 if (m_OutputRealignTransform.IsNull())
291 m_OutputRealignTransform = BaseTransformType::New();
293 m_OutputRealignTransform->SetMatrix(tmpOutMatrix);
294 m_OutputRealignTransform->SetOffset(tmpOffset);
297 typename ResampleFilterType::Pointer tmpResample = ResampleFilterType::New();
298 tmpResample->SetTransform(m_OutputRealignTransform);
299 tmpResample->SetInput(m_FloatingImage);
301 tmpResample->SetSize(m_ReferenceImage->GetLargestPossibleRegion().GetSize());
302 tmpResample->SetOutputOrigin(m_ReferenceImage->GetOrigin());
303 tmpResample->SetOutputSpacing(m_ReferenceImage->GetSpacing());
304 tmpResample->SetOutputDirection(m_ReferenceImage->GetDirection());
305 tmpResample->SetDefaultPixelValue(m_FloatingMinimalValue);
306 tmpResample->Update();
308 m_OutputImage = tmpResample->GetOutput();
311 template <
typename ScalarType>
314 std::cout <<
"Writing output image to: " << m_resultFile << std::endl;
316 anima::writeImage <InputImageType> (m_resultFile,m_OutputImage);
318 if (m_outputTransformFile !=
"")
320 std::cout <<
"Writing output transform to: " << m_outputTransformFile << std::endl;
321 itk::TransformFileWriter::Pointer writer = itk::TransformFileWriter::New();
322 writer->SetInput(m_OutputRealignTransform);
323 writer->SetFileName(m_outputTransformFile);
328 template <
typename ScalarType>
331 m_InitialTransform->SetIdentity();
333 typedef typename itk::CenteredTransformInitializer<BaseTransformType, InputImageType, InputImageType> TransformInitializerType;
339 typename ResampleFilterType::Pointer tmpResample = ResampleFilterType::New();
340 tmpResample->SetTransform(m_RefSymmetryTransform);
341 tmpResample->SetInput(m_ReferenceImage);
343 tmpResample->SetSize(m_ReferenceImage->GetLargestPossibleRegion().GetSize());
344 tmpResample->SetOutputOrigin(m_ReferenceImage->GetOrigin());
345 tmpResample->SetOutputSpacing(m_ReferenceImage->GetSpacing());
346 tmpResample->SetOutputDirection(m_ReferenceImage->GetDirection());
347 tmpResample->SetDefaultPixelValue(m_ReferenceMinimalValue);
348 tmpResample->Update();
350 initialReferenceImage = tmpResample->GetOutput();
351 initialReferenceImage->DisconnectPipeline();
354 OffsetType directionRefReal, directionFloReal;
357 unsigned int indexAbsRefMax = 0;
358 unsigned int indexAbsFloMax = 0;
359 typename InputImageType::DirectionType dirRefMatrix = m_ReferenceImage->GetDirection();
360 typename InputImageType::DirectionType dirFloMatrix = m_FloatingImage->GetDirection();
361 double valRefMax = std::abs(dirRefMatrix(0,0));
362 double valFloMax = std::abs(dirFloMatrix(0,0));
363 for (
unsigned int i = 1;i < InputImageType::ImageDimension;++i)
365 if (std::abs(dirRefMatrix(0,i)) > valRefMax)
367 valRefMax = std::abs(dirRefMatrix(0,i));
371 if (std::abs(dirFloMatrix(0,i)) > valFloMax)
373 valFloMax = std::abs(dirFloMatrix(0,i));
379 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
381 directionRefReal[i] = dirRefMatrix(i,indexAbsRefMax);
382 directionFloReal[i] = dirFloMatrix(i,indexAbsFloMax);
385 typedef itk::Matrix <ScalarType,InputImageType::ImageDimension+1,InputImageType::ImageDimension+1> TransformMatrixType;
387 TransformMatrixType floRefMatrix;
388 floRefMatrix.SetIdentity();
390 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
392 for (
unsigned int j = 0;j < InputImageType::ImageDimension;++j)
393 floRefMatrix(i,j) = tmpMatrix(i,j);
396 itk::ContinuousIndex <ScalarType,InputImageType::ImageDimension> refImageCenter, floImageCenter;
398 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
399 refImageCenter[i] = m_ReferenceImage->GetLargestPossibleRegion().GetSize()[i] / 2.0;
400 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
401 floImageCenter[i] = m_FloatingImage->GetLargestPossibleRegion().GetSize()[i] / 2.0;
403 typename InputImageType::PointType refCenter, floCenter;
404 m_ReferenceImage->TransformContinuousIndexToPhysicalPoint(refImageCenter,refCenter);
405 m_FloatingImage->TransformContinuousIndexToPhysicalPoint(floImageCenter,floCenter);
407 TransformMatrixType refTranslationMatrix, floTranslationMatrix;
408 refTranslationMatrix.SetIdentity();
409 floTranslationMatrix.SetIdentity();
411 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
413 refTranslationMatrix(i,3) = - refCenter[i];
414 floTranslationMatrix(i,3) = floCenter[i];
417 floRefMatrix = floTranslationMatrix * floRefMatrix * refTranslationMatrix;
419 TransformMatrixType FloatingSymmetry;
420 FloatingSymmetry.SetIdentity();
422 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
424 for (
unsigned int j = 0;j < InputImageType::ImageDimension;++j)
425 FloatingSymmetry(i,j) = m_FloSymmetryTransform->GetMatrix()(i,j);
427 FloatingSymmetry(i,3) = m_FloSymmetryTransform->GetOffset()[i];
430 floRefMatrix = FloatingSymmetry * floRefMatrix;
431 m_OutputTransform->SetCenter(refCenter);
436 for (
unsigned int i = 0;i < InputImageType::ImageDimension;++i)
438 for (
unsigned int j = 0;j < InputImageType::ImageDimension;++j)
439 FloatingMatrix(i,j) = floRefMatrix(i,j);
441 FloatingOffset[i] = floRefMatrix(i,3);
444 m_InitialTransform->SetMatrix(FloatingMatrix);
445 m_InitialTransform->SetOffset(FloatingOffset);
448 m_ReferencePyramid = PyramidType::New();
450 m_ReferencePyramid->SetInput(initialReferenceImage);
451 m_ReferencePyramid->SetNumberOfLevels(m_NumberOfPyramidLevels);
452 typename ResampleFilterType::Pointer refResampler = ResampleFilterType::New();
453 refResampler->SetDefaultPixelValue(m_ReferenceMinimalValue);
454 m_ReferencePyramid->SetImageResampler(refResampler);
455 m_ReferencePyramid->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
456 m_ReferencePyramid->Update();
458 tmpResample = ResampleFilterType::New();
459 tmpResample->SetTransform(m_InitialTransform);
460 tmpResample->SetInput(m_FloatingImage);
462 tmpResample->SetSize(m_ReferenceImage->GetLargestPossibleRegion().GetSize());
463 tmpResample->SetOutputOrigin(m_ReferenceImage->GetOrigin());
464 tmpResample->SetOutputSpacing(m_ReferenceImage->GetSpacing());
465 tmpResample->SetOutputDirection(m_ReferenceImage->GetDirection());
466 tmpResample->SetDefaultPixelValue(m_FloatingMinimalValue);
467 tmpResample->Update();
469 initialFloatingImage = tmpResample->GetOutput();
470 initialFloatingImage->DisconnectPipeline();
473 m_FloatingPyramid = PyramidType::New();
475 m_FloatingPyramid->SetInput(initialFloatingImage);
476 m_FloatingPyramid->SetNumberOfLevels(m_NumberOfPyramidLevels);
477 m_FloatingPyramid->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
479 typename ResampleFilterType::Pointer floResampler = ResampleFilterType::New();
480 floResampler->SetDefaultPixelValue(m_FloatingMinimalValue);
481 m_FloatingPyramid->SetImageResampler(floResampler);
482 m_FloatingPyramid->Update();
PyramidalSymmetryConstrainedRegistrationBridge()
void Update() ITK_OVERRIDE
Implements an ITK wrapper for the NLOPT library.
itk::Matrix< double, 3, 3 > GetRotationMatrixFromVectors(const VectorType &first_direction, const VectorType &second_direction, const unsigned int dimension)
TransformType::MatrixType MatrixType
InputImageType::RegionType InputImageRegionType
TransformType::OffsetType OffsetType
InputImageType::Pointer InputImagePointer
virtual ~PyramidalSymmetryConstrainedRegistrationBridge()