ANIMA  4.0
animaEddyCurrentCorrection.cxx
Go to the documentation of this file.
1 #include <tclap/CmdLine.h>
2 
6 #include <itkExtractImageFilter.h>
7 
8 #include <itkImageRegionIterator.h>
9 #include <itkCompositeTransform.h>
10 #include <itkStationaryVelocityFieldTransform.h>
11 #include <rpiDisplacementFieldTransform.h>
12 #include <animaVelocityUtils.h>
15 
16 int main(int argc, const char** argv)
17 {
18  const unsigned int Dimension = 3;
19 
20  typedef itk::Image <double,Dimension+1> InputImageType;
21  typedef itk::Image <double,Dimension> InputSubImageType;
22  typedef itk::ImageRegionIterator <InputImageType> InputImageIteratorType;
23  typedef itk::ImageRegionIterator <InputSubImageType> InputSubImageIteratorType;
24 
26  typedef anima::BaseTransformAgregator <Dimension> AgregatorType;
27  typedef itk::AffineTransform<AgregatorType::ScalarType,Dimension> AffineTransformType;
28  typedef AffineTransformType::Pointer AffineTransformPointer;
29 
30  // Parsing arguments
31  TCLAP::CmdLine cmd("INRIA / IRISA - VisAGeS/Empenn Team", ' ',ANIMA_VERSION);
32 
33  // Setting up parameters
34  TCLAP::ValueArg<std::string> inputArg("i","input","Input 4D image",true,"","input 4D image",cmd);
35  TCLAP::ValueArg<std::string> inBVecArg("I","input-bvec","Input gradient vectors file",true,"","input gradients",cmd);
36  TCLAP::ValueArg<std::string> outArg("o","output","Output (corrected) image",true,"","output image",cmd);
37  TCLAP::ValueArg<std::string> outBVecArg("O","output-bvec","Output gradient vectors (bvec format)",true,"","output gradients",cmd);
38 
39  TCLAP::ValueArg<unsigned int> directionArg("d","dir","Affine direction for directional transform output (default: 1 = Y axis)",false,1,"direction of directional affine",cmd);
40  TCLAP::ValueArg<unsigned int> b0Arg("b","b0","Index of the B0 reference image",false,0,"reference image index",cmd);
41 
42  TCLAP::ValueArg<unsigned int> blockSizeArg("","bs","Block size (default: 5)",false,5,"block size",cmd);
43  TCLAP::ValueArg<unsigned int> blockSpacingArg("","sp","Block spacing (default: 5)",false,5,"block spacing",cmd);
44  TCLAP::ValueArg<unsigned int> nlBlockSpacingArg("","nsp","Block spacing (default: 3)",false,3,"non linear matching block spacing",cmd);
45  TCLAP::ValueArg<double> stdevThresholdArg("s","stdev","Threshold block standard deviation (default: 5)",false,5,"block minimal standard deviation",cmd);
46  TCLAP::ValueArg<double> percentageKeptArg("k","per-kept","Percentage of blocks with the highest variance kept (default: 0.8)",false,0.8,"percentage of blocks kept",cmd);
47 
48  TCLAP::ValueArg<unsigned int> blockMetricArg("","metric","Similarity metric between blocks (0: squared correlation coefficient, 1: correlation coefficient, 2: mean squares, default: 0)",false,0,"similarity metric",cmd);
49  TCLAP::ValueArg<unsigned int> optimizerArg("","opt","Optimizer for optimal block search (0: Exhaustive, 1: Bobyqa, default: 1)",false,1,"optimizer",cmd);
50 
51  TCLAP::ValueArg<unsigned int> maxIterationsArg("","mi","Maximum block match iterations (default: 10)",false,10,"maximum iterations",cmd);
52  TCLAP::ValueArg<double> minErrorArg("","me","Minimal distance between consecutive estimated transforms (default: 0.01)",false,0.01,"minimal distance between transforms",cmd);
53 
54  TCLAP::ValueArg<unsigned int> optimizerMaxIterationsArg("","oi","Maximum iterations for local optimizer (default: 100)",false,100,"maximum local optimizer iterations",cmd);
55 
56  TCLAP::ValueArg<double> searchRadiusArg("","sr","Search radius in pixels (exhaustive search window, rho start for bobyqa, default: 2)",false,2,"optimizer search radius",cmd);
57  TCLAP::ValueArg<double> finalRadiusArg("","fr","Final radius (rho end for bobyqa, default: 0.001)",false,0.001,"optimizer final radius",cmd);
58  TCLAP::ValueArg<double> searchStepArg("","st","Search step for exhaustive search (default: 2)",false,2,"exhaustive optimizer search step",cmd);
59  TCLAP::ValueArg<double> translateUpperBoundArg("","tub","Upper bound on translation for bobyqa (in voxels, default: 10)",false,10,"Bobyqa translate upper bound",cmd);
60 
61  TCLAP::ValueArg<unsigned int> symmetryArg("","sym-reg","Registration symmetry type (0: asymmetric, 1: symmetric, 2: kissing, default: 0)",false,0,"symmetry type",cmd);
62  TCLAP::ValueArg<unsigned int> agregatorArg("a","agregator","Transformation agregator type (0: M-Estimation, 1: least squares, 2: least trimmed squares, default: 0)",false,0,"agregator type",cmd);
63  TCLAP::ValueArg<double> agregThresholdArg("","at","Agregator threshold value (for M-estimation or LTS)",false,0.5,"agregator threshold value",cmd);
64  TCLAP::ValueArg<double> extrapolationSigmaArg("","fs","Sigma for extrapolation of local pairings (default: 3)",false,3,"extrapolation sigma",cmd);
65  TCLAP::ValueArg<double> elasticSigmaArg("","es","Sigma for elastic regularization (default: 3)",false,3,"elastic regularization sigma",cmd);
66  TCLAP::ValueArg<double> outlierSigmaArg("","os","Sigma for outlier rejection among local pairings (default: 3)",false,3,"outlier rejection sigma",cmd);
67  TCLAP::ValueArg<double> seStoppingThresholdArg("","lst","LTS Stopping Threshold (default: 0.01)",false,0.01,"LTS stopping threshold",cmd);
68 
69  TCLAP::ValueArg<unsigned int> numPyramidLevelsArg("p","pyr","Number of pyramid levels (default: 3)",false,3,"number of pyramid levels",cmd);
70  TCLAP::ValueArg<unsigned int> lastPyramidLevelArg("l","last-level","Index of the last pyramid level explored (default: 0)",false,0,"last pyramid level",cmd);
71  TCLAP::ValueArg<unsigned int> numThreadsArg("T","threads","Number of execution threads (default: 0 = all cores)",false,0,"number of threads",cmd);
72 
73  try
74  {
75  cmd.parse(argc,argv);
76  }
77  catch (TCLAP::ArgException& e)
78  {
79  std::cerr << "Error: " << e.error() << "for argument " << e.argId() << std::endl;
80  return EXIT_FAILURE;
81  }
82 
83  InputImageType::Pointer inputImage = anima::readImage <InputImageType> (inputArg.getValue());
84  unsigned int numberOfImages = inputImage->GetLargestPossibleRegion().GetSize()[Dimension];
85  typedef itk::ExtractImageFilter <InputImageType, InputSubImageType> ExtractFilterType;
86 
87  ExtractFilterType::Pointer referenceExtractFilter = ExtractFilterType::New();
88  referenceExtractFilter->SetInput(inputImage);
89  InputImageType::RegionType refExtractRegion = inputImage->GetLargestPossibleRegion();
90  refExtractRegion.SetIndex(Dimension,b0Arg.getValue());
91  refExtractRegion.SetSize(Dimension,0);
92  referenceExtractFilter->SetExtractionRegion(refExtractRegion);
93  referenceExtractFilter->SetDirectionCollapseToGuess();
94  referenceExtractFilter->Update();
95 
96  typedef anima::GradientFileReader < vnl_vector_fixed <double,3>, double > GFReaderType;
97  GFReaderType gfReader;
98  gfReader.SetGradientFileName(inBVecArg.getValue());
99  gfReader.SetGradientIndependentNormalization(false);
100  gfReader.Update();
101 
102  GFReaderType::GradientVectorType directions = gfReader.GetGradients();
103 
104  for (unsigned int i = 0;i < numberOfImages;++i)
105  {
106  if (i == b0Arg.getValue())
107  continue;
108 
109  std::cout << "\033[K\rProcessing image " << i+1 << " out of " << numberOfImages << std::flush;
110 
111  ExtractFilterType::Pointer extractFilter = ExtractFilterType::New();
112  extractFilter->SetInput(inputImage);
113  InputImageType::RegionType extractRegion = inputImage->GetLargestPossibleRegion();
114  extractRegion.SetIndex(Dimension,i);
115  extractRegion.SetSize(Dimension,0);
116  extractFilter->SetExtractionRegion(extractRegion);
117  extractFilter->SetDirectionCollapseToGuess();
118  extractFilter->Update();
119 
120  // First perform rigid registration to correct for movement
121  PyramidBMType::Pointer matcher = PyramidBMType::New();
122 
123  matcher->SetBlockSize(blockSizeArg.getValue());
124  matcher->SetBlockSpacing(blockSpacingArg.getValue());
125  matcher->SetStDevThreshold(stdevThresholdArg.getValue());
126  matcher->SetMetric((PyramidBMType::Metric) blockMetricArg.getValue());
127  matcher->SetOptimizer((PyramidBMType::Optimizer) optimizerArg.getValue());
128  matcher->SetMaximumIterations(maxIterationsArg.getValue());
129  matcher->SetMinimalTransformError(minErrorArg.getValue());
130  matcher->SetFinalRadius(finalRadiusArg.getValue());
131  matcher->SetOptimizerMaximumIterations(optimizerMaxIterationsArg.getValue());
132  matcher->SetSearchRadius(searchRadiusArg.getValue());
133  matcher->SetStepSize(searchStepArg.getValue());
134  matcher->SetTranslateUpperBound(translateUpperBoundArg.getValue());
135  matcher->SetSymmetryType((PyramidBMType::SymmetryType) symmetryArg.getValue());
136  matcher->SetAgregator((PyramidBMType::Agregator) agregatorArg.getValue());
137  matcher->SetOutputTransformType(PyramidBMType::outRigid);
138  matcher->SetAffineDirection(directionArg.getValue());
139  matcher->SetAgregThreshold(agregThresholdArg.getValue());
140  matcher->SetSeStoppingThreshold(seStoppingThresholdArg.getValue());
141  matcher->SetNumberOfPyramidLevels(numPyramidLevelsArg.getValue());
142  matcher->SetLastPyramidLevel(lastPyramidLevelArg.getValue());
143  matcher->SetVerbose(false);
144 
145  if (numThreadsArg.getValue() != 0)
146  matcher->SetNumberOfWorkUnits( numThreadsArg.getValue() );
147 
148  matcher->SetPercentageKept( percentageKeptArg.getValue() );
149  matcher->SetTransformInitializationType(PyramidBMType::GravityCenters);
150 
151  matcher->SetFloatingImage(referenceExtractFilter->GetOutput());
152  matcher->SetReferenceImage(extractFilter->GetOutput());
153 
154  AffineTransformPointer rigidTrsf = AffineTransformType::New();
155  rigidTrsf->SetIdentity();
156  matcher->SetOutputTransform(rigidTrsf.GetPointer());
157 
158  try
159  {
160  matcher->Update();
161  }
162  catch (itk::ExceptionObject &e)
163  {
164  std::cerr << e << std::endl;
165  return EXIT_FAILURE;
166  }
167 
168  rigidTrsf = dynamic_cast <AffineTransformType *> (matcher->GetOutputTransform().GetPointer());
169 
170  InputSubImageType::Pointer rigidReference = matcher->GetOutputImage();
171 
172  // Then perform directional affine registration
173  matcher->SetReferenceImage(rigidReference);
174  matcher->SetFloatingImage(extractFilter->GetOutput());
175  matcher->SetTransform(PyramidBMType::Directional_Affine);
176  matcher->SetOutputTransformType(PyramidBMType::outAffine);
177 
178  AffineTransformPointer tmpTrsfDirectional = AffineTransformType::New();
179  tmpTrsfDirectional->SetIdentity();
180  matcher->SetOutputTransform(tmpTrsfDirectional.GetPointer());
181 
182  try
183  {
184  matcher->Update();
185  }
186  catch (itk::ExceptionObject &e)
187  {
188  std::cerr << e << std::endl;
189  return EXIT_FAILURE;
190  }
191 
192  // Finally, perform non linear registration to get rid of non linear distortions
193  typedef anima::PyramidalDenseSVFMatchingBridge <Dimension> NonLinearPyramidBMType;
194  NonLinearPyramidBMType::Pointer nonLinearMatcher = NonLinearPyramidBMType::New();
195 
196  nonLinearMatcher->SetReferenceImage(rigidReference);
197  nonLinearMatcher->SetFloatingImage(matcher->GetOutputImage());
198 
199  // Setting matcher arguments
200  nonLinearMatcher->SetBlockSize(blockSizeArg.getValue());
201  nonLinearMatcher->SetBlockSpacing(nlBlockSpacingArg.getValue());
202  nonLinearMatcher->SetStDevThreshold(stdevThresholdArg.getValue());
203  nonLinearMatcher->SetTransform(NonLinearPyramidBMType::Directional_Affine);
204  nonLinearMatcher->SetAffineDirection(directionArg.getValue());
205  nonLinearMatcher->SetMetric((NonLinearPyramidBMType::Metric) blockMetricArg.getValue());
206  nonLinearMatcher->SetOptimizer((NonLinearPyramidBMType::Optimizer) optimizerArg.getValue());
207  nonLinearMatcher->SetMaximumIterations(maxIterationsArg.getValue());
208  nonLinearMatcher->SetMinimalTransformError(minErrorArg.getValue());
209  nonLinearMatcher->SetFinalRadius(finalRadiusArg.getValue());
210  nonLinearMatcher->SetOptimizerMaximumIterations(optimizerMaxIterationsArg.getValue());
211  nonLinearMatcher->SetSearchRadius(searchRadiusArg.getValue());
212  nonLinearMatcher->SetStepSize(searchStepArg.getValue());
213  nonLinearMatcher->SetTranslateUpperBound(translateUpperBoundArg.getValue());
214  nonLinearMatcher->SetSymmetryType((NonLinearPyramidBMType::SymmetryType) symmetryArg.getValue());
215  nonLinearMatcher->SetAgregator(NonLinearPyramidBMType::Baloo);
216  nonLinearMatcher->SetBCHCompositionOrder(1);
217  nonLinearMatcher->SetExponentiationOrder(0);
218  nonLinearMatcher->SetExtrapolationSigma(extrapolationSigmaArg.getValue());
219  nonLinearMatcher->SetElasticSigma(elasticSigmaArg.getValue());
220  nonLinearMatcher->SetOutlierSigma(outlierSigmaArg.getValue());
221  nonLinearMatcher->SetNumberOfPyramidLevels(numPyramidLevelsArg.getValue());
222  nonLinearMatcher->SetLastPyramidLevel(lastPyramidLevelArg.getValue());
223  nonLinearMatcher->SetVerbose(false);
224 
225  if (numThreadsArg.getValue() != 0)
226  nonLinearMatcher->SetNumberOfWorkUnits(numThreadsArg.getValue());
227 
228  nonLinearMatcher->SetPercentageKept(percentageKeptArg.getValue());
229 
230  try
231  {
232  nonLinearMatcher->Update();
233  }
234  catch (itk::ExceptionObject &e)
235  {
236  std::cerr << e << std::endl;
237  return EXIT_FAILURE;
238  }
239 
240  // Finally, apply transform serie to image
241  typedef itk::CompositeTransform <AgregatorType::ScalarType,Dimension> GeneralTransformType;
242  GeneralTransformType::Pointer transformSerie = GeneralTransformType::New();
243  transformSerie->AddTransform(tmpTrsfDirectional);
244 
245  typedef itk::StationaryVelocityFieldTransform <AgregatorType::ScalarType,Dimension> SVFTransformType;
246  typedef SVFTransformType::Pointer SVFTransformPointer;
247 
248  typedef rpi::DisplacementFieldTransform <AgregatorType::ScalarType,Dimension> DenseTransformType;
249  typedef DenseTransformType::Pointer DenseTransformPointer;
250 
251  SVFTransformPointer svfPointer = nonLinearMatcher->GetOutputTransform();
252 
253  DenseTransformPointer dispTrsf = DenseTransformType::New();
254  anima::GetSVFExponential(svfPointer.GetPointer(),dispTrsf.GetPointer(),0,numThreadsArg.getValue(),false);
255 
256  transformSerie->AddTransform(dispTrsf.GetPointer());
257 
258  // Apply rigid matrix to gradient vectors
259  AffineTransformType::MatrixType rigidMatrix = rigidTrsf->GetMatrix();
260  vnl_vector_fixed <double,3> tmpDir(0.0);
261  for (unsigned int j = 0;j < 3;++j)
262  {
263  for (unsigned int k = 0;k < 3;++k)
264  tmpDir[j] += rigidMatrix(j,k) * directions[i][k];
265  }
266 
267  directions[i] = tmpDir;
268 
269  AffineTransformPointer rigidTrsfInverse = AffineTransformType::New();
270  rigidTrsf->GetInverse(rigidTrsfInverse);
271  transformSerie->AddTransform(rigidTrsfInverse.GetPointer());
272 
274  ResampleFilterType::Pointer scalarResampler = ResampleFilterType::New();
275 
276  InputSubImageType::SizeType size = referenceExtractFilter->GetOutput()->GetLargestPossibleRegion().GetSize();
277  InputSubImageType::PointType origin = referenceExtractFilter->GetOutput()->GetOrigin();
278  InputSubImageType::SpacingType spacing = referenceExtractFilter->GetOutput()->GetSpacing();
279  InputSubImageType::DirectionType direction = referenceExtractFilter->GetOutput()->GetDirection();
280 
281  scalarResampler->SetTransform(transformSerie);
282  scalarResampler->SetSize(size);
283  scalarResampler->SetOutputOrigin(origin);
284  scalarResampler->SetOutputSpacing(spacing);
285  scalarResampler->SetOutputDirection(direction);
286 
287  scalarResampler->SetInput(extractFilter->GetOutput());
288  if (numThreadsArg.getValue() != 0)
289  scalarResampler->SetNumberOfWorkUnits(numThreadsArg.getValue());
290  scalarResampler->Update();
291 
292  InputSubImageType::RegionType regionSubImage = scalarResampler->GetOutput()->GetLargestPossibleRegion();
293  InputImageType::RegionType regionImage = inputImage->GetLargestPossibleRegion();
294  regionImage.SetIndex(Dimension,i);
295  regionImage.SetSize(Dimension,1);
296 
297  InputImageIteratorType outIterator(inputImage,regionImage);
298  InputSubImageIteratorType inIterator(scalarResampler->GetOutput(),regionSubImage);
299 
300  while (!inIterator.IsAtEnd())
301  {
302  outIterator.Set(inIterator.Get());
303 
304  ++inIterator;
305  ++outIterator;
306  }
307  }
308 
309  std::cout << std::endl;
310 
311  anima::writeImage <InputImageType> (outArg.getValue(),inputImage);
312 
313  // Writing output gradients
314  std::ofstream outputFile(outBVecArg.getValue());
315  for (unsigned int i = 0;i < 3;++i)
316  {
317  for (unsigned int j = 0;j < directions.size();++j)
318  outputFile << directions[j][i] << " ";
319 
320  outputFile << std::endl;
321  }
322 
323  outputFile.close();
324 
325  return EXIT_SUCCESS;
326 }
int main(int argc, const char **argv)
void SetGradientFileName(std::string fName)
void GetSVFExponential(itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *baseTrsf, rpi::DisplacementFieldTransform< ScalarType, NDimensions > *resultTransform, unsigned int exponentiationOrder, unsigned int numThreads, bool invert)