ANIMA  4.0
animaDistortionCorrectionBMRegistrationMethod.hxx
Go to the documentation of this file.
1 #pragma once
3 
8 
9 #include <itkComposeDisplacementFieldsImageFilter.h>
10 #include <itkVectorLinearInterpolateNearestNeighborExtrapolateImageFunction.h>
11 
12 #include <itkSubtractImageFilter.h>
13 #include <itkMultiplyImageFilter.h>
14 
15 namespace anima
16 {
17 
18 template <typename TInputImageType>
19 void
20 DistortionCorrectionBMRegistrationMethod <TInputImageType>
21 ::SetupTransform(TransformPointer &optimizedTransform)
22 {
23  if (m_CurrentTransform)
24  optimizedTransform = m_CurrentTransform;
25  else
26  {
27  DisplacementFieldTransformPointer tmpTrsf = DisplacementFieldTransformType::New();
28  tmpTrsf->SetIdentity();
29  optimizedTransform = tmpTrsf;
30  }
31 }
32 
33 template <typename TInputImageType>
34 void
36 ::ResampleImages(TransformType *currentTransform, InputImagePointer &refImage, InputImagePointer &movingImage)
37 {
40 
41  typedef typename DisplacementFieldTransformType::VectorFieldType VectorFieldType;
42  typedef itk::ComposeDisplacementFieldsImageFilter <VectorFieldType,VectorFieldType> ComposeFilterType;
43  typedef itk::MultiplyImageFilter <VectorFieldType,itk::Image <double, InputImageType::ImageDimension>, VectorFieldType> MultiplyFilterType;
44  typedef typename itk::ImageRegionIterator <VectorFieldType> VectorFieldIterator;
45  typedef typename VectorFieldType::PixelType VectorType;
46 
47  typedef itk::VectorLinearInterpolateNearestNeighborExtrapolateImageFunction <VectorFieldType,
48  typename TransformType::ParametersValueType> VectorInterpolateFunctionType;
49 
50  if (this->GetInitialTransform())
51  {
52  // Here compose and make things straight again
53  positiveTrsf = DisplacementFieldTransformType::New();
54 
55  typename ComposeFilterType::Pointer composePositiveFilter = ComposeFilterType::New();
56  DisplacementFieldTransformType *currentTrsf = dynamic_cast <DisplacementFieldTransformType *> (currentTransform);
57  composePositiveFilter->SetWarpingField(currentTrsf->GetParametersAsVectorField());
58 
59  DisplacementFieldTransformType *initTrsf = dynamic_cast <DisplacementFieldTransformType *> (this->GetInitialTransform().GetPointer());
60  composePositiveFilter->SetDisplacementField(initTrsf->GetParametersAsVectorField());
61  composePositiveFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
62 
63  typename VectorInterpolateFunctionType::Pointer interpolator = VectorInterpolateFunctionType::New();
64 
65  composePositiveFilter->SetInterpolator(interpolator);
66  composePositiveFilter->Update();
67  positiveTrsf->SetParametersAsVectorField(composePositiveFilter->GetOutput());
68 
69  typename MultiplyFilterType::Pointer multiplyInitFilter = MultiplyFilterType::New();
70  multiplyInitFilter->SetInput(initTrsf->GetParametersAsVectorField());
71  multiplyInitFilter->SetConstant(-1.0);
72  multiplyInitFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
73 
74  multiplyInitFilter->Update();
75 
76  typename MultiplyFilterType::Pointer multiplyCurrentFilter = MultiplyFilterType::New();
77  multiplyCurrentFilter->SetInput(currentTrsf->GetParametersAsVectorField());
78  multiplyCurrentFilter->SetConstant(-1.0);
79  multiplyCurrentFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
80 
81  multiplyCurrentFilter->Update();
82 
83  typename ComposeFilterType::Pointer composeNegativeFilter = ComposeFilterType::New();
84  composeNegativeFilter->SetWarpingField(multiplyCurrentFilter->GetOutput());
85  composeNegativeFilter->SetDisplacementField(multiplyInitFilter->GetOutput());
86  composeNegativeFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
87 
88  interpolator = VectorInterpolateFunctionType::New();
89 
90  composeNegativeFilter->SetInterpolator(interpolator);
91  composeNegativeFilter->Update();
92  negativeTrsf = DisplacementFieldTransformType::New();
93  negativeTrsf->SetParametersAsVectorField(composeNegativeFilter->GetOutput());
94 
95  VectorFieldIterator positiveItr(const_cast <VectorFieldType *> (positiveTrsf->GetParametersAsVectorField()),
96  positiveTrsf->GetParametersAsVectorField()->GetLargestPossibleRegion());
97 
98  VectorFieldIterator negativeItr(const_cast <VectorFieldType *> (negativeTrsf->GetParametersAsVectorField()),
99  negativeTrsf->GetParametersAsVectorField()->GetLargestPossibleRegion());
100 
101  // And compose them to get a transformation conform to disto correction requirements
102  VectorType tmpVec;
103  while (!positiveItr.IsAtEnd())
104  {
105  tmpVec = 0.5 * (positiveItr.Get() - negativeItr.Get());
106  positiveItr.Set(tmpVec);
107  negativeItr.Set(- tmpVec);
108 
109  ++positiveItr;
110  ++negativeItr;
111  }
112  }
113  else
114  {
115  // Just use the current transform which is already in good shape
116  positiveTrsf = dynamic_cast <DisplacementFieldTransformType *> (currentTransform);
117  negativeTrsf = DisplacementFieldTransformType::New();
118 
119  typename MultiplyFilterType::Pointer multiplyFilter = MultiplyFilterType::New();
120  multiplyFilter->SetInput(positiveTrsf->GetParametersAsVectorField());
121  multiplyFilter->SetConstant(-1.0);
122  multiplyFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
123 
124  multiplyFilter->Update();
125  negativeTrsf->SetParametersAsVectorField(multiplyFilter->GetOutput());
126  }
127 
128  // Anatomical resampling
129  typedef itk::Image <ImageScalarType, TInputImageType::ImageDimension> InternalScalarImageType;
131  InternalFilterType *resampleFilter = dynamic_cast <InternalFilterType *> (this->GetMovingImageResampler().GetPointer());
132 
133  resampleFilter->SetTransform(positiveTrsf);
134  this->GetMovingImageResampler()->SetInput(this->GetMovingImage());
135 
136  this->GetMovingImageResampler()->Update();
137 
138  movingImage = this->GetMovingImageResampler()->GetOutput();
139  movingImage->DisconnectPipeline();
140 
141  // Fixed image resampling
142  resampleFilter = dynamic_cast <InternalFilterType *> (this->GetReferenceImageResampler().GetPointer());
143  resampleFilter->SetTransform(negativeTrsf);
144  resampleFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
145 
146  this->GetReferenceImageResampler()->SetInput(this->GetFixedImage());
147  this->GetReferenceImageResampler()->Update();
148 
149  refImage = this->GetReferenceImageResampler()->GetOutput();
150  refImage->DisconnectPipeline();
151 }
152 
153 template <typename TInputImageType>
154 bool
156 ::ComposeAddOnWithTransform(TransformPointer &computedTransform, TransformType *addOn)
157 {
158  // Now compute positive and negative updated transform
159  DisplacementFieldTransformPointer positiveDispTrsf = DisplacementFieldTransformType::New();
160  SVFTransformType *addOnCast = dynamic_cast <SVFTransformType *> (addOn);
161  anima::GetSVFExponential(addOnCast,positiveDispTrsf.GetPointer(),this->GetExponentiationOrder(),this->GetNumberOfWorkUnits(),false);
162 
163  DisplacementFieldTransformPointer negativeDispTrsf = DisplacementFieldTransformType::New();
164  anima::GetSVFExponential(addOnCast,negativeDispTrsf.GetPointer(),this->GetExponentiationOrder(),this->GetNumberOfWorkUnits(),true);
165 
166  DisplacementFieldTransformPointer computedTransformCast = dynamic_cast <DisplacementFieldTransformType *> (computedTransform.GetPointer());
167  anima::composeDistortionCorrections<typename AgregatorType::ScalarType, InputImageType::ImageDimension>
168  (computedTransformCast,positiveDispTrsf,negativeDispTrsf,this->GetNumberOfWorkUnits());
169 
170  // Smooth (elastic)
171  if (this->GetSVFElasticRegSigma() > 0)
172  {
173  typedef typename DisplacementFieldTransformType::VectorFieldType VectorFieldType;
175  typename SmoothingFilterType::Pointer smootherPtr = SmoothingFilterType::New();
176 
177  smootherPtr->SetInput(computedTransformCast->GetParametersAsVectorField());
178  smootherPtr->SetSigma(this->GetSVFElasticRegSigma());
179  smootherPtr->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
180 
181  smootherPtr->Update();
182 
183  typename VectorFieldType::Pointer tmpSmoothed = smootherPtr->GetOutput();
184  tmpSmoothed->DisconnectPipeline();
185  tmpSmoothed->Register();
186 
187  computedTransformCast->SetParametersAsVectorField(tmpSmoothed);
188  }
189 
190  computedTransform = computedTransformCast;
191 
192  return true;
193 }
194 
195 template <typename TInputImageType>
196 void
198 ::PerformOneIteration(InputImageType *refImage, InputImageType *movingImage, TransformPointer &addOn)
199 {
200  itk::TimeProbe tmpTime;
201  tmpTime.Start();
202 
203  this->GetBlockMatcher()->SetForceComputeBlocks(true);
204  this->GetBlockMatcher()->SetReferenceImage(refImage);
205  this->GetBlockMatcher()->SetMovingImage(movingImage);
206  this->GetBlockMatcher()->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
207  this->GetBlockMatcher()->Update();
208 
209  tmpTime.Stop();
210 
211  if (this->GetVerboseProgression())
212  std::cout << "Forward matching performed in " << tmpTime.GetTotal() << std::endl;
213 
214  this->GetAgregator()->SetInputRegions(this->GetBlockMatcher()->GetBlockRegions());
215  this->GetAgregator()->SetInputOrigins(this->GetBlockMatcher()->GetBlockPositions());
216 
217  this->GetAgregator()->SetInputWeights(this->GetBlockMatcher()->GetBlockWeights());
218  this->GetAgregator()->SetInputTransforms(this->GetBlockMatcher()->GetBlockTransformPointers());
219 
220  TransformPointer positiveAddOn = this->GetAgregator()->GetOutput();
221 
222  typedef typename SVFTransformType::VectorFieldType VectorFieldType;
223  SVFTransformType *tmpTrsf = dynamic_cast <SVFTransformType *> (positiveAddOn.GetPointer());
224  typename VectorFieldType::Pointer positiveSVF = const_cast <VectorFieldType *> (tmpTrsf->GetParametersAsVectorField());
225  positiveSVF->DisconnectPipeline();
226 
227  itk::TimeProbe tmpTimeReverse;
228  tmpTimeReverse.Start();
229 
230  this->GetBlockMatcher()->SetReferenceImage(movingImage);
231  this->GetBlockMatcher()->SetMovingImage(refImage);
232  this->GetBlockMatcher()->Update();
233 
234  tmpTimeReverse.Stop();
235 
236  if (this->GetVerboseProgression())
237  std::cout << "Backward matching performed in " << tmpTimeReverse.GetTotal() << std::endl;
238 
239  this->GetAgregator()->SetInputRegions(this->GetBlockMatcher()->GetBlockRegions());
240  this->GetAgregator()->SetInputOrigins(this->GetBlockMatcher()->GetBlockPositions());
241 
242  this->GetAgregator()->SetInputWeights(this->GetBlockMatcher()->GetBlockWeights());
243  this->GetAgregator()->SetInputTransforms(this->GetBlockMatcher()->GetBlockTransformPointers());
244 
245  TransformPointer negativeAddOn = this->GetAgregator()->GetOutput();
246  tmpTrsf = dynamic_cast <SVFTransformType *> (negativeAddOn.GetPointer());
247  typename VectorFieldType::Pointer negativeSVF = const_cast <VectorFieldType *> (tmpTrsf->GetParametersAsVectorField());
248  negativeSVF->DisconnectPipeline();
249 
250  typedef itk::MultiplyImageFilter <VectorFieldType,itk::Image <double,InputImageType::ImageDimension>,VectorFieldType> MultiplyFilterType;
251  typedef itk::SubtractImageFilter <VectorFieldType,VectorFieldType,VectorFieldType> SubtractFilterType;
252 
253  typename SubtractFilterType::Pointer subFilter = SubtractFilterType::New();
254  subFilter->SetInput1(positiveSVF);
255  subFilter->SetInput2(negativeSVF);
256  subFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
257  subFilter->InPlaceOn();
258 
259  subFilter->Update();
260 
261  typename MultiplyFilterType::Pointer multiplyFilter = MultiplyFilterType::New();
262  multiplyFilter->SetInput(subFilter->GetOutput());
263  multiplyFilter->SetConstant(0.25);
264  multiplyFilter->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
265  multiplyFilter->InPlaceOn();
266 
267  multiplyFilter->Update();
268 
269  positiveSVF = multiplyFilter->GetOutput();
270  positiveSVF->DisconnectPipeline();
271 
272  tmpTrsf = dynamic_cast <SVFTransformType *> (positiveAddOn.GetPointer());
273  tmpTrsf->SetParametersAsVectorField(positiveSVF);
274  addOn = positiveAddOn;
275 }
276 
277 }
virtual void SetMovingImage(InputImageType *_arg)
DisplacementFieldTransformType::Pointer DisplacementFieldTransformPointer
AgregatorType::BaseOutputTransformType TransformType
rpi::DisplacementFieldTransform< AgregatorScalarType, TInputImageType::ImageDimension > DisplacementFieldTransformType
void GetSVFExponential(itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *baseTrsf, rpi::DisplacementFieldTransform< ScalarType, NDimensions > *resultTransform, unsigned int exponentiationOrder, unsigned int numThreads, bool invert)
itk::StationaryVelocityFieldTransform< AgregatorScalarType, TInputImageType::ImageDimension > SVFTransformType
void SetTransform(TransformType *transform)