ANIMA  4.0
animaN4BiasCorrection.cxx
Go to the documentation of this file.
1 #include <tclap/CmdLine.h>
2 #include <itkN4BiasFieldCorrectionImageFilter.h>
4 #include <itkConstantPadImageFilter.h>
5 #include <itkOtsuThresholdImageFilter.h>
6 #include <itkShrinkImageFilter.h>
7 #include <itkCommand.h>
8 #include <itkTimeProbe.h>
9 
10 //Update progression of the process
11 void eventCallback(itk::Object* caller, const itk::EventObject& event, void* clientData)
12 {
13  itk::ProcessObject * processObject = (itk::ProcessObject*) caller;
14  std::cout << "\033[K\rProgression: " << (int)(processObject->GetProgress() * 100) << "%" << std::flush;
15 }
16 
17 template<typename T>
18 T convertStringToNumber(std::string const &pi_rsValue)
19 {
20  T numRes;
21 
22  std::istringstream(pi_rsValue)>> numRes;
23 
24  return numRes;
25 }
26 
27 template<typename T>
28 std::vector<T> ConvertVector(std::string &pi_rsVector)
29 {
30  std::vector<T> oValuesVectorRes;
31 
32  size_t iSeparatorPosition = pi_rsVector.find('x', 0);
33 
34  if (iSeparatorPosition == std::string::npos)
35  {
36  oValuesVectorRes.push_back(convertStringToNumber<T>(pi_rsVector));
37  }
38  else
39  {
40  std::string sChunk = pi_rsVector.substr(0, iSeparatorPosition);
41 
42  oValuesVectorRes.push_back(convertStringToNumber<T>(sChunk));
43  while (iSeparatorPosition != std::string::npos)
44  {
45  std::string::size_type crossposfrom = iSeparatorPosition;
46  iSeparatorPosition = pi_rsVector.find('x', crossposfrom + 1);
47  if (iSeparatorPosition == std::string::npos)
48  {
49  sChunk = pi_rsVector.substr(crossposfrom + 1, pi_rsVector.length());
50  }
51  else
52  {
53  sChunk = pi_rsVector.substr(crossposfrom + 1, iSeparatorPosition);
54  }
55  oValuesVectorRes.push_back(convertStringToNumber<unsigned int>(sChunk));
56  }
57  }
58 
59  return oValuesVectorRes;
60 }
61 
62 void checkIterationArg(std::string &sIterations)
63 {
64  char const *pchTemp = sIterations.c_str();
65  for (int i = 0; i < sIterations.length(); ++i)
66  {
67  if (!((pchTemp[i] >= '0' && pchTemp[i] <= '9') || pchTemp[i] == 'x'))
68  {
69  std::cerr << "error: " << " for argument iterations does not respect format strictly positive number separated by 'x' like 100x50x50..." << std::endl;
70  exit(EXIT_FAILURE);
71  }
72  }
73 }
74 
75 std::vector<unsigned int> extractMaxNumberOfIterationsVector(std::string pi_rsIterations)
76 {
77  checkIterationArg(pi_rsIterations);
78  return ConvertVector<unsigned int>(pi_rsIterations);
79 }
80 
81 int main(int argc, char *argv[])
82 {
83  TCLAP::CmdLine cmd("INRIA / IRISA - VisAGeS/Empenn Team", ' ', ANIMA_VERSION);
84  TCLAP::ValueArg<std::string> oArgInputImg("i", "input", "Input image.", true, "", "input Name", cmd);
85  TCLAP::ValueArg<std::string> oArgOutputName("o", "output", "Name for output file", true, "", "output Name", cmd);
86  TCLAP::ValueArg<std::string> oArgIterationsString("I", "iterations", "Table of number of iterations, default=50x40x30", false, "50x40x30", "iterations table", cmd);
87  TCLAP::ValueArg<int> oArgShrinkFactors("S", "shrinkFactor", "Shrink factor, default=4", false, 4, "shrink Factor", cmd);
88  TCLAP::ValueArg<double> oArgWienerFilterNoise("W", "wiener", "Wiener Filter Noise, default=0.01", false, 0.01, "wiener noise", cmd);
89  TCLAP::ValueArg<double> oArgbfFWHM("B", "bfFWHM", "Bias field Full Width at Half Maximum, default=0.15", false, 0.15, "Bias field Full Width at Half Maximum", cmd);
90  TCLAP::ValueArg<double> oArgConvergenceThreshold("T", "threshold", "Convergence Threshold, default=0.0001", false, 0.0001, "threshold", cmd);
91  TCLAP::ValueArg<int> oArgSplineOrder("O", "splineOrder", "BSpline Order, default=3", false, 3, "spline Order", cmd);
92  TCLAP::ValueArg<double> oArgSplineDistance("D", "splineDistance", "B-Spline distance, default=0.0", false, 0.0, "spline Distance", cmd);
93  TCLAP::ValueArg<std::string> oArgInitialMeshResolutionString("G", "splineGrid", "B-Spline grid resolution. It is ignored if splineDistance>0 or if dimention of it <>3, default=1x1x1", false, "1x1x1", "spline Grid", cmd);
94  TCLAP::ValueArg<unsigned int> oArgNbP("p", "numberofthreads", "Number of threads to run on (default: all cores)", false, itk::MultiThreaderBase::GetGlobalDefaultNumberOfThreads(), "number of threads", cmd);
95 
96  try
97  {
98  cmd.parse( argc, argv );
99  }
100  catch (TCLAP::ArgException & e)
101  {
102  std::cerr << "error: " << e.error() << " for arg " << e.argId() << std::endl;
103  return EXIT_FAILURE;
104  }
105 
106  typedef itk::Image<double, 3 > ImageType;
107  typedef itk::Image<unsigned char, 3> MaskImageType;
108  typedef itk::N4BiasFieldCorrectionImageFilter<ImageType, MaskImageType, ImageType> BiasFilter;
109  typedef itk::ConstantPadImageFilter<ImageType, ImageType> PadderType;
110  typedef itk::ConstantPadImageFilter<MaskImageType, MaskImageType> MaskPadderType;
111  typedef itk::ShrinkImageFilter<ImageType, ImageType> ShrinkerType;
112  typedef itk::ShrinkImageFilter<MaskImageType, MaskImageType> MaskShrinkerType;
113  typedef itk::BSplineControlPointImageFilter< BiasFilter::BiasFieldControlPointLatticeType, BiasFilter::ScalarImageType> BSplinerType;
114  typedef itk::ExpImageFilter<ImageType, ImageType> ExpFilterType;
115  typedef itk::DivideImageFilter<ImageType, ImageType, ImageType> DividerType;
116  typedef itk::ExtractImageFilter<ImageType, ImageType> CropperType;
117 
118  std::vector<unsigned int> oMaxNumbersIterationsVector = extractMaxNumberOfIterationsVector(oArgIterationsString.getValue());
119  std::vector<double> oInitialMeshResolutionVect = ConvertVector<double>(oArgInitialMeshResolutionString.getValue());
120  double fSplineDistance = oArgSplineDistance.getValue();
121 
122 
123  /********************************************************************************/
124  /***************************** PREPARING STARTING *******************************/
125  /********************************************************************************/
126 
127  /*** 0 ******************* Create filter and accessories ******************/
128  BiasFilter::Pointer filter = BiasFilter::New();
129  BiasFilter::ArrayType oNumberOfControlPointsArray;
130 
131  /*** 1 ******************* Read input image *******************************/
132  ImageType::Pointer image = anima::readImage<ImageType>( oArgInputImg.getValue());
133 
134  /*** 2 ******************* Creating Otsu mask *****************************/
135  std::cout << "Creating Otsu mask." << std::endl;
136  itk::TimeProbe timer;
137  timer.Start();
138  MaskImageType::Pointer maskImage = ITK_NULLPTR;
139  typedef itk::OtsuThresholdImageFilter<ImageType, MaskImageType> ThresholderType;
140  ThresholderType::Pointer otsu = ThresholderType::New();
141  otsu->SetInput(image);
142  otsu->SetNumberOfHistogramBins(200);
143  otsu->SetInsideValue(0);
144  otsu->SetOutsideValue(1);
145 
146  otsu->SetNumberOfWorkUnits(oArgNbP.getValue());
147  otsu->Update();
148  maskImage = otsu->GetOutput();
149 
150  /*** 3A *************** Set Maximum number of Iterations for the filter ***/
151  BiasFilter::VariableSizeArrayType itkTabMaximumIterations;
152  itkTabMaximumIterations.SetSize(oMaxNumbersIterationsVector.size());
153  for (int i=0; i<oMaxNumbersIterationsVector.size(); ++i)
154  {
155  itkTabMaximumIterations[i] = oMaxNumbersIterationsVector[i];
156  }
157  filter->SetMaximumNumberOfIterations(itkTabMaximumIterations);
158 
159  /*** 3B *************** Set Fitting Levels for the filter *****************/
160  BiasFilter::ArrayType oFittingLevelsTab;
161  oFittingLevelsTab.Fill(oMaxNumbersIterationsVector.size());
162  filter->SetNumberOfFittingLevels(oFittingLevelsTab);
163 
164  /*** 4 ******************* Save image's index, size, origine **************/
165  ImageType::IndexType oImageIndex = image->GetLargestPossibleRegion().GetIndex();
166  ImageType::SizeType oImageSize = image->GetLargestPossibleRegion().GetSize();
167  ImageType::PointType newOrigin = image->GetOrigin();
168 
169  if (fSplineDistance>0)
170  {
171  /*** 5 ******************* Compute number of control points **************/
172  itk::SizeValueType lowerBound[3];
173  itk::SizeValueType upperBound[3];
174 
175  for (unsigned int i = 0; i < 3; i++)
176  {
177  double domain = static_cast<double>(image->GetLargestPossibleRegion().GetSize()[i] - 1) * image->GetSpacing()[i];
178  unsigned int numberOfSpans = static_cast<unsigned int>(std::ceil(domain / fSplineDistance));
179  unsigned long extraPadding = static_cast<unsigned long>((numberOfSpans * fSplineDistance - domain) / image->GetSpacing()[i] + 0.5);
180  lowerBound[i] = static_cast<unsigned long>(0.5 * extraPadding);
181  upperBound[i] = extraPadding - lowerBound[i];
182  newOrigin[i] -= (static_cast<double>(lowerBound[i]) * image->GetSpacing()[i]);
183  oNumberOfControlPointsArray[i] = numberOfSpans + filter->GetSplineOrder();
184  }
185 
186  /*** 6 ******************* Padder ****************************************/
187  PadderType::Pointer imagePadder = PadderType::New();
188  imagePadder->SetInput(image);
189  imagePadder->SetPadLowerBound(lowerBound);
190  imagePadder->SetPadUpperBound(upperBound);
191  imagePadder->SetConstant(0);
192  imagePadder->SetNumberOfWorkUnits(oArgNbP.getValue());
193  imagePadder->Update();
194 
195  image = imagePadder->GetOutput();
196 
197  /*** 7 ******************** Handle the mask image *************************/
198  MaskPadderType::Pointer maskPadder = MaskPadderType::New();
199  maskPadder->SetInput(maskImage);
200  maskPadder->SetPadLowerBound(lowerBound);
201  maskPadder->SetPadUpperBound(upperBound);
202  maskPadder->SetConstant(0);
203 
204  maskPadder->SetNumberOfWorkUnits(oArgNbP.getValue());
205  maskPadder->Update();
206 
207  maskImage = maskPadder->GetOutput();
208 
209  /*** 8 ******************** SetNumber Of Control Points *******************/
210  filter->SetNumberOfControlPoints(oNumberOfControlPointsArray);
211  }
212  else if(oInitialMeshResolutionVect.size() == 3)
213  {
214  /*** 9 ******************** SetNumber Of Control Points alternative *******/
215  for (unsigned i = 0; i < 3; i++)
216  {
217  oNumberOfControlPointsArray[i] = static_cast<unsigned int>(oInitialMeshResolutionVect[i]) + filter->GetSplineOrder();
218  }
219  filter->SetNumberOfControlPoints(oNumberOfControlPointsArray);
220  }
221  else
222  {
223  std::cout << "No BSpline distance and Mesh Resolution is ignored because not 3 dimensions" << std::endl;
224  }
225 
226  /*** 10 ******************* Shrinker image ********************************/
227  ShrinkerType::Pointer imageShrinker = ShrinkerType::New();
228  imageShrinker->SetInput(image);
229  imageShrinker->SetShrinkFactors(1);
230 
231  /*** 11 ******************* Shrinker mask *********************************/
232  MaskShrinkerType::Pointer maskShrinker = MaskShrinkerType::New();
233  maskShrinker->SetInput(maskImage);
234  maskShrinker->SetShrinkFactors(1);
235 
236  /*** 12 ******************* Shrink mask and image *************************/
237  imageShrinker->SetShrinkFactors(oArgShrinkFactors.getValue());
238  maskShrinker->SetShrinkFactors(oArgShrinkFactors.getValue());
239  imageShrinker->SetNumberOfWorkUnits(oArgNbP.getValue());
240  maskShrinker->SetNumberOfWorkUnits(oArgNbP.getValue());
241  imageShrinker->Update();
242  maskShrinker->Update();
243 
244  /*** 13 ******************* Filter setings ********************************/
245  filter->SetSplineOrder(oArgSplineOrder.getValue());
246  filter->SetWienerFilterNoise(oArgWienerFilterNoise.getValue());
247  filter->SetBiasFieldFullWidthAtHalfMaximum(oArgbfFWHM.getValue());
248  filter->SetConvergenceThreshold(oArgConvergenceThreshold.getValue());
249  filter->SetInput(imageShrinker->GetOutput());
250  filter->SetMaskImage(maskShrinker->GetOutput());
251 
252  /*** 14 ******************* Apply filter **********************************/
253  itk::CStyleCommand::Pointer callback = itk::CStyleCommand::New();
254  callback->SetCallback(eventCallback);
255  filter->AddObserver(itk::ProgressEvent(), callback);
256  try
257  {
258  filter->SetNumberOfWorkUnits(oArgNbP.getValue());
259  filter->Update();
260  }
261  catch (itk::ExceptionObject & err)
262  {
263  std::cerr << "ExceptionObject caught !" << std::endl;
264  std::cerr << err << std::endl;
265  return EXIT_FAILURE;
266  }
267 
273  BSplinerType::Pointer bspliner = BSplinerType::New();
274  bspliner->SetInput(filter->GetLogBiasFieldControlPointLattice());
275  bspliner->SetSplineOrder(filter->GetSplineOrder());
276  bspliner->SetSize(image->GetLargestPossibleRegion().GetSize());
277  bspliner->SetOrigin(newOrigin);
278  bspliner->SetDirection(image->GetDirection());
279  bspliner->SetSpacing(image->GetSpacing());
280  bspliner->SetNumberOfWorkUnits(oArgNbP.getValue());
281  bspliner->Update();
282 
283  ImageType::Pointer logField = ImageType::New();
284  logField->SetOrigin(image->GetOrigin());
285  logField->SetSpacing(image->GetSpacing());
286  logField->SetRegions(image->GetLargestPossibleRegion());
287  logField->SetDirection(image->GetDirection());
288  logField->Allocate();
289 
290  itk::ImageRegionIterator<BiasFilter::ScalarImageType> IB(bspliner->GetOutput(), bspliner->GetOutput()->GetLargestPossibleRegion());
291 
292  itk::ImageRegionIterator<ImageType> IF(logField, logField->GetLargestPossibleRegion());
293 
294  for (IB.GoToBegin(), IF.GoToBegin(); !IB.IsAtEnd(); ++IB, ++IF)
295  {
296  IF.Set(IB.Get()[0]);
297  }
298 
299  ExpFilterType::Pointer expFilter = ExpFilterType::New();
300  expFilter->SetInput(logField);
301  expFilter->SetNumberOfWorkUnits(oArgNbP.getValue());
302  expFilter->Update();
303 
304  DividerType::Pointer divider = DividerType::New();
305  divider->SetInput1(image);
306  divider->SetInput2(expFilter->GetOutput());
307  divider->SetNumberOfWorkUnits(oArgNbP.getValue());
308  divider->Update();
309 
310  ImageType::RegionType inputRegion;
311  inputRegion.SetIndex(oImageIndex);
312  inputRegion.SetSize(oImageSize);
313 
314  CropperType::Pointer cropper = CropperType::New();
315  cropper->SetInput(divider->GetOutput());
316  cropper->SetExtractionRegion(inputRegion);
317  cropper->SetDirectionCollapseToSubmatrix();
318  cropper->SetNumberOfWorkUnits(oArgNbP.getValue());
319  cropper->Update();
320 
321  timer.Stop();
322  std::cout << "\nComputation time : " << timer.GetTotal() << std::endl;
323 
324  /********************** Write output image *************************/
325  try
326  {
327  anima::writeImage<ImageType>(oArgOutputName.getValue(), cropper->GetOutput());
328  }
329  catch (itk::ExceptionObject & err)
330  {
331  std::cerr << "ExceptionObject caught !" << std::endl;
332  std::cerr << err << std::endl;
333  return EXIT_FAILURE;
334  }
335 
336  return EXIT_SUCCESS;
337 }
338 
int main(int argc, char *argv[])
void checkIterationArg(std::string &sIterations)
T convertStringToNumber(std::string const &pi_rsValue)
std::vector< unsigned int > extractMaxNumberOfIterationsVector(std::string pi_rsIterations)
std::vector< T > ConvertVector(std::string &pi_rsVector)
void eventCallback(itk::Object *caller, const itk::EventObject &event, void *clientData)