🥑原理:數字水印 | 奇異值分解 SVD 的定義、原理及性質
🥑參考:Python 機器學習筆記:奇異值分解(SVD)算法
正文
對于一個圖像矩陣,我們總可以將其分解為以下形式:

通過選取不同個數 Σ \Sigma Σ 矩陣中的奇異值,就可以實現圖像的壓縮。
如果你沒有了解過原理,那么你當然看不懂這是什么意思😇
如果想要實現圖像的壓縮,那么可以先使用 n u m p y \mathsf{numpy} numpy 庫中的 linalg.svd 函數對圖像矩陣進行分解,然后提取前 k k k 個奇異值以實現 SVD 圖像壓縮效果。下面讓我們看一下代碼。
1?核心代碼
定義 s v d _ c o m p r e s s i o n \mathsf{svd\_compression} svd_compression 函數:
def svd_compression(img, k):res_image = np.zeros_like(img)for i in range(img.shape[2]):U, Sigma, VT = np.linalg.svd(img[:, :, i])res_image[:, :, i] = U[:, :k].dot(np.diag(Sigma[:k])).dot(VT[:k, :])return res_image
參數說明:
- i m g \mathsf{img} img 是待處理的圖像
- k \mathsf{k} k 用于設置選定前 k k k 個奇異值
代碼說明:
初始化 r e s _ i m a g e \mathsf{res\_image} res_image 變量,用于存放處理結果:
res_image = np.zeros_like(img)
循環壓縮每一個通道:
for i in range(img.shape[2]):U, Sigma, VT = np.linalg.svd(img[:, :, i])res_image[:, :, i] = U[:, :k].dot(np.diag(Sigma[:k])).dot(VT[:k, :])
- 參數: i m g . s h a p e [ 2 ] \mathsf{img.shape[2]} img.shape[2] 是圖像的通道個數
- 第一行:對第 i i i 個通道進行 SVD 分解
- 第二行:取前 k k k 個奇異值重新構造圖像
說明:由于 S i g m a \mathsf{Sigma} Sigma 矩陣除對角元素外,其余元素都為 0 \mathsf{0} 0,因此
linalg.svd函數將其處理為一維矩陣返回。在重新構造圖像時,我們需要使用np.diag函數將其還原為對角矩陣。
2?完整代碼
import numpy as np
import cv2
from matplotlib import pyplot as pltimg = cv2.imread('white_bear.jpg')
img = img[:, :, [2, 1, 0]]
print('image shape is ', img.shape)def svd_compression(img, k):res_image = np.zeros_like(img)for i in range(img.shape[2]):U, Sigma, VT = np.linalg.svd(img[:, :, i])res_image[:, :, i] = U[:, :k].dot(np.diag(Sigma[:k])).dot(VT[:k, :])return res_image# 保留前 k 個奇異值
res1 = svd_compression(img, k=300)
res2 = svd_compression(img, k=200)
res3 = svd_compression(img, k=100)
res4 = svd_compression(img, k=50)plt.subplot(1, 5, 1)
plt.title("image", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(img, cmap='gray')plt.subplot(1, 5, 2)
plt.title("image", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(res1, cmap='gray')plt.subplot(1, 5, 3)
plt.title("u", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(res2, cmap='gray')plt.subplot(1, 5, 4)
plt.title("s", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(res3, cmap='gray')plt.subplot(1, 5, 5)
plt.title("v", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(res4, cmap='gray')plt.show()