ANIMA  4.0
animaFuzzyCMeansFilter.hxx
Go to the documentation of this file.
1 #pragma once
3 
4 #include <iostream>
5 #include <cmath>
6 #include <limits>
7 
9 
10 namespace anima
11 {
12 
13 template <class ScalarType>
14 FuzzyCMeansFilter <ScalarType>
16 {
17  m_ClassesMembership.clear();
18  m_Centroids.clear();
19  m_InputData.clear();
20 
21  m_NbClass = 0;
22  m_NbInputs = 0;
23  m_NDim = 0;
24  m_MaxIterations = 100;
25 
26  m_Verbose = true;
27  m_SpectralClusterInit = false;
28  m_SphericalAverageType = Euclidean;
29 
30  m_RelStopCriterion = 1.0e-4;
31  m_MValue = 2;
32 }
33 
34 template <class ScalarType>
35 void
37 ::SetInputData(DataHolderType &data)
38 {
39  if (data.size() == 0)
40  return;
41 
42  m_InputData = data;
43  m_NbInputs = m_InputData.size();
44  m_NDim = m_InputData[0].size();
45 }
46 
47 template <class ScalarType>
48 void
50 ::Update()
51 {
52  if (m_NbClass > m_NbInputs)
53  throw itk::ExceptionObject(__FILE__,__LINE__,"More classes than inputs...",ITK_LOCATION);
54 
55  m_DataWeights.resize(m_NbInputs);
56  std::fill(m_DataWeights.begin(),m_DataWeights.end(),1.0 / m_NbInputs);
57 
58  InitializeCMeansFromData();
59  DataHolderType oldMemberships = m_ClassesMembership;
60 
61  unsigned int itncount = 0;
62  bool continueLoop = true;
63 
64  while ((itncount < m_MaxIterations)&&(continueLoop))
65  {
66  itncount++;
67 
68  if (m_Verbose)
69  std::cout << "Iteration " << itncount << "..." << std::endl;
70 
71  ComputeCentroids();
72  UpdateMemberships();
73 
74  continueLoop = !endConditionReached(oldMemberships);
75  oldMemberships = m_ClassesMembership;
76  }
77 }
78 
79 template <class ScalarType>
80 void
82 ::ComputeCentroids()
83 {
84  if (m_PowMemberships.size() != m_NbInputs)
85  m_PowMemberships.resize(m_NbInputs);
86 
87  for (unsigned int i = 0;i < m_NbInputs;++i)
88  {
89  if (m_PowMemberships[i].size() != m_NbClass)
90  m_PowMemberships[i].resize(m_NbClass);
91 
92  for (unsigned int j = 0;j < m_NbClass;++j)
93  m_PowMemberships[i][j] = std::pow(m_ClassesMembership[i][j],m_MValue);
94  }
95 
96  if (m_TmpVector.size() != m_NDim)
97  m_TmpVector.resize(m_NDim);
98 
99  if (m_TmpWeights.size() != m_NbInputs)
100  m_TmpWeights.resize(m_NbInputs);
101 
102  for (unsigned int i = 0;i < m_NbClass;++i)
103  {
104  double sumPowMemberShips = 0;
105  std::fill(m_TmpVector.begin(),m_TmpVector.end(),0.0);
106 
107  for (unsigned int j = 0;j < m_NbInputs;++j)
108  {
109  sumPowMemberShips += m_DataWeights[j] * m_PowMemberships[j][i];
110  for (unsigned int k = 0;k < m_NDim;++k)
111  m_TmpVector[k] += m_DataWeights[j] * m_PowMemberships[j][i] * m_InputData[j][k];
112  }
113 
114  for (unsigned int k = 0;k < m_NDim;++k)
115  m_TmpVector[k] /= sumPowMemberShips;
116 
117  switch (m_SphericalAverageType)
118  {
119  case Euclidean:
120  m_Centroids[i] = m_TmpVector;
121  break;
122 
123  case ApproximateSpherical:
124  {
125  double tmpSum = 0;
126  for (unsigned int k = 0;k < m_NDim;++k)
127  tmpSum += m_TmpVector[k] * m_TmpVector[k];
128 
129  tmpSum = std::sqrt(tmpSum);
130  for (unsigned int k = 0;k < m_NDim;++k)
131  m_TmpVector[k] /= tmpSum;
132 
133  m_Centroids[i] = m_TmpVector;
134  break;
135  }
136 
137  case Spherical:
138  default:
139  {
140  double tmpSum = 0;
141  for (unsigned int k = 0;k < m_NDim;++k)
142  tmpSum += m_TmpVector[k] * m_TmpVector[k];
143 
144  tmpSum = std::sqrt(tmpSum);
145  for (unsigned int k = 0;k < m_NDim;++k)
146  m_TmpVector[k] /= tmpSum;
147 
148  for (unsigned int k = 0;k < m_NbInputs;++k)
149  m_TmpWeights[k] = m_DataWeights[k] * m_PowMemberships[k][i];
150 
151  anima::ComputeSphericalCentroid(m_InputData,m_Centroids[i],m_TmpVector,m_TmpWeights,&m_WorkLogVector,&m_WorkVector);
152 
153  break;
154  }
155  }
156  }
157 }
158 
159 template <class ScalarType>
160 void
162 ::UpdateMemberships()
163 {
164  long double powFactor = 1.0/(m_MValue - 1.0);
165  m_DistancesPointsCentroids.resize(m_NbClass);
166 
167  for (unsigned int i = 0;i < m_NbInputs;++i)
168  {
169  unsigned int minClassIndex = 0;
170  bool nullDistance = false;
171  for (unsigned int j = 0;j < m_NbClass;++j)
172  {
173  m_DistancesPointsCentroids[j] = computeDistance(m_InputData[i],m_Centroids[j]);
174 
175  if (m_DistancesPointsCentroids[j] <= 0)
176  {
177  nullDistance = true;
178  minClassIndex = j;
179  break;
180  }
181  }
182 
183  if (nullDistance)
184  {
185  for (unsigned int j = 0;j < m_NbClass;++j)
186  m_ClassesMembership[i][j] = 0;
187 
188  m_ClassesMembership[i][minClassIndex] = 1.0;
189  continue;
190  }
191 
192  for (unsigned int j = 0;j < m_NbClass;++j)
193  {
194  long double tmpVal = 0;
195 
196  if (m_MValue == 2.0)
197  {
198  for (unsigned int k = 0;k < m_NbClass;++k)
199  tmpVal += m_DistancesPointsCentroids[j] / m_DistancesPointsCentroids[k];
200  }
201  else
202  {
203  for (unsigned int k = 0;k < m_NbClass;++k)
204  tmpVal += std::pow(m_DistancesPointsCentroids[j] / m_DistancesPointsCentroids[k],powFactor);
205  }
206 
207  m_ClassesMembership[i][j] = 1.0 / tmpVal;
208  }
209  }
210 }
211 
212 template <class ScalarType>
213 bool
215 ::endConditionReached(DataHolderType &oldMemberships)
216 {
217  double absDiff = 0;
218 
219  for (unsigned int i = 0;i < m_NbInputs;++i)
220  {
221  for (unsigned int j = 0;j < m_NbClass;++j)
222  {
223  double testValue = std::abs(oldMemberships[i][j] - m_ClassesMembership[i][j]);
224  if (testValue > absDiff)
225  absDiff = testValue;
226  }
227  }
228 
229  if (absDiff > m_RelStopCriterion)
230  return false;
231  else
232  return true;
233 }
234 
235 template <class ScalarType>
236 void
238 ::InitializeCMeansFromData()
239 {
240  m_Centroids.resize(m_NbClass);
241  m_ClassesMembership.resize(m_NbInputs);
242  VectorType tmpVec(m_NbClass,0);
243 
244  if (!m_SpectralClusterInit)
245  {
246  for (unsigned int i = 0;i < m_NbClass;++i)
247  m_Centroids[i] = m_InputData[i];
248 
249  double fixVal = 0.95;
250  for (unsigned int i = 0;i < m_NbInputs;++i)
251  {
252  unsigned int tmp = i % m_NbClass;
253  m_ClassesMembership[i] = tmpVec;
254  m_ClassesMembership[i][tmp] = fixVal;
255  for (unsigned j = 0;j < m_NbClass;++j)
256  {
257  if (j != tmp)
258  m_ClassesMembership[i][j] = (1.0 - fixVal)/(m_NbClass - 1.0);
259  }
260  }
261  }
262  else
263  {
264  m_Centroids[0] = m_InputData[0];
265  std::vector <unsigned int> alreadyIn(m_NbClass,0);
266 
267  for (unsigned int i = 1;i < m_NbClass;++i)
268  {
269  double minCrossProd = std::numeric_limits <double>::max();
270  unsigned int minIndex = 0;
271  for (unsigned int j = 0;j < m_NbInputs;++j)
272  {
273  bool useIt = true;
274  for (unsigned int k = 0;k < i;++k)
275  {
276  if (alreadyIn[k] == j)
277  {
278  useIt = false;
279  break;
280  }
281  }
282 
283  if (useIt)
284  {
285  double maxCrossProd = 0;
286  for (unsigned int l = 0;l < i;++l)
287  {
288  double crossProd = 0;
289  for (unsigned int k = 0;k < m_NDim;++k)
290  crossProd += m_InputData[j][k]*m_Centroids[l][k];
291 
292  if (crossProd > maxCrossProd)
293  maxCrossProd = crossProd;
294  }
295 
296  if (maxCrossProd < minCrossProd)
297  {
298  minCrossProd = maxCrossProd;
299  minIndex = j;
300  }
301  }
302  }
303 
304  m_Centroids[i] = m_InputData[minIndex];
305  alreadyIn[i] = minIndex;
306  }
307 
308  //Centroids initialized, now compute memberships
309  for (unsigned int i = 0;i < m_NbInputs;++i)
310  m_ClassesMembership[i] = tmpVec;
311 
312  this->UpdateMemberships();
313  }
314 }
315 
316 template <class ScalarType>
317 void
319 ::InitializeClassesMemberships(DataHolderType &classM)
320 {
321  if (classM.size() == m_NbInputs)
322  {
323  m_ClassesMembership.resize(m_NbInputs);
324 
325  for (unsigned int i = 0;i < m_NbInputs;++i)
326  m_ClassesMembership[i] = classM[i];
327  }
328 }
329 
330 template <class ScalarType>
331 long double
333 ::computeDistance(VectorType &vec1, VectorType &vec2)
334 {
335  long double resVal = 0;
336 
337  if (m_SphericalAverageType != Euclidean)
338  {
339  long double dotProd = 0;
340  for (unsigned int i = 0;i < m_NDim;++i)
341  dotProd += vec1[i]*vec2[i];
342 
343  if (dotProd > 1)
344  dotProd = 1;
345 
346  resVal = std::abs(std::acos(dotProd));
347  }
348  else
349  {
350  resVal = 0;
351  for (unsigned int i = 0;i < m_NDim;++i)
352  resVal += (vec1[i] - vec2[i])*(vec1[i] - vec2[i]);
353  }
354 
355  return resVal;
356 }
357 
358 } // end namespace anima
std::vector< VectorType > DataHolderType
void ComputeSphericalCentroid(const std::vector< std::vector< ScalarType > > &dataPoints, std::vector< ScalarType > &centroidValue, const std::vector< ScalarType > &initPoint, const std::vector< ScalarType > &weights, std::vector< ScalarType > *workLogVector=0, std::vector< ScalarType > *workVector=0, double tol=1.0e-4)