ANIMA  4.0
animaMEstTransformAgregator.hxx
Go to the documentation of this file.
1 #pragma once
2 
6 #include <algorithm>
7 
8 namespace anima
9 {
10 
11 template <unsigned int NDimensions>
12 MEstTransformAgregator <NDimensions>::
13 MEstTransformAgregator() : Superclass()
14 {
15  m_MEstimateFactor = 0.5;
16  m_StoppingThreshold = 1.0e-2;
17  m_EstimationBarycenter.Fill(0);
18 }
19 
20 template <unsigned int NDimensions>
24 {
25  return m_EstimationBarycenter;
26 }
27 
28 template <unsigned int NDimensions>
29 bool
32 {
33  this->SetUpToDate(false);
34  bool returnValue = false;
35 
36  if (this->GetInputWeights().size() != this->GetInputTransforms().size())
37  return false;
38 
39  switch (this->GetInputTransformType())
40  {
43  if ((this->GetInputWeights().size() != this->GetInputOrigins().size())||
44  (this->GetInputTransforms().size() != this->GetInputOrigins().size()))
45  return false;
46 
47  returnValue = this->mestEstimateTranslationsToAny();
48  this->SetUpToDate(returnValue);
49  return returnValue;
50 
51  case Superclass::RIGID:
52  case Superclass::AFFINE:
53  returnValue = this->mestEstimateAnyToAffine();
54  return returnValue;
55 
56  default:
57  throw itk::ExceptionObject(__FILE__, __LINE__,"Specific M-estimation agregation not handled yet...",ITK_LOCATION);
58  return false;
59  }
60 }
61 
62 template <unsigned int NDimensions>
63 bool
65 mestEstimateTranslationsToAny()
66 {
67  unsigned int nbPts = this->GetInputOrigins().size();
68 
69  if (NDimensions > 3)
70  throw itk::ExceptionObject(__FILE__, __LINE__,"Dimension not supported",ITK_LOCATION);
71 
72  std::vector <PointType> originPoints(nbPts);
73  std::vector <PointType> transformedPoints(nbPts);
74  std::vector <double> weights = this->GetInputWeights();
75 
76  BaseInputTransformType * currTrsf = 0;
78  currTrsf = this->GetCurrentLinearTransform();
79 
80  for (unsigned int i = 0;i < nbPts;++i)
81  {
82  PointType tmpOrig = this->GetInputOrigin(i);
83  BaseInputTransformType * tmpTrsf = this->GetInputTransform(i);
84  PointType tmpDisp = tmpTrsf->TransformPoint(tmpOrig);
85  originPoints[i] = tmpOrig;
87  transformedPoints[i] = currTrsf->TransformPoint(tmpDisp);
88  else
89  transformedPoints[i] = tmpDisp;
90  }
91 
92  vnl_matrix <ScalarType> covPcaOriginPoints(NDimensions, NDimensions, 0);
94  {
95  itk::Matrix<ScalarType, NDimensions, NDimensions> emptyMatrix;
96  emptyMatrix.Fill(0);
97 
98  if (this->GetOrthogonalDirectionMatrix() != emptyMatrix)
99  {
100  covPcaOriginPoints = this->GetOrthogonalDirectionMatrix().GetVnlMatrix().as_matrix();
101  }
102  else
103  {
104  itk::Point <ScalarType, NDimensions> unweightedBarX;
105  vnl_matrix <ScalarType> covOriginPoints(NDimensions, NDimensions, 0);
106  for (unsigned int i = 0; i < nbPts; ++i)
107  {
108  for (unsigned int j = 0; j < NDimensions; ++j)
109  unweightedBarX[j] += originPoints[i][j] / nbPts;
110  }
111  for (unsigned int i = 0; i < nbPts; ++i)
112  {
113  for (unsigned int j = 0; j < NDimensions; ++j)
114  {
115  for (unsigned int k = 0; k < NDimensions; ++k)
116  covOriginPoints(j, k) += (originPoints[i][j] - unweightedBarX[j])*(originPoints[i][k] - unweightedBarX[k]);
117  }
118  }
119  itk::SymmetricEigenAnalysis < vnl_matrix <ScalarType>, vnl_diag_matrix<ScalarType>, vnl_matrix <ScalarType> > eigenSystem(3);
120  vnl_diag_matrix <double> eValsCov(NDimensions);
121  eigenSystem.SetOrderEigenValues(true);
122  eigenSystem.ComputeEigenValuesAndVectors(covOriginPoints, eValsCov, covPcaOriginPoints);
123  /* return eigen vectors in row !!!!!!! */
124  covPcaOriginPoints = covPcaOriginPoints.transpose();
125  if (vnl_determinant(covPcaOriginPoints) < 0)
126  covPcaOriginPoints *= -1.0;
127  }
128  }
129 
130  std::vector <double> weightsFiltered = weights;
131 
132  std::vector < double > residualErrors;
133  std::vector < double > mestWeights(nbPts,1);
134 
135  PointType tmpOutPoint;
136  itk::Vector <double,3> tmpDiff;
137 
138  bool continueLoop = true;
139  unsigned int numMaxIter = 100;
140  unsigned int num_itr = 0;
141  double averageResidualValue = 1;
142 
143  typename BaseOutputTransformType::Pointer resultTransform, resultTransformOld;
144 
145  while(num_itr < numMaxIter)
146  {
147  ++num_itr;
148 
149  switch (this->GetOutputTransformType())
150  {
152  anima::computeTranslationLSWFromTranslations<InternalScalarType,ScalarType,NDimensions>
153  (originPoints,transformedPoints,weightsFiltered,resultTransform);
154  break;
155 
156  case Superclass::RIGID:
157  anima::computeRigidLSWFromTranslations<InternalScalarType,ScalarType,NDimensions>
158  (originPoints,transformedPoints,weightsFiltered,resultTransform);
159  break;
160 
162  m_EstimationBarycenter = anima::computeAnisotropSimLSWFromTranslations<InternalScalarType, ScalarType, NDimensions>
163  (originPoints, transformedPoints, weightsFiltered, resultTransform, covPcaOriginPoints);
164  break;
165 
166  case Superclass::AFFINE:
167  m_EstimationBarycenter = anima::computeAffineLSWFromTranslations<InternalScalarType,ScalarType,NDimensions>
168  (originPoints,transformedPoints,weightsFiltered,resultTransform);
169  break;
170 
171  default:
172  throw itk::ExceptionObject(__FILE__, __LINE__,"Not implemented yet...",ITK_LOCATION);
173  return false;
174  }
175 
176  continueLoop = endLTSCondition(resultTransformOld,resultTransform);
177 
178  if (!continueLoop)
179  break;
180 
181  resultTransformOld = resultTransform;
182  residualErrors.clear();
183  for (unsigned int i = 0;i < nbPts;++i)
184  {
185  if (weights[i] <= 0)
186  continue;
187 
188  tmpOutPoint = resultTransform->TransformPoint(originPoints[i]);
189  tmpDiff = tmpOutPoint - transformedPoints[i];
190  double tmpRes = tmpDiff.GetNorm();
191  residualErrors.push_back(tmpRes * tmpRes);
192  }
193 
194  if (num_itr == 1)
195  {
196  // At first iteration, compute factor for M-estimation
197  double averageDist = 0;
198  for (unsigned int i = 0;i < residualErrors.size();++i)
199  averageDist += residualErrors[i];
200 
201  averageResidualValue = averageDist / residualErrors.size();
202 
203  if (averageResidualValue <= 0)
204  averageResidualValue = 1;
205  }
206 
207  unsigned int residualIndex = 0;
208  for (unsigned int i = 0;i < nbPts;++i)
209  {
210  if (weights[i] <= 0)
211  continue;
212 
213  mestWeights[i] = exp(- residualErrors[residualIndex] / (averageResidualValue * m_MEstimateFactor));
214  ++residualIndex;
215  }
216 
217  for (unsigned int i = 0;i < nbPts;++i)
218  weightsFiltered[i] = weights[i] * mestWeights[i];
219  }
220 
221  this->SetOutput(resultTransform);
222  return true;
223 }
224 
225 template <unsigned int NDimensions>
226 bool
228 mestEstimateAnyToAffine()
229 {
231  throw itk::ExceptionObject(__FILE__, __LINE__,"Agregation from affine transforms to rigid is not supported yet...",ITK_LOCATION);
232 
233  typedef itk::MatrixOffsetTransformBase <InternalScalarType, NDimensions> BaseMatrixTransformType;
234  typedef anima::LogRigid3DTransform <InternalScalarType> LogRigidTransformType;
235 
236  unsigned int nbPts = this->GetInputTransforms().size();
237  std::vector <InternalScalarType> weights = this->GetInputWeights();
238 
239  std::vector < vnl_matrix <InternalScalarType> > logTransformations(nbPts);
240  vnl_matrix <InternalScalarType> tmpMatrix(NDimensions+1,NDimensions+1,0), tmpLogMatrix(NDimensions+1,NDimensions+1,0);
241  tmpMatrix(NDimensions,NDimensions) = 1;
242  typename BaseMatrixTransformType::MatrixType affinePart;
243  itk::Vector <InternalScalarType, NDimensions> offsetPart;
244 
245  for (unsigned int i = 0;i < nbPts;++i)
246  {
248  {
249  BaseMatrixTransformType *tmpTrsf = (BaseMatrixTransformType *)this->GetInputTransform(i);
250  affinePart = tmpTrsf->GetMatrix();
251  offsetPart = tmpTrsf->GetOffset();
252 
253  for (unsigned int j = 0;j < NDimensions;++j)
254  {
255  tmpMatrix(j,NDimensions) = offsetPart[j];
256  for (unsigned int k = 0;k < NDimensions;++k)
257  tmpMatrix(j,k) = affinePart(j,k);
258  }
259 
260  logTransformations[i] = anima::GetLogarithm(tmpMatrix);
261  if (!std::isfinite(logTransformations[i](0,0)))
262  {
263  logTransformations[i].fill(0);
264  this->SetInputWeight(i,0);
265  }
266  }
267  else
268  {
269  LogRigidTransformType *tmpTrsf = (LogRigidTransformType *)this->GetInputTransform(i);
270  logTransformations[i] = tmpTrsf->GetLogTransform();
271  }
272  }
273 
274  std::vector <InternalScalarType> weightsFiltered = weights;
275 
276  // For LTS
277  std::vector < PointType > originPoints(nbPts);
278  std::vector < PointType > transformedPoints(nbPts);
279 
280  for (unsigned int i = 0;i < nbPts;++i)
281  {
282  PointType tmpOrig = this->GetInputOrigin(i);
283  BaseInputTransformType * tmpTrsf = this->GetInputTransform(i);
284  PointType tmpDisp = tmpTrsf->TransformPoint(tmpOrig);
285  originPoints[i] = tmpOrig;
286  transformedPoints[i] = tmpDisp;
287  }
288 
289  std::vector < double > residualErrors;
290  std::vector < double > mestWeights(nbPts,1);
291 
292  bool continueLoop = true;
293  unsigned int numMaxIter = 100;
294  unsigned int num_itr = 0;
295  double averageResidualValue = 1.0;
296 
297  typename BaseOutputTransformType::Pointer resultTransform, resultTransformOld;
298 
299  while(num_itr < numMaxIter)
300  {
301  ++num_itr;
302 
303  anima::computeLogEuclideanAverage<InternalScalarType,ScalarType,NDimensions>(logTransformations,weightsFiltered,resultTransform);
304  continueLoop = endLTSCondition(resultTransformOld,resultTransform);
305 
306  if (!continueLoop)
307  break;
308 
309  resultTransformOld = resultTransform;
310  residualErrors.clear();
311 
312  BaseMatrixTransformType *tmpTrsf = (BaseMatrixTransformType *)resultTransform.GetPointer();
313 
314  for (unsigned int i = 0;i < nbPts;++i)
315  {
316  if (weights[i] <= 0)
317  continue;
318 
319  double tmpDiff = 0;
320  PointType tmpDisp = tmpTrsf->TransformPoint(originPoints[i]);
321 
322  for (unsigned int j = 0;j < NDimensions;++j)
323  tmpDiff += (transformedPoints[i][j] - tmpDisp[j]) * (transformedPoints[i][j] - tmpDisp[j]);
324 
325  residualErrors.push_back(tmpDiff);
326  }
327 
328  if (num_itr == 1)
329  {
330  // At first iteration, compute factor for M-estimation
331  double averageDist = 0;
332  for (unsigned int i = 0;i < residualErrors.size();++i)
333  averageDist += residualErrors[i];
334 
335  averageResidualValue = averageDist / residualErrors.size();
336 
337  if (averageResidualValue <= 0)
338  averageResidualValue = 1;
339  }
340 
341  unsigned int residualIndex = 0;
342  for (unsigned int i = 0;i < nbPts;++i)
343  {
344  if (weights[i] <= 0)
345  continue;
346 
347  mestWeights[i] = exp(- residualErrors[residualIndex] / (averageResidualValue * m_MEstimateFactor));
348  ++residualIndex;
349  }
350 
351  for (unsigned int i = 0;i < nbPts;++i)
352  weightsFiltered[i] = weights[i] * mestWeights[i];
353  }
354 
355  this->SetOutput(resultTransform);
356  return true;
357 }
358 
359 template <unsigned int NDimensions>
360 bool
362 endLTSCondition(BaseOutputTransformType *oldTrsf, BaseOutputTransformType *newTrsf)
363 {
364  if (oldTrsf == NULL)
365  return true;
366 
367  typename BaseOutputTransformType::ParametersType oldParams = oldTrsf->GetParameters();
368  typename BaseOutputTransformType::ParametersType newParams = newTrsf->GetParameters();
369 
370  for (unsigned int i = 0;i < newParams.GetSize();++i)
371  {
372  double diffParam = fabs(newParams[i] - oldParams[i]);
373  if (diffParam > m_StoppingThreshold)
374  return true;
375  }
376 
377  return false;
378 }
379 
380 }// end of namespace anima
itk::Point< InternalScalarType, NDimensions > PointType
BaseInputTransformType * GetInputTransform(unsigned int i)
std::vector< BaseInputTransformPointer > & GetInputTransforms()
BaseInputTransformPointer & GetCurrentLinearTransform()
std::vector< InternalScalarType > & GetInputWeights()
PointType & GetInputOrigin(unsigned int i)
itk::Transform< ScalarType, NDimensions, NDimensions > BaseOutputTransformType
void SetInputWeight(unsigned int i, double w)
vnl_matrix< T > GetLogarithm(const vnl_matrix< T > &m, const double precision=0.00000000001, const int numApprox=1)
Computation of the matrix logarithm. Algo: inverse scaling and squaring, variant proposed by Cheng et...
void SetOutput(BaseOutputTransformType *output)
PointType GetEstimationBarycenter() ITK_OVERRIDE
std::vector< PointType > & GetInputOrigins()
itk::Transform< InternalScalarType, NDimensions, NDimensions > BaseInputTransformType