ANIMA  4.0
animaBLMLambdaCostFunction.cxx
Go to the documentation of this file.
3 #include <animaQRDecomposition.h>
4
5 namespace anima
6 {
7
10 {
11  unsigned int nbParams = m_LowerBoundsPermutted.size();
12  unsigned int numLines = m_InputWResiduals.size();
13
14  m_PPermutted.set_size(nbParams);
15  m_PPermuttedShrunk.set_size(m_JRank);
16
17  // Solve (JtJ + lambda d^2) x = - Jt r, for a given lambda as parameter
18  // Use the fact that pi^T D pi and pi a permutation matrix when D diagonal is diagonal, is equivalent to d[pivotVector[i]] in vector form
19
20  m_RAlphaTranspose.set_size(nbParams,nbParams);
21  m_RAlphaTranspose.fill(0.0);
22
23  if (parameters[0] > 0.0)
24  {
25  // Compute QR solution for any lambda > 0.0
26  m_WorkMatrix = m_InputWorkMatrix;
27  for (unsigned int i = 0;i < nbParams;++i)
28  m_WorkMatrix.put(m_JRank + i,i,std::sqrt(parameters[0]) * m_DValues[m_PivotVector[i]]);
29
30  for (unsigned int i = 0;i < numLines;++i)
31  m_WResiduals[i] = m_InputWResiduals[i];
32
33  anima::QRGivensDecomposition(m_WorkMatrix,m_WResiduals);
34  anima::UpperTriangularSolver(m_WorkMatrix,m_WResiduals,m_PPermutted,nbParams);
35
36  for (unsigned int i = 0;i < nbParams;++i)
37  {
38  for (unsigned int j = i;j < nbParams;++j)
39  m_RAlphaTranspose.put(j,i,m_WorkMatrix.get(i,j));
40  }
41
42  m_SolutionInBounds = this->CheckSolutionIsInBounds(m_PPermutted);
43  }
44  else
45  {
46  // Compute simpler solution if tested parameter is zero
47  // (solver uses only square rank subpart of work matrix, and rank first wresiduals)
48  anima::UpperTriangularSolver(m_ZeroWorkMatrix,m_ZeroWResiduals,m_PPermuttedShrunk,m_JRank);
49  m_SolutionInBounds = this->CheckSolutionIsInBounds(m_PPermuttedShrunk);
50
51  for (unsigned int i = 0;i < m_JRank;++i)
52  {
53  for (unsigned int j = i;j < m_JRank;++j)
54  m_RAlphaTranspose.put(j,i,m_ZeroWorkMatrix.get(i,j));
55  }
56
57  for (unsigned int i = 0;i < m_JRank;++i)
58  m_PPermutted[i] = m_PPermuttedShrunk[i];
59  for (unsigned int i = m_JRank;i < nbParams;++i)
60  m_PPermutted[i] = 0.0;
61  }
62
63  if (!m_SolutionInBounds)
64  {
65  for (unsigned int i = 0;i < nbParams;++i)
66  {
67  double value = m_PPermutted[i] + m_PreviousParametersPermutted[i];
68  value = std::min(m_UpperBoundsPermutted[i],std::max(m_LowerBoundsPermutted[i],value));
69  m_PPermutted[i] = value - m_PreviousParametersPermutted[i];
70  }
71  }
72
73  m_SolutionVector.set_size(nbParams);
74
75  double phiNorm = 0.0;
76  for (unsigned int i = 0;i < nbParams;++i)
77  {
78  m_SolutionVector[i] = m_PPermutted[m_InversePivotVector[i]];
79  double phiVal = m_DValues[i] * m_SolutionVector[i];
80  phiNorm += phiVal * phiVal;
81  }
82
83  phiNorm = std::sqrt(phiNorm);
84
85  return phiNorm - m_DeltaParameter;
86 }
87
89 {
90  unsigned int rank = solutionVector.size();
91  for (unsigned int i = 0;i < rank;++i)
92  {
93  double value = m_PreviousParametersPermutted[i] + solutionVector[i];
94  if (value < m_LowerBoundsPermutted[i])
95  return false;
96
97  if (value > m_UpperBoundsPermutted[i])
98  return false;
99  }
100
101  return true;
102 }
103
104 void
106 {
107  // Computes the derivative assuming GetValue has ben called right before and RAlphaTranspose is thus defined
108  // Valid only if the solution is in bounds, otherwise barely an approximation that should not be trusted
109  derivative.set_size(1);
110
111  // Regular case when in bounds
112  unsigned int nbParams = m_LowerBoundsPermutted.size();
113  ParametersType q(nbParams), workQ(nbParams);
114
115  double normQ = 0.0;
116  for (unsigned int i = 0;i < nbParams;++i)
117  {
118  q[i] = m_DValues[i] * m_SolutionVector[i];
119  normQ += q[i] * q[i];
120  }
121
122  normQ = std::sqrt(normQ);
123  for (unsigned int i = 0;i < nbParams;++i)
124  workQ[i] = m_DValues[m_PivotVector[i]] * q[m_PivotVector[i]] / normQ;
125
126  derivative[0] = 0.0;
127  if (parameters[0] == 0.0)
128  {
129  anima::LowerTriangularSolver(m_RAlphaTranspose,workQ,q,m_JRank);
130  for (unsigned int i = 0;i < m_JRank;++i)
131  derivative[0] -= q[i] * q[i];
132  }
133  else
134  {
135  anima::LowerTriangularSolver(m_RAlphaTranspose,workQ,q);
136  for (unsigned int i = 0;i < nbParams;++i)
137  derivative[0] -= q[i] * q[i];
138  }
139
140  derivative[0] *= normQ;
141 }
142
143 void
145  ParametersType &qtResiduals, unsigned int rank)
146 {
147  unsigned int nbParams = qrDerivative.cols();
148  unsigned int numLinesWorkMatrix = rank + nbParams;
149
150  m_ZeroWorkMatrix.set_size(rank,rank);
151  m_ZeroWorkMatrix.fill(0.0);
152  m_ZeroWResiduals.set_size(rank);
153  m_ZeroWResiduals.fill(0.0);
154
155  m_InputWorkMatrix.set_size(numLinesWorkMatrix,nbParams);
156  m_InputWorkMatrix.fill(0.0);
157  m_InputWResiduals.set_size(numLinesWorkMatrix);
158  m_InputWResiduals.fill(0.0);
159
160  m_WorkMatrix.set_size(numLinesWorkMatrix,nbParams);
161  m_WorkMatrix.fill(0.0);
162  m_WResiduals.set_size(numLinesWorkMatrix);
163  m_WResiduals.fill(0.0);
164
165  for (unsigned int i = 0;i < rank;++i)
166  {
167  for (unsigned int j = i;j < rank;++j)
168  {
169  double tmpVal = qrDerivative.get(i,j);
170  m_ZeroWorkMatrix.put(i,j,tmpVal);
171  m_InputWorkMatrix.put(i,j,tmpVal);
172  }
173
174  for (unsigned int j = rank;j < nbParams;++j)
175  m_InputWorkMatrix.put(i,j,qrDerivative.get(i,j));
176
177  m_InputWResiduals[i] = - qtResiduals[i];
178  m_ZeroWResiduals[i] = - qtResiduals[i];
179  }
180 }
181
182 } // end namespace anima
void LowerTriangularSolver(vnl_matrix< ScalarType > &matrix, VectorType &rhs, VectorType &result, unsigned int rank=0)
void SetInputWorkMatricesAndVectorsFromQRDerivative(vnl_matrix< double > &qrDerivative, ParametersType &qtResiduals, unsigned int rank)
virtual void GetDerivative(const ParametersType &parameters, DerivativeType &derivative) const ITK_OVERRIDE
virtual MeasureType GetValue(const ParametersType &parameters) const ITK_OVERRIDE
void QRGivensDecomposition(vnl_matrix< ScalarType > &aMatrix, vnl_vector< ScalarType > &bVector)
Superclass::DerivativeType DerivativeType
void UpperTriangularSolver(const vnl_matrix< ScalarType > &matrix, const VectorType &rhs, VectorType &result, unsigned int rank=0)
Superclass::ParametersType ParametersType
bool CheckSolutionIsInBounds(ParametersType &solutionVector) const