9 #include <itkComposeDisplacementFieldsImageFilter.h> 10 #include <itkVectorLinearInterpolateNearestNeighborExtrapolateImageFunction.h> 12 #include <itkSubtractImageFilter.h> 13 #include <itkMultiplyImageFilter.h> 18 template <
typename TInputImageType>
20 DistortionCorrectionBMRegistrationMethod <TInputImageType>
23 if (m_CurrentTransform)
24 optimizedTransform = m_CurrentTransform;
28 tmpTrsf->SetIdentity();
29 optimizedTransform = tmpTrsf;
33 template <
typename TInputImageType>
41 typedef typename DisplacementFieldTransformType::VectorFieldType VectorFieldType;
42 typedef itk::ComposeDisplacementFieldsImageFilter <VectorFieldType,VectorFieldType> ComposeFilterType;
43 typedef itk::MultiplyImageFilter <VectorFieldType,itk::Image <double, InputImageType::ImageDimension>, VectorFieldType> MultiplyFilterType;
44 typedef typename itk::ImageRegionIterator <VectorFieldType> VectorFieldIterator;
45 typedef typename VectorFieldType::PixelType VectorType;
47 typedef itk::VectorLinearInterpolateNearestNeighborExtrapolateImageFunction <VectorFieldType,
48 typename TransformType::ParametersValueType> VectorInterpolateFunctionType;
50 if (this->GetInitialTransform())
53 positiveTrsf = DisplacementFieldTransformType::New();
55 typename ComposeFilterType::Pointer composePositiveFilter = ComposeFilterType::New();
57 composePositiveFilter->SetWarpingField(currentTrsf->GetParametersAsVectorField());
60 composePositiveFilter->SetDisplacementField(initTrsf->GetParametersAsVectorField());
61 composePositiveFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
63 typename VectorInterpolateFunctionType::Pointer interpolator = VectorInterpolateFunctionType::New();
65 composePositiveFilter->SetInterpolator(interpolator);
66 composePositiveFilter->Update();
67 positiveTrsf->SetParametersAsVectorField(composePositiveFilter->GetOutput());
69 typename MultiplyFilterType::Pointer multiplyInitFilter = MultiplyFilterType::New();
70 multiplyInitFilter->SetInput(initTrsf->GetParametersAsVectorField());
71 multiplyInitFilter->SetConstant(-1.0);
72 multiplyInitFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
74 multiplyInitFilter->Update();
76 typename MultiplyFilterType::Pointer multiplyCurrentFilter = MultiplyFilterType::New();
77 multiplyCurrentFilter->SetInput(currentTrsf->GetParametersAsVectorField());
78 multiplyCurrentFilter->SetConstant(-1.0);
79 multiplyCurrentFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
81 multiplyCurrentFilter->Update();
83 typename ComposeFilterType::Pointer composeNegativeFilter = ComposeFilterType::New();
84 composeNegativeFilter->SetWarpingField(multiplyCurrentFilter->GetOutput());
85 composeNegativeFilter->SetDisplacementField(multiplyInitFilter->GetOutput());
86 composeNegativeFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
88 interpolator = VectorInterpolateFunctionType::New();
90 composeNegativeFilter->SetInterpolator(interpolator);
91 composeNegativeFilter->Update();
92 negativeTrsf = DisplacementFieldTransformType::New();
93 negativeTrsf->SetParametersAsVectorField(composeNegativeFilter->GetOutput());
95 VectorFieldIterator positiveItr(const_cast <VectorFieldType *> (positiveTrsf->GetParametersAsVectorField()),
96 positiveTrsf->GetParametersAsVectorField()->GetLargestPossibleRegion());
98 VectorFieldIterator negativeItr(const_cast <VectorFieldType *> (negativeTrsf->GetParametersAsVectorField()),
99 negativeTrsf->GetParametersAsVectorField()->GetLargestPossibleRegion());
103 while (!positiveItr.IsAtEnd())
105 tmpVec = 0.5 * (positiveItr.Get() - negativeItr.Get());
106 positiveItr.Set(tmpVec);
107 negativeItr.Set(- tmpVec);
117 negativeTrsf = DisplacementFieldTransformType::New();
119 typename MultiplyFilterType::Pointer multiplyFilter = MultiplyFilterType::New();
120 multiplyFilter->SetInput(positiveTrsf->GetParametersAsVectorField());
121 multiplyFilter->SetConstant(-1.0);
122 multiplyFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
124 multiplyFilter->Update();
125 negativeTrsf->SetParametersAsVectorField(multiplyFilter->GetOutput());
129 typedef itk::Image <ImageScalarType, TInputImageType::ImageDimension> InternalScalarImageType;
131 InternalFilterType *resampleFilter = dynamic_cast <InternalFilterType *> (this->GetMovingImageResampler().GetPointer());
134 this->GetMovingImageResampler()->SetInput(this->GetMovingImage());
136 this->GetMovingImageResampler()->Update();
138 movingImage = this->GetMovingImageResampler()->GetOutput();
139 movingImage->DisconnectPipeline();
142 resampleFilter = dynamic_cast <InternalFilterType *> (this->GetReferenceImageResampler().GetPointer());
143 resampleFilter->SetTransform(negativeTrsf);
144 resampleFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
146 this->GetReferenceImageResampler()->SetInput(this->GetFixedImage());
147 this->GetReferenceImageResampler()->Update();
149 refImage = this->GetReferenceImageResampler()->GetOutput();
150 refImage->DisconnectPipeline();
153 template <
typename TInputImageType>
161 anima::GetSVFExponential(addOnCast,positiveDispTrsf.GetPointer(),this->GetExponentiationOrder(),this->GetNumberOfWorkUnits(),
false);
164 anima::GetSVFExponential(addOnCast,negativeDispTrsf.GetPointer(),this->GetExponentiationOrder(),this->GetNumberOfWorkUnits(),
true);
167 anima::composeDistortionCorrections<typename AgregatorType::ScalarType, InputImageType::ImageDimension>
168 (computedTransformCast,positiveDispTrsf,negativeDispTrsf,this->GetNumberOfWorkUnits());
171 if (this->GetSVFElasticRegSigma() > 0)
173 typedef typename DisplacementFieldTransformType::VectorFieldType VectorFieldType;
175 typename SmoothingFilterType::Pointer smootherPtr = SmoothingFilterType::New();
177 smootherPtr->SetInput(computedTransformCast->GetParametersAsVectorField());
178 smootherPtr->SetSigma(this->GetSVFElasticRegSigma());
179 smootherPtr->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
181 smootherPtr->Update();
183 typename VectorFieldType::Pointer tmpSmoothed = smootherPtr->GetOutput();
184 tmpSmoothed->DisconnectPipeline();
185 tmpSmoothed->Register();
187 computedTransformCast->SetParametersAsVectorField(tmpSmoothed);
190 computedTransform = computedTransformCast;
195 template <
typename TInputImageType>
200 itk::TimeProbe tmpTime;
203 this->GetBlockMatcher()->SetForceComputeBlocks(
true);
204 this->GetBlockMatcher()->SetReferenceImage(refImage);
206 this->GetBlockMatcher()->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
207 this->GetBlockMatcher()->Update();
211 if (this->GetVerboseProgression())
212 std::cout <<
"Forward matching performed in " << tmpTime.GetTotal() << std::endl;
214 this->GetAgregator()->SetInputRegions(this->GetBlockMatcher()->GetBlockRegions());
215 this->GetAgregator()->SetInputOrigins(this->GetBlockMatcher()->GetBlockPositions());
217 this->GetAgregator()->SetInputWeights(this->GetBlockMatcher()->GetBlockWeights());
218 this->GetAgregator()->SetInputTransforms(this->GetBlockMatcher()->GetBlockTransformPointers());
222 typedef typename SVFTransformType::VectorFieldType VectorFieldType;
224 typename VectorFieldType::Pointer positiveSVF = const_cast <VectorFieldType *> (tmpTrsf->GetParametersAsVectorField());
225 positiveSVF->DisconnectPipeline();
227 itk::TimeProbe tmpTimeReverse;
228 tmpTimeReverse.Start();
230 this->GetBlockMatcher()->SetReferenceImage(movingImage);
231 this->GetBlockMatcher()->SetMovingImage(refImage);
232 this->GetBlockMatcher()->Update();
234 tmpTimeReverse.Stop();
236 if (this->GetVerboseProgression())
237 std::cout <<
"Backward matching performed in " << tmpTimeReverse.GetTotal() << std::endl;
239 this->GetAgregator()->SetInputRegions(this->GetBlockMatcher()->GetBlockRegions());
240 this->GetAgregator()->SetInputOrigins(this->GetBlockMatcher()->GetBlockPositions());
242 this->GetAgregator()->SetInputWeights(this->GetBlockMatcher()->GetBlockWeights());
243 this->GetAgregator()->SetInputTransforms(this->GetBlockMatcher()->GetBlockTransformPointers());
247 typename VectorFieldType::Pointer negativeSVF = const_cast <VectorFieldType *> (tmpTrsf->GetParametersAsVectorField());
248 negativeSVF->DisconnectPipeline();
250 typedef itk::MultiplyImageFilter <VectorFieldType,itk::Image <double,InputImageType::ImageDimension>,VectorFieldType> MultiplyFilterType;
251 typedef itk::SubtractImageFilter <VectorFieldType,VectorFieldType,VectorFieldType> SubtractFilterType;
253 typename SubtractFilterType::Pointer subFilter = SubtractFilterType::New();
254 subFilter->SetInput1(positiveSVF);
255 subFilter->SetInput2(negativeSVF);
256 subFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
257 subFilter->InPlaceOn();
261 typename MultiplyFilterType::Pointer multiplyFilter = MultiplyFilterType::New();
262 multiplyFilter->SetInput(subFilter->GetOutput());
263 multiplyFilter->SetConstant(0.25);
264 multiplyFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
265 multiplyFilter->InPlaceOn();
267 multiplyFilter->Update();
269 positiveSVF = multiplyFilter->GetOutput();
270 positiveSVF->DisconnectPipeline();
273 tmpTrsf->SetParametersAsVectorField(positiveSVF);
274 addOn = positiveAddOn;
virtual void SetMovingImage(InputImageType *_arg)
TInputImageType InputImageType
DisplacementFieldTransformType::Pointer DisplacementFieldTransformPointer
AgregatorType::BaseOutputTransformType TransformType
rpi::DisplacementFieldTransform< AgregatorScalarType, TInputImageType::ImageDimension > DisplacementFieldTransformType
InputImageType::Pointer InputImagePointer
void GetSVFExponential(itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *baseTrsf, rpi::DisplacementFieldTransform< ScalarType, NDimensions > *resultTransform, unsigned int exponentiationOrder, unsigned int numThreads, bool invert)
itk::StationaryVelocityFieldTransform< AgregatorScalarType, TInputImageType::ImageDimension > SVFTransformType
void SetTransform(TransformType *transform)
TransformType::Pointer TransformPointer