4 #include <itkResampleImageFilter.h> 5 #include <itkMinimumMaximumImageFilter.h> 21 template <
unsigned int ImageDimension>
24 m_ReferenceImage =
nullptr;
25 m_FloatingImage =
nullptr;
27 m_OutputTransform = BaseTransformType::New();
28 m_OutputTransform->SetIdentity();
30 m_outputTransformFile =
"";
32 m_OutputImage =
nullptr;
34 m_ReferenceMinimalValue = 0.0;
35 m_FloatingMinimalValue = 0.0;
43 m_AffineDirection = 1;
47 m_MaximumIterations = 10;
48 m_MinimalTransformError = 0.01;
49 m_OptimizerMaximumIterations = 100;
51 m_SearchAngleRadius = 5;
52 m_SearchScaleRadius = 0.1;
53 m_FinalRadius = 0.001;
55 m_TranslateUpperBound = 50;
56 m_AngleUpperBound = 180;
57 m_ScaleUpperBound = 3;
59 m_ExtrapolationSigma = 3;
62 m_MEstimateConvergenceThreshold = 0.01;
63 m_NeighborhoodApproximation = 2.5;
64 m_BCHCompositionOrder = 1;
65 m_ExponentiationOrder = 1;
66 m_NumberOfPyramidLevels = 3;
67 m_LastPyramidLevel = 0;
68 m_PercentageKept = 0.8;
69 this->SetNumberOfWorkUnits(itk::MultiThreaderBase::GetGlobalDefaultNumberOfThreads());
74 m_callback = itk::CStyleCommand::New();
75 m_callback->SetClientData ((
void *)
this);
76 m_callback->SetCallback (ManageProgress);
79 template <
unsigned int ImageDimension>
84 template <
unsigned int ImageDimension>
94 template <
unsigned int ImageDimension>
101 m_progressReporter =
new itk::ProgressReporter(
this, 0, GetNumberOfPyramidLevels()*m_MaximumIterations);
102 this->AddObserver(itk::ProgressEvent(), m_progressCallback);
104 this->InvokeEvent(itk::StartEvent());
107 using MinMaxFilterType = itk::MinimumMaximumImageFilter <InputImageType>;
108 typename MinMaxFilterType::Pointer minMaxFilter = MinMaxFilterType::New();
109 minMaxFilter->SetInput(m_ReferenceImage);
110 if (this->GetNumberOfWorkUnits() != 0)
111 minMaxFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
112 minMaxFilter->Update();
114 m_ReferenceMinimalValue = minMaxFilter->GetMinimum();
116 minMaxFilter = MinMaxFilterType::New();
117 minMaxFilter->SetInput(m_FloatingImage);
118 if (this->GetNumberOfWorkUnits() != 0)
119 minMaxFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
120 minMaxFilter->Update();
122 m_FloatingMinimalValue = minMaxFilter->GetMinimum();
125 if (m_ReferenceMinimalValue < 0.0)
126 m_ReferenceMinimalValue = -1024;
128 m_ReferenceMinimalValue = 0.0;
130 if (m_FloatingMinimalValue < 0.0)
131 m_FloatingMinimalValue = -1024;
133 m_FloatingMinimalValue = 0.0;
135 this->SetupPyramids();
138 for (
unsigned int i = 0;i < m_ReferencePyramid->GetNumberOfLevels();++i)
140 if (i + m_LastPyramidLevel >= m_ReferencePyramid->GetNumberOfLevels())
143 typename InputImageType::Pointer refImage = m_ReferencePyramid->GetOutput(i);
144 refImage->DisconnectPipeline();
146 typename InputImageType::Pointer floImage = m_FloatingPyramid->GetOutput(i);
147 floImage->DisconnectPipeline();
149 typename MaskImageType::Pointer maskGenerationImage = ITK_NULLPTR;
150 if (m_BlockGenerationPyramid)
152 maskGenerationImage = m_BlockGenerationPyramid->GetOutput(i);
153 maskGenerationImage->DisconnectPipeline();
157 if (m_OutputTransform->GetParametersAsVectorField() != NULL)
159 typedef itk::ResampleImageFilter<VelocityFieldType,VelocityFieldType> VectorResampleFilterType;
160 typedef typename VectorResampleFilterType::Pointer VectorResampleFilterPointer;
163 tmpIdentity->SetIdentity();
165 VectorResampleFilterPointer tmpResample = VectorResampleFilterType::New();
166 tmpResample->SetTransform(tmpIdentity);
167 tmpResample->SetInput(m_OutputTransform->GetParametersAsVectorField());
169 tmpResample->SetSize(refImage->GetLargestPossibleRegion().GetSize());
170 tmpResample->SetOutputOrigin(refImage->GetOrigin());
171 tmpResample->SetOutputSpacing(refImage->GetSpacing());
172 tmpResample->SetOutputDirection(refImage->GetDirection());
174 tmpResample->Update();
177 m_OutputTransform->SetParametersAsVectorField(tmpOut);
178 tmpOut->DisconnectPipeline();
183 std::cout <<
"Processing pyramid level " << i << std::endl;
184 std::cout <<
"Image size: " << refImage->GetLargestPossibleRegion().GetSize() << std::endl;
187 double meanSpacing = 0;
188 for (
unsigned int j = 0;j < ImageDimension;++j)
189 meanSpacing += refImage->GetSpacing()[j];
190 meanSpacing /= ImageDimension;
202 if (this->GetNumberOfWorkUnits() != 0)
208 agreg->
SetDistanceBoundary(m_ExtrapolationSigma * meanSpacing * m_NeighborhoodApproximation);
220 if (this->GetNumberOfWorkUnits() != 0)
233 BlockMatcherType *mainMatcher =
new BlockMatcherType;
234 BlockMatcherType *reverseMatcher = 0;
236 mainMatcher->SetBlockSize(GetBlockSize());
237 mainMatcher->SetBlockSpacing(GetBlockSpacing());
238 mainMatcher->SetBlockVarianceThreshold(GetStDevThreshold() * GetStDevThreshold());
239 mainMatcher->SetBlockGenerationMask(maskGenerationImage);
240 mainMatcher->SetDefaultBackgroundValue(m_FloatingMinimalValue);
242 switch (m_SymmetryType)
247 m_bmreg = BlockMatchRegistrationType::New();
254 typename BlockMatchRegistrationType::Pointer tmpReg = BlockMatchRegistrationType::New();
256 reverseMatcher =
new BlockMatcherType;
257 reverseMatcher->SetBlockPercentageKept(GetPercentageKept());
258 reverseMatcher->SetBlockSize(GetBlockSize());
259 reverseMatcher->SetBlockSpacing(GetBlockSpacing());
260 reverseMatcher->SetBlockVarianceThreshold(GetStDevThreshold() * GetStDevThreshold());
261 reverseMatcher->SetBlockGenerationMask(maskGenerationImage);
262 reverseMatcher->SetDefaultBackgroundValue(m_ReferenceMinimalValue);
263 reverseMatcher->SetVerbose(m_Verbose);
265 tmpReg->SetReverseBlockMatcher(reverseMatcher);
273 typename BlockMatchRegistrationType::Pointer tmpReg = BlockMatchRegistrationType::New();
274 tmpReg->SetReferenceBackgroundValue(m_ReferenceMinimalValue);
275 tmpReg->SetFloatingBackgroundValue(m_FloatingMinimalValue);
282 mainMatcher->SetVerbose(m_Verbose);
283 m_bmreg->SetBlockMatcher(mainMatcher);
284 m_bmreg->SetBCHCompositionOrder(m_BCHCompositionOrder);
285 m_bmreg->SetExponentiationOrder(m_ExponentiationOrder);
287 if (m_progressCallback)
292 m_bmreg->AddObserver(itk::ProgressEvent(), m_callback);
295 m_bmreg->SetAgregator(agregPtr);
297 if (this->GetNumberOfWorkUnits() != 0)
298 m_bmreg->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
300 m_bmreg->SetFixedImage(refImage);
301 m_bmreg->SetMovingImage(floImage);
303 m_bmreg->SetSVFElasticRegSigma(m_ElasticSigma * meanSpacing);
308 typename ResampleFilterType::Pointer refResampler = ResampleFilterType::New();
309 refResampler->SetSize(floImage->GetLargestPossibleRegion().GetSize());
310 refResampler->SetOutputOrigin(floImage->GetOrigin());
311 refResampler->SetOutputSpacing(floImage->GetSpacing());
312 refResampler->SetOutputDirection(floImage->GetDirection());
313 refResampler->SetDefaultPixelValue(m_ReferenceMinimalValue);
314 refResampler->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
315 m_bmreg->SetReferenceImageResampler(refResampler);
317 typename ResampleFilterType::Pointer movingResampler = ResampleFilterType::New();
318 movingResampler->SetSize(refImage->GetLargestPossibleRegion().GetSize());
319 movingResampler->SetOutputOrigin(refImage->GetOrigin());
320 movingResampler->SetOutputSpacing(refImage->GetSpacing());
321 movingResampler->SetOutputDirection(refImage->GetDirection());
322 movingResampler->SetDefaultPixelValue(m_FloatingMinimalValue);
323 movingResampler->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
324 m_bmreg->SetMovingImageResampler(movingResampler);
326 switch (GetTransform())
340 case Directional_Affine:
341 mainMatcher->SetBlockTransformType(BlockMatcherType::Superclass::Directional_Affine);
342 mainMatcher->SetAffineDirection(m_AffineDirection);
345 reverseMatcher->SetBlockTransformType(BlockMatcherType::Superclass::Directional_Affine);
346 reverseMatcher->SetAffineDirection(m_AffineDirection);
358 switch (GetOptimizer())
394 m_bmreg->SetMaximumIterations(m_MaximumIterations);
395 m_bmreg->SetMinimalTransformError(m_MinimalTransformError);
396 m_bmreg->SetInitialTransform(m_OutputTransform.GetPointer());
398 mainMatcher->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
399 mainMatcher->SetOptimizerMaximumIterations(GetOptimizerMaximumIterations());
401 double sr = GetSearchRadius();
402 mainMatcher->SetSearchRadius(sr);
404 double sar = GetSearchAngleRadius();
405 mainMatcher->SetSearchAngleRadius(sar);
407 double scr = GetSearchScaleRadius();
408 mainMatcher->SetSearchScaleRadius(scr);
410 double fr = GetFinalRadius();
411 mainMatcher->SetFinalRadius(fr);
413 double ss = GetStepSize();
414 mainMatcher->SetStepSize(ss);
416 double tub = GetTranslateUpperBound();
417 mainMatcher->SetTranslateMax(tub);
419 double aub = GetAngleUpperBound();
420 mainMatcher->SetAngleMax(aub);
422 double scub = GetScaleUpperBound();
423 mainMatcher->SetScaleMax(scub);
427 reverseMatcher->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
428 reverseMatcher->SetOptimizerMaximumIterations(GetOptimizerMaximumIterations());
430 reverseMatcher->SetSearchRadius(sr);
431 reverseMatcher->SetSearchAngleRadius(sar);
432 reverseMatcher->SetSearchScaleRadius(scr);
433 reverseMatcher->SetFinalRadius(fr);
434 reverseMatcher->SetStepSize(ss);
435 reverseMatcher->SetTranslateMax(tub);
436 reverseMatcher->SetAngleMax(aub);
437 reverseMatcher->SetScaleMax(scub);
440 m_bmreg->SetVerboseProgression(m_Verbose);
446 catch( itk::ExceptionObject & err )
448 std::cout <<
"ExceptionObject caught !" << err << std::endl;
453 m_OutputTransform->SetParametersAsVectorField(resTrsf->GetParametersAsVectorField());
457 delete reverseMatcher;
463 std::cout <<
"Process aborted" << std::endl;
465 this->InvokeEvent(itk::EndEvent());
470 typedef itk::MultiplyImageFilter <VelocityFieldType,itk::Image <double,ImageDimension>,
VelocityFieldType> MultiplyFilterType;
472 typename MultiplyFilterType::Pointer fieldMultiplier = MultiplyFilterType::New();
473 fieldMultiplier->SetInput(finalTrsfField);
474 fieldMultiplier->SetConstant(2.0);
475 fieldMultiplier->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
476 fieldMultiplier->InPlaceOn();
478 fieldMultiplier->Update();
481 m_OutputTransform->SetParametersAsVectorField(fieldMultiplier->GetOutput());
482 outputField->DisconnectPipeline();
486 anima::GetSVFExponential(m_OutputTransform.GetPointer(), outputDispTrsf.GetPointer(), m_ExponentiationOrder, GetNumberOfWorkUnits(),
false);
490 typename ResampleFilterType::Pointer tmpResample = ResampleFilterType::New();
491 tmpResample->SetTransform(outputDispTrsf);
492 tmpResample->SetInput(m_FloatingImage);
494 tmpResample->SetSize(m_ReferenceImage->GetLargestPossibleRegion().GetSize());
495 tmpResample->SetOutputOrigin(m_ReferenceImage->GetOrigin());
496 tmpResample->SetOutputSpacing(m_ReferenceImage->GetSpacing());
497 tmpResample->SetOutputDirection(m_ReferenceImage->GetDirection());
498 tmpResample->SetDefaultPixelValue(m_FloatingMinimalValue);
499 tmpResample->Update();
501 m_OutputImage = tmpResample->GetOutput();
502 m_OutputImage->DisconnectPipeline();
505 template <
unsigned int ImageDimension>
511 anima::GetSVFExponential(m_OutputTransform.GetPointer(), outputDispTrsf.GetPointer(), m_ExponentiationOrder, this->GetNumberOfWorkUnits(),
false);
513 return outputDispTrsf;
516 template <
unsigned int ImageDimension>
520 if (m_progressReporter)
521 m_progressReporter->CompletedPixel();
524 template <
unsigned int ImageDimension>
528 itk::ProcessObject *processObject = (itk::ProcessObject *) caller;
530 if (source && processObject)
531 source->
EmitProgress(processObject->GetProgress() * 100);
534 template <
unsigned int ImageDimension>
538 std::cout <<
"Writing output image to: " << m_resultFile << std::endl;
539 anima::writeImage <InputImageType> (m_resultFile,m_OutputImage);
541 if (m_outputTransformFile !=
"")
543 std::cout <<
"Writing output SVF to: " << m_outputTransformFile << std::endl;
544 anima::writeImage <VelocityFieldType> (m_outputTransformFile,
545 const_cast <
VelocityFieldType *> (m_OutputTransform->GetParametersAsVectorField()));
549 template <
unsigned int ImageDimension>
554 m_ReferencePyramid = PyramidType::New();
556 m_ReferencePyramid->SetInput(m_ReferenceImage);
557 m_ReferencePyramid->SetNumberOfLevels(m_NumberOfPyramidLevels);
559 if (this->GetNumberOfWorkUnits() != 0)
560 m_ReferencePyramid->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
565 typename ResampleFilterType::Pointer refResampler = ResampleFilterType::New();
566 refResampler->SetDefaultPixelValue(m_ReferenceMinimalValue);
567 m_ReferencePyramid->SetImageResampler(refResampler);
569 m_ReferencePyramid->Update();
572 m_FloatingPyramid = PyramidType::New();
574 m_FloatingPyramid->SetInput(m_FloatingImage);
575 m_FloatingPyramid->SetNumberOfLevels(m_NumberOfPyramidLevels);
577 if (this->GetNumberOfWorkUnits() != 0)
578 m_FloatingPyramid->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
580 typename ResampleFilterType::Pointer floResampler = ResampleFilterType::New();
581 floResampler->SetDefaultPixelValue(m_FloatingMinimalValue);
582 m_FloatingPyramid->SetImageResampler(floResampler);
584 m_FloatingPyramid->Update();
586 m_BlockGenerationPyramid = 0;
587 if (m_BlockGenerationMask)
592 typename MaskResampleFilterType::Pointer maskResampler = MaskResampleFilterType::New();
594 m_BlockGenerationPyramid = MaskPyramidType::New();
595 m_BlockGenerationPyramid->SetImageResampler(maskResampler);
596 m_BlockGenerationPyramid->SetInput(m_BlockGenerationMask);
597 m_BlockGenerationPyramid->SetNumberOfLevels(GetNumberOfPyramidLevels());
598 m_BlockGenerationPyramid->SetNumberOfWorkUnits(GetNumberOfWorkUnits());
599 m_BlockGenerationPyramid->Update();
void Update() ITK_OVERRIDE
PyramidalDenseSVFMatchingBridge()
itk::Image< unsigned char, ImageDimension > MaskImageType
virtual ~PyramidalDenseSVFMatchingBridge()
void EmitProgress(int prog)
BaseTransformType::VectorFieldType VelocityFieldType
itk::Image< double, ImageDimension > InputImageType
static void ManageProgress(itk::Object *caller, const itk::EventObject &event, void *clientData)
AffineTransformType::Pointer AffineTransformPointer
DisplacementFieldTransformType::Pointer DisplacementFieldTransformPointer
void SetBlockPercentageKept(double val)
void GetSVFExponential(itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *baseTrsf, rpi::DisplacementFieldTransform< ScalarType, NDimensions > *resultTransform, unsigned int exponentiationOrder, unsigned int numThreads, bool invert)
DisplacementFieldTransformPointer GetOutputDisplacementFieldTransform()
MEstimateAgregatorType::BaseOutputTransformType BaseTransformType