對同一像素點值的像素點歸為一類,通過平均值進行取代,從而將圖像進行壓縮并且保證圖像盡可能不失真,關鍵信息仍保留。
from PIL import Image
import numpy as np
from sklearn.cluster import KMeans
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3Ddef restore_image(cb, cluster, shape):row, col, dummy = shapeimage = np.empty((row, col, 3))index = 0for r in range(row):for c in range(col):image[r, c] = cb[cluster[index]]index += 1return imagedef show_scatter(a):N = 10print('原始數據:\n', a)density, edges = np.histogramdd(a, bins=[N,N,N], range=[(0,1), (0,1), (0,1)])density /= density.max()x = y = z = np.arange(N)d = np.meshgrid(x, y, z)fig = plt.figure(1, facecolor='w')ax = fig.add_subplot(111, projection='3d')ax.scatter(d[1], d[0], d[2], c='r', s=100*density, marker='o', depthshade=True)ax.set_xlabel(u'紅色分量')ax.set_ylabel(u'綠色分量')ax.set_zlabel(u'藍色分量')plt.title(u'圖像顏色三維頻數分布', fontsize=20)plt.figure(2, facecolor='w')den = density[density > 0]den = np.sort(den)[::-1]t = np.arange(len(den))plt.plot(t, den, 'r-', t, den, 'go', lw=2)plt.title(u'圖像顏色頻數分布', fontsize=18)plt.grid(True)plt.show()if __name__ == '__main__':matplotlib.rcParams['font.sans-serif'] = [u'SimHei']matplotlib.rcParams['axes.unicode_minus'] = Falsenum_vq = 256 #256個像素,最后想降維到256個維度im = Image.open('Lena.png') # flower2.png(200)/lena.png(50) 讀取圖片image = np.array(im).astype(np.float) / 255 # 將圖像數據轉化為array類型,方便后續的操作image = image[:, :, :3]#所有行、列、前三個維度,因為png有四個屬性,RGBα,alpha為透明度,不需要,只用前三個維度信息即可image_v = image.reshape((-1, 3))#拉伸像素點,行不關心,列為3列===轉換為二維數據,每一列均為一個維度的全部數據,RGB變成了3列,每行為一個像素點model = KMeans(num_vq)#通過Kmeans對每一行進行處理,也就是每個像素點進行處理,對每個像素點進行分類,像素點相似度的歸類;創建聚類對象#每類都有一個中心像素點,將該中心像素點代替這一類,這里手動傳入的是分成256個類別。不管像素點位置,只考慮相似與否是否為一類show_scatter(image_v)#畫圖N = image_v.shape[0] # 圖像像素總數# 選擇足夠多的樣本(如1000個),計算聚類中心idx = np.random.randint(0, N, size=1000)#從圖像中隨機選取1000個像素點image_sample = image_v[idx]model.fit(image_sample)#將這1000個像素點去訓練模型,聚類結果,從1000個像素點中找到最重要的256個像素點作為中心點c = model.predict(image_v) # 將圖像全部的像素點進行預測,看看圖像中的所有像素點離這256個簇哪一個最近,把圖像的所有像素點進行分類print('聚類結果:\n', c)print('聚類中心:\n', model.cluster_centers_)plt.figure(figsize=(15, 8), facecolor='w')plt.subplot(121)plt.axis('off')plt.title(u'原始圖片', fontsize=18)plt.imshow(image)#plt.savefig('1.png')plt.subplot(122)vq_image = restore_image(model.cluster_centers_, c, image.shape)#聚類中心點、聚類結果、模型圖像的形狀 作為參數進行恢復圖像plt.axis('off')plt.title(u'矢量量化后圖片:%d色' % num_vq, fontsize=20)plt.imshow(vq_image)#plt.savefig('2.png')plt.tight_layout()plt.show()