ANIMA  4.0
animaKMeansFilter.hxx
Go to the documentation of this file.
1 #pragma once
2 #include "animaKMeansFilter.h"
3 
4 namespace anima {
5 
6 template <class DataType, unsigned int PointDimension>
7 KMeansFilter <DataType,PointDimension>::
8 KMeansFilter()
9 {
10  m_ClassesMembership.clear();
11  m_Centroids.clear();
12  m_InputData.clear();
13  m_NumberPerClass.clear();
14 
15  m_NbClass = 0;
16  m_NbInputs = 0;
17  m_MaxIterations = 100;
18 
19  m_Verbose = true;
20 }
21 
22 template <class DataType, unsigned int PointDimension>
24 ~KMeansFilter()
25 {
26 }
27 
28 template <class DataType, unsigned int PointDimension>
29 void
31 SetInputData(DataHolderType &data)
32 {
33  if (data.size() == 0)
34  return;
35 
36  m_InputData = data;
37  m_NbInputs = m_InputData.size();
38 }
39 
40 template <class DataType, unsigned int PointDimension>
41 void
43 Update()
44 {
45  if (m_NbClass > m_NbInputs)
46  throw itk::ExceptionObject(__FILE__, __LINE__,"More classes than inputs...",ITK_LOCATION);
47 
48  this->InitializeKMeansFromData();
49  MembershipType oldMemberships = m_ClassesMembership;
50  unsigned int itncount = 0;
51  bool continueLoop = true;
52 
53  while ((itncount < m_MaxIterations)&&(continueLoop))
54  {
55  itncount++;
56 
57  if (m_Verbose)
58  std::cout << "Iteration " << itncount << "..." << std::endl;
59 
60  this->ComputeCentroids();
61  this->UpdateMemberships();
62 
63  continueLoop = !this->endConditionReached(oldMemberships);
64  oldMemberships = m_ClassesMembership;
65  }
66 }
67 
68 template <class DataType, unsigned int PointDimension>
69 void
71 ComputeCentroids()
72 {
73  for (unsigned int i = 0;i < m_NbClass;++i)
74  m_Centroids[i].Fill(0);
75 
76  for (unsigned int j = 0;j < m_NbInputs;++j)
77  {
78  for (unsigned int k = 0;k < PointDimension;++k)
79  m_Centroids[m_ClassesMembership[j]][k] += m_InputData[j][k];
80  }
81 
82  for (unsigned int j = 0;j < m_NbClass;++j)
83  {
84  if (m_NumberPerClass[j] != 0)
85  {
86  for (unsigned int k = 0;k < PointDimension;++k)
87  m_Centroids[j][k] /= m_NumberPerClass[j];
88  }
89  }
90 }
91 
92 template <class DataType, unsigned int PointDimension>
93 void
95 UpdateMemberships()
96 {
97  std::fill(m_NumberPerClass.begin(),m_NumberPerClass.end(),0);
98  for (unsigned int i = 0;i < m_NbInputs;++i)
99  {
100  unsigned int bestClass = 0;
101  double bestDistance = this->computeDistance(m_InputData[i],m_Centroids[0]);
102 
103  for (unsigned int j = 1;j < m_NbClass;++j)
104  {
105  double tmpDist = this->computeDistance(m_InputData[i],m_Centroids[j]);
106  if (tmpDist < bestDistance)
107  {
108  bestDistance = tmpDist;
109  bestClass = j;
110  }
111  }
112 
113  m_ClassesMembership[i] = bestClass;
114  ++m_NumberPerClass[bestClass];
115  }
116 }
117 
118 template <class DataType, unsigned int PointDimension>
119 bool
121 endConditionReached(MembershipType &oldMemberships)
122 {
123  for (unsigned int i = 0;i < m_NbInputs;++i)
124  {
125  if (oldMemberships[i] != m_ClassesMembership[i])
126  return false;
127  }
128 
129  return true;
130 }
131 
132 template <class DataType, unsigned int PointDimension>
133 void
135 InitializeKMeansFromData()
136 {
137  m_Centroids.clear();
138 
139  for (unsigned int i = 0;i < m_NbClass;++i)
140  m_Centroids.push_back(m_InputData[i]);
141 
142  //Centroids initialized, now compute memberships
143  if (m_ClassesMembership.size() != m_NbInputs)
144  {
145  m_ClassesMembership.resize(m_NbInputs);
146  std::fill(m_ClassesMembership.begin(),m_ClassesMembership.end(),0);
147 
148  this->UpdateMemberships();
149  }
150 }
151 
152 template <class DataType, unsigned int PointDimension>
153 void
155 InitializeClassesMemberships(MembershipType &classM)
156 {
157  if (classM.size() == m_NbInputs)
158  m_ClassesMembership = classM;
159 
160  for (unsigned int i = 0;i < m_NbInputs;++i)
161  m_NumberPerClass[m_ClassesMembership[i]]++;
162 }
163 
164 template <class DataType, unsigned int PointDimension>
165 double
167 computeDistance(VectorType &vec1, VectorType &vec2)
168 {
169  double resVal = 0;
170 
171  for (unsigned int i = 0;i < PointDimension;++i)
172  resVal += (vec1[i] - vec2[i])*(vec1[i] - vec2[i]);
173 
174  return resVal;
175 }
176 
177 } // end namespace anima
std::vector< VectorType > DataHolderType
std::vector< unsigned int > MembershipType