17.整體代碼講解

從入門AI到手寫Transformer-17.整體代碼講解

  • 17.整體代碼講解
  • 代碼

整理自視頻 老袁不說話 。

17.整體代碼講解

代碼

import collectionsimport math
import torch
from torch import nn
import os
import time
import numpy as np
from matplotlib import pyplot as plt
from matplotlib_inline import backend_inline
import hashlib
import os
import tarfile
import zipfile
import requests
from IPython import display
from torch.utils import dataDATA_HUB = dict()
DATA_URL = "http://d2l-data.s3-accelerate.amazonaws.com/"
DATA_HUB["fra-eng"] = (DATA_URL + "fra-eng.zip","94646ad1522d915e7b0f9296181140edcf86a4f5",
)def try_gpu(i=0):"""如果存在,則返回gpu(i),否則返回cpu()"""if torch.cuda.device_count() >= i + 1:return torch.device(f"cuda:{i}")return torch.device("cpu")def bleu(pred_seq, label_seq, k):"""計算BLEU"""pred_tokens, label_tokens = pred_seq.split(" "), label_seq.split(" ")len_pred, len_label = len(pred_tokens), len(label_tokens)score = math.exp(min(0, 1 - len_label / len_pred))for n in range(1, k + 1):num_matches, label_subs = 0, collections.defaultdict(int)for i in range(len_label - n + 1):label_subs[" ".join(label_tokens[i : i + n])] += 1for i in range(len_pred - n + 1):if label_subs[" ".join(pred_tokens[i : i + n])] > 0:num_matches += 1label_subs[" ".join(pred_tokens[i : i + n])] -= 1score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))return scoredef count_corpus(tokens):  # @save"""統計詞元的頻率"""# 這里的tokens是1D列表或2D列表# tokens:["大","哥","大","嫂"] 已經是詞元# tokens:[["大","哥","大","嫂"]["過","年","好"]]if len(tokens) == 0 or isinstance(tokens[0], list):# 將空的/二維詞元列表展平成一個列表tokens = [token for line in tokens for token in line]return collections.Counter(tokens) # Couter類統計頻率def download(name, cache_dir=os.path.join(".", "./data")):"""下載一個DATA_HUB中的文件,返回本地文件名"""assert name in DATA_HUB, f"{name} 不存在于{DATA_HUB}"url, sha1_hash = DATA_HUB[name]os.makedirs(cache_dir, exist_ok=True)fname = os.path.join(cache_dir, url.split("/")[-1])if os.path.exists(fname):sha1 = hashlib.sha1()with open(fname, "rb") as f:while True:data = f.read(1048576)if not data:breaksha1.update(data)if sha1.hexdigest() == sha1_hash:return fname  # 命中緩存print(f"正在從{url}下載{fname}...")r = requests.get(url, stream=True, verify=True)with open(fname, "wb") as f:f.write(r.content)return fnamedef download_extract(name, folder=None):  # @save"""下載并解壓zip/tar文件"""fname = download(name)base_dir = os.path.dirname(fname)data_dir, ext = os.path.splitext(fname)if ext == ".zip":fp = zipfile.ZipFile(fname, "r")elif ext in (".tar", ".gz"):fp = tarfile.open(fname, "r")else:assert False, "只有zip/tar文件可以被解壓縮"fp.extractall(base_dir)return os.path.join(base_dir, folder) if folder else data_dirdef read_data_nmt():"""載入“英語-法語”數據集"""data_dir = download_extract("fra-eng")with open(os.path.join(data_dir, "fra.txt"), "r", encoding="utf-8") as f:return f.read()def masked_softmax(X, valid_lens):"""通過在最后一個軸上掩蔽元素來執行softmax操作"""# X:3D張量,valid_lens:1D或2D張量if valid_lens is None:return nn.functional.softmax(X, dim=-1)else:shape = X.shapeif valid_lens.dim() == 1:valid_lens = torch.repeat_interleave(valid_lens, shape[1])else:valid_lens = valid_lens.reshape(-1)# 最后一軸上被掩蔽的元素使用一個非常大的負值替換,從而其softmax輸出為0X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)return nn.functional.softmax(X.reshape(shape), dim=-1)def sequence_mask(X, valid_len, value=0):"""在序列中屏蔽不相關的項"""maxlen = X.size(1)mask = (torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :]< valid_len[:, None])X[~mask] = valuereturn Xdef preprocess_nmt(text):"""預處理“英語-法語”數據集"""def no_space(char, prev_char):return char in set(",.!?") and prev_char != " "# 使用空格替換不間斷空格# 使用小寫字母替換大寫字母text = text.replace("\u202f", " ").replace("\xa0", " ").lower()# 在單詞和標點符號之間插入空格out = [" " + char if i > 0 and no_space(char, text[i - 1]) else charfor i, char in enumerate(text)]return "".join(out)def tokenize_nmt(text, num_examples=None):"""詞元化“英語-法語”數據數據集"""source, target = [], []for i, line in enumerate(text.split("\n")):if num_examples and i > num_examples:breakparts = line.split("\t")if len(parts) == 2:source.append(parts[0].split(" "))target.append(parts[1].split(" "))return source, targetdef grad_clipping(net, theta):  # @save"""裁剪梯度"""if isinstance(net, nn.Module): # 如果模型繼承于nn.Moduleparams = [p for p in net.parameters() if p.requires_grad] # 拿出所有參數,如果參數有梯度,就放進一個列表else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad**2)) for p in params)) # 對梯度平方求和,求和兩次之后就變成一個標量了if norm > theta: # 和1比較for param in params:param.grad[:] *= theta / norm #/n 縮放模型大小,就是梯度裁剪def truncate_pad(line, num_steps, padding_token):"""截斷或填充文本序列"""if len(line) > num_steps:return line[:num_steps]  # 截斷return line + [padding_token] * (num_steps - len(line))  # 填充def build_array_nmt(lines, vocab, num_steps):"""將機器翻譯的文本序列轉換成小批量"""lines = [vocab[l] for l in lines]lines = [l + [vocab["<eos>"]] for l in lines]array = torch.tensor([truncate_pad(l, num_steps, vocab["<pad>"]) for l in lines])valid_len = (array != vocab["<pad>"]).type(torch.int32).sum(1)return array, valid_lendef load_array(data_arrays, batch_size, is_train=True):  # @save"""構造一個PyTorch數據迭代器"""dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)def load_data_nmt(batch_size, num_steps, num_examples=600):"""返回翻譯數據集的迭代器和詞表"""text = preprocess_nmt(read_data_nmt())source, target = tokenize_nmt(text, num_examples)src_vocab = Vocab(source, min_freq=2, reserved_tokens=["<pad>", "<bos>", "<eos>"])tgt_vocab = Vocab(target, min_freq=2, reserved_tokens=["<pad>", "<bos>", "<eos>"])src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)data_iter = load_array(data_arrays, batch_size)return data_iter, src_vocab, tgt_vocabdef sequence_mask(X, valid_len, value=0):# """在序列中屏蔽不相關的項"""maxlen = X.size(1)mask = (torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :]< valid_len[:, None])X[~mask] = valuereturn Xdef transpose_qkv(X, num_heads):# """為了多注意力頭的并行計算而變換形狀"""# 輸入X的形狀:(batch_size,查詢或者“鍵-值”對的個數,num_hiddens)# 輸出X的形狀:(batch_size,查詢或者“鍵-值”對的個數,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 輸出X的形狀:(batch_size,num_heads,查詢或者“鍵-值”對的個數,# num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最終輸出的形狀:(batch_size*num_heads,查詢或者“鍵-值”對的個數,# num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):# """訓練序列到序列模型"""def xavier_init_weights(m): # 初始化權重if type(m) == nn.Linear:nn.init.xavier_uniform_(m.weight) # 線性層的初始化方式if type(m) == nn.GRU:for param in m._flat_weights_names:if "weight" in param:nn.init.xavier_uniform_(m._parameters[param])net.apply(xavier_init_weights) # 給模型應用函數net.to(device)optimizer = torch.optim.Adam(net.parameters(), lr=lr) # 優化器loss = MaskedSoftmaxCELoss() # 損失函數net.train()animator = Animator(xlabel="epoch", ylabel="loss", xlim=[10, num_epochs])for epoch in range(num_epochs): # 執行批量循環timer = Timer()metric = Accumulator(2)  # 訓練損失總和,詞元數量for batch in data_iter:optimizer.zero_grad() # 梯度置零X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch] # 取出XY和它們的有效長度bos = torch.tensor([tgt_vocab["<bos>"]] * Y.shape[0], device=device # 對Y添加bos).reshape(-1, 1)dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 強制教學Y_hat, _ = net(X, dec_input, X_valid_len)l = loss(Y_hat, Y, Y_valid_len)l.sum().backward()  # 損失函數的標量進行“反向傳播”grad_clipping(net, 1) # 梯度裁剪num_tokens = Y_valid_len.sum() # 統計一下計算了多少tokenoptimizer.step() # 梯度反傳with torch.no_grad():metric.add(l.sum(), num_tokens)if (epoch + 1) % 10 == 0:animator.add(epoch + 1, (metric[0] / metric[1],))print(f"loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} "f"tokens/sec on {str(device)}")def predict_seq2seq(net,src_sentence,src_vocab,tgt_vocab,num_steps,device,save_attention_weights=False,
):# """序列到序列模型的預測"""# 在預測時將net設置為評估模式net.to(device)net.eval()src_tokens = src_vocab[src_sentence.lower().split(" ")] + [src_vocab["<eos>"]]enc_valid_len = torch.tensor([len(src_tokens)], device=device)src_tokens = truncate_pad(src_tokens, num_steps, src_vocab["<pad>"])# 添加批量軸enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)enc_outputs = net.encoder(enc_X, enc_valid_len) # 編碼器只執行次dec_state = net.decoder.init_state(enc_outputs, enc_valid_len) # 把編碼器輸出和有效長度都放進state里面# 添加批量軸dec_X = torch.unsqueeze(torch.tensor([tgt_vocab["<bos>"]], dtype=torch.long, device=device), dim=0)output_seq, attention_weight_seq = [], []for _ in range(num_steps):# 只使用解碼器塊進行了n次預測Y, dec_state = net.decoder(dec_X, dec_state) # Y:[b,n,vs]vs詞表大小 預測時一句話b=1# 我們使用具有預測最高可能性的詞元,作為解碼器在下一時間步的輸入dec_X = Y.argmax(dim=2) # 求維度里面最大值的下標,得到下標索引pred = dec_X.squeeze(dim=0).type(torch.int32).item() # 根據下標索引轉化成整形,就是預測值,[1,n]# 保存注意力權重(稍后討論)if save_attention_weights:attention_weight_seq.append(net.decoder.attention_weights)# 一旦序列結束詞元被預測,輸出序列的生成就完成了if pred == tgt_vocab["<eos>"]:breakoutput_seq.append(pred) # 把值添加進outputreturn " ".join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq # 根據詞表大小把這些值轉換成對應的詞元,用join連接起來def transpose_output(X, num_heads):# """逆轉transpose_qkv函數的操作"""X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)def use_svg_display():  # @save"""使用svg格式在Jupyter中顯示繪圖"""backend_inline.set_matplotlib_formats("svg")def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):"""設置matplotlib的軸"""axes.set_xlabel(xlabel)axes.set_ylabel(ylabel)axes.set_xscale(xscale)axes.set_yscale(yscale)axes.set_xlim(xlim)axes.set_ylim(ylim)if legend:axes.legend(legend)axes.grid()def set_figsize(figsize=(3.5, 2.5)):  # @save"""設置matplotlib的圖表大小"""use_svg_display()plt.rcParams["figure.figsize"] = figsizedef dropout_layer(X, dropout):assert 0 <= dropout <= 1# 在本情況中,所有元素都被丟棄if dropout == 1:return torch.zeros_like(X)# 在本情況中,所有元素都被保留if dropout == 0:return Xmask = (torch.rand(X.shape) > dropout).float()return mask * X / (1.0 - dropout)class Accumulator:  # @save"""在n個變量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]class Timer:  # @save"""記錄多次運行時間"""def __init__(self):self.times = []self.start()def start(self):"""啟動計時器"""self.tik = time.time()def stop(self):"""停止計時器并將時間記錄在列表中"""self.times.append(time.time() - self.tik)return self.times[-1]def avg(self):"""返回平均時間"""return sum(self.times) / len(self.times)def sum(self):"""返回時間總和"""return sum(self.times)def cumsum(self):"""返回累計時間"""return np.array(self.times).cumsum().tolist()class Animator:"""在動畫中繪制數據"""def __init__(self,xlabel=None,ylabel=None,legend=None,xlim=None,ylim=None,xscale="linear",yscale="linear",fmts=("-", "m--", "g-.", "r:"),nrows=1,ncols=1,figsize=(3.5, 2.5),):# 增量地繪制多條線if legend is None:legend = []use_svg_display()self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes,]# 使用lambda函數捕獲參數self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# 向圖表中添加多個數據點if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)plt.draw()plt.pause(0.001)# display.clear_output(wait=True)class Vocab:"""文本詞表"""# 初始化類# tokens:list ["go","some","play","run"]def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):if tokens is None:tokens = []if reserved_tokens is None: # 特殊字符reserved_tokens = []# 按出現頻率排序counter = count_corpus(tokens) # 統計頻率# 排序,item拿到類似字典的鍵值對 x[1]頻率 [(文字,頻率),(文字,頻率)]self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)# 未知詞元的索引為0# 保存所有的詞元self.idx_to_token = ["<unk>"] + reserved_tokens# 字典,轉化為鍵值對方便查找self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}# 將未舍棄的所有詞元添加到(_token_freqs)添加到idx_to_token和token_to_idxfor token, freq in self._token_freqs:if freq < min_freq: # 截斷頻率,默認為0,每個詞都不舍棄breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1 # 把索引加到這個字典里# 返回詞表的長度,list方便計算def __len__(self):return len(self.idx_to_token)# 實現詞元轉為對應的數字# tokens:list,tupledef __getitem__(self, tokens):if not isinstance(tokens, (list, tuple)): # 如果是一個單獨的詞元return self.token_to_idx.get(tokens, self.unk) # 在字典里用get方法找到它return [self.__getitem__(token) for token in tokens] # 一個一個拿出來# 將數字轉化為詞元def to_tokens(self, indices):if not isinstance(indices, (list, tuple)):return self.idx_to_token[indices] # 單獨索引直接返回return [self.idx_to_token[index] for index in indices] # 遍歷按照list返回@property # 裝飾器def unk(self):  # 未知詞元的索引為0return 0@property # 裝飾器def token_freqs(self):return self._token_freqs # 返回原始的未經舍棄的listclass MaskedSoftmaxCELoss(nn.CrossEntropyLoss):# """帶遮蔽的softmax交叉熵損失函數"""# pred的形狀:(batch_size,num_steps,vocab_size)# label的形狀:(batch_size,num_steps)# valid_len的形狀:(batch_size,)def forward(self, pred, label, valid_len):weights = torch.ones_like(label)weights = sequence_mask(weights, valid_len)self.reduction = "none"unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(pred.permute(0, 2, 1), label)weighted_loss = (unweighted_loss * weights).mean(dim=1)return weighted_lossclass MultiHeadAttention(nn.Module):# """多頭注意力"""def __init__(self,key_size,query_size,value_size,num_hiddens,num_heads,dropout,bias=False,**kwargs,):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形狀:# (batch_size,查詢或者“鍵-值”對的個數,num_hiddens)# valid_lens 的形狀:# (batch_size,)或(batch_size,查詢的個數)# 經過變換后,輸出的queries,keys,values 的形狀:# (batch_size*num_heads,查詢或者“鍵-值”對的個數,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在軸0,將第一項(標量或者矢量)復制num_heads次,# 然后如此復制第二項,然后諸如此類。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形狀:(batch_size*num_heads,查詢的個數,# num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形狀:(batch_size,查詢的個數,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)class PositionalEncoding(nn.Module):# """位置編碼"""def __init__(self, num_hiddens, dropout, max_len=1000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(dropout)# 創建一個足夠長的Pself.P = torch.zeros((1, max_len, num_hiddens))X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)self.P[:, :, 0::2] = torch.sin(X)self.P[:, :, 1::2] = torch.cos(X)def forward(self, X):X = X + self.P[:, : X.shape[1], :].to(X.device)return self.dropout(X)class PositionWiseFFN(nn.Module):# """基于位置的前饋網絡"""def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):super(PositionWiseFFN, self).__init__(**kwargs)self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)self.relu = nn.ReLU()self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)def forward(self, X):return self.dense2(self.relu(self.dense1(X)))class AddNorm(nn.Module):# """殘差連接后進行層規范化"""def __init__(self, normalized_shape, dropout, **kwargs):super(AddNorm, self).__init__(**kwargs)self.dropout = nn.Dropout(dropout)self.ln = nn.LayerNorm(normalized_shape)nn.Softmax()def forward(self, X, Y):return self.ln(self.dropout(Y) + X)class Encoder(nn.Module):# """編碼器-解碼器架構的基本編碼器接口"""def __init__(self, **kwargs):super(Encoder, self).__init__(**kwargs)def forward(self, X, *args):raise NotImplementedErrorclass Decoder(nn.Module):# """編碼器-解碼器架構的基本解碼器接口"""def __init__(self, **kwargs):super(Decoder, self).__init__(**kwargs)def init_state(self, enc_outputs, *args):raise NotImplementedErrordef forward(self, X, state):raise NotImplementedErrorclass EncoderDecoder(nn.Module):# """編碼器-解碼器架構的基類"""def __init__(self, encoder, decoder, **kwargs):super(EncoderDecoder, self).__init__(**kwargs)self.encoder = encoderself.decoder = decoderdef forward(self, enc_X, dec_X, *args):enc_outputs = self.encoder(enc_X, *args)dec_state = self.decoder.init_state(enc_outputs, *args)return self.decoder(dec_X, dec_state)class DotProductAttention(nn.Module):# """縮放點積注意力"""def __init__(self, dropout, **kwargs):super(DotProductAttention, self).__init__(**kwargs)self.dropout = nn.Dropout(dropout)# queries的形狀:(batch_size,查詢的個數,d)# keys的形狀:(batch_size,“鍵-值”對的個數,d)# values的形狀:(batch_size,“鍵-值”對的個數,值的維度)# valid_lens的形狀:(batch_size,)或者(batch_size,查詢的個數)def forward(self, queries, keys, values, valid_lens=None):d = queries.shape[-1]# 設置transpose_b=True為了交換keys的最后兩個維度scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)self.attention_weights = masked_softmax(scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)class AttentionDecoder(Decoder):# """帶有注意力機制解碼器的基本接口"""def __init__(self, **kwargs):super(AttentionDecoder, self).__init__(**kwargs)@propertydef attention_weights(self):raise NotImplementedErrorclass EncoderBlock(nn.Module):# """Transformer編碼器塊"""def __init__(self,key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout,use_bias=False,**kwargs,):super(EncoderBlock, self).__init__(**kwargs)self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)self.addnorm1 = AddNorm(norm_shape, dropout)self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)self.addnorm2 = AddNorm(norm_shape, dropout)def forward(self, X, valid_lens):Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))return self.addnorm2(Y, self.ffn(Y))class DecoderBlock(nn.Module):# """解碼器中第i個塊"""def __init__(self,key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout,i,**kwargs,):super(DecoderBlock, self).__init__(**kwargs)self.i = i # 表示這是第i個塊self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)self.addnorm1 = AddNorm(norm_shape, dropout) # dropout在addnorm里面self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)self.addnorm2 = AddNorm(norm_shape, dropout)self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)self.addnorm3 = AddNorm(norm_shape, dropout)def forward(self, X, state): # 輸入的output 推理階段大小[1,1]state存放3個量,1個編碼器輸出,1個用來產生編碼器mask,1個用來連接推理結果enc_outputs, enc_valid_lens = state[0], state[1]# 訓練階段,輸出序列的所有詞元都在同一時間處理,# 因此state[2][self.i]初始化為None。# 預測階段,輸出序列是通過詞元一個接著一個解碼的,# 因此state[2][self.i]包含著直到當前時間步第i個塊解碼的輸出表示 [bos] he isif state[2][self.i] is None:key_values = Xelse:key_values = torch.cat((state[2][self.i], X), axis=1)state[2][self.i] = key_valuesif self.training:batch_size, num_steps, _ = X.shape# dec_valid_lens的開頭:(batch_size,num_steps),# 其中每一行是[1,2,...,num_steps]dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)else:dec_valid_lens = None# 自注意力X2 = self.attention1(X, key_values, key_values, dec_valid_lens)Y = self.addnorm1(X, X2) # dropout加在addnorm里面# 編碼器-解碼器注意力。# enc_outputs的開頭:(batch_size,num_steps,num_hiddens)Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens) # Q來自addnorm,解碼器輸出做K,VZ = self.addnorm2(Y, Y2)return self.addnorm3(Z, self.ffn(Z)), stateclass TransformerEncoder(Encoder):# """Transformer編碼器"""def __init__(self,vocab_size,key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,use_bias=False,**kwargs,):super(TransformerEncoder, self).__init__(**kwargs)self.num_hiddens = num_hiddensself.embedding = nn.Embedding(vocab_size, num_hiddens)# self.embedding = nn.Embedding(vocab_size, num_hiddens, device=try_gpu())self.pos_encoding = PositionalEncoding(num_hiddens, dropout)self.blks = nn.Sequential()for i in range(num_layers):self.blks.add_module("block" + str(i),EncoderBlock(key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout,use_bias,),)def forward(self, X, valid_lens, *args):# 因為位置編碼值在-1和1之間,# 因此嵌入值乘以嵌入維度的平方根進行縮放,# 然后再與位置編碼相加。X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))self.attention_weights = [None] * len(self.blks)for i, blk in enumerate(self.blks):X = blk(X, valid_lens)self.attention_weights[i] = blk.attention.attention.attention_weightsreturn Xclass TransformerDecoder(AttentionDecoder):def __init__(self,vocab_size,key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,**kwargs,):super(TransformerDecoder, self).__init__(**kwargs)self.num_hiddens = num_hiddensself.num_layers = num_layersself.embedding = nn.Embedding(vocab_size, num_hiddens)self.pos_encoding = PositionalEncoding(num_hiddens, dropout) # dropout在里面self.blks = nn.Sequential()for i in range(num_layers): # n個block塊self.blks.add_module("block" + str(i),DecoderBlock(key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout,i,),)self.dense = nn.Linear(num_hiddens, vocab_size) # 線性層,不執行softmax不影響下標def init_state(self, enc_outputs, enc_valid_lens, *args):return [enc_outputs, enc_valid_lens, [None] * self.num_layers]# state 第一個有效數字是編碼器輸出,第二個有效數字是編碼器的有效長度,用來產生mask,第三個是用來保存KVdef forward(self, X, state):X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) # *根號d,位置編碼self._attention_weights = [[None] * len(self.blks) for _ in range(2)]for i, blk in enumerate(self.blks): # block塊X, state = blk(X, state)# 解碼器自注意力權重self._attention_weights[0][i] = blk.attention1.attention.attention_weights# “編碼器-解碼器”自注意力權重self._attention_weights[1][i] = blk.attention2.attention.attention_weightsreturn self.dense(X), state@propertydef attention_weights(self):return self._attention_weightsif __name__ == "__main__":num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10lr, num_epochs, device = 0.005, 200, try_gpu()ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4key_size, query_size, value_size = 32, 32, 32norm_shape = [32]train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size, num_steps)encoder = TransformerEncoder(len(src_vocab),key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,)decoder = TransformerDecoder(len(tgt_vocab),key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,)net = EncoderDecoder(encoder, decoder)train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) # 訓練engs = ["go .", "i lost .", "he's calm .", "i'm home ."]fras = ["va !", "j'ai perdu .", "il est calme .", "je suis chez moi ."]for eng, fra in zip(engs, fras):translation, dec_attention_weight_seq = predict_seq2seq( # 預測net, eng, src_vocab, tgt_vocab, num_steps, device, True)print(f"{eng} => {translation}, ", f"bleu {bleu(translation, fra, k=2):.3f}")

輸出結果
```python
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
<Figure size 350x250 with 1 Axes>
loss 0.034, 10150.2 tokens/sec on cpu
go . => va !,  bleu 1.000
i lost . => je vous en <unk> .,  bleu 0.000
he's calm . => il est calme .,  bleu 1.000
i'm home . => je suis chez moi .,  bleu 1.000

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/77647.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/77647.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/77647.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

前端性能優化:所有權轉移

前端性能優化&#xff1a;所有權轉移 在學習rust過程中&#xff0c;學到了所有權概念&#xff0c;于是便聯想到了前端&#xff0c;前端是否有相關內容&#xff0c;于是進行了一些實驗&#xff0c;并整理了這些內容。 所有權轉移&#xff08;Transfer of Ownership&#xff09;…

Missashe考研日記-day23

Missashe考研日記-day23 0 寫在前面 博主前幾天有事回家去了&#xff0c;斷更幾天了不好意思&#xff0c;就當回家休息一下調整一下狀態了&#xff0c;今天接著開始更新。雖然每天的博客寫的內容不算多&#xff0c;但其實還是挺費時間的&#xff0c;比如這篇就花了我40多分鐘…

Docker 中將文件映射到 Linux 宿主機

在 Docker 中&#xff0c;有多種方式可以將文件映射到 Linux 宿主機&#xff0c;以下是常見的幾種方法&#xff1a; 使用-v參數? 基本語法&#xff1a;docker run -v [宿主機文件路徑]:[容器內文件路徑] 容器名稱? 示例&#xff1a;docker run -it -v /home/user/myfile.txt:…

HarmonyOS-ArkUI-動畫分類簡介

本文的目的是,了解一下HarmonyOS動畫體系中的分類。有個大致的了解即可。 動效與動畫簡介 動畫,是客戶端提升界面交互用戶體驗的一個重要的方式。可以使應用程序更加生動靈越,提高用戶體驗。 HarmonyOS對于界面的交互方面,圍繞回歸本源的設計理念,打造自然,流暢品質一提…

C++如何處理多線程環境下的異常?如何確保資源在異常情況下也能正確釋放

多線程編程的基本概念與挑戰 多線程編程的核心思想是將程序的執行劃分為多個并行運行的線程&#xff0c;每個線程可以獨立處理任務&#xff0c;從而充分利用多核處理器的性能優勢。在C中&#xff0c;開發者可以通過std::thread創建線程&#xff0c;并使用同步原語如std::mutex、…

區間選點詳解

步驟 operator< 的作用在 C 中&#xff0c; operator< 是一個運算符重載函數&#xff0c;它定義了如何比較兩個對象的大小。在 std::sort 函數中&#xff0c;它會用到這個比較函數來決定排序的順序。 在 sort 中&#xff0c;默認會使用 < 運算符來比較兩個對象…

前端配置代理解決發送cookie問題

場景&#xff1a; 在開發任務管理系統時&#xff0c;我遇到了一個典型的身份認證問題&#xff1a;??用戶登錄成功后&#xff0c;調獲取當前用戶信息接口卻提示"用戶未登錄"??。系統核心流程如下&#xff1a; ??用戶登錄??&#xff1a;調用 /login 接口&…

8.1 線性變換的思想

一、線性變換的概念 當一個矩陣 A A A 乘一個向量 v \boldsymbol v v 時&#xff0c;它將 v \boldsymbol v v “變換” 成另一個向量 A v A\boldsymbol v Av. 輸入 v \boldsymbol v v&#xff0c;輸出 T ( v ) A v T(\boldsymbol v)A\boldsymbol v T(v)Av. 變換 T T T…

【java實現+4種變體完整例子】排序算法中【冒泡排序】的詳細解析,包含基礎實現、常見變體的完整代碼示例,以及各變體的對比表格

以下是冒泡排序的詳細解析&#xff0c;包含基礎實現、常見變體的完整代碼示例&#xff0c;以及各變體的對比表格&#xff1a; 一、冒泡排序基礎實現 原理 通過重復遍歷數組&#xff0c;比較相鄰元素并交換逆序對&#xff0c;逐步將最大值“冒泡”到數組末尾。 代碼示例 pu…

系統架構設計(二):基于架構的軟件設計方法ABSD

“基于架構的軟件設計方法”&#xff08;Architecture-Based Software Design, ABSD&#xff09;是一種通過從軟件架構層面出發指導詳細設計的系統化方法。它旨在橋接架構設計與詳細設計之間的鴻溝&#xff0c;確保系統的高層結構能夠有效指導后續開發。 ABSD 的核心思想 ABS…

Office文件內容提取 | 獲取Word文件內容 |Javascript提取PDF文字內容 |PPT文檔文字內容提取

關于Office系列文件文字內容的提取 本文主要通過接口的方式獲取Office文件和PDF、OFD文件的文字內容。適用于需要獲取Word、OFD、PDF、PPT等文件內容的提取實現。例如在線文字統計以及論文文字內容的提取。 一、提取Word及WPS文檔的文字內容。 支持以下文件格式&#xff1a; …

Cesium學習筆記——dem/tif地形的分塊與加載

前言 在Cesium的學習中&#xff0c;學會讀文檔十分重要&#xff01;&#xff01;&#xff01;在這里附上Cesium中英文文檔1.117。 在Cesium項目中&#xff0c;在平坦坦地球中加入三維地形不僅可以增強真實感與可視化效果&#xff0c;還可以??提升用戶體驗與交互性&#xff0c…

Spring Boot 斷點續傳實戰:大文件上傳不再怕網絡中斷

精心整理了最新的面試資料和簡歷模板&#xff0c;有需要的可以自行獲取 點擊前往百度網盤獲取 點擊前往夸克網盤獲取 一、痛點與挑戰 在網絡傳輸大文件&#xff08;如視頻、數據集、設計稿&#xff09;時&#xff0c;常面臨&#xff1a; 上傳中途網絡中斷需重新開始服務器內…

數碼管LED顯示屏矩陣驅動技術詳解

1. 矩陣驅動原理 矩陣驅動是LED顯示屏常用的一種高效驅動方式&#xff0c;利用COM&#xff08;Common&#xff0c;公共端&#xff09;和SEG&#xff08;Segment&#xff0c;段選&#xff09;線的交叉點控制單個LED的亮滅。相比直接驅動&#xff0c;矩陣驅動可以顯著減少所需I/…

【上位機——MFC】菜單類與工具欄

菜單類 CMenu&#xff0c;封裝了關于菜單的各種操作成員函數&#xff0c;另外還封裝了一個非常重要的成員變量m_hMenu(菜單句柄) 菜單使用 添加菜單資源加載菜單 工具欄相關類 CToolBarCtrl-》父類是CWnd&#xff0c;封裝了關于工具欄控件的各種操作。 CToolBar-》父類是CC…

liunx中常用操作

查看或修改linux本地mysql端口 cat /etc/my.cnf 如果沒有port可以添加&#xff0c;有可以修改 查看本地端口占用情況 bash netstat -nlt | grep 3307 HADOOP集群 hdfs啟動與停止 # 一鍵啟動hdfs集群 start-dfs.sh # 一鍵關閉hdfs集群 stop-dfs.sh #除了一鍵啟停外&#x…

衡石chatbi如何通過 iframe 集成

iframe 集成方式是最簡單的一種&#xff0c;您只需要在您的 HTML 文件中&#xff08;或 Vue/React 組件中&#xff09;添加一個 iframe 元素&#xff0c;并設置其 src 屬性為 AI 助手的 URL。 <iframesrc"https://develop.hengshi.org/copilot"width"100%&q…

Java集合框架深度解析:HashMap、HashSet、TreeMap、TreeSet與哈希表原理詳解

一、核心數據結構總覽 1. 核心類繼承體系 graph TDMap接口 --> HashMapMap接口 --> TreeMapSet接口 --> HashSetSet接口 --> TreeSetHashMap --> LinkedHashMapHashSet --> LinkedHashSetTreeMap --> NavigableMapTreeSet --> NavigableSet 2. 核心特…

HTTP 1.0 和 2.0 的區別

HTTP 1.0 和 2.0 的核心區別體現在性能優化、協議設計和功能擴展上&#xff0c;以下是具體對比&#xff1a; 一、核心區別對比 特性HTTP 1.0HTTP 2.0連接方式非持久連接&#xff08;默認每次請求新建 TCP 連接&#xff09;持久連接&#xff08;默認保持連接&#xff0c;可復用…

gnome中刪除application中失效的圖標

什么是Application 這一塊的東西應該叫application&#xff0c;準確來說應該是applications。 正文 系統級&#xff1a;/usr/share/applications 用戶級&#xff1a;~/.local/share/applications ying192 ~/.l/s/applications> ls | grep xampp xampp.desktoprm ~/.local…