k-means算法又稱k-均值算法,是機器學習聚類算法中的一種,是一種基于形心的劃分方法,其中每個簇的中心都用簇中所有對象的均值來表示。其思想如下:
輸入:
- k:簇的數目;
- D:包含n個對象的數據集。
方法:
- 從D中隨機選擇幾個對象作為起始質心;
- 對每個質心,計算每個數據到各個質心的距離,并把這些點分配到離該質心最短的距離的簇;
- 對每個簇,計算簇中所有點的均值并將此均值作為新的質心;
- 將數據點按照新的中心重新聚類;
- 重復【步驟3】,直到質心不再發生變化(新的質心和原來的質心相等);
- 輸出聚類結果。
木羊的k-means算法實現包括5各類。其中,DBConnection.java用于連接數據庫,SelectData.java用于從數據庫里讀取數據,Point.java存放點對象模型,ManagePoint.java是對點的操作,Kmeans.java是算法的核心思想及主函數入口。以下分別給出各個類的詳細代碼:
DBConnection.java
數據集獲取,在機器學習數據集獲取官方網站UCI中點擊打開鏈接,木羊已經把該數據集從txt文檔中插入到數據庫,并去除了最后一列(花類別)。讀者若不熟悉數據庫的讀寫,請百度。若木羊有時間,會在后面的博文中補充把txt文檔內容讀到數據庫中的內容。
<span style="font-size:18px;">package db;import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;/*** * 數據庫連接類* */
public class DBConnection {public static final String driver = "com.mysql.jdbc.Driver";public static final String url = "jdbc:mysql://localhost:3306/mydb";public static final String user = "root";public static final String pwd = "123";public static Connection dBConnection() {Connection con = null;try {// 加載mysql驅動器Class.forName(driver);// 建立數據庫連接con = DriverManager.getConnection(url, user, pwd);} catch (ClassNotFoundException e) {// TODO Auto-generated catch blockSystem.out.println("加載驅動器失敗");e.printStackTrace();} catch (SQLException e) {// TODO Auto-generated catch blockSystem.out.println("注冊驅動器失敗");e.printStackTrace();}return con;}
}</span>
數據庫中的數據字段如下(共有150條數據):
SelectData.java
package dao;import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;import model.Point;
import db.DBConnection;/*** * 取出數據* * @return pointList* */
public class SelectData {public static final String SELECT = "select* from iris_Kmeans";public ArrayList<Point> getPoints() throws SQLException {ArrayList<Point> pointsList = new ArrayList<Point>();Connection con = DBConnection.dBConnection();ResultSet rs;// 創建一個PreparedStatement對象PreparedStatement pstmt = con.prepareStatement(SELECT);rs = pstmt.executeQuery();while (rs.next()) {Point point = new Point();point.setX(rs.getDouble(2));point.setY(rs.getDouble(3));point.setZ(rs.getDouble(4));point.setW(rs.getDouble(5));pointsList.add(point);}System.out.println("數據集: " + pointsList);pstmt.close();rs.close();con.close();return pointsList;}
}
Point.java
此處要注意重寫equal和hashcode方法以便后面質心的比較。
package model;public class Point {private double x;private double y;private double z;private double w;public double getX() {return x;}public void setX(double x) {this.x = x;}public double getY() {return y;}public void setY(double y) {this.y = y;}public double getZ() {return z;}public void setZ(double z) {this.z = z;}public double getW() {return w;}public void setW(double w) {this.w = w;}public Point() {}public Point(double x, double y, double z, double w) {super();this.x = x;this.y = y;this.z = z;this.w = w;}@Overridepublic String toString() {return "Point [ x=" + x + ", y=" + y + ", z=" + z + ", w=" + w + "]";}@Overridepublic boolean equals(Object obj) {Point point = (Point) obj;if (this.getX() == point.getX() && this.getY() == point.getY()&& this.getZ() == point.getZ() && this.getW() == point.getW()) {return true;}return false;}@Overridepublic int hashCode() {return (int) (x + y + z + w);}
}
該類包含了3個方法,分別用于計算兩個點的歐氏距離,比較前后兩個質心是否相同,更新質心。
package util;import java.util.ArrayList;
import java.util.Map;import model.Point;public class ManagePoint {/*** * 計算兩點之間的距離* * @param p* 第一個點* @param q* 第二個點* @return distance* */public double getDistance(Point p, Point q) {double dx = p.getX() - q.getX();double dy = p.getY() - q.getY();double dz = p.getZ() - q.getZ();double dw = p.getW() - q.getW();double distance = Math.sqrt(dx * dx + dy * dy + dz * dz + dw * dw);return distance;}/*** 判斷前后兩個質心是否相同* * @param nowCenterCluster* 現在的質心* @param lastCenterCluster* 上一次的質心* @return boolean* */public boolean isEqual(Map<Point, ArrayList<Point>> lastCenterCluster,Map<Point, ArrayList<Point>> nowCenterCluster) {boolean contain = false;if (lastCenterCluster == null)return false;else {for (Point point : nowCenterCluster.keySet()) {contain = lastCenterCluster.containsKey(point);}if (contain)return true;}return false;}/*** * 計算新的質心* * @param value* map中的值,存放簇中的所有點* @return point* */public Point getNewCenter(ArrayList<Point> value) {double sumX = 0, sumY = 0, sumZ = 0, sumW = 0;for (Point point : value) {sumX += point.getX();sumY += point.getY();sumZ += point.getZ();sumW += point.getW();}System.out.println("新的質心: (" + sumX / value.size() + "," + sumY/ value.size() + "," + sumZ / value.size() + "," + sumW/ value.size() + ")");Point point = new Point();point.setX(sumX / value.size());point.setY(sumY / value.size());point.setZ(sumZ / value.size());point.setW(sumW / value.size());return point;}
}
Kmeans.java
木羊把簇存在hashmap里,其中key存放該簇的質心,value存放該簇的所有點。特別注意的是,為了使最終聚類相對較理想,隨機選擇的三個初始質心應該在[0-50)、[50-100)、[100-150]三個區間內。
package util;import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;import model.Point;
import dao.SelectData;public class Kmeans {public Map<Point, ArrayList<Point>> executeKmeans(int k) {ArrayList<Point> dataList = new ArrayList<Point>();// 存放原始數據Map<Point, ArrayList<Point>> nowCenterClusterMap = new HashMap<Point, ArrayList<Point>>();// 當前質心及其簇內的點Map<Point, ArrayList<Point>> lastCenterClusterMap = null;// 上一個質心及其簇內的點try {dataList = new SelectData().getPoints();// 隨機創建K個點作為起始質心Random rd = new Random();int[] initIndex = { 50, 50, 50 };int[] tempIndex = { 0, 50, 100 };System.out.println("起始質心下標: ");for (int i = 0; i < k; i++) {int index = rd.nextInt(initIndex[i]) + tempIndex[i];System.out.println("第" + (i + 1) + "個 : " + index);nowCenterClusterMap.put(dataList.get(index),new ArrayList<Point>());}// 輸出起始質心System.out.println("起始質心: ");for (Point point : nowCenterClusterMap.keySet())System.out.println("key: " + point);// 將數據點point加入配到離其最近的map的value中ManagePoint managePoint = new ManagePoint();while (true) {for (Point point : dataList) {double shortestDistance = Double.MAX_VALUE;// 初始化最短距離為Double的最大值Point key = null;for (Entry<Point, ArrayList<Point>> entry : nowCenterClusterMap.entrySet()) {// 計算質心與各點間的距離double distance = managePoint.getDistance(entry.getKey(), point);if (distance < shortestDistance) {shortestDistance = distance;key = entry.getKey();}}nowCenterClusterMap.get(key).add(point);}// 如果新的質心與上次的質心相等,則退出整個循環if (managePoint.isEqual(lastCenterClusterMap,nowCenterClusterMap)) {System.out.println("相等了。");break;}// 更新質心lastCenterClusterMap = nowCenterClusterMap;nowCenterClusterMap = new HashMap<Point, ArrayList<Point>>();System.out.println("------------------------------------------------------------------");for (Entry<Point, ArrayList<Point>> entry : lastCenterClusterMap.entrySet()) {nowCenterClusterMap.put(managePoint.getNewCenter(entry.getValue()),new ArrayList<Point>());}}} catch (SQLException e) {// TODO Auto-generated catch blockSystem.out.println("數據庫操作失敗");e.printStackTrace();}return nowCenterClusterMap;}public static void main(String[] args) {int K = 3;// 分為三個類Map<Point, ArrayList<Point>> result = new Kmeans().executeKmeans(K);// 輸出分類System.out.println("===========聚類結果: ============");for (Entry<Point, ArrayList<Point>> entry : result.entrySet()) {System.out.println("\n" + "穩定的質心: " + entry.getKey());System.out.println("該簇的大小: " + entry.getValue().size());System.out.println("簇里的點:" + entry.getValue());}}
}
以上代碼均從MyEclipse上復制粘貼而來,親測可運行,結果如下:
經測試,無論初始質心被隨機選擇成哪3個,最終穩定的質心都不變。
(歡迎討論。代碼尚有不完善之處,請多多指教。轉載請注明出處。)