ANIMA  4.0
animaBaseProbabilisticTractographyImageFilter.hxx
Go to the documentation of this file.
2 
3 #include <itkImageRegionIteratorWithIndex.h>
4 #include <itkLinearInterpolateImageFunction.h>
5 
6 #include <itkImageRegionIterator.h>
7 #include <itkImageRegionConstIterator.h>
8 #include <itkExtractImageFilter.h>
9 #include <itkImageMomentsCalculator.h>
10 
11 #include <animaVectorOperations.h>
13 
14 #include <vnl/algo/vnl_matrix_inverse.h>
15 
16 #include <vtkPointData.h>
17 #include <vtkCellData.h>
18 #include <vtkDoubleArray.h>
19 
20 #include <animaKMeansFilter.h>
21 
22 #include <ctime>
23 
24 namespace anima
25 {
26 
27 template <class TInputModelImageType>
28 BaseProbabilisticTractographyImageFilter <TInputModelImageType>
30 {
31  m_PointsToProcess.clear();
32 
33  m_NumberOfFibersPerPixel = 1;
34  m_NumberOfParticles = 1000;
35  m_MinimalNumberOfParticlesPerClass = 10;
36 
37  m_ResamplingThreshold = 0.8;
38 
39  m_StepProgression = 1.0;
40 
41  m_KappaOfPriorDistribution = 30.0;
42 
43  m_MinLengthFiber = 10.0;
44  m_MaxLengthFiber = 150.0;
45 
46  m_PositionDistanceFuseThreshold = 0.5;
47  m_KappaSplitThreshold = 30.0;
48 
49  m_ClusterDistance = 1;
50 
51  m_ComputeLocalColors = true;
52  m_MAPMergeFibers = true;
53 
54  m_InitialColinearityDirection = Center;
55  m_InitialDirectionMode = Weight;
56 
57  m_Generators.clear();
58 
59  m_HighestProcessedSeed = 0;
60  m_ProgressReport = 0;
61 }
62 
63 template <class TInputModelImageType>
65 ::~BaseProbabilisticTractographyImageFilter()
66 {
67  if (m_ProgressReport)
68  delete m_ProgressReport;
69 }
70 
71 template <class TInputModelImageType>
72 void
74 ::Update()
75 {
76  this->PrepareTractography();
77  m_Output = vtkPolyData::New();
78 
79  if (m_ProgressReport)
80  delete m_ProgressReport;
81 
82  unsigned int stepData = std::min((int)m_PointsToProcess.size(),100);
83  if (stepData == 0)
84  stepData = 1;
85 
86  unsigned int numSteps = std::floor(m_PointsToProcess.size() / (double)stepData);
87  if (m_PointsToProcess.size() % stepData != 0)
88  numSteps++;
89 
90  m_ProgressReport = new itk::ProgressReporter(this,0,numSteps);
91 
92  FiberProcessVectorType resultFibers;
93  ListType resultWeights;
94 
95  trackerArguments tmpStr;
96  tmpStr.trackerPtr = this;
97  tmpStr.resultFibersFromThreads.resize(this->GetNumberOfWorkUnits());
98  tmpStr.resultWeightsFromThreads.resize(this->GetNumberOfWorkUnits());
99 
100  for (unsigned int i = 0;i < this->GetNumberOfWorkUnits();++i)
101  {
102  tmpStr.resultFibersFromThreads[i] = resultFibers;
103  tmpStr.resultWeightsFromThreads[i] = resultWeights;
104  }
105 
106  this->GetMultiThreader()->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
107  this->GetMultiThreader()->SetSingleMethod(this->ThreadTracker,&tmpStr);
108  this->GetMultiThreader()->SingleMethodExecute();
109 
110  for (unsigned int j = 0;j < this->GetNumberOfWorkUnits();++j)
111  {
112  resultFibers.insert(resultFibers.end(),tmpStr.resultFibersFromThreads[j].begin(),tmpStr.resultFibersFromThreads[j].end());
113  resultWeights.insert(resultWeights.end(),tmpStr.resultWeightsFromThreads[j].begin(),tmpStr.resultWeightsFromThreads[j].end());
114  }
115 
116  std::cout << "\nKept " << resultFibers.size() << " fibers after filtering" << std::endl;
117  this->createVTKOutput(resultFibers, resultWeights);
118 }
119 
120 template <class TInputModelImageType>
121 void
123 ::PrepareTractography()
124 {
125  if (!m_B0Image)
126  itkExceptionMacro("No B0 image, required");
127 
128  if (!m_NoiseImage)
129  itkExceptionMacro("No sigma square noise image, required");
130 
131  m_B0Interpolator = ScalarInterpolatorType::New();
132  m_B0Interpolator->SetInputImage(m_B0Image);
133 
134  m_NoiseInterpolator = ScalarInterpolatorType::New();
135  m_NoiseInterpolator->SetInputImage(m_NoiseImage);
136 
137  // Initialize random generator
138  m_Generators.resize(this->GetNumberOfWorkUnits());
139 
140  std::mt19937 motherGenerator(time(0));
141 
142  for (unsigned int i = 0;i < this->GetNumberOfWorkUnits();++i)
143  m_Generators[i] = std::mt19937(motherGenerator());
144 
145  bool is2d = m_InputModelImage->GetLargestPossibleRegion().GetSize()[2] == 1;
146  if (is2d && (m_InitialColinearityDirection == Top))
147  m_InitialColinearityDirection = Front;
148  if (is2d && (m_InitialColinearityDirection == Bottom))
149  m_InitialColinearityDirection = Back;
150 
151  // If needed, ensure DWI gravity center is computed
152  if ((m_InitialColinearityDirection == Outward)||(m_InitialColinearityDirection == Center))
153  {
154  itk::ImageMomentsCalculator <ScalarImageType>::Pointer momentsCalculator = itk::ImageMomentsCalculator <ScalarImageType>::New();
155  momentsCalculator->SetImage(m_B0Image);
156  momentsCalculator->Compute();
157  m_DWIGravityCenter = momentsCalculator->GetCenterOfGravity();
158  }
159 
160  typedef itk::ImageRegionIteratorWithIndex <MaskImageType> MaskImageIteratorType;
161 
162  MaskImageIteratorType maskItr(m_SeedMask, m_InputModelImage->GetLargestPossibleRegion());
163  m_PointsToProcess.clear();
164 
165  IndexType tmpIndex;
166  PointType tmpPoint;
167  ContinuousIndexType realIndex;
168 
169  m_FilteringValues.clear();
170  double startN = -0.5 + 1.0 / (2.0 * m_NumberOfFibersPerPixel);
171  double stepN = 1.0 / m_NumberOfFibersPerPixel;
172  FiberType tmpFiber(1);
173 
174  if (m_FilterMask)
175  {
176  MaskImageIteratorType filterItr(m_FilterMask, m_InputModelImage->GetLargestPossibleRegion());
177  while (!filterItr.IsAtEnd())
178  {
179  if (filterItr.Get() == 0)
180  {
181  ++filterItr;
182  continue;
183  }
184 
185  bool isAlreadyIn = false;
186  for (unsigned int i = 0;i < m_FilteringValues.size();++i)
187  {
188  if (m_FilteringValues[i] == filterItr.Get())
189  {
190  isAlreadyIn = true;
191  break;
192  }
193  }
194 
195  if (!isAlreadyIn)
196  m_FilteringValues.push_back(filterItr.Get());
197 
198  ++filterItr;
199  }
200  }
201 
202  while (!maskItr.IsAtEnd())
203  {
204  if (maskItr.Get() == 0)
205  {
206  ++maskItr;
207  continue;
208  }
209 
210  tmpIndex = maskItr.GetIndex();
211 
212  if (is2d)
213  {
214  realIndex[2] = tmpIndex[2];
215  for (unsigned int j = 0;j < m_NumberOfFibersPerPixel;++j)
216  {
217  realIndex[1] = tmpIndex[1] + startN + j * stepN;
218  for (unsigned int i = 0;i < m_NumberOfFibersPerPixel;++i)
219  {
220  realIndex[0] = tmpIndex[0] + startN + i * stepN;
221  m_SeedMask->TransformContinuousIndexToPhysicalPoint(realIndex,tmpPoint);
222  tmpFiber[0] = tmpPoint;
223  m_PointsToProcess.push_back(tmpFiber);
224  }
225  }
226  }
227  else
228  {
229  for (unsigned int k = 0;k < m_NumberOfFibersPerPixel;++k)
230  {
231  realIndex[2] = tmpIndex[2] + startN + k * stepN;
232  for (unsigned int j = 0;j < m_NumberOfFibersPerPixel;++j)
233  {
234  realIndex[1] = tmpIndex[1] + startN + j * stepN;
235  for (unsigned int i = 0;i < m_NumberOfFibersPerPixel;++i)
236  {
237  realIndex[0] = tmpIndex[0] + startN + i * stepN;
238 
239  m_SeedMask->TransformContinuousIndexToPhysicalPoint(realIndex,tmpPoint);
240  tmpFiber[0] = tmpPoint;
241  m_PointsToProcess.push_back(tmpFiber);
242  }
243  }
244  }
245  }
246 
247  ++maskItr;
248  }
249 
250  std::cout << "Generated " << m_PointsToProcess.size() << " seed points from ROI mask" << std::endl;
251 }
252 
253 template <class TInputModelImageType>
256 {
257  typedef itk::LinearInterpolateImageFunction <InputModelImageType> InternalInterpolatorType;
258 
259  typename InternalInterpolatorType::Pointer outInterpolator = InternalInterpolatorType::New();
260  outInterpolator->SetInputImage(m_InputModelImage);
261 
262  outInterpolator->Register();
263  return outInterpolator;
264 }
265 
266 template <class TInputModelImageType>
267 ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION
269 ::ThreadTracker(void *arg)
270 {
271  itk::MultiThreaderBase::WorkUnitInfo *threadArgs = (itk::MultiThreaderBase::WorkUnitInfo *)arg;
272  unsigned int nbThread = threadArgs->WorkUnitID;
273 
274  trackerArguments *tmpArg = (trackerArguments *)threadArgs->UserData;
275  tmpArg->trackerPtr->ThreadTrack(nbThread,tmpArg->resultFibersFromThreads[nbThread],tmpArg->resultWeightsFromThreads[nbThread]);
276 
277  return ITK_THREAD_RETURN_DEFAULT_VALUE;
278 }
279 
280 template <class TInputModelImageType>
281 void
283 ::ThreadTrack(unsigned int numThread, FiberProcessVectorType &resultFibers,
284  ListType &resultWeights)
285 {
286  bool continueLoop = true;
287  unsigned int highestToleratedSeedIndex = m_PointsToProcess.size();
288 
289  unsigned int stepData = std::min((int)m_PointsToProcess.size(),100);
290  if (stepData == 0)
291  stepData = 1;
292 
293  while (continueLoop)
294  {
295  m_LockHighestProcessedSeed.lock();
296 
297  if (m_HighestProcessedSeed >= highestToleratedSeedIndex)
298  {
299  m_LockHighestProcessedSeed.unlock();
300  continueLoop = false;
301  continue;
302  }
303 
304  unsigned int startPoint = m_HighestProcessedSeed;
305  unsigned int endPoint = m_HighestProcessedSeed + stepData;
306  if (endPoint > highestToleratedSeedIndex)
307  endPoint = highestToleratedSeedIndex;
308 
309  m_HighestProcessedSeed = endPoint;
310 
311  m_LockHighestProcessedSeed.unlock();
312 
313  this->ThreadedTrackComputer(numThread,resultFibers,resultWeights,startPoint,endPoint);
314 
315  m_LockHighestProcessedSeed.lock();
316  m_ProgressReport->CompletedPixel();
317  m_LockHighestProcessedSeed.unlock();
318  }
319 }
320 
321 template <class TInputModelImageType>
322 void
324 ::ThreadedTrackComputer(unsigned int numThread, FiberProcessVectorType &resultFibers,
325  ListType &resultWeights, unsigned int startSeedIndex,
326  unsigned int endSeedIndex)
327 {
328  InterpolatorPointer modelInterpolator = this->GetModelInterpolator();
329  FiberProcessVectorType tmpFibers;
330  ListType tmpWeights;
331  ContinuousIndexType startIndex;
332 
333  for (unsigned int i = startSeedIndex;i < endSeedIndex;++i)
334  {
335  m_SeedMask->TransformPhysicalPointToContinuousIndex(m_PointsToProcess[i][0],startIndex);
336 
337  tmpFibers = this->ComputeFiber(m_PointsToProcess[i], modelInterpolator, numThread, tmpWeights);
338 
339  tmpFibers = this->FilterOutputFibers(tmpFibers, tmpWeights);
340 
341  for (unsigned int j = 0;j < tmpFibers.size();++j)
342  {
343  if (tmpFibers[j].size() > m_MinLengthFiber / m_StepProgression)
344  {
345  resultFibers.push_back(tmpFibers[j]);
346  resultWeights.push_back(tmpWeights[j]);
347  }
348  }
349  }
350 }
351 
352 template <class TInputModelImageType>
355 ::FilterOutputFibers(FiberProcessVectorType &fibers, ListType &weights)
356 {
357  FiberProcessVectorType resVal;
358  ListType tmpWeights = weights;
359  weights.clear();
360 
361  if ((m_FilteringValues.size() > 0)||(m_ForbiddenMask))
362  {
363  MembershipType touchingLabels;
364  IndexType tmpIndex;
365  PointType tmpPoint;
366 
367  for (unsigned int i = 0;i < fibers.size();++i)
368  {
369  touchingLabels.clear();
370  bool forbiddenTouched = false;
371 
372  for (unsigned int j = 0;j < fibers[i].size();++j)
373  {
374  tmpPoint = fibers[i][j];
375  m_SeedMask->TransformPhysicalPointToIndex(tmpPoint,tmpIndex);
376 
377  unsigned int maskValue = 0;
378  unsigned int forbiddenMaskValue = 0;
379 
380  if (m_FilterMask)
381  maskValue = m_FilterMask->GetPixel(tmpIndex);
382 
383  if (m_ForbiddenMask)
384  forbiddenMaskValue = m_ForbiddenMask->GetPixel(tmpIndex);
385 
386  if (forbiddenMaskValue != 0)
387  {
388  forbiddenTouched = true;
389  break;
390  }
391 
392  if (maskValue != 0)
393  {
394  bool alreadyIn = false;
395  for (unsigned int k = 0;k < touchingLabels.size();++k)
396  {
397  if (maskValue == touchingLabels[k])
398  {
399  alreadyIn = true;
400  break;
401  }
402  }
403 
404  if (!alreadyIn)
405  touchingLabels.push_back(maskValue);
406  }
407  }
408 
409  if (forbiddenTouched)
410  continue;
411 
412  if (touchingLabels.size() == m_FilteringValues.size())
413  {
414  resVal.push_back(fibers[i]);
415  weights.push_back(tmpWeights[i]);
416  }
417  }
418  }
419  else
420  {
421  resVal = fibers;
422  weights = tmpWeights;
423  }
424 
425  return resVal;
426 }
427 
428 template <class TInputModelImageType>
429 void
431 ::createVTKOutput(FiberProcessVectorType &filteredFibers, ListType &filteredWeights)
432 {
433  m_Output = vtkPolyData::New();
434  m_Output->Initialize();
435  m_Output->Allocate();
436 
437  vtkSmartPointer <vtkPoints> myPoints = vtkPoints::New();
438  vtkSmartPointer <vtkDoubleArray> weights = vtkDoubleArray::New();
439  weights->SetNumberOfComponents(1);
440  weights->SetName("Fiber weights");
441 
442  for (unsigned int i = 0;i < filteredFibers.size();++i)
443  {
444  unsigned int npts = filteredFibers[i].size();
445  vtkIdType* ids = new vtkIdType[npts];
446 
447  for (unsigned int j = 0;j < npts;++j)
448  {
449  ids[j] = myPoints->InsertNextPoint(filteredFibers[i][j][0],filteredFibers[i][j][1],filteredFibers[i][j][2]);
450  weights->InsertNextValue(filteredWeights[i]);
451  }
452 
453  m_Output->InsertNextCell (VTK_POLY_LINE, npts, ids);
454  delete[] ids;
455  }
456 
457  m_Output->SetPoints(myPoints);
458  if (m_ComputeLocalColors)
459  this->ComputeAdditionalScalarMaps();
460 
461  // Add particle weights to data
462  m_Output->GetPointData()->AddArray(weights);
463 }
464 
465 template <class TInputModelImageType>
468 ::ComputeFiber(FiberType &fiber, InterpolatorPointer &modelInterpolator,
469  unsigned int numThread, ListType &resultWeights)
470 {
471  unsigned int numberOfClasses = 1;
472 
473  FiberWorkType fiberComputationData;
474  fiberComputationData.fiberParticles.resize(m_NumberOfParticles);
475  for (unsigned int i = 0;i < m_NumberOfParticles;++i)
476  fiberComputationData.fiberParticles[i] = fiber;
477 
478  fiberComputationData.particleWeights = ListType(m_NumberOfParticles, 1.0 / m_NumberOfParticles);
479  fiberComputationData.stoppedParticles = std::vector <bool> (m_NumberOfParticles,false);
480  fiberComputationData.classSizes = MembershipType(numberOfClasses,m_NumberOfParticles);
481  fiberComputationData.classWeights = ListType(numberOfClasses,1.0 / numberOfClasses);
482  //We need membership vectors in each direction
483  fiberComputationData.classMemberships = MembershipType(m_NumberOfParticles,0);
484  fiberComputationData.reverseClassMemberships.resize(numberOfClasses);
485  MembershipType tmpVec(m_NumberOfParticles);
486  for (unsigned int j = 0;j < m_NumberOfParticles;++j)
487  tmpVec[j] = j;
488  fiberComputationData.reverseClassMemberships[0] = tmpVec;
489 
490  ListType logWeightVals(m_NumberOfParticles, 1.0 / m_NumberOfParticles);
491  ListType oldFiberWeights(m_NumberOfParticles, 1.0 / m_NumberOfParticles);
492  std::vector <bool> emptyClasses;
493  ListType tmpVector(m_NumberOfParticles,0);
494 
495  ListType logWeightSums(numberOfClasses,0);
496  ListType effectiveNumberOfParticles(numberOfClasses,0);
497 
498  DirectionVectorType previousDirections(m_NumberOfParticles);
499 
500  // Data structures for resampling
501  FiberProcessVectorType fiberParticlesCopy;
502  DirectionVectorType previousDirectionsCopy;
503  ListType weightSpecificClassValues;
504  FiberProcessVectorType fiberTrash;
505  std::vector <bool> usedFibers;
506 
507  // Here to constrain directions to 2D plane if needed
508  bool is2d = m_InputModelImage->GetLargestPossibleRegion().GetSize()[2] == 1;
509 
510  VectorType modelValue(m_ModelDimension);
511 
512  Vector3DType sampling_direction(0.0), newDirection;
513  PointType currentPoint;
514  ContinuousIndexType currentIndex, newIndex;
515  IndexType closestIndex;
516 
517  unsigned int numIter = 0;
518  bool stopLoop = false;
519  while (!stopLoop)
520  {
521  ++numIter;
522 
523  // Store previous weights for the resampling step
524  if (numIter > 1)
525  oldFiberWeights = fiberComputationData.particleWeights;
526 
527  logWeightSums.resize(numberOfClasses);
528  std::fill(logWeightSums.begin(),logWeightSums.end(),0.0);
529  for (unsigned int i = 0;i < m_NumberOfParticles;++i)
530  {
531  // Do not compute trashed fibers
532  if (fiberComputationData.stoppedParticles[i])
533  continue;
534 
535  currentPoint = fiberComputationData.fiberParticles[i].back();
536 
537  m_SeedMask->TransformPhysicalPointToContinuousIndex(currentPoint,currentIndex);
538 
539  // Trash fiber if it goes outside of the brain
540  if (!modelInterpolator->IsInsideBuffer(currentIndex))
541  {
542  fiberComputationData.stoppedParticles[i] = true;
543  fiberComputationData.particleWeights[i] = 0;
544  continue;
545  }
546 
547  // Trash fiber if it goes through the cut mask
548  m_SeedMask->TransformPhysicalPointToIndex(currentPoint,closestIndex);
549 
550  if (m_CutMask)
551  {
552  if (m_CutMask->GetPixel(closestIndex) != 0)
553  {
554  fiberComputationData.stoppedParticles[i] = true;
555  fiberComputationData.particleWeights[i] = 0;
556  continue;
557  }
558  }
559 
560  // Computes diffusion information at current position
561  modelValue.Fill(0.0);
562  double estimatedNoiseValue = 20.0;
563  this->ComputeModelValue(modelInterpolator, currentIndex, modelValue);
564  double estimatedB0Value = m_B0Interpolator->EvaluateAtContinuousIndex(currentIndex);
565  estimatedNoiseValue = m_NoiseInterpolator->EvaluateAtContinuousIndex(currentIndex);
566 
567  if (!this->CheckModelProperties(estimatedB0Value,estimatedNoiseValue,modelValue,numThread))
568  {
569  fiberComputationData.stoppedParticles[i] = true;
570  fiberComputationData.particleWeights[i] = 0;
571  continue;
572  }
573 
574  // Set initial direction to the principal eigenvector of the tensor
575  if (numIter == 1)
576  {
577  Vector3DType initDir(0.0);
578  switch (m_InitialColinearityDirection)
579  {
580  case Top:
581  initDir[2] = 1;
582  break;
583  case Bottom:
584  initDir[2] = -1;
585  break;
586  case Left:
587  initDir[0] = -1;
588  break;
589  case Right:
590  initDir[0] = 1;
591  break;
592  case Front:
593  initDir[1] = -1;
594  break;
595  case Back:
596  initDir[1] = 1;
597  break;
598  case Outward:
599  for (unsigned int j = 0;j < InputModelImageType::ImageDimension;++j)
600  initDir[j] = currentPoint[j] - m_DWIGravityCenter[j];
601  break;
602  case Center:
603  default:
604  for (unsigned int j = 0;j < InputModelImageType::ImageDimension;++j)
605  initDir[j] = m_DWIGravityCenter[j] - currentPoint[j];
606  break;
607  }
608 
609  if (is2d)
610  initDir[2] = 0;
611  initDir.Normalize();
612 
613  previousDirections[i] = this->InitializeFirstIterationFromModel(initDir,modelValue,numThread);
614  }
615 
616  // Propose a new direction based on the previous one and the diffusion information at current position
617  double log_prior = 0, log_proposal = 0;
618  newDirection = this->ProposeNewDirection(previousDirections[i], modelValue, sampling_direction, log_prior,
619  log_proposal, m_Generators[numThread], numThread);
620 
621  // Update the position of the particle
622  for (unsigned int j = 0;j < InputModelImageType::ImageDimension;++j)
623  currentPoint[j] += m_StepProgression * newDirection[j];
624 
625  // Log-weight update must be done at new position (except for prior and proposal)
626  m_SeedMask->TransformPhysicalPointToContinuousIndex(currentPoint,newIndex);
627 
628  // Set the new proposed direction as the current direction
629  previousDirections[i] = newDirection;
630 
631  modelValue.Fill(0.0);
632 
633  if (!modelInterpolator->IsInsideBuffer(newIndex))
634  {
635  fiberComputationData.stoppedParticles[i] = true;
636  fiberComputationData.particleWeights[i] = 0;
637  continue;
638  }
639 
640  fiberComputationData.fiberParticles[i].push_back(currentPoint);
641 
642  this->ComputeModelValue(modelInterpolator, newIndex, modelValue);
643  estimatedB0Value = m_B0Interpolator->EvaluateAtContinuousIndex(newIndex);
644  estimatedNoiseValue = m_NoiseInterpolator->EvaluateAtContinuousIndex(newIndex);
645 
646  // Update the weight of the particle
647  double updateWeightLogVal = this->ComputeLogWeightUpdate(estimatedB0Value, estimatedNoiseValue, newDirection,
648  modelValue, log_prior, log_proposal, numThread);
649 
650  logWeightVals[i] = updateWeightLogVal + anima::safe_log(oldFiberWeights[i]);
651  }
652 
653  // Continue only if some particles are still moving
654  stopLoop = true;
655  for (unsigned int i = 0;i < m_NumberOfParticles;++i)
656  {
657  if (!fiberComputationData.stoppedParticles[i])
658  {
659  stopLoop = false;
660  break;
661  }
662  }
663 
664  if (stopLoop)
665  continue;
666 
667  emptyClasses.resize(numberOfClasses);
668  // Computes weight sum for further weight normalization
669  for (unsigned int i = 0;i < numberOfClasses;++i)
670  {
671  emptyClasses[i] = false;
672  unsigned int classSize = fiberComputationData.reverseClassMemberships[i].size();
673  tmpVector.clear();
674 
675  for (unsigned int j = 0;j < classSize;++j)
676  {
677  if (!fiberComputationData.stoppedParticles[fiberComputationData.reverseClassMemberships[i][j]])
678  tmpVector.push_back(logWeightVals[fiberComputationData.reverseClassMemberships[i][j]]);
679  }
680 
681  if (tmpVector.size() != 0)
682  logWeightSums[i] = anima::ExponentialSum(tmpVector);
683  else
684  {
685  logWeightSums[i] = 0;
686  emptyClasses[i] = true;
687  }
688  }
689 
690  // Weight normalization
691  tmpVector.clear();
692  for (unsigned int i = 0;i < numberOfClasses;++i)
693  {
694  if (!emptyClasses[i])
695  tmpVector.push_back(anima::safe_log(fiberComputationData.classWeights[i]) + logWeightSums[i]);
696  }
697 
698  double tmpSum = 0;
699  double logSumTmpVector = anima::ExponentialSum(tmpVector);
700 
701  for (unsigned int i = 0;i < numberOfClasses;++i)
702  {
703  if (!emptyClasses[i])
704  {
705  double t = std::exp(anima::safe_log(fiberComputationData.classWeights[i]) + logWeightSums[i] - logSumTmpVector);
706  fiberComputationData.classWeights[i] = t;
707  tmpSum += t;
708  }
709  else
710  fiberComputationData.classWeights[i] = 0;
711  }
712 
713  for (unsigned int i = 0;i < numberOfClasses;++i)
714  fiberComputationData.classWeights[i] /= tmpSum;
715 
716  for (unsigned int i = 0;i < m_NumberOfParticles;++i)
717  {
718  if (fiberComputationData.stoppedParticles[i])
719  {
720  fiberComputationData.particleWeights[i] = 0;
721  continue;
722  }
723 
724  double tmpWeight = logWeightSums[fiberComputationData.classMemberships[i]];
725 
726  if (std::isfinite(tmpWeight))
727  logWeightVals[i] -= tmpWeight;
728 
729  fiberComputationData.particleWeights[i] = std::exp(logWeightVals[i]);
730  // Q: shouldn't we be treating this case as an empty cluster that shouldn't even exist?
731  }
732 
733  // Resampling if necessary, done class by class
734  effectiveNumberOfParticles.resize(numberOfClasses);
735  std::fill(effectiveNumberOfParticles.begin(),effectiveNumberOfParticles.end(),0);
736  for (unsigned int i = 0;i < m_NumberOfParticles;++i)
737  {
738  double weight = fiberComputationData.particleWeights[i];
739  effectiveNumberOfParticles[fiberComputationData.classMemberships[i]] += weight * weight;
740  }
741 
742  for (unsigned int m = 0;m < numberOfClasses;++m)
743  {
744  if (effectiveNumberOfParticles[m] != 0)
745  effectiveNumberOfParticles[m] = 1.0 / effectiveNumberOfParticles[m];
746  else
747  continue; // Q: shouldn't we be treating this case as an empty cluster that shouldn't even exist? (same as previous Q)
748 
749  // Actual class resampling
750  if (effectiveNumberOfParticles[m] < m_ResamplingThreshold * fiberComputationData.classSizes[m])
751  {
752  weightSpecificClassValues.resize(fiberComputationData.classSizes[m]);
753  previousDirectionsCopy.resize(fiberComputationData.classSizes[m]);
754  fiberParticlesCopy.resize(fiberComputationData.classSizes[m]);
755 
756  for (unsigned int i = 0;i < fiberComputationData.classSizes[m];++i)
757  {
758  unsigned int posIndex = fiberComputationData.reverseClassMemberships[m][i];
759  weightSpecificClassValues[i] = fiberComputationData.particleWeights[posIndex];
760  previousDirectionsCopy[i] = previousDirections[posIndex];
761  fiberParticlesCopy[i] = fiberComputationData.fiberParticles[posIndex];
762  }
763 
764  std::discrete_distribution<> dist(weightSpecificClassValues.begin(),weightSpecificClassValues.end());
765  usedFibers.resize(fiberComputationData.classSizes[m]);
766  std::fill(usedFibers.begin(),usedFibers.end(),false);
767 
768  for (unsigned int i = 0;i < fiberComputationData.classSizes[m];++i)
769  {
770  unsigned int z = dist(m_Generators[numThread]);
771  unsigned int iReal = fiberComputationData.reverseClassMemberships[m][i];
772  previousDirections[iReal] = previousDirectionsCopy[z];
773  fiberComputationData.fiberParticles[iReal] = fiberParticlesCopy[z];
774  // In all of this, we suppose that stopped particles have zero weights and will therefore
775  // be lost when resampling
776  fiberComputationData.stoppedParticles[iReal] = false;
777  usedFibers[z] = true;
778  }
779 
780  for (unsigned int i = 0;i < fiberComputationData.classSizes[m];++i)
781  {
782  unsigned int iReal = fiberComputationData.reverseClassMemberships[m][i];
783 
784  if (!usedFibers[i])
785  {
786  if (oldFiberWeights[iReal] > m_FiberTrashThreshold / m_NumberOfParticles)
787  {
788  if (fiberComputationData.particleWeights[iReal] != 0)
789  fiberParticlesCopy[i].pop_back();
790 
791  // The fiber trash used to contain fibers that were lost with a sufficient weight
792  // However, using way too much memory so removed for now
793  //if (fiberParticlesCopy[i].size() > m_MinLengthFiber / m_StepProgression)
794  // fiberTrash.push_back(fiberParticlesCopy[i]);
795  }
796  }
797  }
798 
799  // Update only weightVals, oldWeightVals will get updated when starting back the loop
800  // Same here for stopped fibers, they get rejected when resampling
801  for (unsigned int i = 0;i < fiberComputationData.classSizes[m];++i)
802  fiberComputationData.particleWeights[fiberComputationData.reverseClassMemberships[m][i]] = 1.0 / fiberComputationData.classSizes[m];
803  }
804  }
805 
806  // We need stopping criterions
807  // Length is easy, given that each step is constant we just need to check the fiber size: numIter
808  // Example :
809  if (numIter > m_MaxLengthFiber / m_StepProgression)
810  stopLoop = true;
811 
812  numberOfClasses = this->UpdateClassesMemberships(fiberComputationData,previousDirections,m_Generators[numThread]);
813 
814  for (unsigned int i = 0;i < fiberComputationData.particleWeights.size();++i)
815  {
816  if (!std::isfinite(fiberComputationData.particleWeights[i]))
817  itkExceptionMacro("Nan weights after update class membership");
818  }
819  }
820 
821  // Now that we're done, if we don't keep individual particles, merge them cluster by cluster
822  if (m_MAPMergeFibers)
823  {
824  FiberProcessVectorType mergedOutput,classMergedOutput;
825  for (unsigned int i = 0;i < numberOfClasses;++i)
826  {
827  this->MergeParticleClassFibers(fiberComputationData,classMergedOutput,i);
828  mergedOutput.insert(mergedOutput.end(),classMergedOutput.begin(),classMergedOutput.end());
829  }
830 
831  fiberComputationData.fiberParticles = mergedOutput;
832  }
833 
834  if (m_MAPMergeFibers)
835  resultWeights = fiberComputationData.classWeights;
836  else
837  resultWeights = fiberComputationData.particleWeights;
838 
839  return fiberComputationData.fiberParticles;
840 }
841 
842 template <class TInputModelImageType>
843 unsigned int
845 ::UpdateClassesMemberships(FiberWorkType &fiberData, DirectionVectorType &directions, std::mt19937 &random_generator)
846 {
847  const unsigned int p = PointType::PointDimension;
848  typedef anima::KMeansFilter <PointType,p> KMeansFilterType;
849  unsigned int numClasses = fiberData.classSizes.size();
850 
851  // Deciding on cluster merges
852  FiberProcessVectorType mapMergedFibersRef, mapMergedFibersFlo;
853  unsigned int newNumClasses = numClasses;
854 
855  MembershipType classesFusion(numClasses);
856  for (unsigned int i = 0;i < numClasses;++i)
857  classesFusion[i] = i;
858 
859  // As described in IPMI, we take the input classes and first try to fuse them
860  // This is based on a range of possible criterions specified by the user
861  if (numClasses > 1)
862  {
863  for (unsigned int i = 0;i < numClasses;++i)
864  {
865  // Fuse test is done on average cluster fiber, if it is an active cluster,
866  // i.e. at least one of its particle is still moving
867  bool activeClass = this->MergeParticleClassFibers(fiberData,mapMergedFibersRef,i);
868  if (!activeClass)
869  continue;
870 
871  for (unsigned int j = i+1;j < numClasses;++j)
872  {
873  if (classesFusion[j] != j)
874  continue;
875 
876  double maxVal = 0;
877  bool activeSubClass = this->MergeParticleClassFibers(fiberData,mapMergedFibersFlo,j);
878  if (!activeSubClass)
879  continue;
880 
881  // Compute a distance between the two clusters, based on user input
882  switch (m_ClusterDistance)
883  {
884  case 0:
885  {
886  // Former method (quickest)
887  unsigned int minSizeFiber = std::min(mapMergedFibersRef[0].size(),mapMergedFibersFlo[0].size());
888 
889  for (unsigned int l = 0;l < minSizeFiber;++l)
890  {
891  double positionDist = anima::ComputeEuclideanDistance(mapMergedFibersRef[0][l], mapMergedFibersFlo[0][l]);
892 
893  if (positionDist > maxVal)
894  maxVal = positionDist;
895 
896  if (maxVal > m_PositionDistanceFuseThreshold)
897  break;
898  }
899 
900  break;
901  }
902 
903  case 1:
904  {
905  // Hausdorff distance
906  for (unsigned int l = 0;l < mapMergedFibersRef[0].size();++l)
907  {
908  double tmpVal = anima::ComputePointToSetDistance(mapMergedFibersRef[0][l], mapMergedFibersFlo[0]);
909 
910  if (tmpVal > maxVal)
911  maxVal = tmpVal;
912 
913  if (maxVal > m_PositionDistanceFuseThreshold)
914  break;
915  }
916 
917  if (maxVal <= m_PositionDistanceFuseThreshold)
918  {
919  for (unsigned int l = 0;l < mapMergedFibersFlo[0].size();++l)
920  {
921  double tmpVal = anima::ComputePointToSetDistance(mapMergedFibersFlo[0][l], mapMergedFibersRef[0]);
922 
923  if (tmpVal > maxVal)
924  maxVal = tmpVal;
925 
926  if (maxVal > m_PositionDistanceFuseThreshold)
927  break;
928  }
929  }
930 
931  break;
932  }
933 
934  case 2:
935  {
936  // Modified Hausdorff distance
937  maxVal = anima::ComputeModifiedDirectedHausdorffDistance(mapMergedFibersRef[0], mapMergedFibersFlo[0]);
938 
939  if (maxVal <= m_PositionDistanceFuseThreshold)
940  maxVal = std::max(maxVal, anima::ComputeModifiedDirectedHausdorffDistance(mapMergedFibersFlo[0], mapMergedFibersRef[0]));
941 
942  break;
943  }
944 
945  default:
946  break;
947  }
948 
949  // If computed distance is smaller than a threshold, we fuse
950  // To do so, an index table (classesFusion) is updated, each of its cells tells
951  // to which new class the current class belongs. newNumClasses is the new number of classes
952  if (maxVal <= m_PositionDistanceFuseThreshold)
953  {
954  newNumClasses--;
955  classesFusion[j] = classesFusion[i];
956  }
957  }
958  }
959 
960  // Some post-processing to have contiguous class numbers as an output
961  // mapFusion will hold the correspondance between non contiguous and contiguous indexes
962  int maxVal = -1;
963  unsigned int currentIndex = 0;
964  std::map <unsigned int, unsigned int> mapFusion;
965  for (unsigned int i = 0;i < numClasses;++i)
966  {
967  if (maxVal < (int)classesFusion[i])
968  {
969  mapFusion.insert(std::make_pair(classesFusion[i],currentIndex));
970  ++currentIndex;
971  maxVal = classesFusion[i];
972  }
973  }
974 
975  for (unsigned int i = 0;i < numClasses;++i)
976  classesFusion[i] = mapFusion[classesFusion[i]];
977  }
978 
979  std::vector <MembershipType> fusedClassesIndexes(newNumClasses);
980  for (unsigned int i = 0;i < numClasses;++i)
981  fusedClassesIndexes[classesFusion[i]].push_back(i);
982 
983  // Now we're done with selecting what to fuse
984  // Therefore, deciding on cluster splits. The trick here is that we don't want to actually really perform fuse
985  // if it is to split right after (for speed reasons). So we'll play with classesFusion indexes all along
986  // to keep track of the original indexes.
987  DirectionVectorType afterMergeClassesDirections(newNumClasses);
988  MembershipType afterMergeNumPoints(newNumClasses,0);
989  std::vector <bool> splitClasses(newNumClasses,false);
990 
991  Vector3DType zeroDirection(0.0);
992  std::fill(afterMergeClassesDirections.begin(),afterMergeClassesDirections.end(),zeroDirection);
993 
994  // Compute average directions after fusion, directions contains the last directions taken by particles
995  for (unsigned int i = 0;i < m_NumberOfParticles;++i)
996  {
997  if (fiberData.particleWeights[i] == 0)
998  continue;
999 
1000  unsigned int classIndex = classesFusion[fiberData.classMemberships[i]];
1001  for (unsigned int j = 0;j < p;++j)
1002  afterMergeClassesDirections[classIndex][j] += directions[i][j];
1003  afterMergeNumPoints[classIndex]++;
1004  }
1005 
1006  unsigned int numSplits = 0;
1007  std::vector < std::pair <unsigned int, double> > afterMergeKappaValues;
1008  // From those averaged directions, we compute a dispersion kappa value, that will be used to decide on split
1009  for (unsigned int i = 0;i < newNumClasses;++i)
1010  {
1011  double norm = 0;
1012  for (unsigned int j = 0;j < p;++j)
1013  norm += afterMergeClassesDirections[i][j] * afterMergeClassesDirections[i][j];
1014  norm = sqrt(norm);
1015 
1016  double R = 1.0;
1017  double kappa = m_KappaSplitThreshold + 1;
1018  if (afterMergeNumPoints[i] != 0)
1019  {
1020  R = norm / afterMergeNumPoints[i];
1021 
1022  if (R*R > 1.0 - 1.0e-16)
1023  R = sqrt(1.0 - 1.0e-16);
1024 
1025  kappa = std::exp( anima::safe_log(R) + anima::safe_log(p-R*R) - anima::safe_log(1-R*R) );
1026  }
1027 
1028  afterMergeKappaValues.push_back(std::make_pair(i,kappa));
1029  // We do not allow any split resulting in less than m_MinimalNumberOfParticlesPerClass
1030  // so testing with respect to 2*m_MinimalNumberOfParticlesPerClass
1031  // If it is ok, and kappa is small enough, there is too much dispersion inside the cluster -> splitting
1032  if ((kappa <= m_KappaSplitThreshold)&&(afterMergeNumPoints[i] >= 2 * m_MinimalNumberOfParticlesPerClass)&&(afterMergeNumPoints[i] != 0))
1033  numSplits++;
1034  else
1035  afterMergeKappaValues[i].second = m_KappaSplitThreshold + 1;
1036  }
1037 
1038  std::partial_sort(afterMergeKappaValues.begin(),afterMergeKappaValues.begin() + numSplits,afterMergeKappaValues.end(),pair_comparator());
1039 
1040  for (unsigned int i = 0;i < numSplits;++i)
1041  splitClasses[afterMergeKappaValues[i].first] = true;
1042 
1043  // Finally apply all this to get our final clusters
1044  // Each split class will be split into two, so new number of classes
1045  // after merge and split is newNumClasses + numSplits
1046  unsigned int finalNumClasses = newNumClasses + numSplits;
1047 
1048  MembershipType newClassesMemberships(m_NumberOfParticles,0);
1049  std::vector <MembershipType> newReverseClassesMemberships(finalNumClasses);
1050  MembershipType newClassSizes(finalNumClasses,0);
1051  ListType newParticleWeights = fiberData.particleWeights;
1052  ListType newClassWeights(finalNumClasses,0);
1053 
1054  unsigned int currentIndex = 0;
1055 
1056  FiberType vectorToCluster;
1057  MembershipType clustering;
1058 
1059  // Now, do the real merge/split part
1060  for (unsigned int i = 0;i < newNumClasses;++i)
1061  {
1062  if (!splitClasses[i])
1063  {
1064  // ith class is just a potential merge of classes. Easy case: just take all particles
1065  // from classes marked as new ith class
1066  for (unsigned int j = 0;j < fusedClassesIndexes[i].size();++j)
1067  {
1068  unsigned int classIndex = fusedClassesIndexes[i][j];
1069  for (unsigned int k = 0;k < fiberData.reverseClassMemberships[classIndex].size();++k)
1070  {
1071  unsigned int particleNumber = fiberData.reverseClassMemberships[classIndex][k];
1072  newClassesMemberships[particleNumber] = currentIndex;
1073  newReverseClassesMemberships[currentIndex].push_back(particleNumber);
1074  }
1075  }
1076 
1077  if (fusedClassesIndexes[i].size() != 1)
1078  {
1079  // Recompute class weights after fusion
1080  newClassWeights[currentIndex] = 0;
1081  for (unsigned int j = 0;j < fusedClassesIndexes[i].size();++j)
1082  {
1083  unsigned int classIndex = fusedClassesIndexes[i][j];
1084  for (unsigned int k = 0;k < fiberData.reverseClassMemberships[classIndex].size();++k)
1085  newClassWeights[currentIndex] += fiberData.classWeights[classIndex] * fiberData.particleWeights[fiberData.reverseClassMemberships[classIndex][k]];
1086  }
1087 
1088  // Recompute particle weights after fusion
1089  for (unsigned int j = 0;j < fusedClassesIndexes[i].size();++j)
1090  {
1091  unsigned int classIndex = fusedClassesIndexes[i][j];
1092  for (unsigned int k = 0;k < fiberData.reverseClassMemberships[classIndex].size();++k)
1093  {
1094  unsigned int posIndex = fiberData.reverseClassMemberships[classIndex][k];
1095  newParticleWeights[posIndex] = std::exp(anima::safe_log(fiberData.classWeights[classIndex]) + anima::safe_log(fiberData.particleWeights[posIndex]) - anima::safe_log(newClassWeights[currentIndex]));
1096  }
1097  }
1098  }
1099  else
1100  newClassWeights[currentIndex] = fiberData.classWeights[fusedClassesIndexes[i][0]];
1101 
1102  ++currentIndex;
1103  }
1104  else
1105  {
1106  // ith class is a split at the end. In that case, first gather all particles
1107  // from classes merged before. Then, plug k-means
1108  vectorToCluster.clear();
1109 
1110  // Gather particles
1111  for (unsigned int j = 0;j < fusedClassesIndexes[i].size();++j)
1112  {
1113  unsigned int classIndex = fusedClassesIndexes[i][j];
1114  for (unsigned int k = 0;k < fiberData.reverseClassMemberships[classIndex].size();++k)
1115  vectorToCluster.push_back(fiberData.fiberParticles[fiberData.reverseClassMemberships[classIndex][k]].back());
1116  }
1117 
1118  clustering.resize(vectorToCluster.size());
1119 
1120  // Now cluster them, loop until the two classes are not empty
1121  bool loopOnClustering = true;
1122  std::uniform_int_distribution <unsigned int> uniInt(0,1);
1123  while (loopOnClustering)
1124  {
1125  for (unsigned int j = 0;j < clustering.size();++j)
1126  clustering[j] = uniInt(random_generator) % 2;
1127 
1128  KMeansFilterType kmFilter;
1129  kmFilter.SetInputData(vectorToCluster);
1130  kmFilter.SetNumberOfClasses(2);
1131  kmFilter.InitializeClassesMemberships(clustering);
1132  kmFilter.SetMaxIterations(100);
1133  kmFilter.SetVerbose(false);
1134 
1135  kmFilter.Update();
1136 
1137  // Otherwise, go ahead and do the splitting
1138  clustering = kmFilter.GetClassesMemberships();
1139 
1140  if ((kmFilter.GetNumberPerClass(0) > 0)&&(kmFilter.GetNumberPerClass(1) > 0))
1141  loopOnClustering = false;
1142  }
1143 
1144  // Now assigne new class indexes to particles, plus update class weights
1145  unsigned int newClassIndex = currentIndex + 1;
1146 
1147  newClassWeights[currentIndex] = 0;
1148  newClassWeights[newClassIndex] = 0;
1149 
1150  unsigned int pos = 0;
1151  for (unsigned int j = 0;j < fusedClassesIndexes[i].size();++j)
1152  {
1153  unsigned int classIndex = fusedClassesIndexes[i][j];
1154  for (unsigned int k = 0;k < fiberData.reverseClassMemberships[classIndex].size();++k)
1155  {
1156  unsigned int classPos = currentIndex + clustering[pos];
1157 
1158  newClassesMemberships[fiberData.reverseClassMemberships[classIndex][k]] = classPos;
1159  newReverseClassesMemberships[classPos].push_back(fiberData.reverseClassMemberships[classIndex][k]);
1160 
1161  newClassWeights[classPos] += fiberData.classWeights[classIndex] * fiberData.particleWeights[fiberData.reverseClassMemberships[classIndex][k]];
1162 
1163  ++pos;
1164  }
1165  }
1166 
1167  // Finally, update particle weights
1168  pos = 0;
1169  for (unsigned int j = 0;j < fusedClassesIndexes[i].size();++j)
1170  {
1171  unsigned int classIndex = fusedClassesIndexes[i][j];
1172  for (unsigned int k = 0;k < fiberData.reverseClassMemberships[classIndex].size();++k)
1173  {
1174  unsigned int classPos = currentIndex + clustering[pos];
1175 
1176  unsigned int posIndex = fiberData.reverseClassMemberships[classIndex][k];
1177  newParticleWeights[posIndex] = std::exp(anima::safe_log(fiberData.classWeights[classIndex]) + anima::safe_log(fiberData.particleWeights[posIndex]) - anima::safe_log(newClassWeights[classPos]));
1178 
1179  ++pos;
1180  }
1181  }
1182 
1183  currentIndex += 2;
1184  }
1185  }
1186 
1187  double tmpSum = 0;
1188  for (unsigned int i = 0;i < finalNumClasses;++i)
1189  {
1190  if (newReverseClassesMemberships[i].size() > 0)
1191  newClassWeights[i] = std::max(1.0e-16,newClassWeights[i]);
1192  tmpSum += newClassWeights[i];
1193  }
1194 
1195  for (unsigned int i = 0;i < finalNumClasses;++i)
1196  newClassWeights[i] /= tmpSum;
1197 
1198  for (unsigned int i = 0;i < finalNumClasses;++i)
1199  newClassSizes[i] = newReverseClassesMemberships[i].size();
1200 
1201  // Replace all fiber data by new ones computed here and we're done
1202  fiberData.classSizes = newClassSizes;
1203  fiberData.classWeights = newClassWeights;
1204  fiberData.classMemberships = newClassesMemberships;
1205  fiberData.particleWeights = newParticleWeights;
1206  fiberData.reverseClassMemberships = newReverseClassesMemberships;
1207 
1208  return finalNumClasses;
1209 }
1210 
1211 template <class TInputModelImageType>
1212 bool
1214 ::MergeParticleClassFibers(FiberWorkType &fiberData, FiberProcessVectorType &outputMerged, unsigned int classNumber)
1215 {
1216  unsigned int numClasses = fiberData.classSizes.size();
1217  outputMerged.clear();
1218  if (classNumber >= numClasses)
1219  return false;
1220 
1221  outputMerged.resize(1);
1222 
1223  std::vector <unsigned int> runningIndexes, stoppedIndexes;
1224  for (unsigned int j = 0;j < fiberData.classSizes[classNumber];++j)
1225  {
1226  if (fiberData.stoppedParticles[fiberData.reverseClassMemberships[classNumber][j]])
1227  stoppedIndexes.push_back(fiberData.reverseClassMemberships[classNumber][j]);
1228  else
1229  runningIndexes.push_back(fiberData.reverseClassMemberships[classNumber][j]);
1230  }
1231 
1232  FiberType classFiber;
1233  FiberType tmpFiber;
1234  unsigned int sizeMerged = 0;
1235  unsigned int p = PointType::GetPointDimension();
1236 
1237  double sumWeights = 0;
1238  for (unsigned int j = 0;j < runningIndexes.size();++j)
1239  sumWeights += fiberData.particleWeights[runningIndexes[j]];
1240 
1241  if (runningIndexes.size() != 0)
1242  {
1243  // Use weights provided
1244  for (unsigned int j = 0;j < runningIndexes.size();++j)
1245  {
1246  double tmpWeight = fiberData.particleWeights[runningIndexes[j]];
1247  if (tmpWeight <= 0)
1248  continue;
1249 
1250  tmpFiber = fiberData.fiberParticles[runningIndexes[j]];
1251  for (unsigned int k = 0;k < tmpFiber.size();++k)
1252  {
1253  if (k < sizeMerged)
1254  {
1255  for (unsigned int l = 0;l < p;++l)
1256  classFiber[k][l] += tmpWeight * tmpFiber[k][l];
1257  }
1258  else
1259  {
1260  if (tmpWeight != 0)
1261  {
1262  sizeMerged++;
1263  classFiber.push_back(tmpFiber[k]);
1264  for (unsigned int l = 0;l < p;++l)
1265  classFiber[k][l] *= tmpWeight;
1266  }
1267  }
1268  }
1269  }
1270 
1271  for (unsigned int j = 0;j < sizeMerged;++j)
1272  {
1273  for (unsigned int k = 0;k < p;++k)
1274  classFiber[j][k] /= sumWeights;
1275  }
1276 
1277  outputMerged[0] = classFiber;
1278  return true;
1279  }
1280 
1281  // Treat all fibers equivalently, first construct groups of equal lengths
1282  std::vector < std::vector <unsigned int> > particleGroups;
1283  std::vector <unsigned int> particleSizes;
1284  for (unsigned int i = 0;i < stoppedIndexes.size();++i)
1285  {
1286  unsigned int particleSize = fiberData.fiberParticles[stoppedIndexes[i]].size();
1287  bool sizeFound = false;
1288  for (unsigned int j = 0;j < particleSizes.size();++j)
1289  {
1290  if (particleSize == particleSizes[j])
1291  {
1292  particleGroups[j].push_back(stoppedIndexes[i]);
1293  sizeFound = true;
1294  break;
1295  }
1296  }
1297 
1298  if (!sizeFound)
1299  {
1300  particleSizes.push_back(particleSize);
1301  std::vector <unsigned int> tmpVec(1,stoppedIndexes[i]);
1302  particleGroups.push_back(tmpVec);
1303  }
1304  }
1305 
1306  // For each group of equal length, build fiber
1307  outputMerged.resize(particleGroups.size());
1308  for (unsigned int i = 0;i < particleGroups.size();++i)
1309  {
1310  classFiber.clear();
1311  sizeMerged = 0;
1312 
1313  for (unsigned int j = 0;j < particleGroups[i].size();++j)
1314  {
1315  tmpFiber = fiberData.fiberParticles[particleGroups[i][j]];
1316  for (unsigned int k = 0;k < tmpFiber.size();++k)
1317  {
1318  if (k < sizeMerged)
1319  {
1320  for (unsigned int l = 0;l < p;++l)
1321  classFiber[k][l] += tmpFiber[k][l];
1322  }
1323  else
1324  {
1325  sizeMerged++;
1326  classFiber.push_back(tmpFiber[k]);
1327  }
1328  }
1329  }
1330 
1331  for (unsigned int j = 0;j < sizeMerged;++j)
1332  {
1333  for (unsigned int k = 0;k < p;++k)
1334  classFiber[j][k] /= particleGroups[i].size();
1335  }
1336 
1337  outputMerged[i] = classFiber;
1338  }
1339 
1340  return false;
1341 }
1342 
1343 } // end of namespace anima
double ExponentialSum(const VectorType &x, const unsigned int NDimension)
double ComputeEuclideanDistance(const VectorType &x1, const VectorType &x2, const unsigned int NDimension)
itk::InterpolateImageFunction< InputModelImageType > InterpolatorType
double ComputePointToSetDistance(const VectorType &x, const std::vector< VectorType > &s)
double ComputeModifiedDirectedHausdorffDistance(const std::vector< VectorType > &s1, const std::vector< VectorType > &s2)
double safe_log(const ScalarType x)
void ThreadTrack(unsigned int numThread, FiberProcessVectorType &resultFibers, ListType &resultWeights)
Doing the thread work dispatch.