K-中心點聚類算法
(1)任意選擇k個對象作為初始的簇中心點
(2)指派每個剩余對象給離他最近的中心點所表示的簇
(3)選擇一個未被選擇的中心點直到所有的中心點都被選擇過
(4)選擇一個未被選擇過的非中心點對象,計算用代替的總代價并記錄在S中
,直到所有非中心點都被選擇過。
(5)如果在S中的所有非中心點代替所有中心點后的計算出總代價有小于0的存在,然后找出S中的用非中心點替代中心點后代價最小的一個,并用該非中心點替代對應的中心點,形成一個新的k個中心點的集合
(6)重復步驟2-5,直到沒有再發生簇的重新分配,即所有的S都大于0.
代碼
public class Cluster {private int id;// 標識private Point center;// 中心private List<Point> members = new ArrayList<Point>();// 成員public Cluster(int id, Point center) {this.id = id;this.center = center;}public Cluster(int id, Point center, List<Point> members) {this.id = id;this.center = center;this.members = members;}public void addPoint(Point newPoint) {if (!members.contains(newPoint)){members.add(newPoint);}else{System.out.println("樣本數據點 {"+newPoint.toString()+"} 已經存在!");}}public float getdis() {float cur=0;for (Point point : members) {cur+=point.getDist()*point.getDist();}return cur;}public int getId() {return id;}public Point getCenter() {return center;}public void setCenter(Point center) {this.center = center;}public List<Point> getMembers() {return members;}@Overridepublic String toString() {String toString = "-----------Cluster"+this.getId()+"---------\n";toString+="Mid_Point: "+center+" Points_num: "+members.size();for (Point point : members) {toString+="\n"+point.toString();}return toString+"\n";}
}
public class datahandler {public static List<float[]> readTxt(String fileName){List<float[]> list=new ArrayList<>();try {File filename = new File(fileName); // 讀取input.txt文件InputStreamReader reader = new InputStreamReader(new FileInputStream(filename)); // 建立一個輸入流對象readerBufferedReader br = new BufferedReader(reader);String line = "";line = br.readLine();while (true) {line = br.readLine();if(line==null) break;String[] temp=line.split(",");float[] c=new float[temp.length];for(int i=0;i<temp.length;i++){c[i]=Float.parseFloat(temp[i]);}list.add(c);}} catch (Exception e) {e.printStackTrace();}return list;}public static void writeTxt(String content){try { // 防止文件建立或讀取失敗,用catch捕捉錯誤并打印,也可以throw/* 讀入TXT文件 */File writename = new File("src/k/output.txt"); // 相對路徑,如果沒有則要建立一個新的output。txt文件writename.createNewFile(); // 創建新文件BufferedWriter out = new BufferedWriter(new FileWriter(writename));out.write(content); // \r\n即為換行out.flush(); // 把緩存區內容壓入文件out.close(); // 最后記得關閉文件} catch (Exception e) {e.printStackTrace();}}public static void main(String[] args) {
/* List<float[]> ret = readTxt("src/k/t2.txt");long s=System.currentTimeMillis();KMeansRun kRun = new KMeansRun(5, ret);Set<Cluster> clusterSet = kRun.run();System.out.println("K-means聚類算法運行時間:"+(System.currentTimeMillis()-s)+"ms");System.out.println("單次迭代運行次數:" + kRun.getIterTimes());StringBuilder stringBuilder=new StringBuilder();for (Cluster cluster : clusterSet) {System.out.println("Mid_Point: "+cluster.getCenter()+" clusterId: "+cluster.getId()+" Points_num: "+cluster.getMembers().size());stringBuilder.append(cluster).append("\n");}writeTxt(stringBuilder.toString());*/List<float[]> ret = readTxt("src/k/t2.txt");XYSeries series = new XYSeries("xySeries");for (int x = 1; x < 20; x++) {KMeansRun kRun = new KMeansRun(x, ret);Set<Cluster> clusterSet = kRun.run();float y = 0;for (Cluster cluster : clusterSet){y+=cluster.getdis();}series.add(x, y);}XYSeriesCollection dataset = new XYSeriesCollection();dataset.addSeries(series);JFreeChart chart = ChartFactory.createXYLineChart("sum of the squared errors", // chart title"K", // x axis label"SSE", // y axis labeldataset, // dataPlotOrientation.VERTICAL,false, // include legendfalse, // tooltipsfalse // urls);ChartFrame frame = new ChartFrame("my picture", chart);frame.pack();frame.setVisible(true);frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);}
}
public class DistanceCompute {/*** 求歐式距離*/public double getEuclideanDis(Point p1, Point p2) {double count_dis = 0;float[] p1_local_array = p1.getlocalArray();float[] p2_local_array = p2.getlocalArray();if (p1_local_array.length != p2_local_array.length) {throw new IllegalArgumentException("length of array must be equal!");}for (int i = 0; i < p1_local_array.length; i++) {count_dis += Math.pow(p1_local_array[i] - p2_local_array[i], 2);}return Math.sqrt(count_dis);}
}
import java.util.*;public class KMeansRun {private int kNum; //簇的個數private int iterNum = 200; //迭代次數private int iterMaxTimes = 100000; //單次迭代最大運行次數private int iterRunTimes = 0; //單次迭代實際運行次數private float disDiff = (float) 0.01; //單次迭代終止條件,兩次運行中類中心的距離差private List<float[]> original_data =null; //用于存放,原始數據集private static List<Point> pointList = null; //用于存放,原始數據集所構建的點集private DistanceCompute disC = new DistanceCompute();private int len = 0; //用于記錄每個數據點的維度public KMeansRun(int k, List<float[]> original_data) {this.kNum = k;this.original_data = original_data;this.len = original_data.get(0).length;//檢查規范check();//初始化點集。init();}/*** 檢查規范*/private void check() {if (kNum == 0){throw new IllegalArgumentException("k must be the number > 0");}if (original_data == null){throw new IllegalArgumentException("program can't get real data");}}/*** 初始化數據集,把數組轉化為Point類型。*/private void init() {pointList = new ArrayList<Point>();for (int i = 0, j = original_data.size(); i < j; i++){pointList.add(new Point(i, original_data.get(i)));}}/*** 隨機選取中心點,構建成中心類。*/private Set<Cluster> chooseCenterCluster() {Set<Cluster> clusterSet = new HashSet<Cluster>();Random random = new Random();for (int id = 0; id < kNum; ) {Point point = pointList.get(random.nextInt(pointList.size()));// 用于標記是否已經選擇過該數據。boolean flag =true;for (Cluster cluster : clusterSet) {if (cluster.getCenter().equals(point)) {flag = false;}}// 如果隨機選取的點沒有被選中過,則生成一個clusterif (flag) {Cluster cluster =new Cluster(id, point);clusterSet.add(cluster);id++;}}return clusterSet;}/*** 為每個點分配一個類!*/public void cluster(Set<Cluster> clusterSet){// 計算每個點到K個中心的距離,并且為每個點標記類別號for (Point point : pointList) {float min_dis = Integer.MAX_VALUE;for (Cluster cluster : clusterSet) {float tmp_dis = (float) Math.min(disC.getEuclideanDis(point, cluster.getCenter()), min_dis);if (tmp_dis != min_dis) {min_dis = tmp_dis;point.setClusterId(cluster.getId());point.setDist(min_dis);}}}// 新清除原來所有的類中成員。把所有的點,分別加入每個類別for (Cluster cluster : clusterSet) {cluster.getMembers().clear();for (Point point : pointList) {if (point.getClusterid()==cluster.getId()) {cluster.addPoint(point);}}}}/*** 計算每個類的中心位置!*/public boolean calculateCenter(Set<Cluster> clusterSet) {boolean ifNeedIter = false;for (Cluster cluster : clusterSet) {List<Point> point_list = cluster.getMembers();float[] sumAll =new float[len];// 所有點,對應各個維度進行求和for (int i = 0; i < len; i++) {for (int j = 0; j < point_list.size(); j++) {sumAll[i] += point_list.get(j).getlocalArray()[i];}}// 計算平均值for (int i = 0; i < sumAll.length; i++) {sumAll[i] = (float) sumAll[i]/point_list.size();}// 計算兩個新、舊中心的距離,如果任意一個類中心移動的距離大于dis_diff則繼續迭代。if(disC.getEuclideanDis(cluster.getCenter(), new Point(sumAll)) > disDiff){ifNeedIter = true;}// 設置新的類中心位置cluster.setCenter(new Point(sumAll));}return ifNeedIter;}/*** 運行 k-means*/public Set<Cluster> run() {Set<Cluster> clusterSet= chooseCenterCluster();boolean ifNeedIter = true;while (ifNeedIter) {cluster(clusterSet);ifNeedIter = calculateCenter(clusterSet);iterRunTimes ++ ;}return clusterSet;}/*** 返回實際運行次數*/public int getIterTimes() {return iterRunTimes;}}
public class Point {private float[] localArray;private int id;private int clusterId; // 標識屬于哪個類中心。private float dist; // 標識和所屬類中心的距離。public Point(int id, float[] localArray) {this.id = id;this.localArray = localArray;}public Point(float[] localArray) {this.id = -1; //表示不屬于任意一個類this.localArray = localArray;}public float[] getlocalArray() {return localArray;}public int getId() {return id;}public void setClusterId(int clusterId) {this.clusterId = clusterId;}public int getClusterid() {return clusterId;}public float getDist() {return dist;}public void setDist(float dist) {this.dist = dist;}@Overridepublic String toString() {String result = "Point_id=" + id + " [";for (int i = 0; i < localArray.length; i++) {result += localArray[i] + " ";}return result.trim()+"] clusterId: "+clusterId;}@Overridepublic boolean equals(Object obj) {if (obj == null || getClass() != obj.getClass())return false;Point point = (Point) obj;if (point.localArray.length != localArray.length)return false;for (int i = 0; i < localArray.length; i++) {if (Float.compare(point.localArray[i], localArray[i]) != 0) {return false;}}return true;}@Overridepublic int hashCode() {float x = localArray[0];float y = localArray[localArray.length - 1];long temp = x != +0.0d ? Double.doubleToLongBits(x) : 0L;int result = (int) (temp ^ (temp >>> 32));temp = y != +0.0d ? Double.doubleToLongBits(y) : 0L;result = 31 * result + (int) (temp ^ (temp >>> 32));return result;}
}