生成模型實戰 | GLOW詳解與實現
- 0. 前言
- 1. 歸一化流模型
- 1.1 歸一化流與變換公式
- 1.2 RealNVP 的通道翻轉
- 2. GLOW 架構
- 2.1 ActNorm
- 2.2 可逆 1×1 卷積
- 2.3 仿射耦合層
- 2.4 多尺度架構
- 3. 使用 PyTorch 實現 GLOW
- 3.1 數據處理
- 3.2 模型構建
- 3.3 模型訓練
0. 前言
GLOW
(Generative Flow
) 是一種基于歸一化流的生成模型,通過在每個流步驟中引入可逆的 1 × 1
卷積層,替代了 RealNVP 中通道翻轉或固定置換的策略,從而使通道重排更具表達力,同時保持雅可比行列式和逆變換的高效計算能力。本文首先回顧歸一化流與 RealNVP
的基本原理,接著剖析 GLOW
的四大核心模塊:ActNorm
、可逆 1×1
卷積、仿射耦合層和多尺度架構,隨后基于 PyTorch
實現 GLOW
模型,并在 CIFAR-10
數據集上進行訓練。
1. 歸一化流模型
1.1 歸一化流與變換公式
在本節中,我們首先簡要回顧歸一化流模型的核心原理,歸一化流利用可逆映射 fff 將簡單分布 pZ(z)p_Z(z)pZ?(z) 轉換到樣本分布 pX?(x)p_X?(x)pX??(x),并通過以下變換公式實現實現精確對數似然計算和采樣:
pX(x)=pZ(f(x))∣?det?(??f(x)?x)∣?p_X(x)=p_Z(f(x))?|\text{?det}? (?\frac {?f(x)}{?x})| ?pX?(x)=pZ?(f(x))?∣?det?(??x?f(x)?)∣?
1.2 RealNVP 的通道翻轉
RealNVP
通過交替使用掩碼耦合層 (masking coupling
) 和按通道翻轉 (reverse channels
) 或固定置換,保證每個通道都能被多次變換。
2. GLOW 架構
2.1 ActNorm
ActNorm
是一種專為流模型設計的通道級歸一化方法,于 GLOW
中首次提出。該層對輸入激活 xxx 執行可學習仿射變換:
y=s⊙x+by=s⊙x+b y=s⊙x+b
其中 s,b∈RCs,b∈\mathbb R^Cs,b∈RC 分別為每個通道的尺度與偏移參數。這些參數在首次前向傳播時通過對一個 minibatch
計算輸出通道的均值 μμμ 和標準差 σσσ 進行數據依賴的初始化,使得初始化后的 yyy 滿足 E[y]=0\mathbb E[y]=0E[y]=0 和 Var[y]=1Var[y]=1Var[y]=1。
與批歸一化不同,ActNorm
僅在初始化時依賴 minibatch
,之后無需維護運行時統計量,從而提升了小批數據和解耦訓練的穩定性。
import torch.nn as nn
import torchdef mean_dim(tensor, dim=None, keepdims=False):if dim is None:return tensor.mean()else:if isinstance(dim, int):dim = [dim]dim = sorted(dim)for d in dim:tensor = tensor.mean(dim=d, keepdim=True)if not keepdims:for i, d in enumerate(dim):tensor.squeeze_(d-i)return tensorclass ActNorm(nn.Module):def __init__(self, num_features, scale=1., return_ldj=False):super(ActNorm, self).__init__()self.register_buffer('is_initialized', torch.zeros(1))self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))self.logs = nn.Parameter(torch.zeros(1, num_features, 1, 1))self.num_features = num_featuresself.scale = float(scale)self.eps = 1e-6self.return_ldj = return_ldjdef initialize_parameters(self, x):if not self.training:returnwith torch.no_grad():bias = -mean_dim(x.clone(), dim=[0, 2, 3], keepdims=True)v = mean_dim((x.clone() + bias) ** 2, dim=[0, 2, 3], keepdims=True)logs = (self.scale / (v.sqrt() + self.eps)).log()self.bias.data.copy_(bias.data)self.logs.data.copy_(logs.data)self.is_initialized += 1.def _center(self, x, reverse=False):if reverse:return x - self.biaselse:return x + self.biasdef _scale(self, x, sldj, reverse=False):logs = self.logsif reverse:x = x * logs.mul(-1).exp()else:x = x * logs.exp()if sldj is not None:ldj = logs.sum() * x.size(2) * x.size(3)if reverse:sldj = sldj - ldjelse:sldj = sldj + ldjreturn x, sldjdef forward(self, x, ldj=None, reverse=False):if not self.is_initialized:self.initialize_parameters(x)if reverse:x, ldj = self._scale(x, ldj, reverse)x = self._center(x, reverse)else:x = self._center(x, reverse)x, ldj = self._scale(x, ldj, reverse)if self.return_ldj:return x, ldjreturn x
2.2 可逆 1×1 卷積
GLOW
中的可逆 1×1
卷積用一個 C×CC×CC×C 的可學習矩陣 WWW 取代了 RealNVP
中的固定通道翻轉或置換操作。在空間位置 (i,j)(i,j)(i,j) 上,其映射可寫為:
yi,j=Wxi,jy_{i,j}=W?x_{i,j} yi,j?=W?xi,j?
對應的對數行列式為:
log??det???∣?y?x∣=H×W×log∣?detW∣\text {log}\ \text{?det}???|\frac{?y}{?x}|=H×W×\text {log}|\text{?det}W| log??det???∣?x?y?∣=H×W×log∣?detW∣
其中 H,WH,WH,W 分別為空間高寬。
為了進一步加速行列式與逆矩陣的計算,通常將 WWW 參數化為 LU
分解形式,即 W=PLUW=PLUW=PLU,只需學習下三角矩陣 LLL 和上三角矩陣 UUU 的非對角元素,行列式則為 ∏iUii∏_iU_{ii}∏i?Uii?。
通過這種可學習的通道重排,模型能夠自動挖掘最優的特征混合方式,從而在生成質量與訓練效率上均取得顯著提升。
import numpy as npclass InvConv(nn.Module):def __init__(self, num_channels):super(InvConv, self).__init__()self.num_channels = num_channels# Initialize with a random orthogonal matrixw_init = np.random.randn(num_channels, num_channels)w_init = np.linalg.qr(w_init)[0].astype(np.float32)self.weight = nn.Parameter(torch.from_numpy(w_init))def forward(self, x, sldj, reverse=False):ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)if reverse:weight = torch.inverse(self.weight.double()).float()sldj = sldj - ldjelse:weight = self.weightsldj = sldj + ldjweight = weight.view(self.num_channels, self.num_channels, 1, 1)z = F.conv2d(x, weight)return z, sldj
2.3 仿射耦合層
仿射耦合層最早在 RealNVP
中提出,是 GLOW
中不可或缺的組成部分。該層將輸入 x∈RC×H×Wx∈\mathbb R^{C×H×W}x∈RC×H×W 沿通道維度劃分為兩部分 (xa,xb)(x_a,x_b)(xa?,xb?),并通過神經網絡生成尺度和平移參數保證了整個變換的可逆性:
(s,t)=NN(xb),ya=s(xb)⊙xa+t(xb),yb=xb(s,t)=NN(x_b),ya=s(x_b)⊙x_a+t(x_b),y_b=x_b (s,t)=NN(xb?),ya=s(xb?)⊙xa?+t(xb?),yb?=xb?
其對數雅可比行列式可高效地計算為:
∑h,w,c∈alog?sc(xb[h,w])∑_{h,w,??c∈a}\text {log}?s_c(x_b[h,w]) h,w,??c∈a∑?log?sc?(xb?[h,w])
僅與輸出尺度參數 sss 的元素相加相關,計算復雜度隨輸入維度線性增長。
import torch.nn.functional as Fclass Coupling(nn.Module):def __init__(self, in_channels, mid_channels):super(Coupling, self).__init__()self.nn = NN(in_channels, mid_channels, 2 * in_channels)self.scale = nn.Parameter(torch.ones(in_channels, 1, 1))def forward(self, x, ldj, reverse=False):x_change, x_id = x.chunk(2, dim=1)st = self.nn(x_id)s, t = st[:, 0::2, ...], st[:, 1::2, ...]s = self.scale * torch.tanh(s)# Scale and translateif reverse:x_change = x_change * s.mul(-1).exp() - tldj = ldj - s.flatten(1).sum(-1)else:x_change = (x_change + t) * s.exp()ldj = ldj + s.flatten(1).sum(-1)x = torch.cat((x_change, x_id), dim=1)return x, ldjclass NN(nn.Module):def __init__(self, in_channels, mid_channels, out_channels,use_act_norm=False):super(NN, self).__init__()norm_fn = ActNorm if use_act_norm else nn.BatchNorm2dself.in_norm = norm_fn(in_channels)self.in_conv = nn.Conv2d(in_channels, mid_channels,kernel_size=3, padding=1, bias=False)nn.init.normal_(self.in_conv.weight, 0., 0.05)self.mid_norm = norm_fn(mid_channels)self.mid_conv = nn.Conv2d(mid_channels, mid_channels,kernel_size=1, padding=0, bias=False)nn.init.normal_(self.mid_conv.weight, 0., 0.05)self.out_norm = norm_fn(mid_channels)self.out_conv = nn.Conv2d(mid_channels, out_channels,kernel_size=3, padding=1, bias=True)nn.init.zeros_(self.out_conv.weight)nn.init.zeros_(self.out_conv.bias)def forward(self, x):x = self.in_norm(x)x = F.relu(x)x = self.in_conv(x)x = self.mid_norm(x)x = F.relu(x)x = self.mid_conv(x)x = self.out_norm(x)x = F.relu(x)x = self.out_conv(x)return x
2.4 多尺度架構
GLOW
延續了 RealNVP
的多尺度架構思想,通過分層的流步驟和因子化操作將中間表示逐級分解。整體模型由 LLL 個尺度 (level
) 組成,每個尺度內部包含 KKK 次完整的流步驟 (step
),每步依次執行 ActNorm
、可逆 1×1
卷積和仿射耦合層。
在每個尺度結束時,先通過 squeeze
操作將特征圖空間大小減少至原像素的四分之一(同時通道數擴大四倍),然后使用 split
操作將部分通道因子化為潛變量 zzz,余下通道繼續進入下一級流。
這種多尺度分解在保持對數似然精度的同時,有效降低了計算與存儲開銷,并在不同尺度上捕捉圖像的全局與局部結構信息。
3. 使用 PyTorch 實現 GLOW
在本節中,使用 PyTorch
實現 GLOW
,并在 CIFAR-10
數據集上進行訓練。
3.1 數據處理
用 torchvision.transforms
對 CIFAR-10
圖像進行處理:
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor()])transform_test = transforms.Compose([transforms.ToTensor()
])trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
trainloader = data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
testloader = data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
3.2 模型構建
基于 ActNorm
、可逆 1×1
卷積、仿射耦合層和多尺度架構實現 GLOW
模型:
class Glow(nn.Module):def __init__(self, num_channels, num_levels, num_steps):super(Glow, self).__init__()# Use bounds to rescale images before converting to logits, not learnedself.register_buffer('bounds', torch.tensor([0.9], dtype=torch.float32))self.flows = _Glow(in_channels=4 * 3, # RGB image after squeezemid_channels=num_channels,num_levels=num_levels,num_steps=num_steps)def forward(self, x, reverse=False):if reverse:sldj = torch.zeros(x.size(0), device=x.device)else:# Expect inputs in [0, 1]if x.min() < 0 or x.max() > 1:raise ValueError('Expected x in [0, 1], got min/max {}/{}'.format(x.min(), x.max()))# De-quantize and convert to logitsx, sldj = self._pre_process(x)x = squeeze(x)x, sldj = self.flows(x, sldj, reverse)x = squeeze(x, reverse=True)return x, sldjdef _pre_process(self, x):y = (x * 255. + torch.rand_like(x)) / 256.y = (2 * y - 1) * self.boundsy = (y + 1) / 2y = y.log() - (1. - y).log()# Save log-determinant of Jacobian of initial transformldj = F.softplus(y) + F.softplus(-y) \- F.softplus((1. - self.bounds).log() - self.bounds.log())sldj = ldj.flatten(1).sum(-1)return y, sldjclass _Glow(nn.Module):def __init__(self, in_channels, mid_channels, num_levels, num_steps):super(_Glow, self).__init__()self.steps = nn.ModuleList([_FlowStep(in_channels=in_channels,mid_channels=mid_channels)for _ in range(num_steps)])if num_levels > 1:self.next = _Glow(in_channels=2 * in_channels,mid_channels=mid_channels,num_levels=num_levels - 1,num_steps=num_steps)else:self.next = Nonedef forward(self, x, sldj, reverse=False):if not reverse:for step in self.steps:x, sldj = step(x, sldj, reverse)if self.next is not None:x = squeeze(x)x, x_split = x.chunk(2, dim=1)x, sldj = self.next(x, sldj, reverse)x = torch.cat((x, x_split), dim=1)x = squeeze(x, reverse=True)if reverse:for step in reversed(self.steps):x, sldj = step(x, sldj, reverse)return x, sldjclass _FlowStep(nn.Module):def __init__(self, in_channels, mid_channels):super(_FlowStep, self).__init__()# Activation normalization, invertible 1x1 convolution, affine couplingself.norm = ActNorm(in_channels, return_ldj=True)self.conv = InvConv(in_channels)self.coup = Coupling(in_channels // 2, mid_channels)def forward(self, x, sldj=None, reverse=False):if reverse:x, sldj = self.coup(x, sldj, reverse)x, sldj = self.conv(x, sldj, reverse)x, sldj = self.norm(x, sldj, reverse)else:x, sldj = self.norm(x, sldj, reverse)x, sldj = self.conv(x, sldj, reverse)x, sldj = self.coup(x, sldj, reverse)return x, sldjdef squeeze(x, reverse=False):b, c, h, w = x.size()if reverse:# Unsqueezex = x.view(b, c // 4, 2, 2, h, w)x = x.permute(0, 1, 4, 2, 5, 3).contiguous()x = x.view(b, c // 4, h * 2, w * 2)else:# Squeezex = x.view(b, c, h // 2, 2, w // 2, 2)x = x.permute(0, 1, 3, 5, 2, 4).contiguous()x = x.view(b, c * 2 * 2, h // 2, w // 2)return x
3.3 模型訓練
實例化模型、損失函數和優化器,并進行訓練:
net = Glow(num_channels=num_channels,num_levels=num_levels,num_steps=num_steps)
net = net.to(device)loss_fn = NLLLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=lr)
scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / warm_up))@torch.enable_grad()
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm):global global_stepprint('\nEpoch: %d' % epoch)net.train()loss_meter = AverageMeter()with tqdm(total=len(trainloader.dataset)) as progress_bar:for x, _ in trainloader:x = x.to(device)optimizer.zero_grad()z, sldj = net(x, reverse=False)loss = loss_fn(z, sldj)loss_meter.update(loss.item(), x.size(0))loss.backward()if max_grad_norm > 0:clip_grad_norm(optimizer, max_grad_norm)optimizer.step()scheduler.step(global_step)progress_bar.set_postfix(nll=loss_meter.avg,bpd=bits_per_dim(x, loss_meter.avg),lr=optimizer.param_groups[0]['lr'])progress_bar.update(x.size(0))global_step += x.size(0)@torch.no_grad()
def sample(net, batch_size, device):z = torch.randn((batch_size, 3, 32, 32), dtype=torch.float32, device=device)x, _ = net(z, reverse=True)x = torch.sigmoid(x)return x@torch.no_grad()
def test(epoch, net, testloader, device, loss_fn, num_samples):global best_lossnet.eval()loss_meter = AverageMeter()with tqdm(total=len(testloader.dataset)) as progress_bar:for x, _ in testloader:x = x.to(device)z, sldj = net(x, reverse=False)loss = loss_fn(z, sldj)loss_meter.update(loss.item(), x.size(0))progress_bar.set_postfix(nll=loss_meter.avg,bpd=bits_per_dim(x, loss_meter.avg))progress_bar.update(x.size(0))# Save checkpointif loss_meter.avg < best_loss:print('Saving...')state = {'net': net.state_dict(),'test_loss': loss_meter.avg,'epoch': epoch,}os.makedirs('ckpts', exist_ok=True)torch.save(state, 'ckpts/best.pth.tar')best_loss = loss_meter.avg# Save samples and dataimages = sample(net, num_samples, device)os.makedirs('samples', exist_ok=True)images_concat = torchvision.utils.make_grid(images, nrow=int(num_samples ** 0.5), padding=2, pad_value=255)torchvision.utils.save_image(images_concat, 'samples/epoch_{}.png'.format(epoch))start_epoch = 0
for epoch in range(start_epoch, start_epoch + num_epochs):train(epoch, net, trainloader, device, optimizer, scheduler,loss_fn, max_grad_norm)test(epoch, net, testloader, device, loss_fn, num_samples)
在 100
個 epoch
后,模型可生成逼真度較高的 32 × 32
彩色圖像,樣本在多通道細節和整體結構上均有良好效果,下圖展示了訓練過程中,不同 epoch
生成的圖像對比: