ANIMA  4.0
animaComputeSolution.hxx
Go to the documentation of this file.
1 #pragma once
2 
3 #include "animaComputeSolution.h"
4 
5 namespace anima
6 {
7 
8 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
10 {
11  this->SetNthInput(0, const_cast<TMaskImage*>(mask));
12 }
13 
14 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
16 {
17  this->SetNthInput(1, const_cast<TInputImage*>(image));
18 }
19 
20 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
22 {
23  this->SetNthInput(2, const_cast<TInputImage*>(image));
24 }
25 
26 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
28 {
29  this->SetNthInput(3, const_cast<TInputImage*>(image));
30 }
31 
32 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
34 {
35  this->SetNthInput(4, const_cast<TAtlasImage*>(atlas));
36 }
37 
38 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
40 {
41  this->SetNthInput(5, const_cast<TAtlasImage*>(atlas));
42 }
43 
44 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
46 {
47  this->SetNthInput(6, const_cast<TAtlasImage*>(atlas));
48 }
49 
50 
51 
52 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
54 {
55  return static_cast< const ImageTypeUC * >
56  ( this->itk::ProcessObject::GetInput(0) );
57 }
58 
59 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
61 {
62  return static_cast< const TInputImage * >
63  ( this->itk::ProcessObject::GetInput(1) );
64 }
65 
66 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
68 {
69  return static_cast< const TInputImage * >
70  ( this->itk::ProcessObject::GetInput(2) );
71 }
72 
73 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
75 {
76  return static_cast< const TInputImage * >
77  ( this->itk::ProcessObject::GetInput(3) );
78 }
79 
80 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
82 {
83  return static_cast< const TAtlasImage * >
84  ( this->itk::ProcessObject::GetInput(4) );
85 }
86 
87 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
89 {
90  return static_cast< const TAtlasImage * >
91  ( this->itk::ProcessObject::GetInput(5) );
92 }
93 
94 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
96 {
97  return static_cast< const TAtlasImage * >
98  ( this->itk::ProcessObject::GetInput(6) );
99 }
100 
101 
102 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
103 void
105 {
106  std::vector<GaussianFunctionType::Pointer> tmpGaussianModel;
107  std::vector<double> tmpAlphas;
108 
109  typedef std::map<double,unsigned int> OrderType;
110  OrderType ordered;
111  OrderType::iterator it;
112 
113  ordered.clear();
114  for(unsigned int i = 0; i < m_GaussianModel.size(); i++)
115  {
116  GaussianFunctionType::MeanVectorType mean = (m_GaussianModel[i])->GetMean();
117  ordered.insert(OrderType::value_type(mean[0],i));
118  }
119 
120  for(it = ordered.begin(); it != ordered.end(); ++it)
121  {
122  tmpGaussianModel.push_back(m_GaussianModel[it->second]);
123  tmpAlphas.push_back(m_Alphas[it->second]);
124  }
125  m_GaussianModel = tmpGaussianModel;
126  m_Alphas = tmpAlphas;
127 }
128 
129 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
130 void
132 ::CheckInputs()
133 {
134  if( (this->GetMask().IsNull()) || (this->GetInputImage1().IsNull()) || (this->GetInputImage2().IsNull()) || (this->GetInputImage3().IsNull()) )
135  {
136  std::cerr << "Error: Inputs are missing... Exiting..." << std::endl;
137  exit(-1);
138  }
139 
140  if(m_UseT2 && m_UseDP && m_UseFLAIR)
141  {
142  std::cerr << "-- Error in Automatic segmentation: only 2 images among T2, DP and FLAIR must be used for automatic segmentation" << std::endl;
143  exit(-1);
144  }
145 
146  if(m_InitMethodType==0)
147  {
148  if((this->GetInputCSFAtlas().IsNull()) || (this->GetInputGMAtlas().IsNull()) || (this->GetInputWMAtlas().IsNull()))
149  {
150  std::cerr << "Error: Some atlas images are missing for the initialization... Exiting..." << std::endl;
151  exit(-1);
152  }
153  }
154 
155  if(m_InitMethodType==1)
156  {
157  if((m_UseT2==false) || (m_UseDP==false))
158  {
159  std::cerr << "Error: Automatic segmentation with Hierarchical DP initialisation requires T2 and DP images... Exiting..." << std::endl;
160  exit(-1);
161  }
162  }
163 
164  if(m_InitMethodType==2)
165  {
166  if(m_UseFLAIR==false)
167  {
168  std::cerr << "Error: Automatic segmentation with Hierarchical FLAIR initialisation requires FLAIR image... Exiting..." << std::endl;
169  exit(-1);
170  }
171  }
172 }
173 
174 
175 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
176 void
178 ::RescaleImages()
179 {
180  m_InputImage_T1_UC = ImageTypeUC::New();
181  m_InputImage_T2_DP_UC = ImageTypeUC::New();
182  m_InputImage_DP_FLAIR_UC = ImageTypeUC::New();
183 
184  double desiredMinimum=0,desiredMaximum=255;
185 
186  typename RescaleFilterType::Pointer rescaleFilter1 = RescaleFilterType::New();
187  rescaleFilter1->SetInput( this->GetInputImage1() );
188  rescaleFilter1->SetOutputMinimum( desiredMinimum );
189  rescaleFilter1->SetOutputMaximum( desiredMaximum );
190  rescaleFilter1->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
191  rescaleFilter1->UpdateLargestPossibleRegion();
192  m_InputImage_T1_UC = rescaleFilter1->GetOutput();
193 
194  typename RescaleFilterType::Pointer rescaleFilter2 = RescaleFilterType::New();
195  rescaleFilter2->SetInput( this->GetInputImage2() );
196  rescaleFilter2->SetOutputMinimum( desiredMinimum );
197  rescaleFilter2->SetOutputMaximum( desiredMaximum );
198  rescaleFilter2->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
199  rescaleFilter2->UpdateLargestPossibleRegion();
200  m_InputImage_T2_DP_UC = rescaleFilter2->GetOutput();
201 
202  typename RescaleFilterType::Pointer rescaleFilter3 = RescaleFilterType::New();
203  rescaleFilter3->SetInput( this->GetInputImage3() );
204  rescaleFilter3->SetOutputMinimum( desiredMinimum );
205  rescaleFilter3->SetOutputMaximum( desiredMaximum );
206  rescaleFilter3->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
207  rescaleFilter3->UpdateLargestPossibleRegion();
208  m_InputImage_DP_FLAIR_UC = rescaleFilter3->GetOutput();
209 }
210 
211 
212 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
213 int
215 ::WriteSolution(std::string filename)
216 {
217  const unsigned int ARows = 1;
218  const unsigned int ACols = 39; // m_NbTissus*(1+m_NbModalities+m_NbModalities*m_NbModalities);
219 
220  typedef itk::Array2D<double> MatrixType;
221  MatrixType matrix(ARows,ACols);
222 
223  unsigned int t = 0;
224  for(unsigned int i = 0; i < m_NbTissus; i++)
225  {
226  matrix[0][t] = m_Alphas[i];
227  t++;
228 
229  GaussianFunctionType::MeanVectorType mu = (m_GaussianModel[i])->GetMean();
230  for(unsigned int j = 0; j < m_NbModalities; j++)
231  {
232  matrix[0][t] = mu[j];
233  t++;
234  }
235 
236  GaussianFunctionType::CovarianceMatrixType covar = (m_GaussianModel[i])->GetCovariance();
237  for(unsigned int k = 0; k < m_NbModalities; k++)
238  {
239  for(unsigned int j = 0; j < m_NbModalities; j++)
240  {
241  matrix[0][t] = covar(k,j);
242  t++;
243  }
244  }
245  }
246 
247  // write out the array2D object
248  typedef itk::CSVNumericObjectFileWriter<double, ARows, ACols> WriterType;
249  WriterType::Pointer writer = WriterType::New();
250 
251  writer->SetFieldDelimiterCharacter(';');
252  writer->SetFileName( filename );
253  writer->SetInput( &matrix );
254  try
255  {
256  writer->Write();
257  }
258  catch (itk::ExceptionObject& exp)
259  {
260  std::cerr << "Exception caught!" << std::endl;
261  std::cerr << exp << std::endl;
262  return -1;
263  }
264 
265  return 0;
266 }
267 
268 
269 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
270 int
272 ::ReadSolution(std::string filename)
273 {
274  typedef itk::Array2D<double> MatrixType;
275  typedef itk::CSVArray2DFileReader<double > ReaderType;
276  ReaderType::Pointer reader = ReaderType::New();
277  reader->SetFileName( filename );
278  reader->SetFieldDelimiterCharacter( ';' );
279  reader->SetStringDelimiterCharacter( '"' );
280  reader->HasColumnHeadersOff();
281  reader->HasRowHeadersOff();
282 
283  // read the file
284  try
285  {
286  reader->Update();
287  }
288  catch (itk::ExceptionObject& exp)
289  {
290  std::cerr << "Exception caught!" << std::endl;
291  std::cerr << exp << std::endl;
292  return -1;
293  }
294 
295  typedef itk::CSVArray2DDataObject<double> DataFrameObjectType;
296  DataFrameObjectType::Pointer dfo = reader->GetOutput();
297  MatrixType matrix = dfo->GetMatrix();
298 
299  unsigned int nbCols = m_NbTissus*(1+m_NbModalities+m_NbModalities*m_NbModalities);
300  if(matrix.rows()!=1 || matrix.cols()!=nbCols)
301  {
302  std::cout<< "wrong type of matrix file... cannot read solution" << std::endl;
303  return -1;
304  }
305 
306  unsigned int t = 0;
307  for(unsigned int i = 0; i < m_NbTissus; i++)
308  {
309  m_Alphas.push_back(matrix[0][t]);
310  t++;
311 
312  GaussianFunctionType::Pointer densityFunction = GaussianFunctionType::New();
313  densityFunction->SetMeasurementVectorSize( m_NbTissus );
314  GaussianFunctionType::MeanVectorType mean( m_NbModalities );
315  GaussianFunctionType::CovarianceMatrixType cov;
316  cov.SetSize( m_NbModalities, m_NbModalities );
317  cov.Fill(0);
318 
319  for(unsigned int j = 0; j < m_NbModalities; j++)
320  {
321  mean[j] = matrix[0][t];
322  t++;
323  }
324 
325  for(unsigned int k = 0; k < m_NbModalities; k++)
326  {
327  for(unsigned int j = 0; j < m_NbModalities; j++)
328  {
329  cov[k][j] = matrix[0][t];
330  t++;
331  }
332  }
333 
334  densityFunction->SetMean( mean );
335  densityFunction->SetCovariance( cov );
336  m_GaussianModel.push_back( densityFunction );
337  }
338 
339  return 0;
340 }
341 
342 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
343 int
345 ::PrintSolution(std::vector<double> alphas, std::vector<GaussianFunctionType::Pointer> model)
346 {
347  unsigned int nbTissus = alphas.size();
348  unsigned int nbModalities = (model[0])->GetMean().Size();
349  for(unsigned int i = 0; i < nbTissus; i++)
350  {
351  std::cout << "* Class: " << i << std::endl;
352  std::cout << " Alpha: " << alphas[i] << std::endl;
353  GaussianFunctionType::MeanVectorType mu = (model[i])->GetMean();
354  std::cout << " Mean: " << std::endl;
355  for(unsigned int j = 0; j < nbModalities; j++)
356  {
357  std::cout << " " << mu[j] << std::endl;
358  }
359 
360  GaussianFunctionType::CovarianceMatrixType covar = (model[i])->GetCovariance();
361  std::cout << " covar: " << std::endl;
362  for(unsigned int k = 0; k < nbModalities; k++)
363  {
364  std::cout << " " << covar(k,0) << " " << covar(k,1) << " " << covar(k,2) << std::endl;
365  }
366  }
367  return 0;
368 }
369 
370 
371 template <typename TInputImage, typename TMaskImage, typename TAtlasImage>
372 void
374 ::Update()
375 {
376  this->CheckInputs();
377  this->RescaleImages();
378 
379  if((m_GaussianModel.size() == m_NbTissus) && (m_Alphas.size() == m_NbTissus))
380  {
381  m_SolutionSet = true;
382  std::cout << "Solution already set..."<< std::endl;
383  if( m_Verbose )
384  {
385  this->PrintSolution(m_Alphas, m_GaussianModel);
386  }
387  }
388  else
389  {
390  if(m_SolutionReadFilename!="")
391  {
392  m_SolutionSet = true;
393  std::cout << "Reading solution file..." << std::endl;
394  if(ReadSolution(m_SolutionReadFilename))
395  {
396  m_SolutionSet = false;
397  }
398  if( m_Verbose )
399  {
400  this->PrintSolution(m_Alphas, m_GaussianModel);
401  }
402  }
403  }
404 
405  if(!m_SolutionSet)
406  {
407  // Choose intializer
408  ModelInitializer::Pointer initializer;
409  switch (m_InitMethodType)
410  {
411  case 0:
412  {
413  initializer = AtlasInitializerType::New();
414  dynamic_cast<AtlasInitializerType *>( initializer.GetPointer() ) ->SetMask( this->GetMask() );
415  dynamic_cast<AtlasInitializerType *>( initializer.GetPointer() ) ->SetInputImage1( m_InputImage_T1_UC );
416  dynamic_cast<AtlasInitializerType *>( initializer.GetPointer() ) ->SetInputImage2( m_InputImage_T2_DP_UC );
417  dynamic_cast<AtlasInitializerType *>( initializer.GetPointer() ) ->SetInputImage3( m_InputImage_DP_FLAIR_UC );
418  dynamic_cast<AtlasInitializerType *>( initializer.GetPointer() ) ->SetAtlasImage1( this->GetInputCSFAtlas() );
419  dynamic_cast<AtlasInitializerType *>( initializer.GetPointer() ) ->SetAtlasImage2( this->GetInputGMAtlas() );
420  dynamic_cast<AtlasInitializerType *>( initializer.GetPointer() ) ->SetAtlasImage3( this->GetInputWMAtlas() );
421  std::cout<< "Choosen initializer: Atlas" << std::endl;
422  break;
423  }
424  case 1:
425  {
426  bool use_HierarFLAIR = false;
427  initializer = HierarchicalType::New();
428  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetMask( this->GetMask() );
429  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetInputImage1( m_InputImage_T1_UC );
430  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetInputImage2( m_InputImage_T2_DP_UC );
431  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetInputImage3( m_InputImage_DP_FLAIR_UC );
432  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetThirdIsFLAIR( use_HierarFLAIR );
433  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetRobust( m_RejRatioHierar );
434  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetTol( m_Tol );
435  std::cout<< "Choosen initializer: Hierarchical DP " << std::endl;
436  break;
437  }
438  case 2:
439  {
440  bool use_HierarFLAIR = true;
441  initializer = HierarchicalType::New();
442  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetMask( this->GetMask() );
443  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetInputImage1( m_InputImage_T1_UC );
444  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetInputImage2( m_InputImage_T2_DP_UC );
445  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetInputImage3( m_InputImage_DP_FLAIR_UC );
446  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetThirdIsFLAIR( use_HierarFLAIR );
447  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetRobust( m_RejRatioHierar );
448  dynamic_cast<HierarchicalType *>( initializer.GetPointer() ) ->SetTol( m_Tol );
449  std::cout<< "Choosen initializer: Hierarchical FLAIR" << std::endl;
450  break;
451  }
452  default:
453  {
454  std::cerr<< "-- Error in compute solution filter: initialisation failed" << std::endl;
455  exit(-1);
456  }
457  }//switch InitMethod
458 
459  std::cout << "Computing initialization for EM..." << std::endl;
460  initializer->Update();
461  std::vector<GaussianFunctionType::Pointer> initia = initializer->GetInitialization();
462  std::vector<double> initiaAlphas = initializer->GetAlphas();
463 
464  std::cout << "Computing gaussian model..." << std::endl;
465  typename GaussianREMEstimatorType::Pointer estimator = GaussianREMEstimatorType::New();
466  estimator ->SetMaxIterationsConc( m_EmIter_concentration );
467  estimator ->SetStremMode( m_EM_before_concentration );
468  estimator ->SetRejectionRatio( m_RejRatio );
469  estimator ->SetMaxIterations( m_EmIter );
470  estimator ->SetModelMinDistance( m_MinDistance );
471  estimator ->SetMask( this->GetMask() );
472  estimator ->SetInputImage1( m_InputImage_T1_UC );
473  estimator ->SetInputImage2( m_InputImage_T2_DP_UC );
474  estimator ->SetInputImage3( m_InputImage_DP_FLAIR_UC );
475  estimator ->SetVerbose( m_Verbose );
476 
477  itk::CStyleCommand::Pointer callback = itk::CStyleCommand::New();
478  callback ->SetCallback(eventCallback);
479  estimator ->AddObserver(itk::ProgressEvent(), callback );
480  estimator ->SetInitialGaussianModel(initia);
481  estimator ->SetInitialAlphas(initiaAlphas);
482  estimator ->Update();
483 
484  m_GaussianModel = estimator->GetGaussianModel();
485  m_Alphas = estimator->GetAlphas();
486 
487  this->SortGaussianModel();
488 
489  if( m_Verbose )
490  {
491  std::cout << std::endl;
492  std::cout<< "EM summary: " << std::endl << std::endl;
493  std::cout<< "* Initial model: " << std::endl;
494  PrintSolution(initiaAlphas, initia);
495  std::cout << std::endl;
496  std::cout<< "* Final model: " << std::endl;
497  this->PrintSolution(m_Alphas, m_GaussianModel);
498  std::cout << std::endl;
499  }
500  }
501 
502  // Print NABT solution if necessary
503  if(m_SolutionWriteFilename!="")
504  {
505  std::cout << "Writing solution in csv file..." << std::endl;
506  WriteSolution(m_SolutionWriteFilename);
507  }
508 }
509 
510 } //end of namespace anima
Class computing the 3-class GMM respresenting the NABT, where each Gaussian represents one of the bra...
itk::SmartPointer< Self > Pointer
Class initializing a gaussian mixture with hierarchical information It uses &#39;a priori&#39; knowledge of t...
itk::Image< PixelTypeUC, 3 > ImageTypeUC
void eventCallback(itk::Object *caller, const itk::EventObject &event, void *clientData)
itk::SmartPointer< Self > Pointer