计算数据集中的元素与各个簇的中心的距离,将它赋给最近的簇,然后重新计算每个簇的平均值,再将元素按离平均值点最近的原则重新分配直到没有出现重新分配
该算法要事先给出k的值,即划分为几个簇。
vector<int> datoclu(data.size(), -1);用这个来标记每个数据在哪个簇中。
#include <fstream>
#include <sstream>
#include <vector>
#include <iostream>using namespace std;struct Point
{double x;double y;
};double distance(const Point& a, const Point& b)
{return sqrt(pow(a.x - b.x, 2) + pow(a.y - b.y, 2));
}vector<int> KMeans(vector<Point>& data, int k, int maxIterations)
{vector<Point> centroids(k);for (int i = 0; i < k; i++) {centroids[i] = data[rand() % data.size()]; //随机选择k个类聚中心。0到(data.size()-1)}vector<int> datoclu(data.size(), -1); //每个数据属于哪个簇bool flag = 0;while (!flag && maxIterations){flag = 1;for (int i = 0; i < data.size(); i++){double minDis = numeric_limits<double>::max();int index = -1;for (int j = 0; j < centroids.size(); j++){double dis = distance(data[i], centroids[j]);if (dis < minDis){minDis = dis;index = j;}}if (datoclu[i] != index) //记录每个数据属于的聚类中心{datoclu[i] = index;flag = 0;}}vector<Point> newClu(k);vector<int> num(k, 0);//计算每个簇平均值点for (int i = 0; i < data.size(); i++){newClu[datoclu[i]].x += data[i].x;newClu[datoclu[i]].y += data[i].y;num[datoclu[i]]++;}for (int i = 0; i < k; i++){newClu[i].x /= num[i];newClu[i].y /= num[i];}centroids = newClu;maxIterations--;}return datoclu;
}
vector<Point> ReadData(string filename)
{vector<Point> data;ifstream file(filename);if (file.is_open()){string line;while (getline(file, line)){istringstream iss(line);double x, y;string token;Point point;if (getline(iss, token, ',') && istringstream(token) >> point.x &&getline(iss, token, ',') && istringstream(token) >> point.y) {data.push_back(point);}}}else{cout << "open fail";}file.close();return data;
}int main()
{vector<Point> dataset = ReadData("data.txt");vector<int> clusters;int k, maxIterations;cout << "输入簇的个数和最大迭代次数"<<endl;cin >> k >> maxIterations;clusters= KMeans(dataset, k, maxIterations);vector <vector<int>> index(k);for (int j = 0; j < k; j++){for (int i = 0; i < clusters.size(); i++){if (clusters[i] == j){index[j].push_back(i);}}}for (int i = 0; i < index.size(); i++){cout << "{";for (int j = 0; j < index[i].size(); j++){cout << index[i][j]+1;if (j != index[i].size() - 1){cout << ",";}}cout << "}";}
}
数据集
1.0, 1.0
2.0, 1.0
1.0, 2.0
2.0, 2.0
4.0, 3.0
5.0, 3.0
4.0, 4.0
5.0,4.0
运行结果