3 #include <itkImageRegionIteratorWithIndex.h> 4 #include <itkLinearInterpolateImageFunction.h> 6 #include <itkImageRegionIterator.h> 7 #include <itkImageRegionConstIterator.h> 8 #include <itkExtractImageFilter.h> 9 #include <itkImageMomentsCalculator.h> 14 #include <vnl/algo/vnl_matrix_inverse.h> 16 #include <vtkPointData.h> 17 #include <vtkCellData.h> 18 #include <vtkDoubleArray.h> 27 template <
class TInputModelImageType>
28 BaseProbabilisticTractographyImageFilter <TInputModelImageType>
31 m_PointsToProcess.clear();
33 m_NumberOfFibersPerPixel = 1;
34 m_NumberOfParticles = 1000;
35 m_MinimalNumberOfParticlesPerClass = 10;
37 m_ResamplingThreshold = 0.8;
39 m_StepProgression = 1.0;
41 m_KappaOfPriorDistribution = 30.0;
43 m_MinLengthFiber = 10.0;
44 m_MaxLengthFiber = 150.0;
46 m_PositionDistanceFuseThreshold = 0.5;
47 m_KappaSplitThreshold = 30.0;
49 m_ClusterDistance = 1;
51 m_ComputeLocalColors =
true;
52 m_MAPMergeFibers =
true;
54 m_InitialColinearityDirection = Center;
55 m_InitialDirectionMode = Weight;
59 m_HighestProcessedSeed = 0;
63 template <
class TInputModelImageType>
65 ::~BaseProbabilisticTractographyImageFilter()
68 delete m_ProgressReport;
71 template <
class TInputModelImageType>
76 this->PrepareTractography();
77 m_Output = vtkPolyData::New();
80 delete m_ProgressReport;
82 unsigned int stepData = std::min((
int)m_PointsToProcess.size(),100);
86 unsigned int numSteps = std::floor(m_PointsToProcess.size() / (double)stepData);
87 if (m_PointsToProcess.size() % stepData != 0)
90 m_ProgressReport =
new itk::ProgressReporter(
this,0,numSteps);
95 trackerArguments tmpStr;
96 tmpStr.trackerPtr =
this;
97 tmpStr.resultFibersFromThreads.resize(this->GetNumberOfWorkUnits());
98 tmpStr.resultWeightsFromThreads.resize(this->GetNumberOfWorkUnits());
100 for (
unsigned int i = 0;i < this->GetNumberOfWorkUnits();++i)
102 tmpStr.resultFibersFromThreads[i] = resultFibers;
103 tmpStr.resultWeightsFromThreads[i] = resultWeights;
106 this->GetMultiThreader()->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
107 this->GetMultiThreader()->SetSingleMethod(this->ThreadTracker,&tmpStr);
108 this->GetMultiThreader()->SingleMethodExecute();
110 for (
unsigned int j = 0;j < this->GetNumberOfWorkUnits();++j)
112 resultFibers.insert(resultFibers.end(),tmpStr.resultFibersFromThreads[j].begin(),tmpStr.resultFibersFromThreads[j].end());
113 resultWeights.insert(resultWeights.end(),tmpStr.resultWeightsFromThreads[j].begin(),tmpStr.resultWeightsFromThreads[j].end());
116 std::cout <<
"\nKept " << resultFibers.size() <<
" fibers after filtering" << std::endl;
117 this->createVTKOutput(resultFibers, resultWeights);
120 template <
class TInputModelImageType>
123 ::PrepareTractography()
126 itkExceptionMacro(
"No B0 image, required");
129 itkExceptionMacro(
"No sigma square noise image, required");
131 m_B0Interpolator = ScalarInterpolatorType::New();
132 m_B0Interpolator->SetInputImage(m_B0Image);
134 m_NoiseInterpolator = ScalarInterpolatorType::New();
135 m_NoiseInterpolator->SetInputImage(m_NoiseImage);
138 m_Generators.resize(this->GetNumberOfWorkUnits());
140 std::mt19937 motherGenerator(time(0));
142 for (
unsigned int i = 0;i < this->GetNumberOfWorkUnits();++i)
143 m_Generators[i] = std::mt19937(motherGenerator());
145 bool is2d = m_InputModelImage->GetLargestPossibleRegion().GetSize()[2] == 1;
146 if (is2d && (m_InitialColinearityDirection == Top))
147 m_InitialColinearityDirection = Front;
148 if (is2d && (m_InitialColinearityDirection == Bottom))
149 m_InitialColinearityDirection = Back;
152 if ((m_InitialColinearityDirection == Outward)||(m_InitialColinearityDirection == Center))
154 itk::ImageMomentsCalculator <ScalarImageType>::Pointer momentsCalculator = itk::ImageMomentsCalculator <ScalarImageType>::New();
155 momentsCalculator->SetImage(m_B0Image);
156 momentsCalculator->Compute();
157 m_DWIGravityCenter = momentsCalculator->GetCenterOfGravity();
160 typedef itk::ImageRegionIteratorWithIndex <MaskImageType> MaskImageIteratorType;
162 MaskImageIteratorType maskItr(m_SeedMask, m_InputModelImage->GetLargestPossibleRegion());
163 m_PointsToProcess.clear();
169 m_FilteringValues.clear();
170 double startN = -0.5 + 1.0 / (2.0 * m_NumberOfFibersPerPixel);
171 double stepN = 1.0 / m_NumberOfFibersPerPixel;
176 MaskImageIteratorType filterItr(m_FilterMask, m_InputModelImage->GetLargestPossibleRegion());
177 while (!filterItr.IsAtEnd())
179 if (filterItr.Get() == 0)
185 bool isAlreadyIn =
false;
186 for (
unsigned int i = 0;i < m_FilteringValues.size();++i)
188 if (m_FilteringValues[i] == filterItr.Get())
196 m_FilteringValues.push_back(filterItr.Get());
202 while (!maskItr.IsAtEnd())
204 if (maskItr.Get() == 0)
210 tmpIndex = maskItr.GetIndex();
214 realIndex[2] = tmpIndex[2];
215 for (
unsigned int j = 0;j < m_NumberOfFibersPerPixel;++j)
217 realIndex[1] = tmpIndex[1] + startN + j * stepN;
218 for (
unsigned int i = 0;i < m_NumberOfFibersPerPixel;++i)
220 realIndex[0] = tmpIndex[0] + startN + i * stepN;
221 m_SeedMask->TransformContinuousIndexToPhysicalPoint(realIndex,tmpPoint);
222 tmpFiber[0] = tmpPoint;
223 m_PointsToProcess.push_back(tmpFiber);
229 for (
unsigned int k = 0;k < m_NumberOfFibersPerPixel;++k)
231 realIndex[2] = tmpIndex[2] + startN + k * stepN;
232 for (
unsigned int j = 0;j < m_NumberOfFibersPerPixel;++j)
234 realIndex[1] = tmpIndex[1] + startN + j * stepN;
235 for (
unsigned int i = 0;i < m_NumberOfFibersPerPixel;++i)
237 realIndex[0] = tmpIndex[0] + startN + i * stepN;
239 m_SeedMask->TransformContinuousIndexToPhysicalPoint(realIndex,tmpPoint);
240 tmpFiber[0] = tmpPoint;
241 m_PointsToProcess.push_back(tmpFiber);
250 std::cout <<
"Generated " << m_PointsToProcess.size() <<
" seed points from ROI mask" << std::endl;
253 template <
class TInputModelImageType>
257 typedef itk::LinearInterpolateImageFunction <InputModelImageType> InternalInterpolatorType;
259 typename InternalInterpolatorType::Pointer outInterpolator = InternalInterpolatorType::New();
260 outInterpolator->SetInputImage(m_InputModelImage);
262 outInterpolator->Register();
263 return outInterpolator;
266 template <
class TInputModelImageType>
267 ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION
269 ::ThreadTracker(
void *arg)
271 itk::MultiThreaderBase::WorkUnitInfo *threadArgs = (itk::MultiThreaderBase::WorkUnitInfo *)arg;
272 unsigned int nbThread = threadArgs->WorkUnitID;
274 trackerArguments *tmpArg = (trackerArguments *)threadArgs->UserData;
275 tmpArg->trackerPtr->
ThreadTrack(nbThread,tmpArg->resultFibersFromThreads[nbThread],tmpArg->resultWeightsFromThreads[nbThread]);
277 return ITK_THREAD_RETURN_DEFAULT_VALUE;
280 template <
class TInputModelImageType>
286 bool continueLoop =
true;
287 unsigned int highestToleratedSeedIndex = m_PointsToProcess.size();
289 unsigned int stepData = std::min((
int)m_PointsToProcess.size(),100);
295 m_LockHighestProcessedSeed.lock();
297 if (m_HighestProcessedSeed >= highestToleratedSeedIndex)
299 m_LockHighestProcessedSeed.unlock();
300 continueLoop =
false;
304 unsigned int startPoint = m_HighestProcessedSeed;
305 unsigned int endPoint = m_HighestProcessedSeed + stepData;
306 if (endPoint > highestToleratedSeedIndex)
307 endPoint = highestToleratedSeedIndex;
309 m_HighestProcessedSeed = endPoint;
311 m_LockHighestProcessedSeed.unlock();
313 this->ThreadedTrackComputer(numThread,resultFibers,resultWeights,startPoint,endPoint);
315 m_LockHighestProcessedSeed.lock();
316 m_ProgressReport->CompletedPixel();
317 m_LockHighestProcessedSeed.unlock();
321 template <
class TInputModelImageType>
325 ListType &resultWeights,
unsigned int startSeedIndex,
326 unsigned int endSeedIndex)
333 for (
unsigned int i = startSeedIndex;i < endSeedIndex;++i)
335 m_SeedMask->TransformPhysicalPointToContinuousIndex(m_PointsToProcess[i][0],startIndex);
337 tmpFibers = this->ComputeFiber(m_PointsToProcess[i], modelInterpolator, numThread, tmpWeights);
339 tmpFibers = this->FilterOutputFibers(tmpFibers, tmpWeights);
341 for (
unsigned int j = 0;j < tmpFibers.size();++j)
343 if (tmpFibers[j].size() > m_MinLengthFiber / m_StepProgression)
345 resultFibers.push_back(tmpFibers[j]);
346 resultWeights.push_back(tmpWeights[j]);
352 template <
class TInputModelImageType>
361 if ((m_FilteringValues.size() > 0)||(m_ForbiddenMask))
367 for (
unsigned int i = 0;i < fibers.size();++i)
369 touchingLabels.clear();
370 bool forbiddenTouched =
false;
372 for (
unsigned int j = 0;j < fibers[i].size();++j)
374 tmpPoint = fibers[i][j];
375 m_SeedMask->TransformPhysicalPointToIndex(tmpPoint,tmpIndex);
377 unsigned int maskValue = 0;
378 unsigned int forbiddenMaskValue = 0;
381 maskValue = m_FilterMask->GetPixel(tmpIndex);
384 forbiddenMaskValue = m_ForbiddenMask->GetPixel(tmpIndex);
386 if (forbiddenMaskValue != 0)
388 forbiddenTouched =
true;
394 bool alreadyIn =
false;
395 for (
unsigned int k = 0;k < touchingLabels.size();++k)
397 if (maskValue == touchingLabels[k])
405 touchingLabels.push_back(maskValue);
409 if (forbiddenTouched)
412 if (touchingLabels.size() == m_FilteringValues.size())
414 resVal.push_back(fibers[i]);
415 weights.push_back(tmpWeights[i]);
422 weights = tmpWeights;
428 template <
class TInputModelImageType>
433 m_Output = vtkPolyData::New();
434 m_Output->Initialize();
435 m_Output->Allocate();
437 vtkSmartPointer <vtkPoints> myPoints = vtkPoints::New();
438 vtkSmartPointer <vtkDoubleArray> weights = vtkDoubleArray::New();
439 weights->SetNumberOfComponents(1);
440 weights->SetName(
"Fiber weights");
442 for (
unsigned int i = 0;i < filteredFibers.size();++i)
444 unsigned int npts = filteredFibers[i].size();
445 vtkIdType* ids =
new vtkIdType[npts];
447 for (
unsigned int j = 0;j < npts;++j)
449 ids[j] = myPoints->InsertNextPoint(filteredFibers[i][j][0],filteredFibers[i][j][1],filteredFibers[i][j][2]);
450 weights->InsertNextValue(filteredWeights[i]);
453 m_Output->InsertNextCell (VTK_POLY_LINE, npts, ids);
457 m_Output->SetPoints(myPoints);
458 if (m_ComputeLocalColors)
459 this->ComputeAdditionalScalarMaps();
462 m_Output->GetPointData()->AddArray(weights);
465 template <
class TInputModelImageType>
469 unsigned int numThread,
ListType &resultWeights)
471 unsigned int numberOfClasses = 1;
473 FiberWorkType fiberComputationData;
474 fiberComputationData.fiberParticles.resize(m_NumberOfParticles);
475 for (
unsigned int i = 0;i < m_NumberOfParticles;++i)
476 fiberComputationData.fiberParticles[i] = fiber;
478 fiberComputationData.particleWeights =
ListType(m_NumberOfParticles, 1.0 / m_NumberOfParticles);
479 fiberComputationData.stoppedParticles = std::vector <bool> (m_NumberOfParticles,
false);
480 fiberComputationData.classSizes =
MembershipType(numberOfClasses,m_NumberOfParticles);
481 fiberComputationData.classWeights =
ListType(numberOfClasses,1.0 / numberOfClasses);
483 fiberComputationData.classMemberships =
MembershipType(m_NumberOfParticles,0);
484 fiberComputationData.reverseClassMemberships.resize(numberOfClasses);
486 for (
unsigned int j = 0;j < m_NumberOfParticles;++j)
488 fiberComputationData.reverseClassMemberships[0] = tmpVec;
490 ListType logWeightVals(m_NumberOfParticles, 1.0 / m_NumberOfParticles);
491 ListType oldFiberWeights(m_NumberOfParticles, 1.0 / m_NumberOfParticles);
492 std::vector <bool> emptyClasses;
493 ListType tmpVector(m_NumberOfParticles,0);
495 ListType logWeightSums(numberOfClasses,0);
496 ListType effectiveNumberOfParticles(numberOfClasses,0);
505 std::vector <bool> usedFibers;
508 bool is2d = m_InputModelImage->GetLargestPossibleRegion().GetSize()[2] == 1;
517 unsigned int numIter = 0;
518 bool stopLoop =
false;
525 oldFiberWeights = fiberComputationData.particleWeights;
527 logWeightSums.resize(numberOfClasses);
528 std::fill(logWeightSums.begin(),logWeightSums.end(),0.0);
529 for (
unsigned int i = 0;i < m_NumberOfParticles;++i)
532 if (fiberComputationData.stoppedParticles[i])
535 currentPoint = fiberComputationData.fiberParticles[i].back();
537 m_SeedMask->TransformPhysicalPointToContinuousIndex(currentPoint,currentIndex);
540 if (!modelInterpolator->IsInsideBuffer(currentIndex))
542 fiberComputationData.stoppedParticles[i] =
true;
543 fiberComputationData.particleWeights[i] = 0;
548 m_SeedMask->TransformPhysicalPointToIndex(currentPoint,closestIndex);
552 if (m_CutMask->GetPixel(closestIndex) != 0)
554 fiberComputationData.stoppedParticles[i] =
true;
555 fiberComputationData.particleWeights[i] = 0;
561 modelValue.Fill(0.0);
562 double estimatedNoiseValue = 20.0;
563 this->ComputeModelValue(modelInterpolator, currentIndex, modelValue);
564 double estimatedB0Value = m_B0Interpolator->EvaluateAtContinuousIndex(currentIndex);
565 estimatedNoiseValue = m_NoiseInterpolator->EvaluateAtContinuousIndex(currentIndex);
567 if (!this->CheckModelProperties(estimatedB0Value,estimatedNoiseValue,modelValue,numThread))
569 fiberComputationData.stoppedParticles[i] =
true;
570 fiberComputationData.particleWeights[i] = 0;
578 switch (m_InitialColinearityDirection)
599 for (
unsigned int j = 0;j < InputModelImageType::ImageDimension;++j)
600 initDir[j] = currentPoint[j] - m_DWIGravityCenter[j];
604 for (
unsigned int j = 0;j < InputModelImageType::ImageDimension;++j)
605 initDir[j] = m_DWIGravityCenter[j] - currentPoint[j];
613 previousDirections[i] = this->InitializeFirstIterationFromModel(initDir,modelValue,numThread);
617 double log_prior = 0, log_proposal = 0;
618 newDirection = this->ProposeNewDirection(previousDirections[i], modelValue, sampling_direction, log_prior,
619 log_proposal, m_Generators[numThread], numThread);
622 for (
unsigned int j = 0;j < InputModelImageType::ImageDimension;++j)
623 currentPoint[j] += m_StepProgression * newDirection[j];
626 m_SeedMask->TransformPhysicalPointToContinuousIndex(currentPoint,newIndex);
629 previousDirections[i] = newDirection;
631 modelValue.Fill(0.0);
633 if (!modelInterpolator->IsInsideBuffer(newIndex))
635 fiberComputationData.stoppedParticles[i] =
true;
636 fiberComputationData.particleWeights[i] = 0;
640 fiberComputationData.fiberParticles[i].push_back(currentPoint);
642 this->ComputeModelValue(modelInterpolator, newIndex, modelValue);
643 estimatedB0Value = m_B0Interpolator->EvaluateAtContinuousIndex(newIndex);
644 estimatedNoiseValue = m_NoiseInterpolator->EvaluateAtContinuousIndex(newIndex);
647 double updateWeightLogVal = this->ComputeLogWeightUpdate(estimatedB0Value, estimatedNoiseValue, newDirection,
648 modelValue, log_prior, log_proposal, numThread);
650 logWeightVals[i] = updateWeightLogVal +
anima::safe_log(oldFiberWeights[i]);
655 for (
unsigned int i = 0;i < m_NumberOfParticles;++i)
657 if (!fiberComputationData.stoppedParticles[i])
667 emptyClasses.resize(numberOfClasses);
669 for (
unsigned int i = 0;i < numberOfClasses;++i)
671 emptyClasses[i] =
false;
672 unsigned int classSize = fiberComputationData.reverseClassMemberships[i].size();
675 for (
unsigned int j = 0;j < classSize;++j)
677 if (!fiberComputationData.stoppedParticles[fiberComputationData.reverseClassMemberships[i][j]])
678 tmpVector.push_back(logWeightVals[fiberComputationData.reverseClassMemberships[i][j]]);
681 if (tmpVector.size() != 0)
685 logWeightSums[i] = 0;
686 emptyClasses[i] =
true;
692 for (
unsigned int i = 0;i < numberOfClasses;++i)
694 if (!emptyClasses[i])
695 tmpVector.push_back(
anima::safe_log(fiberComputationData.classWeights[i]) + logWeightSums[i]);
701 for (
unsigned int i = 0;i < numberOfClasses;++i)
703 if (!emptyClasses[i])
705 double t = std::exp(
anima::safe_log(fiberComputationData.classWeights[i]) + logWeightSums[i] - logSumTmpVector);
706 fiberComputationData.classWeights[i] = t;
710 fiberComputationData.classWeights[i] = 0;
713 for (
unsigned int i = 0;i < numberOfClasses;++i)
714 fiberComputationData.classWeights[i] /= tmpSum;
716 for (
unsigned int i = 0;i < m_NumberOfParticles;++i)
718 if (fiberComputationData.stoppedParticles[i])
720 fiberComputationData.particleWeights[i] = 0;
724 double tmpWeight = logWeightSums[fiberComputationData.classMemberships[i]];
726 if (std::isfinite(tmpWeight))
727 logWeightVals[i] -= tmpWeight;
729 fiberComputationData.particleWeights[i] = std::exp(logWeightVals[i]);
734 effectiveNumberOfParticles.resize(numberOfClasses);
735 std::fill(effectiveNumberOfParticles.begin(),effectiveNumberOfParticles.end(),0);
736 for (
unsigned int i = 0;i < m_NumberOfParticles;++i)
738 double weight = fiberComputationData.particleWeights[i];
739 effectiveNumberOfParticles[fiberComputationData.classMemberships[i]] += weight * weight;
742 for (
unsigned int m = 0;m < numberOfClasses;++m)
744 if (effectiveNumberOfParticles[m] != 0)
745 effectiveNumberOfParticles[m] = 1.0 / effectiveNumberOfParticles[m];
750 if (effectiveNumberOfParticles[m] < m_ResamplingThreshold * fiberComputationData.classSizes[m])
752 weightSpecificClassValues.resize(fiberComputationData.classSizes[m]);
753 previousDirectionsCopy.resize(fiberComputationData.classSizes[m]);
754 fiberParticlesCopy.resize(fiberComputationData.classSizes[m]);
756 for (
unsigned int i = 0;i < fiberComputationData.classSizes[m];++i)
758 unsigned int posIndex = fiberComputationData.reverseClassMemberships[m][i];
759 weightSpecificClassValues[i] = fiberComputationData.particleWeights[posIndex];
760 previousDirectionsCopy[i] = previousDirections[posIndex];
761 fiberParticlesCopy[i] = fiberComputationData.fiberParticles[posIndex];
764 std::discrete_distribution<> dist(weightSpecificClassValues.begin(),weightSpecificClassValues.end());
765 usedFibers.resize(fiberComputationData.classSizes[m]);
766 std::fill(usedFibers.begin(),usedFibers.end(),
false);
768 for (
unsigned int i = 0;i < fiberComputationData.classSizes[m];++i)
770 unsigned int z = dist(m_Generators[numThread]);
771 unsigned int iReal = fiberComputationData.reverseClassMemberships[m][i];
772 previousDirections[iReal] = previousDirectionsCopy[z];
773 fiberComputationData.fiberParticles[iReal] = fiberParticlesCopy[z];
776 fiberComputationData.stoppedParticles[iReal] =
false;
777 usedFibers[z] =
true;
780 for (
unsigned int i = 0;i < fiberComputationData.classSizes[m];++i)
782 unsigned int iReal = fiberComputationData.reverseClassMemberships[m][i];
786 if (oldFiberWeights[iReal] > m_FiberTrashThreshold / m_NumberOfParticles)
788 if (fiberComputationData.particleWeights[iReal] != 0)
789 fiberParticlesCopy[i].pop_back();
801 for (
unsigned int i = 0;i < fiberComputationData.classSizes[m];++i)
802 fiberComputationData.particleWeights[fiberComputationData.reverseClassMemberships[m][i]] = 1.0 / fiberComputationData.classSizes[m];
809 if (numIter > m_MaxLengthFiber / m_StepProgression)
812 numberOfClasses = this->UpdateClassesMemberships(fiberComputationData,previousDirections,m_Generators[numThread]);
814 for (
unsigned int i = 0;i < fiberComputationData.particleWeights.size();++i)
816 if (!std::isfinite(fiberComputationData.particleWeights[i]))
817 itkExceptionMacro(
"Nan weights after update class membership");
822 if (m_MAPMergeFibers)
825 for (
unsigned int i = 0;i < numberOfClasses;++i)
827 this->MergeParticleClassFibers(fiberComputationData,classMergedOutput,i);
828 mergedOutput.insert(mergedOutput.end(),classMergedOutput.begin(),classMergedOutput.end());
831 fiberComputationData.fiberParticles = mergedOutput;
834 if (m_MAPMergeFibers)
835 resultWeights = fiberComputationData.classWeights;
837 resultWeights = fiberComputationData.particleWeights;
839 return fiberComputationData.fiberParticles;
842 template <
class TInputModelImageType>
845 ::UpdateClassesMemberships(FiberWorkType &fiberData,
DirectionVectorType &directions, std::mt19937 &random_generator)
847 const unsigned int p = PointType::PointDimension;
849 unsigned int numClasses = fiberData.classSizes.size();
853 unsigned int newNumClasses = numClasses;
856 for (
unsigned int i = 0;i < numClasses;++i)
857 classesFusion[i] = i;
863 for (
unsigned int i = 0;i < numClasses;++i)
867 bool activeClass = this->MergeParticleClassFibers(fiberData,mapMergedFibersRef,i);
871 for (
unsigned int j = i+1;j < numClasses;++j)
873 if (classesFusion[j] != j)
877 bool activeSubClass = this->MergeParticleClassFibers(fiberData,mapMergedFibersFlo,j);
882 switch (m_ClusterDistance)
887 unsigned int minSizeFiber = std::min(mapMergedFibersRef[0].size(),mapMergedFibersFlo[0].size());
889 for (
unsigned int l = 0;l < minSizeFiber;++l)
893 if (positionDist > maxVal)
894 maxVal = positionDist;
896 if (maxVal > m_PositionDistanceFuseThreshold)
906 for (
unsigned int l = 0;l < mapMergedFibersRef[0].size();++l)
913 if (maxVal > m_PositionDistanceFuseThreshold)
917 if (maxVal <= m_PositionDistanceFuseThreshold)
919 for (
unsigned int l = 0;l < mapMergedFibersFlo[0].size();++l)
926 if (maxVal > m_PositionDistanceFuseThreshold)
939 if (maxVal <= m_PositionDistanceFuseThreshold)
952 if (maxVal <= m_PositionDistanceFuseThreshold)
955 classesFusion[j] = classesFusion[i];
963 unsigned int currentIndex = 0;
964 std::map <unsigned int, unsigned int> mapFusion;
965 for (
unsigned int i = 0;i < numClasses;++i)
967 if (maxVal < (
int)classesFusion[i])
969 mapFusion.insert(std::make_pair(classesFusion[i],currentIndex));
971 maxVal = classesFusion[i];
975 for (
unsigned int i = 0;i < numClasses;++i)
976 classesFusion[i] = mapFusion[classesFusion[i]];
979 std::vector <MembershipType> fusedClassesIndexes(newNumClasses);
980 for (
unsigned int i = 0;i < numClasses;++i)
981 fusedClassesIndexes[classesFusion[i]].push_back(i);
989 std::vector <bool> splitClasses(newNumClasses,
false);
992 std::fill(afterMergeClassesDirections.begin(),afterMergeClassesDirections.end(),zeroDirection);
995 for (
unsigned int i = 0;i < m_NumberOfParticles;++i)
997 if (fiberData.particleWeights[i] == 0)
1000 unsigned int classIndex = classesFusion[fiberData.classMemberships[i]];
1001 for (
unsigned int j = 0;j < p;++j)
1002 afterMergeClassesDirections[classIndex][j] += directions[i][j];
1003 afterMergeNumPoints[classIndex]++;
1006 unsigned int numSplits = 0;
1007 std::vector < std::pair <unsigned int, double> > afterMergeKappaValues;
1009 for (
unsigned int i = 0;i < newNumClasses;++i)
1012 for (
unsigned int j = 0;j < p;++j)
1013 norm += afterMergeClassesDirections[i][j] * afterMergeClassesDirections[i][j];
1017 double kappa = m_KappaSplitThreshold + 1;
1018 if (afterMergeNumPoints[i] != 0)
1020 R = norm / afterMergeNumPoints[i];
1022 if (R*R > 1.0 - 1.0e-16)
1023 R = sqrt(1.0 - 1.0e-16);
1028 afterMergeKappaValues.push_back(std::make_pair(i,kappa));
1032 if ((kappa <= m_KappaSplitThreshold)&&(afterMergeNumPoints[i] >= 2 * m_MinimalNumberOfParticlesPerClass)&&(afterMergeNumPoints[i] != 0))
1035 afterMergeKappaValues[i].second = m_KappaSplitThreshold + 1;
1038 std::partial_sort(afterMergeKappaValues.begin(),afterMergeKappaValues.begin() + numSplits,afterMergeKappaValues.end(),pair_comparator());
1040 for (
unsigned int i = 0;i < numSplits;++i)
1041 splitClasses[afterMergeKappaValues[i].first] =
true;
1046 unsigned int finalNumClasses = newNumClasses + numSplits;
1049 std::vector <MembershipType> newReverseClassesMemberships(finalNumClasses);
1051 ListType newParticleWeights = fiberData.particleWeights;
1052 ListType newClassWeights(finalNumClasses,0);
1054 unsigned int currentIndex = 0;
1060 for (
unsigned int i = 0;i < newNumClasses;++i)
1062 if (!splitClasses[i])
1066 for (
unsigned int j = 0;j < fusedClassesIndexes[i].size();++j)
1068 unsigned int classIndex = fusedClassesIndexes[i][j];
1069 for (
unsigned int k = 0;k < fiberData.reverseClassMemberships[classIndex].size();++k)
1071 unsigned int particleNumber = fiberData.reverseClassMemberships[classIndex][k];
1072 newClassesMemberships[particleNumber] = currentIndex;
1073 newReverseClassesMemberships[currentIndex].push_back(particleNumber);
1077 if (fusedClassesIndexes[i].size() != 1)
1080 newClassWeights[currentIndex] = 0;
1081 for (
unsigned int j = 0;j < fusedClassesIndexes[i].size();++j)
1083 unsigned int classIndex = fusedClassesIndexes[i][j];
1084 for (
unsigned int k = 0;k < fiberData.reverseClassMemberships[classIndex].size();++k)
1085 newClassWeights[currentIndex] += fiberData.classWeights[classIndex] * fiberData.particleWeights[fiberData.reverseClassMemberships[classIndex][k]];
1089 for (
unsigned int j = 0;j < fusedClassesIndexes[i].size();++j)
1091 unsigned int classIndex = fusedClassesIndexes[i][j];
1092 for (
unsigned int k = 0;k < fiberData.reverseClassMemberships[classIndex].size();++k)
1094 unsigned int posIndex = fiberData.reverseClassMemberships[classIndex][k];
1100 newClassWeights[currentIndex] = fiberData.classWeights[fusedClassesIndexes[i][0]];
1108 vectorToCluster.clear();
1111 for (
unsigned int j = 0;j < fusedClassesIndexes[i].size();++j)
1113 unsigned int classIndex = fusedClassesIndexes[i][j];
1114 for (
unsigned int k = 0;k < fiberData.reverseClassMemberships[classIndex].size();++k)
1115 vectorToCluster.push_back(fiberData.fiberParticles[fiberData.reverseClassMemberships[classIndex][k]].back());
1118 clustering.resize(vectorToCluster.size());
1121 bool loopOnClustering =
true;
1122 std::uniform_int_distribution <unsigned int> uniInt(0,1);
1123 while (loopOnClustering)
1125 for (
unsigned int j = 0;j < clustering.size();++j)
1126 clustering[j] = uniInt(random_generator) % 2;
1128 KMeansFilterType kmFilter;
1129 kmFilter.SetInputData(vectorToCluster);
1130 kmFilter.SetNumberOfClasses(2);
1131 kmFilter.InitializeClassesMemberships(clustering);
1132 kmFilter.SetMaxIterations(100);
1133 kmFilter.SetVerbose(
false);
1138 clustering = kmFilter.GetClassesMemberships();
1140 if ((kmFilter.GetNumberPerClass(0) > 0)&&(kmFilter.GetNumberPerClass(1) > 0))
1141 loopOnClustering =
false;
1145 unsigned int newClassIndex = currentIndex + 1;
1147 newClassWeights[currentIndex] = 0;
1148 newClassWeights[newClassIndex] = 0;
1150 unsigned int pos = 0;
1151 for (
unsigned int j = 0;j < fusedClassesIndexes[i].size();++j)
1153 unsigned int classIndex = fusedClassesIndexes[i][j];
1154 for (
unsigned int k = 0;k < fiberData.reverseClassMemberships[classIndex].size();++k)
1156 unsigned int classPos = currentIndex + clustering[pos];
1158 newClassesMemberships[fiberData.reverseClassMemberships[classIndex][k]] = classPos;
1159 newReverseClassesMemberships[classPos].push_back(fiberData.reverseClassMemberships[classIndex][k]);
1161 newClassWeights[classPos] += fiberData.classWeights[classIndex] * fiberData.particleWeights[fiberData.reverseClassMemberships[classIndex][k]];
1169 for (
unsigned int j = 0;j < fusedClassesIndexes[i].size();++j)
1171 unsigned int classIndex = fusedClassesIndexes[i][j];
1172 for (
unsigned int k = 0;k < fiberData.reverseClassMemberships[classIndex].size();++k)
1174 unsigned int classPos = currentIndex + clustering[pos];
1176 unsigned int posIndex = fiberData.reverseClassMemberships[classIndex][k];
1188 for (
unsigned int i = 0;i < finalNumClasses;++i)
1190 if (newReverseClassesMemberships[i].size() > 0)
1191 newClassWeights[i] = std::max(1.0e-16,newClassWeights[i]);
1192 tmpSum += newClassWeights[i];
1195 for (
unsigned int i = 0;i < finalNumClasses;++i)
1196 newClassWeights[i] /= tmpSum;
1198 for (
unsigned int i = 0;i < finalNumClasses;++i)
1199 newClassSizes[i] = newReverseClassesMemberships[i].size();
1202 fiberData.classSizes = newClassSizes;
1203 fiberData.classWeights = newClassWeights;
1204 fiberData.classMemberships = newClassesMemberships;
1205 fiberData.particleWeights = newParticleWeights;
1206 fiberData.reverseClassMemberships = newReverseClassesMemberships;
1208 return finalNumClasses;
1211 template <
class TInputModelImageType>
1216 unsigned int numClasses = fiberData.classSizes.size();
1217 outputMerged.clear();
1218 if (classNumber >= numClasses)
1221 outputMerged.resize(1);
1223 std::vector <unsigned int> runningIndexes, stoppedIndexes;
1224 for (
unsigned int j = 0;j < fiberData.classSizes[classNumber];++j)
1226 if (fiberData.stoppedParticles[fiberData.reverseClassMemberships[classNumber][j]])
1227 stoppedIndexes.push_back(fiberData.reverseClassMemberships[classNumber][j]);
1229 runningIndexes.push_back(fiberData.reverseClassMemberships[classNumber][j]);
1234 unsigned int sizeMerged = 0;
1235 unsigned int p = PointType::GetPointDimension();
1237 double sumWeights = 0;
1238 for (
unsigned int j = 0;j < runningIndexes.size();++j)
1239 sumWeights += fiberData.particleWeights[runningIndexes[j]];
1241 if (runningIndexes.size() != 0)
1244 for (
unsigned int j = 0;j < runningIndexes.size();++j)
1246 double tmpWeight = fiberData.particleWeights[runningIndexes[j]];
1250 tmpFiber = fiberData.fiberParticles[runningIndexes[j]];
1251 for (
unsigned int k = 0;k < tmpFiber.size();++k)
1255 for (
unsigned int l = 0;l < p;++l)
1256 classFiber[k][l] += tmpWeight * tmpFiber[k][l];
1263 classFiber.push_back(tmpFiber[k]);
1264 for (
unsigned int l = 0;l < p;++l)
1265 classFiber[k][l] *= tmpWeight;
1271 for (
unsigned int j = 0;j < sizeMerged;++j)
1273 for (
unsigned int k = 0;k < p;++k)
1274 classFiber[j][k] /= sumWeights;
1277 outputMerged[0] = classFiber;
1282 std::vector < std::vector <unsigned int> > particleGroups;
1283 std::vector <unsigned int> particleSizes;
1284 for (
unsigned int i = 0;i < stoppedIndexes.size();++i)
1286 unsigned int particleSize = fiberData.fiberParticles[stoppedIndexes[i]].size();
1287 bool sizeFound =
false;
1288 for (
unsigned int j = 0;j < particleSizes.size();++j)
1290 if (particleSize == particleSizes[j])
1292 particleGroups[j].push_back(stoppedIndexes[i]);
1300 particleSizes.push_back(particleSize);
1301 std::vector <unsigned int> tmpVec(1,stoppedIndexes[i]);
1302 particleGroups.push_back(tmpVec);
1307 outputMerged.resize(particleGroups.size());
1308 for (
unsigned int i = 0;i < particleGroups.size();++i)
1313 for (
unsigned int j = 0;j < particleGroups[i].size();++j)
1315 tmpFiber = fiberData.fiberParticles[particleGroups[i][j]];
1316 for (
unsigned int k = 0;k < tmpFiber.size();++k)
1320 for (
unsigned int l = 0;l < p;++l)
1321 classFiber[k][l] += tmpFiber[k][l];
1326 classFiber.push_back(tmpFiber[k]);
1331 for (
unsigned int j = 0;j < sizeMerged;++j)
1333 for (
unsigned int k = 0;k < p;++k)
1334 classFiber[j][k] /= particleGroups[i].size();
1337 outputMerged[i] = classFiber;
double ExponentialSum(const VectorType &x, const unsigned int NDimension)
MaskImageType::IndexType IndexType
MaskImageType::PointType PointType
double ComputeEuclideanDistance(const VectorType &x1, const VectorType &x2, const unsigned int NDimension)
std::vector< unsigned int > MembershipType
std::vector< FiberType > FiberProcessVectorType
itk::InterpolateImageFunction< InputModelImageType > InterpolatorType
itk::Vector< ScalarType, 3 > Vector3DType
double ComputePointToSetDistance(const VectorType &x, const std::vector< VectorType > &s)
double ComputeModifiedDirectedHausdorffDistance(const std::vector< VectorType > &s1, const std::vector< VectorType > &s2)
double safe_log(const ScalarType x)
void ThreadTrack(unsigned int numThread, FiberProcessVectorType &resultFibers, ListType &resultWeights)
Doing the thread work dispatch.
std::vector< ScalarType > ListType
itk::VariableLengthVector< ScalarType > VectorType
std::vector< PointType > FiberType
InterpolatorType::ContinuousIndexType ContinuousIndexType
InterpolatorType::Pointer InterpolatorPointer
std::vector< Vector3DType > DirectionVectorType