ANIMA  4.0
animaVelocityUtils.hxx
Go to the documentation of this file.
1 #pragma once
2 #include "animaVelocityUtils.h"
3 
4 #include <itkAddImageFilter.h>
5 #include <itkMultiplyImageFilter.h>
6 
7 #include <itkComposeDisplacementFieldsImageFilter.h>
8 #include <itkVectorLinearInterpolateNearestNeighborExtrapolateImageFunction.h>
11 
12 namespace anima
13 {
14 
15 template <class ScalarType, unsigned int NDimensions>
16 void composeSVF(itk::StationaryVelocityFieldTransform <ScalarType,NDimensions> *baseTrsf,
17  itk::StationaryVelocityFieldTransform <ScalarType,NDimensions> *addonTrsf,
18  unsigned int numThreads, unsigned int bchOrder)
19 {
20  if ((baseTrsf->GetParametersAsVectorField() == NULL)&&(addonTrsf->GetParametersAsVectorField() == NULL))
21  return;
22 
23  if ((bchOrder > 4)||(bchOrder < 1))
24  throw itk::ExceptionObject(__FILE__,__LINE__,"Invalid BCH order, not implemented yet",ITK_LOCATION);
25 
26  typedef typename itk::StationaryVelocityFieldTransform <ScalarType,NDimensions>::VectorFieldType VelocityFieldType;
27 
28  if (baseTrsf->GetParametersAsVectorField() == NULL)
29  {
30  baseTrsf->SetParametersAsVectorField(addonTrsf->GetParametersAsVectorField());
31  return;
32  }
33 
34  typedef itk::AddImageFilter <VelocityFieldType, VelocityFieldType> AddFilterType;
35  typedef itk::MultiplyImageFilter <VelocityFieldType, itk::Image <double, NDimensions>,
36  VelocityFieldType> MultiplyConstFilterType;
37 
38  typename AddFilterType::Pointer bchAdder = AddFilterType::New();
39  bchAdder->SetInput(0,baseTrsf->GetParametersAsVectorField());
40  bchAdder->SetInput(1,addonTrsf->GetParametersAsVectorField());
41 
42  if (numThreads > 0)
43  bchAdder->SetNumberOfWorkUnits(numThreads);
44 
45  bchAdder->Update();
46 
47  typename VelocityFieldType::Pointer resField = bchAdder->GetOutput();
48  resField->DisconnectPipeline();
49 
51  typename LieBracketFilterType::JacobianImagePointer baseTrsfJac, addonTrsfJac;
52  typename LieBracketFilterType::OutputImagePointer previousLieBracket;
53 
54  if (bchOrder >= 2)
55  {
56  // Compute Lie bracket and add half of it to the output
57  typename LieBracketFilterType::Pointer lieBracketFilter = LieBracketFilterType::New();
58 
59  lieBracketFilter->SetInput(0,baseTrsf->GetParametersAsVectorField());
60  lieBracketFilter->SetInput(1,addonTrsf->GetParametersAsVectorField());
61 
62  if (numThreads > 0)
63  lieBracketFilter->SetNumberOfWorkUnits(numThreads);
64 
65  lieBracketFilter->Update();
66  baseTrsfJac = lieBracketFilter->GetFirstFieldJacobian();
67  baseTrsfJac->DisconnectPipeline();
68  addonTrsfJac = lieBracketFilter->GetSecondFieldJacobian();
69  addonTrsfJac->DisconnectPipeline();
70  previousLieBracket = lieBracketFilter->GetOutput();
71  previousLieBracket->DisconnectPipeline();
72 
73  typename MultiplyConstFilterType::Pointer bracketMultiplier = MultiplyConstFilterType::New();
74  bracketMultiplier->SetInput(0,previousLieBracket);
75  bracketMultiplier->SetConstant(0.5);
76 
77  if (numThreads > 0)
78  bracketMultiplier->SetNumberOfWorkUnits(numThreads);
79 
80  bracketMultiplier->Update();
81 
82  typename AddFilterType::Pointer bchSecondAdder = AddFilterType::New();
83  bchSecondAdder->SetInput(0,resField);
84  bchSecondAdder->SetInput(1,bracketMultiplier->GetOutput());
85 
86  if (numThreads > 0)
87  bchSecondAdder->SetNumberOfWorkUnits(numThreads);
88 
89  bchSecondAdder->Update();
90 
91  resField = bchSecondAdder->GetOutput();
92  resField->DisconnectPipeline();
93  }
94 
95  if (bchOrder >= 3)
96  {
97  // Compute Lie bracket one way and add 1/12 of it to the output
98  typename LieBracketFilterType::Pointer lieBracketFilter = LieBracketFilterType::New();
99 
100  lieBracketFilter->SetInput(0,baseTrsf->GetParametersAsVectorField());
101  lieBracketFilter->SetInput(1,previousLieBracket);
102  lieBracketFilter->SetFirstFieldJacobian(baseTrsfJac);
103 
104  if (numThreads > 0)
105  lieBracketFilter->SetNumberOfWorkUnits(numThreads);
106 
107  lieBracketFilter->Update();
108 
109  typename MultiplyConstFilterType::Pointer bracketMultiplier = MultiplyConstFilterType::New();
110  bracketMultiplier->SetInput(0,lieBracketFilter->GetOutput());
111  bracketMultiplier->SetConstant(1.0 / 12);
112 
113  if (numThreads > 0)
114  bracketMultiplier->SetNumberOfWorkUnits(numThreads);
115 
116  bracketMultiplier->Update();
117 
118  typename AddFilterType::Pointer bchSecondAdder = AddFilterType::New();
119  bchSecondAdder->SetInput(0,resField);
120  bchSecondAdder->SetInput(1,bracketMultiplier->GetOutput());
121 
122  if (numThreads > 0)
123  bchSecondAdder->SetNumberOfWorkUnits(numThreads);
124 
125  bchSecondAdder->Update();
126 
127  resField = bchSecondAdder->GetOutput();
128  resField->DisconnectPipeline();
129 
130  // Compute Lie bracket the other way round and add 1/12 of it to the output
131  typename LieBracketFilterType::Pointer reverseLieBracketFilter = LieBracketFilterType::New();
132 
133  reverseLieBracketFilter->SetInput(0,previousLieBracket);
134  reverseLieBracketFilter->SetInput(1,addonTrsf->GetParametersAsVectorField());
135  reverseLieBracketFilter->SetFirstFieldJacobian(lieBracketFilter->GetSecondFieldJacobian());
136  reverseLieBracketFilter->SetSecondFieldJacobian(addonTrsfJac);
137 
138  if (numThreads > 0)
139  reverseLieBracketFilter->SetNumberOfWorkUnits(numThreads);
140 
141  reverseLieBracketFilter->Update();
142 
143  bracketMultiplier = MultiplyConstFilterType::New();
144  bracketMultiplier->SetInput(0,reverseLieBracketFilter->GetOutput());
145  bracketMultiplier->SetConstant(1.0 / 12);
146 
147  if (numThreads > 0)
148  bracketMultiplier->SetNumberOfWorkUnits(numThreads);
149 
150  bracketMultiplier->Update();
151 
152  bchSecondAdder = AddFilterType::New();
153  bchSecondAdder->SetInput(0,resField);
154  bchSecondAdder->SetInput(1,bracketMultiplier->GetOutput());
155 
156  if (numThreads > 0)
157  bchSecondAdder->SetNumberOfWorkUnits(numThreads);
158 
159  bchSecondAdder->Update();
160 
161  resField = bchSecondAdder->GetOutput();
162  resField->DisconnectPipeline();
163 
164  previousLieBracket = lieBracketFilter->GetOutput();
165  previousLieBracket->DisconnectPipeline();
166  }
167 
168  if (bchOrder == 4)
169  {
170  // Compute Lie bracket one way and add 1/24 of it to the output
171  typename LieBracketFilterType::Pointer lieBracketFilter = LieBracketFilterType::New();
172 
173  lieBracketFilter->SetInput(0,previousLieBracket);
174  lieBracketFilter->SetInput(1,addonTrsf->GetParametersAsVectorField());
175  lieBracketFilter->SetSecondFieldJacobian(addonTrsfJac);
176 
177  if (numThreads > 0)
178  lieBracketFilter->SetNumberOfWorkUnits(numThreads);
179 
180  lieBracketFilter->Update();
181 
182  typename MultiplyConstFilterType::Pointer bracketMultiplier = MultiplyConstFilterType::New();
183  bracketMultiplier->SetInput(0,lieBracketFilter->GetOutput());
184  bracketMultiplier->SetConstant(1.0 / 24);
185 
186  if (numThreads > 0)
187  bracketMultiplier->SetNumberOfWorkUnits(numThreads);
188 
189  bracketMultiplier->Update();
190 
191  typename AddFilterType::Pointer bchSecondAdder = AddFilterType::New();
192  bchSecondAdder->SetInput(0,resField);
193  bchSecondAdder->SetInput(1,bracketMultiplier->GetOutput());
194 
195  if (numThreads > 0)
196  bchSecondAdder->SetNumberOfWorkUnits(numThreads);
197 
198  bchSecondAdder->Update();
199 
200  resField = bchSecondAdder->GetOutput();
201  resField->DisconnectPipeline();
202  }
203 
204  baseTrsf->SetParametersAsVectorField(resField.GetPointer());
205 }
206 
207 template <class ScalarType, unsigned int NDimensions>
208 void GetSVFExponential(itk::StationaryVelocityFieldTransform <ScalarType,NDimensions> *baseTrsf,
209  rpi::DisplacementFieldTransform <ScalarType,NDimensions> *resultTransform,
210  unsigned int exponentiationOrder, unsigned int numThreads, bool invert)
211 {
212  if (baseTrsf->GetParametersAsVectorField() == NULL)
213  return;
214 
215  typedef itk::StationaryVelocityFieldTransform <ScalarType,NDimensions> SVFType;
216  typedef typename SVFType::VectorFieldType FieldType;
217  typedef typename FieldType::Pointer FieldPointer;
218 
220 
221  FieldPointer tmpPtr = const_cast <FieldType *> (baseTrsf->GetParametersAsVectorField());
222  if (invert)
223  {
224  typedef itk::MultiplyImageFilter <FieldType,itk::Image <double, NDimensions>, FieldType> MultiplyFilterType;
225  typename MultiplyFilterType::Pointer multiplier = MultiplyFilterType::New();
226  multiplier->SetInput(tmpPtr);
227  multiplier->SetConstant(-1.0);
228 
229  multiplier->SetNumberOfWorkUnits(numThreads);
230  multiplier->Update();
231 
232  tmpPtr = multiplier->GetOutput();
233  tmpPtr->DisconnectPipeline();
234  }
235 
236  typename ExponentialFilterType::Pointer expFilter = ExponentialFilterType::New();
237  expFilter->SetInput(tmpPtr);
238  expFilter->SetExponentiationOrder(exponentiationOrder);
239  expFilter->SetNumberOfWorkUnits(numThreads);
240  expFilter->SetMaximalDisplacementAmplitude(0.25);
241 
242  expFilter->Update();
243 
244  FieldPointer resField = expFilter->GetOutput();
245  resField->DisconnectPipeline();
246 
247  resultTransform->SetParametersAsVectorField(resField.GetPointer());
248 }
249 
250 template <class ScalarType, unsigned int NDimensions>
251 void composeDistortionCorrections(typename rpi::DisplacementFieldTransform <ScalarType,NDimensions>::Pointer &baseTrsf,
252  typename rpi::DisplacementFieldTransform <ScalarType,NDimensions>::Pointer &positiveAddOn,
253  typename rpi::DisplacementFieldTransform <ScalarType,NDimensions>::Pointer &negativeAddOn,
254  unsigned int numThreads)
255 {
256  typedef rpi::DisplacementFieldTransform <ScalarType,NDimensions> DisplacementFieldTransformType;
257  typedef typename DisplacementFieldTransformType::VectorFieldType VectorFieldType;
258  typedef itk::ComposeDisplacementFieldsImageFilter <VectorFieldType,VectorFieldType> ComposeFilterType;
259  typedef itk::MultiplyImageFilter <VectorFieldType,itk::Image <double, NDimensions>, VectorFieldType> MultiplyFilterType;
260  typedef typename itk::ImageRegionIterator <VectorFieldType> VectorFieldIterator;
261  typedef typename VectorFieldType::PixelType VectorType;
262 
263  typedef itk::VectorLinearInterpolateNearestNeighborExtrapolateImageFunction <VectorFieldType,
264  typename VectorType::ValueType> VectorInterpolateFunctionType;
265 
266  typename ComposeFilterType::Pointer composePositiveFilter = ComposeFilterType::New();
267  composePositiveFilter->SetWarpingField(positiveAddOn->GetParametersAsVectorField());
268  composePositiveFilter->SetDisplacementField(baseTrsf->GetParametersAsVectorField());
269  composePositiveFilter->SetNumberOfWorkUnits(numThreads);
270 
271  typename VectorInterpolateFunctionType::Pointer interpolator = VectorInterpolateFunctionType::New();
272 
273  composePositiveFilter->SetInterpolator(interpolator);
274  composePositiveFilter->Update();
275  positiveAddOn->SetParametersAsVectorField(composePositiveFilter->GetOutput());
276 
277  typename MultiplyFilterType::Pointer multiplyFilter = MultiplyFilterType::New();
278  multiplyFilter->SetInput(baseTrsf->GetParametersAsVectorField());
279  multiplyFilter->SetConstant(-1.0);
280 
281  if (numThreads > 0)
282  multiplyFilter->SetNumberOfWorkUnits(numThreads);
283 
284  multiplyFilter->InPlaceOn();
285 
286  multiplyFilter->Update();
287 
288  typename ComposeFilterType::Pointer composeNegativeFilter = ComposeFilterType::New();
289  composeNegativeFilter->SetWarpingField(negativeAddOn->GetParametersAsVectorField());
290  composeNegativeFilter->SetDisplacementField(multiplyFilter->GetOutput());
291 
292  if (numThreads > 0)
293  composeNegativeFilter->SetNumberOfWorkUnits(numThreads);
294 
295  interpolator = VectorInterpolateFunctionType::New();
296 
297  composeNegativeFilter->SetInterpolator(interpolator);
298  composeNegativeFilter->Update();
299  negativeAddOn->SetParametersAsVectorField(composeNegativeFilter->GetOutput());
300 
301  VectorFieldIterator positiveItr(const_cast <VectorFieldType *> (positiveAddOn->GetParametersAsVectorField()),
302  positiveAddOn->GetParametersAsVectorField()->GetLargestPossibleRegion());
303 
304  VectorFieldIterator negativeItr(const_cast <VectorFieldType *> (negativeAddOn->GetParametersAsVectorField()),
305  negativeAddOn->GetParametersAsVectorField()->GetLargestPossibleRegion());
306 
307  // And compose them to get a transformation conform to disto correction requirements
308  VectorType tmpVec;
309  while (!positiveItr.IsAtEnd())
310  {
311  tmpVec = 0.5 * (positiveItr.Get() - negativeItr.Get());
312  positiveItr.Set(tmpVec);
313  negativeItr.Set(- tmpVec);
314 
315  ++positiveItr;
316  ++negativeItr;
317  }
318 
319  baseTrsf = positiveAddOn;
320 }
321 
322 } // end of namespace anima
Computes the Lie bracket between two fields u and v as expressed by Bossa et al.
Computes the exponentiation of a stationary velocity field using sclaing and squaring and approximate...
void GetSVFExponential(itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *baseTrsf, rpi::DisplacementFieldTransform< ScalarType, NDimensions > *resultTransform, unsigned int exponentiationOrder, unsigned int numThreads, bool invert)
void composeSVF(itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *baseTrsf, itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *addonTrsf, unsigned int numThreads, unsigned int bchOrder)
void composeDistortionCorrections(typename rpi::DisplacementFieldTransform< ScalarType, NDimensions >::Pointer &baseTrsf, typename rpi::DisplacementFieldTransform< ScalarType, NDimensions >::Pointer &positiveAddOn, typename rpi::DisplacementFieldTransform< ScalarType, NDimensions >::Pointer &negativeAddOn, unsigned int numThreads)