13 template <
class ScalarType>
14 FuzzyCMeansFilter <ScalarType>
17 m_ClassesMembership.clear();
24 m_MaxIterations = 100;
27 m_SpectralClusterInit =
false;
28 m_SphericalAverageType = Euclidean;
30 m_RelStopCriterion = 1.0e-4;
34 template <
class ScalarType>
43 m_NbInputs = m_InputData.size();
44 m_NDim = m_InputData[0].size();
47 template <
class ScalarType>
52 if (m_NbClass > m_NbInputs)
53 throw itk::ExceptionObject(__FILE__,__LINE__,
"More classes than inputs...",ITK_LOCATION);
55 m_DataWeights.resize(m_NbInputs);
56 std::fill(m_DataWeights.begin(),m_DataWeights.end(),1.0 / m_NbInputs);
58 InitializeCMeansFromData();
61 unsigned int itncount = 0;
62 bool continueLoop =
true;
64 while ((itncount < m_MaxIterations)&&(continueLoop))
69 std::cout <<
"Iteration " << itncount <<
"..." << std::endl;
74 continueLoop = !endConditionReached(oldMemberships);
75 oldMemberships = m_ClassesMembership;
79 template <
class ScalarType>
84 if (m_PowMemberships.size() != m_NbInputs)
85 m_PowMemberships.resize(m_NbInputs);
87 for (
unsigned int i = 0;i < m_NbInputs;++i)
89 if (m_PowMemberships[i].size() != m_NbClass)
90 m_PowMemberships[i].resize(m_NbClass);
92 for (
unsigned int j = 0;j < m_NbClass;++j)
93 m_PowMemberships[i][j] = std::pow(m_ClassesMembership[i][j],m_MValue);
96 if (m_TmpVector.size() != m_NDim)
97 m_TmpVector.resize(m_NDim);
99 if (m_TmpWeights.size() != m_NbInputs)
100 m_TmpWeights.resize(m_NbInputs);
102 for (
unsigned int i = 0;i < m_NbClass;++i)
104 double sumPowMemberShips = 0;
105 std::fill(m_TmpVector.begin(),m_TmpVector.end(),0.0);
107 for (
unsigned int j = 0;j < m_NbInputs;++j)
109 sumPowMemberShips += m_DataWeights[j] * m_PowMemberships[j][i];
110 for (
unsigned int k = 0;k < m_NDim;++k)
111 m_TmpVector[k] += m_DataWeights[j] * m_PowMemberships[j][i] * m_InputData[j][k];
114 for (
unsigned int k = 0;k < m_NDim;++k)
115 m_TmpVector[k] /= sumPowMemberShips;
117 switch (m_SphericalAverageType)
120 m_Centroids[i] = m_TmpVector;
123 case ApproximateSpherical:
126 for (
unsigned int k = 0;k < m_NDim;++k)
127 tmpSum += m_TmpVector[k] * m_TmpVector[k];
129 tmpSum = std::sqrt(tmpSum);
130 for (
unsigned int k = 0;k < m_NDim;++k)
131 m_TmpVector[k] /= tmpSum;
133 m_Centroids[i] = m_TmpVector;
141 for (
unsigned int k = 0;k < m_NDim;++k)
142 tmpSum += m_TmpVector[k] * m_TmpVector[k];
144 tmpSum = std::sqrt(tmpSum);
145 for (
unsigned int k = 0;k < m_NDim;++k)
146 m_TmpVector[k] /= tmpSum;
148 for (
unsigned int k = 0;k < m_NbInputs;++k)
149 m_TmpWeights[k] = m_DataWeights[k] * m_PowMemberships[k][i];
159 template <
class ScalarType>
162 ::UpdateMemberships()
164 long double powFactor = 1.0/(m_MValue - 1.0);
165 m_DistancesPointsCentroids.resize(m_NbClass);
167 for (
unsigned int i = 0;i < m_NbInputs;++i)
169 unsigned int minClassIndex = 0;
170 bool nullDistance =
false;
171 for (
unsigned int j = 0;j < m_NbClass;++j)
173 m_DistancesPointsCentroids[j] = computeDistance(m_InputData[i],m_Centroids[j]);
175 if (m_DistancesPointsCentroids[j] <= 0)
185 for (
unsigned int j = 0;j < m_NbClass;++j)
186 m_ClassesMembership[i][j] = 0;
188 m_ClassesMembership[i][minClassIndex] = 1.0;
192 for (
unsigned int j = 0;j < m_NbClass;++j)
194 long double tmpVal = 0;
198 for (
unsigned int k = 0;k < m_NbClass;++k)
199 tmpVal += m_DistancesPointsCentroids[j] / m_DistancesPointsCentroids[k];
203 for (
unsigned int k = 0;k < m_NbClass;++k)
204 tmpVal += std::pow(m_DistancesPointsCentroids[j] / m_DistancesPointsCentroids[k],powFactor);
207 m_ClassesMembership[i][j] = 1.0 / tmpVal;
212 template <
class ScalarType>
219 for (
unsigned int i = 0;i < m_NbInputs;++i)
221 for (
unsigned int j = 0;j < m_NbClass;++j)
223 double testValue = std::abs(oldMemberships[i][j] - m_ClassesMembership[i][j]);
224 if (testValue > absDiff)
229 if (absDiff > m_RelStopCriterion)
235 template <
class ScalarType>
238 ::InitializeCMeansFromData()
240 m_Centroids.resize(m_NbClass);
241 m_ClassesMembership.resize(m_NbInputs);
244 if (!m_SpectralClusterInit)
246 for (
unsigned int i = 0;i < m_NbClass;++i)
247 m_Centroids[i] = m_InputData[i];
249 double fixVal = 0.95;
250 for (
unsigned int i = 0;i < m_NbInputs;++i)
252 unsigned int tmp = i % m_NbClass;
253 m_ClassesMembership[i] = tmpVec;
254 m_ClassesMembership[i][tmp] = fixVal;
255 for (
unsigned j = 0;j < m_NbClass;++j)
258 m_ClassesMembership[i][j] = (1.0 - fixVal)/(m_NbClass - 1.0);
264 m_Centroids[0] = m_InputData[0];
265 std::vector <unsigned int> alreadyIn(m_NbClass,0);
267 for (
unsigned int i = 1;i < m_NbClass;++i)
269 double minCrossProd = std::numeric_limits <double>::max();
270 unsigned int minIndex = 0;
271 for (
unsigned int j = 0;j < m_NbInputs;++j)
274 for (
unsigned int k = 0;k < i;++k)
276 if (alreadyIn[k] == j)
285 double maxCrossProd = 0;
286 for (
unsigned int l = 0;l < i;++l)
288 double crossProd = 0;
289 for (
unsigned int k = 0;k < m_NDim;++k)
290 crossProd += m_InputData[j][k]*m_Centroids[l][k];
292 if (crossProd > maxCrossProd)
293 maxCrossProd = crossProd;
296 if (maxCrossProd < minCrossProd)
298 minCrossProd = maxCrossProd;
304 m_Centroids[i] = m_InputData[minIndex];
305 alreadyIn[i] = minIndex;
309 for (
unsigned int i = 0;i < m_NbInputs;++i)
310 m_ClassesMembership[i] = tmpVec;
312 this->UpdateMemberships();
316 template <
class ScalarType>
321 if (classM.size() == m_NbInputs)
323 m_ClassesMembership.resize(m_NbInputs);
325 for (
unsigned int i = 0;i < m_NbInputs;++i)
326 m_ClassesMembership[i] = classM[i];
330 template <
class ScalarType>
335 long double resVal = 0;
337 if (m_SphericalAverageType != Euclidean)
339 long double dotProd = 0;
340 for (
unsigned int i = 0;i < m_NDim;++i)
341 dotProd += vec1[i]*vec2[i];
346 resVal = std::abs(std::acos(dotProd));
351 for (
unsigned int i = 0;i < m_NDim;++i)
352 resVal += (vec1[i] - vec2[i])*(vec1[i] - vec2[i]);
std::vector< VectorType > DataHolderType
std::vector< double > VectorType
void ComputeSphericalCentroid(const std::vector< std::vector< ScalarType > > &dataPoints, std::vector< ScalarType > ¢roidValue, const std::vector< ScalarType > &initPoint, const std::vector< ScalarType > &weights, std::vector< ScalarType > *workLogVector=0, std::vector< ScalarType > *workVector=0, double tol=1.0e-4)