論文:HiDDeN: Hiding Data With Deep Networks
作者:Jiren Zhu, Russell Kaplan, Justin Johnson, Li Fei-Fei
一、研究背景
在圖像信息隱藏領域,通常有兩類典型的應用場景:
-
隱寫 (Steganography)
- 目標:實現秘密通信。
- 要求:消息可以從圖像中被接收方解碼出來,但攻擊者很難區分哪些圖像包含信息。
- 關鍵點:隱蔽性,難以被檢測。
-
數字水印 (Digital Watermarking)
- 目標:主要用于版權保護和身份認證。
- 要求:即便圖像經過壓縮、裁剪、模糊等破壞,仍能正確恢復水印信息。
- 關鍵點:魯棒性,保證信息可恢復。
傳統方法多依賴人工設計特征,如:
- 修改像素的 最低有效位 (LSB);
- 在頻域的低頻部分嵌入信息。
這些方法在特定場景下有效,但適應性較差。 HiDDeN 則提出了全新的思路:利用 端到端可訓練的卷積神經網絡 替代傳統手工特征,實現更強的 靈活性 和 魯棒性。
二、核心思想
HiDDeN 將數據隱藏任務設計為一個 可微分的端到端管道,通過深度學習來自動學習嵌入策略:
1. 編碼器 (Encoder)
- 使用多層卷積提取封面圖的特征。
- 將消息向量擴展成與圖像相同空間維度的“消息體積”,與特征拼接。
- 最終生成含密圖,保證與封面圖在視覺上接近。
2. 噪聲層 (Noise Layer)
在訓練過程中引入失真,模擬現實場景:
- Dropout / Cropout:隨機替換像素或區域。
- Crop:保留圖像的一部分,裁剪其余部分。
- Gaussian Blur:模擬圖像模糊。
- JPEG Mask / JPEG Drop:近似真實 JPEG 壓縮的可微方法,保證訓練過程中梯度可傳播。
3. 解碼器 (Decoder)
- 多層卷積提取失真圖像特征。
- 使用全局平均池化獲取消息相關信息。
- 最終通過線性層輸出消息位。
- 解碼器支持輸入大小變化,因此對裁剪等操作具有適應性。
4. 對抗鑒別器 (Adversary)
- 結構類似解碼器。
- 輸出一個二分類概率:判斷圖像是否含密。
- 通過對抗訓練提升含密圖的隱蔽性,降低被檢測概率。
這種設計使模型能夠在容量(信息嵌入量)、隱蔽性(難以檢測)、魯棒性(抗失真能力) =三方面取得平衡。
三、代碼實現
實現了一個簡化版 HiDDeN框架在CIFAR-10上的水印嵌入實驗,用于在CIFAR-10圖像中嵌入和恢復二進制水印消息,并驗證其在多種噪聲條件下的魯棒性,以下文件為main.py文件完整程序代碼。
import argparse, math, os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils# ------------------------------
# Utils
# ------------------------------
def psnr(img1, img2, eps=1e-8):# img in [0,1]mse = F.mse_loss(img1, img2, reduction='mean').item()if mse < eps: return 99.0return 10.0 * math.log10(1.0 / mse)def make_gaussian_kernel(ks=5, sigma=1.0, device="cpu"):ax = torch.arange(ks, dtype=torch.float32) - (ks - 1) / 2.0xx, yy = torch.meshgrid(ax, ax, indexing="ij")kernel = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2))kernel = kernel / kernel.sum()return kernel.to(device)# ------------------------------
# Noise layers (Identity / PixelDropout / GaussianBlur / Combined)
# ------------------------------
class NoiseLayer(nn.Module):def __init__(self, kind="identity", p=0.3, gs_ks=5, gs_sigma=1.0):super().__init__()self.kind = kindself.p = pself.gs_ks = gs_ksself.gs_sigma = gs_sigmaself.register_buffer("gs_kernel", torch.empty(0)) # init later per devicedef forward(self, ico, ien):if self.kind == "identity":return ienelif self.kind == "dropout":# pixel-wise mix: with prob p use cover pixel, else use encoded pixelif not self.training:# at test-time keep the same behavior for robustness evalpassmask = (torch.rand_like(ien[:, :1, :, :]) < self.p).float() # [B,1,H,W]return mask * ico + (1.0 - mask) * ienelif self.kind == "gaussian":# depthwise gaussian blurif self.gs_kernel.numel() == 0 or self.gs_kernel.device != ien.device:k = make_gaussian_kernel(self.gs_ks, self.gs_sigma, ien.device)self.gs_kernel = k[None, None, :, :] # [1,1,ks,ks]padding = self.gs_ks // 2C = ien.size(1)weight = self.gs_kernel.expand(C, 1, self.gs_ks, self.gs_ks)return F.conv2d(ien, weight, bias=None, stride=1, padding=padding, groups=C)elif self.kind == "combined":# randomly pick one per batchchoice = torch.randint(0, 3, (1,), device=ien.device).item()if choice == 0:return self.forward(ico, ien) # combined -> fallthrough (avoid recursion)elif choice == 1:# simulate dropout with p uniformly sampled around target pp = float(torch.empty(1).uniform_(self.p*0.8, min(0.95, self.p*1.2)))mask = (torch.rand_like(ien[:, :1, :, :]) < p).float()return mask * ico + (1.0 - mask) * ienelse:# gaussian with slight sigma jittersigma = float(torch.empty(1).uniform_(max(0.5, self.gs_sigma*0.6),self.gs_sigma*1.5))ks = self.gs_ksk = make_gaussian_kernel(ks, sigma, ien.device)weight = k[None, None, :, :].expand(ien.size(1), 1, ks, ks)return F.conv2d(ien, weight, stride=1, padding=ks//2, groups=ien.size(1))else:return ien# ------------------------------
# Encoder / Decoder
# ------------------------------
class ConvBNReLU(nn.Module):def __init__(self, c_in, c_out, k=3, s=1, p=1):super().__init__()self.block = nn.Sequential(nn.Conv2d(c_in, c_out, k, s, p, bias=False),nn.BatchNorm2d(c_out),nn.ReLU(inplace=True),)def forward(self, x): return self.block(x)class Encoder(nn.Module):"""Input: cover image Ico [B,3,H,W], message bits M [B,L] in {0,1}Output: encoded image Ien [B,3,H,W] (clamped to [0,1])"""def __init__(self, L, img_ch=3, base=64):super().__init__()self.L = Lself.stem = nn.Sequential(ConvBNReLU(img_ch, base),ConvBNReLU(base, base),ConvBNReLU(base, base),ConvBNReLU(base, base),)# after concat([feat, msg_volume, image]), reduce to base then to 3self.fuse = ConvBNReLU(base + L + img_ch, base)self.to_img = nn.Conv2d(base, img_ch, kernel_size=1, stride=1, padding=0)def forward(self, ico, m_bits):B, C, H, W = ico.size()feat = self.stem(ico)# replicate bits spatially: [B,L] -> [B,L,H,W]m = m_bits.view(B, self.L, 1, 1).float().expand(B, self.L, H, W)x = torch.cat([feat, m, ico], dim=1)x = self.fuse(x)delta = self.to_img(x)ien = torch.clamp(delta, 0.0, 1.0) # directly predict encoded image (simple & stable)return ienclass Decoder(nn.Module):"""Input: possibly noised image Ino [B,3,H,W]Output: logits over bits [B,L]"""def __init__(self, L, img_ch=3, base=64):super().__init__()self.L = Lself.body = nn.Sequential(ConvBNReLU(img_ch, base),ConvBNReLU(base, base),ConvBNReLU(base, base),ConvBNReLU(base, base),ConvBNReLU(base, base),ConvBNReLU(base, base),)self.head = nn.Sequential(ConvBNReLU(base, L), # produce L feature maps)self.fc = nn.Linear(L, L)def forward(self, ino):x = self.body(ino)x = self.head(x) # [B,L,H,W]x = F.adaptive_avg_pool2d(x, 1) # [B,L,1,1]x = x.view(x.size(0), self.L) # [B,L]logits = self.fc(x) # [B,L]return logits# ------------------------------
# Training / Evaluation
# ------------------------------
def train(args):device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")# Data: CIFAR-10 [0,1]tfm = transforms.ToTensor()train_set = datasets.CIFAR10(root="./data", train=True, download=True, transform=tfm)test_set = datasets.CIFAR10(root="./data", train=False, download=True, transform=tfm)train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True)test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True)encoder = Encoder(L=args.L).to(device)decoder = Decoder(L=args.L).to(device)if args.noise == "identity":noise = NoiseLayer("identity")elif args.noise == "dropout":noise = NoiseLayer("dropout", p=args.drop_p)elif args.noise == "gaussian":noise = NoiseLayer("gaussian", gs_ks=args.gs_ks, gs_sigma=args.gs_sigma)else:noise = NoiseLayer("combined", p=args.drop_p, gs_ks=args.gs_ks, gs_sigma=args.gs_sigma)noise = noise.to(device)opt = torch.optim.Adam(list(encoder.parameters())+list(decoder.parameters()), lr=args.lr)bce = nn.BCEWithLogitsLoss()mse = nn.MSELoss()os.makedirs(args.out_dir, exist_ok=True)for epoch in range(1, args.epochs+1):encoder.train(); decoder.train(); noise.train()run_msg_loss = run_img_loss = run_acc = 0.0for imgs, _ in train_loader:imgs = imgs.to(device)B = imgs.size(0)# random bits per imagem_bits = torch.randint(0, 2, (B, args.L), device=device)# encodeien = encoder(imgs, m_bits)# noiseino = noise(imgs, ien)# decodelogits = decoder(ino)# lossesmsg_loss = bce(logits, m_bits.float())img_loss = mse(ien, imgs)loss = msg_loss + args.lambda_img * img_lossopt.zero_grad()loss.backward()opt.step()with torch.no_grad():pred = (torch.sigmoid(logits) > 0.5).long()acc = (pred == m_bits).float().mean().item()run_msg_loss += msg_loss.item() * Brun_img_loss += img_loss.item() * Brun_acc += acc * Bn = len(train_loader.dataset)print(f"[Epoch {epoch}] msg_loss={run_msg_loss/n:.4f} img_loss={run_img_loss/n:.5f} bit_acc={run_acc/n:.4f}")# quick eval on test set + save a visualizationif epoch % args.eval_every == 0:test_bit_acc, test_psnr = evaluate(encoder, decoder, noise, test_loader, device)print(f" -> Test bit_acc={test_bit_acc:.4f} PSNR(cover,encoded)={test_psnr:.2f} dB")# dump a small griddump_examples(encoder, decoder, noise, test_loader, device, args.out_dir, epoch)torch.save({"encoder": encoder.state_dict(),"decoder": decoder.state_dict()}, os.path.join(args.out_dir, "ckpt.pt"))print("Training done. Checkpoints & samples saved to:", args.out_dir)@torch.no_grad()
def evaluate(encoder, decoder, noise, loader, device):encoder.eval(); decoder.eval(); noise.eval()acc_sum, psnr_sum, cnt = 0.0, 0.0, 0for imgs, _ in loader:imgs = imgs.to(device)B = imgs.size(0)m_bits = torch.randint(0, 2, (B, decoder.L), device=device)ien = encoder(imgs, m_bits)ino = noise(imgs, ien)logits = decoder(ino)pred = (torch.sigmoid(logits) > 0.5).long()acc = (pred == m_bits).float().mean().item()acc_sum += acc * Bpsnr_sum += psnr(imgs, ien) * Bcnt += Breturn acc_sum/cnt, psnr_sum/cnt@torch.no_grad()
def dump_examples(encoder, decoder, noise, loader, device, out_dir, epoch):encoder.eval(); decoder.eval(); noise.eval()imgs, _ = next(iter(loader))imgs = imgs.to(device)[:8]B = imgs.size(0)m_bits = torch.randint(0, 2, (B, decoder.L), device=device)ien = encoder(imgs, m_bits)ino = noise(imgs, ien)# save gridsutils.save_image(imgs, os.path.join(out_dir, f"epoch{epoch:03d}_cover.png"), nrow=4)utils.save_image(ien, os.path.join(out_dir, f"epoch{epoch:03d}_encoded.png"), nrow=4)utils.save_image(ino, os.path.join(out_dir, f"epoch{epoch:03d}_noised.png"), nrow=4)# print first sample bits for sanitylogits = decoder(ino)pred = (torch.sigmoid(logits) > 0.5).long()print("[viz] sample#0 GT bits:", m_bits[0].tolist())print("[viz] sample#0 PR bits:", pred[0].tolist())def main():parser = argparse.ArgumentParser()parser.add_argument("--epochs", type=int, default=5)parser.add_argument("--batch-size", type=int, default=128)parser.add_argument("--L", type=int, default=30, help="number of bits to embed")parser.add_argument("--noise", type=str, default="combined",choices=["identity","dropout","gaussian","combined"])parser.add_argument("--drop-p", type=float, default=0.3)parser.add_argument("--gs-ks", type=int, default=5)parser.add_argument("--gs-sigma", type=float, default=1.0)parser.add_argument("--lambda-img", type=float, default=0.7, help="weight for image MSE")parser.add_argument("--lr", type=float, default=1e-3)parser.add_argument("--eval-every", type=int, default=1)parser.add_argument("--out-dir", type=str, default="runs_hidden_cifar10")parser.add_argument("--cpu", action="store_true")args = parser.parse_args()train(args)if __name__ == "__main__":main()
目錄結構為:
Hidden
├── data
│ └── cifar-10-python.tar.gz
└── main.py
運行命令為:
# 僅保真,不加噪聲(容量/保密取向)
python main.py --epochs 5 --noise identity# 針對像素級 Dropout 的魯棒水印訓練
python main.py --epochs 5 --noise dropout --drop-p 0.3# 混合擾動(Dropout/高斯 模糊 隨機采樣),更通用
python main.py --epochs 5 --noise combined
四、關鍵代碼與論文公式對應
解析論文的核心公式,并將其與 main.py
代碼實現逐一對應。
1. Encoder 對應公式
論文公式 (1):
Ien=Eθ(Ico,M)
I_{en} = E_\theta(I_{co}, M)
Ien?=Eθ?(Ico?,M)
-
含義:Encoder EθE_\thetaEθ? 接收原始圖像 IcoI_{co}Ico? 與比特消息 MMM,輸出帶水印的圖像 IenI_{en}Ien?.
-
代碼對應:
ien = encoder(imgs, m_bits) # (公式1)
2. 噪聲層對應公式
論文公式 (2):
Ino=N(Ico,Ien)
I_{no} = N(I_{co}, I_{en})
Ino?=N(Ico?,Ien?)
-
含義:在 Encoder 和 Decoder 之間插入噪聲層 NNN,輸出失真后的圖像 InoI_{no}Ino?.
-
代碼對應:
ino = noise(imgs, ien) # (公式2)
3. Decoder 對應公式
論文公式 (3):
M′=D?(Ino)
M' = D_\phi(I_{no})
M′=D??(Ino?)
-
含義:Decoder D?D_\phiD?? 接收 InoI_{no}Ino?,輸出預測消息 M′M'M′.
-
代碼對應:
logits = decoder(ino) # (公式3)
4. 損失函數對應公式
論文總損失公式(4):
L=λI?LI(Ico,Ien)+λM?LM(M,M′)+λA?LA
\mathcal{L} = \lambda_I \cdot \mathcal{L}_I(I_{co}, I_{en}) +
\lambda_M \cdot \mathcal{L}_M(M, M') + \lambda_A \cdot \mathcal{L}_A
L=λI??LI?(Ico?,Ien?)+λM??LM?(M,M′)+λA??LA?
- 代碼對應:
msg_loss = bce(logits, m_bits.float()) # L_M
img_loss = mse(ien, imgs) # L_I
loss = msg_loss + args.lambda_img * img_loss # 總損失 (λ_A=0)
5. 評價指標對應公式
- Bit Accuracy:
Acc=1L∑i=1L1(Mi=Mi′) Acc = \frac{1}{L} \sum_{i=1}^L 1(M_i = M'_i) Acc=L1?i=1∑L?1(Mi?=Mi′?)
代碼:
acc = (pred == m_bits).float().mean().item()
- PSNR:
PSNR(Ico,Ien)=10?log?10(1/MSE(Ico,Ien)) \text{PSNR}(I_{co}, I_{en}) = 10 \cdot \log_{10}(1/\text{MSE}(I_{co}, I_{en})) PSNR(Ico?,Ien?)=10?log10?(1/MSE(Ico?,Ien?))
代碼:
def psnr(img1, img2, eps=1e-8):mse = F.mse_loss(img1, img2).item()return 10.0 * math.log10(1.0 / mse)
- 公式 (1) Ien=Eθ(Ico,M)I_{en} = E_\theta(I_{co}, M)Ien?=Eθ?(Ico?,M) →
encoder.forward
- 公式 (2) Ino=N(Ico,Ien)I_{no} = N(I_{co}, I_{en})Ino?=N(Ico?,Ien?) →
noise.forward
- 公式 (3) M′=D?(Ino)M' = D_\phi(I_{no})M′=D??(Ino?) →
decoder.forward
- 公式 (4) 總損失 →
msg_loss + λ * img_loss
- 指標 Bit Accuracy & PSNR →
acc
,psnr()