ANIMA  4.0
animaBaseProbabilisticTractographyImageFilter.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <itkVectorImage.h>
4 #include <itkImage.h>
5 
6 #include <vtkPolyData.h>
7 #include <vtkSmartPointer.h>
8 #include <itkProcessObject.h>
9 #include <itkLinearInterpolateImageFunction.h>
10 #include <mutex>
11 #include <itkProgressReporter.h>
12 
13 #include <vector>
14 #include <random>
15 
16 namespace anima
17 {
18 
19 template <class TInputModelImageType>
20 class BaseProbabilisticTractographyImageFilter : public itk::ProcessObject
21 {
22 public:
25  typedef itk::ProcessObject Superclass;
26 
27  typedef itk::SmartPointer<Self> Pointer;
28  typedef itk::SmartPointer<const Self> ConstPointer;
29 
30  itkTypeMacro(BaseProbabilisticTractographyImageFilter,itk::ProcessObject)
31 
32  // Typdefs for scalar types for reading/writing images and for math operations
33  typedef double ScalarType;
34 
35  // Typdef for input model image
36  typedef TInputModelImageType InputModelImageType;
37  typedef typename InputModelImageType::Pointer InputModelImagePointer;
38 
39  // Typedefs for B0 and noise images
40  typedef itk::Image <ScalarType, 3> ScalarImageType;
41  typedef typename ScalarImageType::Pointer ScalarImagePointer;
42  typedef itk::LinearInterpolateImageFunction <ScalarImageType> ScalarInterpolatorType;
43  typedef typename ScalarInterpolatorType::Pointer ScalarInterpolatorPointer;
44 
45  // Typedefs for input mask image
46  typedef itk::Image <unsigned short, 3> MaskImageType;
47  typedef MaskImageType::Pointer MaskImagePointer;
48  typedef MaskImageType::PointType PointType;
49  typedef MaskImageType::IndexType IndexType;
50 
51  // Typedefs for vectors and matrices
52  typedef itk::Matrix <ScalarType,3,3> Matrix3DType;
53  typedef itk::Vector <ScalarType,3> Vector3DType;
54  typedef itk::VariableLengthVector <ScalarType> VectorType;
55  typedef std::vector <ScalarType> ListType;
56  typedef std::vector <Vector3DType> DirectionVectorType;
57 
58  // Typedefs for model images interpolator
59  typedef itk::InterpolateImageFunction <InputModelImageType> InterpolatorType;
60  typedef typename InterpolatorType::Pointer InterpolatorPointer;
61  typedef typename InterpolatorType::ContinuousIndexType ContinuousIndexType;
62 
63  // Typdefs for fibers
64  typedef std::vector <PointType> FiberType;
65  typedef std::vector <FiberType> FiberProcessVectorType;
66  typedef std::vector <unsigned int> MembershipType;
67 
68  typedef struct {
70  std::vector <FiberProcessVectorType> resultFibersFromThreads;
71  std::vector <ListType> resultWeightsFromThreads;
73 
75  {
76  bool operator() (const std::pair<unsigned int, double> & f, const std::pair<unsigned int, double> & s)
77  { return (f.second < s.second); }
78  };
79 
89  {
90  Center = 0,
92  Top,
98  };
99 
106  {
107  Colinear = 0,
109  };
110 
112  {
115  std::vector <MembershipType> reverseClassMemberships;
119  std::vector <bool> stoppedParticles;
120  };
121 
122  void SetInitialColinearityDirection(const ColinearityDirectionType &colDir) {m_InitialColinearityDirection = colDir;}
123  void SetInitialDirectionMode(const InitialDirectionModeType &dir) {m_InitialDirectionMode = dir;}
124  itkGetMacro(InitialDirectionMode,InitialDirectionModeType)
125 
126  virtual void SetInputModelImage(InputModelImageType *inImage) {m_InputModelImage = inImage;}
127  InputModelImageType *GetInputModelImage() {return m_InputModelImage;}
129 
130  itkSetObjectMacro(SeedMask,MaskImageType)
131  itkSetObjectMacro(FilterMask,MaskImageType)
132  itkSetObjectMacro(CutMask,MaskImageType)
133  itkSetObjectMacro(ForbiddenMask,MaskImageType)
134 
135  itkSetObjectMacro(B0Image,ScalarImageType)
136  itkSetObjectMacro(NoiseImage,ScalarImageType)
137 
138  itkSetMacro(NumberOfParticles,unsigned int)
139  itkSetMacro(NumberOfFibersPerPixel,unsigned int)
140  itkSetMacro(ResamplingThreshold,double)
141 
142  itkSetMacro(StepProgression,double)
143 
144  itkSetMacro(MinLengthFiber,double)
145  itkSetMacro(MaxLengthFiber,double)
146 
147  itkSetMacro(FiberTrashThreshold,double)
148 
149  itkSetMacro(KappaOfPriorDistribution,double)
150  itkGetMacro(KappaOfPriorDistribution,double)
151 
152  itkSetMacro(PositionDistanceFuseThreshold,double)
153  itkSetMacro(KappaSplitThreshold,double)
154 
155  itkSetMacro(ClusterDistance,unsigned int)
156 
157  itkSetMacro(ComputeLocalColors,bool)
158  itkSetMacro(MAPMergeFibers,bool)
159 
160  itkSetMacro(MinimalNumberOfParticlesPerClass,unsigned int)
161 
162  itkSetMacro(ModelDimension, unsigned int)
163  itkGetMacro(ModelDimension, unsigned int)
164 
165  void Update() ITK_OVERRIDE;
166 
167  void createVTKOutput(FiberProcessVectorType &filteredFibers, ListType &filteredWeights);
168  vtkPolyData *GetOutput() {return m_Output;}
169 
170 protected:
173 
175  static ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION ThreadTracker(void *arg);
176 
178  void ThreadTrack(unsigned int numThread, FiberProcessVectorType &resultFibers, ListType &resultWeights);
179 
181  void ThreadedTrackComputer(unsigned int numThread, FiberProcessVectorType &resultFibers,
182  ListType &resultWeights, unsigned int startSeedIndex,
183  unsigned int endSeedIndex);
184 
187  unsigned int numThread, ListType &resultWeights);
188 
190  virtual void PrepareTractography();
191 
193  unsigned int UpdateClassesMemberships(FiberWorkType &fiberData, DirectionVectorType &directions, std::mt19937 &random_generator);
194 
196  // Returns in outputMerged several fibers, as of now if there are active particles it returns only the merge of those, and returns true.
197  // Otherwise, returns false and a merge per stopped fiber lengths
198  bool MergeParticleClassFibers(FiberWorkType &fiberData, FiberProcessVectorType &outputMerged, unsigned int classNumber);
199 
202 
204  virtual Vector3DType ProposeNewDirection(Vector3DType &oldDirection, VectorType &modelValue,
205  Vector3DType &sampling_direction, double &log_prior, double &log_proposal,
206  std::mt19937 &random_generator, unsigned int threadId) = 0;
207 
209  virtual double ComputeLogWeightUpdate(double b0Value, double noiseValue, Vector3DType &newDirection, VectorType &modelValue,
210  double &log_prior, double &log_proposal, unsigned int threadId) = 0;
211 
213  virtual void ComputeModelValue(InterpolatorPointer &modelInterpolator, ContinuousIndexType &index, VectorType &modelValue) = 0;
214 
216  virtual Vector3DType InitializeFirstIterationFromModel(Vector3DType &colinearDir, VectorType &modelValue, unsigned int threadId) = 0;
217 
219  virtual bool CheckModelProperties(double estimatedB0Value, double estimatedNoiseValue, VectorType &modelValue, unsigned int threadId) = 0;
220 
222  virtual void ComputeAdditionalScalarMaps() {}
223 
224 private:
225  ITK_DISALLOW_COPY_AND_ASSIGN(BaseProbabilisticTractographyImageFilter);
226 
227  //Internal variable for model vector dimension, has to be set by child class !
228  unsigned int m_ModelDimension;
229 
230  unsigned int m_NumberOfParticles;
231  unsigned int m_NumberOfFibersPerPixel;
232  unsigned int m_MinimalNumberOfParticlesPerClass;
233 
234  double m_StepProgression;
235 
236  double m_MinLengthFiber;
237  double m_MaxLengthFiber;
238 
239  double m_FiberTrashThreshold;
240 
241  double m_ResamplingThreshold;
242 
243  double m_KappaOfPriorDistribution;
244 
245  InputModelImagePointer m_InputModelImage;
246 
247  MaskImagePointer m_SeedMask;
248  MaskImagePointer m_FilterMask;
249  MaskImagePointer m_CutMask;
250  MaskImagePointer m_ForbiddenMask;
251 
252  ScalarImagePointer m_B0Image, m_NoiseImage;
253  ScalarInterpolatorPointer m_B0Interpolator, m_NoiseInterpolator;
254 
255  std::vector <std::mt19937> m_Generators;
256 
257  ColinearityDirectionType m_InitialColinearityDirection;
258  InitialDirectionModeType m_InitialDirectionMode;
259  Vector3DType m_DWIGravityCenter;
260 
261  FiberProcessVectorType m_PointsToProcess;
262  MembershipType m_FilteringValues;
263 
264  // Multimodal splitting and merging thresholds
265  double m_PositionDistanceFuseThreshold;
266  double m_KappaSplitThreshold;
267 
268  unsigned int m_ClusterDistance;
269 
270  bool m_MAPMergeFibers;
271  bool m_ComputeLocalColors;
272 
273  vtkSmartPointer<vtkPolyData> m_Output;
274 
275  std::mutex m_LockHighestProcessedSeed;
276  int m_HighestProcessedSeed;
277  itk::ProgressReporter *m_ProgressReport;
278 };
279 
280 }//end of namesapce
281 
virtual Vector3DType InitializeFirstIterationFromModel(Vector3DType &colinearDir, VectorType &modelValue, unsigned int threadId)=0
Initialize first direction from user input (model dependent, not implemented here) ...
void createVTKOutput(FiberProcessVectorType &filteredFibers, ListType &filteredWeights)
FiberProcessVectorType ComputeFiber(FiberType &fiber, InterpolatorPointer &modelInterpolator, unsigned int numThread, ListType &resultWeights)
This little guy is the one handling probabilistic tracking.
virtual bool CheckModelProperties(double estimatedB0Value, double estimatedNoiseValue, VectorType &modelValue, unsigned int threadId)=0
Check stopping criterions to stop a particle (model dependent, not implemented here) ...
virtual void ComputeModelValue(InterpolatorPointer &modelInterpolator, ContinuousIndexType &index, VectorType &modelValue)=0
Estimate model from raw diffusion data (model dependent, not implemented here)
void ThreadedTrackComputer(unsigned int numThread, FiberProcessVectorType &resultFibers, ListType &resultWeights, unsigned int startSeedIndex, unsigned int endSeedIndex)
Doing the real tracking by calling ComputeFiber and merging its results.
STL namespace.
bool operator()(const std::pair< unsigned int, double > &f, const std::pair< unsigned int, double > &s)
FiberProcessVectorType FilterOutputFibers(FiberProcessVectorType &fibers, ListType &weights)
Filter output fibers by ROIs and compute local colors.
virtual double ComputeLogWeightUpdate(double b0Value, double noiseValue, Vector3DType &newDirection, VectorType &modelValue, double &log_prior, double &log_proposal, unsigned int threadId)=0
Update particle weight based on an underlying model and the chosen direction (model dependent...
bool MergeParticleClassFibers(FiberWorkType &fiberData, FiberProcessVectorType &outputMerged, unsigned int classNumber)
This guy takes the result of computefiber and merges the classes, each one becomes one fiber...
unsigned int UpdateClassesMemberships(FiberWorkType &fiberData, DirectionVectorType &directions, std::mt19937 &random_generator)
This ugly guy is the heart of multi-modal probabilistic tractography, making decisions on split and m...
void SetInitialColinearityDirection(const ColinearityDirectionType &colDir)
static ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION ThreadTracker(void *arg)
Multithread util function.
ColinearityDirectionType
Which direction should the very first direction point to? (used in conjunction with InitialDirectionM...
void ThreadTrack(unsigned int numThread, FiberProcessVectorType &resultFibers, ListType &resultWeights)
Doing the thread work dispatch.
virtual Vector3DType ProposeNewDirection(Vector3DType &oldDirection, VectorType &modelValue, Vector3DType &sampling_direction, double &log_prior, double &log_proposal, std::mt19937 &random_generator, unsigned int threadId)=0
Propose new direction for a particle, given the old direction, and a model (model dependent...
virtual void PrepareTractography()
Generate seed points (can be re-implemented but this one has to be called)
InitialDirectionModeType
Tells how to choose the very first direction of each particle Colinear: Most colinear to colinear dir...
virtual void ComputeAdditionalScalarMaps()
Computes additional scalar maps that are model dependent to add to the output.