1. RMSNorm 歸一化層
class RMSNorm(nn.Module):def __init__(self, dim: int, eps: float = 1e-8):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim)) # 可學習的縮放參數def _norm(self, x: torch.Tensor):return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x: torch.Tensor):return self.weight * self._norm(x.float()).type_as(x)
通俗解釋
這個部分相當于 數據清理工序,它的作用是 對數據進行歸一化(標準化),確保數據的數值分布合理,