ANIMA  4.0
animaLTSWTransformAgregator.hxx
Go to the documentation of this file.
1 #pragma once
2 
6 
7 #include <algorithm>
8 #include <itkMacro.h>
9 
10 namespace anima
11 {
12 
13 template <unsigned int NDimensions>
14 LTSWTransformAgregator <NDimensions>::
15 LTSWTransformAgregator() : Superclass()
16 {
17  m_LTSCut = 0.5;
18  m_StoppingThreshold = 1.0e-2;
19  m_EstimationBarycenter.Fill(0);
20 }
21 
22 template <unsigned int NDimensions>
26 {
27  return m_EstimationBarycenter;
28 }
29 
30 template <unsigned int NDimensions>
31 bool
34 {
35  this->SetUpToDate(false);
36  bool returnValue = false;
37 
38  if (this->GetInputWeights().size() != this->GetInputTransforms().size())
39  return false;
40 
41  switch (this->GetInputTransformType())
42  {
45  if ((this->GetInputWeights().size() != this->GetInputOrigins().size())||
46  (this->GetInputTransforms().size() != this->GetInputOrigins().size()))
47  return false;
48 
49  returnValue = this->ltswEstimateTranslationsToAny();
50  this->SetUpToDate(returnValue);
51  return returnValue;
52 
53  case Superclass::RIGID:
55  case Superclass::AFFINE:
56  returnValue = this->ltswEstimateAnyToAffine();
57  return returnValue;
58 
59  default:
60  throw itk::ExceptionObject(__FILE__, __LINE__,"Specific LTSW agregation not handled yet...",ITK_LOCATION);
61  }
62 }
63 
64 template <unsigned int NDimensions>
65 bool
67 ltswEstimateTranslationsToAny()
68 {
69  unsigned int nbPts = this->GetInputOrigins().size();
70 
71  if (NDimensions > 3)
72  throw itk::ExceptionObject(__FILE__, __LINE__,"Dimension not supported",ITK_LOCATION);
73 
74  std::vector <PointType> originPoints(nbPts);
75  std::vector <PointType> transformedPoints(nbPts);
76  std::vector <double> weights = this->GetInputWeights();
77 
78  BaseInputTransformType * currTrsf = 0;
80  currTrsf = this->GetCurrentLinearTransform();
81 
82  for (unsigned int i = 0; i < nbPts; ++i)
83  {
84  PointType tmpOrig = this->GetInputOrigin(i);
85  BaseInputTransformType * tmpTrsf = this->GetInputTransform(i);
86  PointType tmpDisp = tmpTrsf->TransformPoint(tmpOrig);
87  originPoints[i] = tmpOrig;
89  transformedPoints[i] = currTrsf->TransformPoint(tmpDisp);
90  else
91  transformedPoints[i] = tmpDisp;
92  }
93 
94  vnl_matrix <ScalarType> covPcaOriginPoints(NDimensions, NDimensions, 0);
96  {
97  itk::Matrix<ScalarType, NDimensions, NDimensions> emptyMatrix;
98  emptyMatrix.Fill(0);
99 
100  if (this->GetOrthogonalDirectionMatrix() != emptyMatrix)
101  {
102  covPcaOriginPoints = this->GetOrthogonalDirectionMatrix().GetVnlMatrix().as_matrix();
103  }
104  else
105  {
106  itk::Point <ScalarType, NDimensions> unweightedBarX;
107  vnl_matrix <ScalarType> covOriginPoints(NDimensions, NDimensions, 0);
108  for (unsigned int i = 0; i < nbPts; ++i)
109  {
110  for (unsigned int j = 0; j < NDimensions; ++j)
111  unweightedBarX[j] += originPoints[i][j] / nbPts;
112  }
113  for (unsigned int i = 0; i < nbPts; ++i)
114  {
115  for (unsigned int j = 0; j < NDimensions; ++j)
116  {
117  for (unsigned int k = 0; k < NDimensions; ++k)
118  covOriginPoints(j, k) += (originPoints[i][j] - unweightedBarX[j])*(originPoints[i][k] - unweightedBarX[k]);
119  }
120  }
121  itk::SymmetricEigenAnalysis < vnl_matrix <ScalarType>, vnl_diag_matrix<ScalarType>, vnl_matrix <ScalarType> > eigenSystem(3);
122  vnl_diag_matrix <double> eValsCov(NDimensions);
123  eigenSystem.SetOrderEigenValues(true);
124  eigenSystem.ComputeEigenValuesAndVectors(covOriginPoints, eValsCov, covPcaOriginPoints);
125  /* return eigen vectors in row !!!!!!! */
126  covPcaOriginPoints = covPcaOriginPoints.transpose();
127  if (vnl_determinant(covPcaOriginPoints) < 0)
128  covPcaOriginPoints *= -1.0;
129  }
130  }
131 
132  std::vector <PointType> originPointsFiltered = originPoints;
133  std::vector <PointType> transformedPointsFiltered = transformedPoints;
134  std::vector <double> weightsFiltered = weights;
135 
136  std::vector < std::pair <unsigned int, double> > residualErrors;
137  PointType tmpOutPoint;
138  itk::Vector <double,3> tmpDiff;
139 
140  bool continueLoop = true;
141  unsigned int numMaxIter = 100;
142  unsigned int num_itr = 0;
143 
144  typename BaseOutputTransformType::Pointer resultTransform, resultTransformOld;
145 
146  while(num_itr < numMaxIter)
147  {
148  ++num_itr;
149 
150  switch (this->GetOutputTransformType())
151  {
153  anima::computeTranslationLSWFromTranslations<InternalScalarType,ScalarType,NDimensions>
154  (originPointsFiltered,transformedPointsFiltered,weightsFiltered,resultTransform);
155  break;
156 
157  case Superclass::RIGID:
158  anima::computeRigidLSWFromTranslations<InternalScalarType,ScalarType,NDimensions>
159  (originPointsFiltered,transformedPointsFiltered,weightsFiltered,resultTransform);
160  break;
161 
163  m_EstimationBarycenter = anima::computeAnisotropSimLSWFromTranslations<InternalScalarType, ScalarType, NDimensions>
164  (originPointsFiltered, transformedPointsFiltered, weightsFiltered, resultTransform, covPcaOriginPoints);
165  break;
166 
167  case Superclass::AFFINE:
168  m_EstimationBarycenter = anima::computeAffineLSWFromTranslations<InternalScalarType,ScalarType,NDimensions>
169  (originPointsFiltered,transformedPointsFiltered,weightsFiltered,resultTransform);
170  break;
171 
172  default:
173  throw itk::ExceptionObject(__FILE__, __LINE__,"Not implemented yet...",ITK_LOCATION);
174  }
175 
176  continueLoop = endLTSCondition(resultTransformOld,resultTransform);
177 
178  if (!continueLoop)
179  break;
180 
181  resultTransformOld = resultTransform;
182  residualErrors.clear();
183  for (unsigned int i = 0;i < nbPts;++i)
184  {
185  if (weights[i] <= 0)
186  continue;
187 
188  tmpOutPoint = resultTransform->TransformPoint(originPoints[i]);
189  tmpDiff = tmpOutPoint - transformedPoints[i];
190  residualErrors.push_back(std::make_pair (i,tmpDiff.GetNorm()));
191  }
192 
193  unsigned int numLts = floor(residualErrors.size() * m_LTSCut);
194 
195  std::vector < std::pair <unsigned int, double> >::iterator begIt = residualErrors.begin();
196  std::vector < std::pair <unsigned int, double> >::iterator sortPart = begIt + numLts;
197 
198  std::partial_sort(begIt,sortPart,residualErrors.end(),anima::errors_pair_comparator());
199 
200  originPointsFiltered.resize(numLts);
201  transformedPointsFiltered.resize(numLts);
202  weightsFiltered.resize(numLts);
203 
204  for (unsigned int i = 0;i < numLts;++i)
205  {
206  originPointsFiltered[i] = originPoints[residualErrors[i].first];
207  transformedPointsFiltered[i] = transformedPoints[residualErrors[i].first];
208  weightsFiltered[i] = weights[residualErrors[i].first];
209  }
210  }
211 
212  this->SetOutput(resultTransform);
213  return true;
214 }
215 
216 template <unsigned int NDimensions>
217 bool
219 ltswEstimateAnyToAffine()
220 {
222  throw itk::ExceptionObject(__FILE__, __LINE__,"Agregation from affine transforms to rigid is not supported yet...",ITK_LOCATION);
223 
224  typedef itk::MatrixOffsetTransformBase <InternalScalarType, NDimensions> BaseMatrixTransformType;
225  typedef anima::LogRigid3DTransform <InternalScalarType> LogRigidTransformType;
226 
227  unsigned int nbPts = this->GetInputTransforms().size();
228  std::vector <InternalScalarType> weights = this->GetInputWeights();
229 
230  std::vector < vnl_matrix <InternalScalarType> > logTransformations(nbPts);
231  vnl_matrix <InternalScalarType> tmpMatrix(NDimensions+1,NDimensions+1,0), tmpLogMatrix(NDimensions+1,NDimensions+1,0);
232  tmpMatrix(NDimensions,NDimensions) = 1;
233  typename BaseMatrixTransformType::MatrixType affinePart;
234  itk::Vector <InternalScalarType, NDimensions> offsetPart;
235 
236  for (unsigned int i = 0;i < nbPts;++i)
237  {
239  {
240  BaseMatrixTransformType *tmpTrsf = (BaseMatrixTransformType *)this->GetInputTransform(i);
241  affinePart = tmpTrsf->GetMatrix();
242  offsetPart = tmpTrsf->GetOffset();
243 
244  for (unsigned int j = 0;j < NDimensions;++j)
245  {
246  tmpMatrix(j,NDimensions) = offsetPart[j];
247  for (unsigned int k = 0;k < NDimensions;++k)
248  tmpMatrix(j,k) = affinePart(j,k);
249  }
250 
251  logTransformations[i] = anima::GetLogarithm(tmpMatrix);
252  if (!std::isfinite(logTransformations[i](0,0)))
253  {
254  logTransformations[i].fill(0);
255  this->SetInputWeight(i,0);
256  }
257  }
258  else
259  {
260  LogRigidTransformType *tmpTrsf = (LogRigidTransformType *)this->GetInputTransform(i);
261  logTransformations[i] = tmpTrsf->GetLogTransform();
262  }
263  }
264 
265  std::vector < vnl_matrix <InternalScalarType> > logTransformationsFiltered = logTransformations;
266  std::vector <InternalScalarType> weightsFiltered = weights;
267 
268  // For LTS
269  std::vector < PointType > originPoints(nbPts);
270  std::vector < PointType > transformedPoints(nbPts);
271 
272  for (unsigned int i = 0;i < nbPts;++i)
273  {
274  PointType tmpOrig = this->GetInputOrigin(i);
275  BaseInputTransformType * tmpTrsf = this->GetInputTransform(i);
276  PointType tmpDisp = tmpTrsf->TransformPoint(tmpOrig);
277  originPoints[i] = tmpOrig;
278  transformedPoints[i] = tmpDisp;
279  }
280 
281  std::vector < std::pair <unsigned int, double> > residualErrors;
282 
283  bool continueLoop = true;
284  unsigned int numMaxIter = 100;
285  unsigned int num_itr = 0;
286 
287  typename BaseOutputTransformType::Pointer resultTransform, resultTransformOld;
288 
289  while(num_itr < numMaxIter)
290  {
291  ++num_itr;
292 
293  anima::computeLogEuclideanAverage<InternalScalarType,ScalarType,NDimensions>(logTransformationsFiltered,weightsFiltered,resultTransform);
294  continueLoop = endLTSCondition(resultTransformOld,resultTransform);
295 
296  if (!continueLoop)
297  break;
298 
299  resultTransformOld = resultTransform;
300  residualErrors.clear();
301 
302  BaseMatrixTransformType *tmpTrsf = (BaseMatrixTransformType *)resultTransform.GetPointer();
303 
304  for (unsigned int i = 0;i < nbPts;++i)
305  {
306  if (weights[i] <= 0)
307  continue;
308 
309  double tmpDiff = 0;
310  PointType tmpDisp = tmpTrsf->TransformPoint(originPoints[i]);
311 
312  for (unsigned int j = 0;j < NDimensions;++j)
313  tmpDiff += (transformedPoints[i][j] - tmpDisp[j]) * (transformedPoints[i][j] - tmpDisp[j]);
314 
315  residualErrors.push_back(std::make_pair (i,tmpDiff));
316  }
317 
318  unsigned int numLts = floor(residualErrors.size() * m_LTSCut);
319 
320  std::vector < std::pair <unsigned int, double> >::iterator begIt = residualErrors.begin();
321  std::vector < std::pair <unsigned int, double> >::iterator sortPart = begIt + numLts;
322 
323  std::partial_sort(begIt,sortPart,residualErrors.end(),anima::errors_pair_comparator());
324 
325  logTransformationsFiltered.resize(numLts);
326  weightsFiltered.resize(numLts);
327 
328  for (unsigned int i = 0;i < numLts;++i)
329  {
330  logTransformationsFiltered[i] = logTransformations[residualErrors[i].first];
331  weightsFiltered[i] = weights[residualErrors[i].first];
332  }
333  }
334 
335  this->SetOutput(resultTransform);
336  return true;
337 }
338 
339 template <unsigned int NDimensions>
340 bool
342 endLTSCondition(BaseOutputTransformType *oldTrsf, BaseOutputTransformType *newTrsf)
343 {
344  if (oldTrsf == NULL)
345  return true;
346 
347  typename BaseOutputTransformType::ParametersType oldParams = oldTrsf->GetParameters();
348  typename BaseOutputTransformType::ParametersType newParams = newTrsf->GetParameters();
349 
350  for (unsigned int i = 0;i < newParams.GetSize();++i)
351  {
352  double diffParam = fabs(newParams[i] - oldParams[i]);
353  if (diffParam > m_StoppingThreshold)
354  return true;
355  }
356 
357  return false;
358 }
359 
360 } // end of namespace anima
itk::Point< InternalScalarType, NDimensions > PointType
BaseInputTransformType * GetInputTransform(unsigned int i)
std::vector< BaseInputTransformPointer > & GetInputTransforms()
BaseInputTransformPointer & GetCurrentLinearTransform()
std::vector< InternalScalarType > & GetInputWeights()
PointType & GetInputOrigin(unsigned int i)
itk::Transform< ScalarType, NDimensions, NDimensions > BaseOutputTransformType
void SetInputWeight(unsigned int i, double w)
vnl_matrix< T > GetLogarithm(const vnl_matrix< T > &m, const double precision=0.00000000001, const int numApprox=1)
Computation of the matrix logarithm. Algo: inverse scaling and squaring, variant proposed by Cheng et...
PointType GetEstimationBarycenter() ITK_OVERRIDE
void SetOutput(BaseOutputTransformType *output)
std::vector< PointType > & GetInputOrigins()
itk::Transform< InternalScalarType, NDimensions, NDimensions > BaseInputTransformType