transformer demo

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import pytestclass PositionalEncoding(nn.Module):def __init__(self, d_model, max_seq_length=5000):super(PositionalEncoding, self).__init__()# 創建位置編碼矩陣pe = torch.zeros(max_seq_length, d_model)position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))# 計算正弦和余弦位置編碼pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)# 注冊為非訓練參數self.register_buffer('pe', pe)def forward(self, x):# 添加位置編碼到輸入張量return x + self.pe[:, :x.size(1)]class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_heads# 定義線性變換層self.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.out_linear = nn.Linear(d_model, d_model)def forward(self, q, k, v, mask=None):batch_size = q.size(0)# 線性變換和重塑q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)# 計算注意力分數scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)# 應用掩碼(如果提供)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)# 應用softmax獲取注意力權重attn_weights = F.softmax(scores, dim=-1)# 應用注意力權重到值向量attn_output = torch.matmul(attn_weights, v)# 重塑并應用最終線性變換attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)output = self.out_linear(attn_output)return outputclass FeedForward(nn.Module):def __init__(self, d_model, d_ff):super(FeedForward, self).__init__()self.linear1 = nn.Linear(d_model, d_ff)self.linear2 = nn.Linear(d_ff, d_model)def forward(self, x):return self.linear2(F.relu(self.linear1(x)))class EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout=0.1):super(EncoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = FeedForward(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):# 自注意力層和殘差連接attn_output = self.self_attn(x, x, x, mask)x = self.norm1(x + self.dropout(attn_output))# 前饋網絡和殘差連接ff_output = self.feed_forward(x)x = self.norm2(x + self.dropout(ff_output))return xclass DecoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout=0.1):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.cross_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = FeedForward(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, src_mask=None, tgt_mask=None):# 自注意力層和殘差連接attn_output = self.self_attn(x, x, x, tgt_mask)x = self.norm1(x + self.dropout(attn_output))# 編碼器-解碼器注意力層和殘差連接cross_attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)x = self.norm2(x + self.dropout(cross_attn_output))# 前饋網絡和殘差連接ff_output = self.feed_forward(x)x = self.norm3(x + self.dropout(ff_output))return xclass Transformer(nn.Module):def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers,num_decoder_layers, d_ff, max_seq_length, dropout=0.1):super(Transformer, self).__init__()# 詞嵌入層self.src_embedding = nn.Embedding(src_vocab_size, d_model)self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)# 位置編碼self.positional_encoding = PositionalEncoding(d_model, max_seq_length)# 編碼器和解碼器層self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout)for _ in range(num_encoder_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout)for _ in range(num_decoder_layers)])# 輸出層self.output_layer = nn.Linear(d_model, tgt_vocab_size)self.dropout = nn.Dropout(dropout)self.d_model = d_model# 初始化參數self._init_parameters()def _init_parameters(self):for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)def forward(self, src, tgt, src_mask=None, tgt_mask=None):# 源序列和目標序列的嵌入和位置編碼src = self.src_embedding(src) * math.sqrt(self.d_model)src = self.positional_encoding(src)src = self.dropout(src)tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)tgt = self.positional_encoding(tgt)tgt = self.dropout(tgt)# 編碼器前向傳播enc_output = srcfor enc_layer in self.encoder_layers:enc_output = enc_layer(enc_output, src_mask)# 解碼器前向傳播dec_output = tgtfor dec_layer in self.decoder_layers:dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)# 輸出層output = self.output_layer(dec_output)return output# 創建掩碼函數
def create_masks(src, tgt):# 源序列掩碼(用于屏蔽填充標記)src_mask = (src != 0).unsqueeze(1).unsqueeze(2)# 目標序列掩碼(用于屏蔽填充標記和未來標記)tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)# 創建后續標記掩碼(用于自回歸解碼)seq_length = tgt.size(1)nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()# 合并掩碼tgt_mask = tgt_mask & nopeak_maskreturn src_mask, tgt_mask# 簡單的訓練函數
def train_transformer(model, optimizer, criterion, train_loader, epochs):model.train()for epoch in range(epochs):total_loss = 0for src, tgt in train_loader:# 創建掩碼src_mask, tgt_mask = create_masks(src, tgt[:, :-1])# 前向傳播output = model(src, tgt[:, :-1], src_mask, tgt_mask)# 計算損失loss = criterion(output.contiguous().view(-1, output.size(-1)),tgt[:, 1:].contiguous().view(-1))# 反向傳播和優化optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}')# 添加model fixture
@pytest.fixture
def model():# 定義超參數d_model = 512num_heads = 8num_encoder_layers = 6num_decoder_layers = 6d_ff = 2048max_seq_length = 100dropout = 0.1# 假設的詞匯表大小src_vocab_size = 10000tgt_vocab_size = 10000# 創建模型model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads,num_encoder_layers, num_decoder_layers, d_ff, max_seq_length, dropout)return model# 添加test_loader fixture
@pytest.fixture
def test_loader():# 創建一個簡單的測試數據集batch_size = 2seq_length = 10# 隨機生成一些測試數據src_data = torch.randint(1, 10000, (batch_size, seq_length))tgt_data = torch.randint(1, 10000, (batch_size, seq_length))# 創建DataLoaderfrom torch.utils.data import TensorDataset, DataLoaderdataset = TensorDataset(src_data, tgt_data)test_loader = DataLoader(dataset, batch_size=batch_size)return test_loader# 簡單的測試函數
def test_transformer(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for src, tgt in test_loader:# 創建掩碼src_mask, _ = create_masks(src, tgt)# 預測output = model(src, tgt, src_mask, None)pred = output.argmax(dim=-1)# 計算準確率total += tgt.size(0) * tgt.size(1)correct += (pred == tgt).sum().item()accuracy = correct / totalprint(f'Test Accuracy: {accuracy:.4f}')# 簡單的序列到序列翻譯示例
def translate(model, src_sequence, src_vocab, tgt_vocab, max_length=50):model.eval()# 將源序列轉換為索引src_indices = [src_vocab.get(token, src_vocab['<unk>']) for token in src_sequence]src_tensor = torch.LongTensor(src_indices).unsqueeze(0)# 創建源序列掩碼src_mask = (src_tensor != 0).unsqueeze(1).unsqueeze(2)# 初始目標序列為開始標記tgt_indices = [tgt_vocab['<sos>']]with torch.no_grad():for i in range(max_length):tgt_tensor = torch.LongTensor(tgt_indices).unsqueeze(0)# 創建目標序列掩碼_, tgt_mask = create_masks(src_tensor, tgt_tensor)# 預測下一個標記output = model(src_tensor, tgt_tensor, src_mask, tgt_mask)next_token_logits = output[:, -1, :]next_token = next_token_logits.argmax(dim=-1).item()# 添加預測的標記到目標序列tgt_indices.append(next_token)# 如果預測到結束標記,則停止if next_token == tgt_vocab['<eos>']:break# 將目標序列索引轉換回標記tgt_sequence = [tgt_vocab.get(index, '<unk>') for index in tgt_indices]return tgt_sequence# 示例使用
if __name__ == "__main__":# 定義超參數d_model = 512num_heads = 8num_encoder_layers = 6num_decoder_layers = 6d_ff = 2048max_seq_length = 100dropout = 0.1# 假設的詞匯表大小src_vocab_size = 10000tgt_vocab_size = 10000# 創建模型model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads,num_encoder_layers, num_decoder_layers, d_ff, max_seq_length, dropout)# 定義優化器和損失函數optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略填充標記# 這里應該有實際的數據加載代碼# train_loader = ...# test_loader = ...# 訓練模型# train_transformer(model, optimizer, criterion, train_loader, epochs=10)# 測試模型# test_transformer(model, test_loader)# 翻譯示例# src_vocab = ...# tgt_vocab = ...# src_sequence = ["hello", "world", "!"]# translation = translate(model, src_sequence, src_vocab, tgt_vocab)# print(f"Source: {' '.join(src_sequence)}")# print(f"Translation: {' '.join(translation)}")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/87081.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/87081.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/87081.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

centos 8.3(阿里云服務器)mariadb由系統自帶版本(10.3)升級到10.6

1. 備份數據庫 在進行任何升級操作前&#xff0c;務必備份所有數據庫&#xff1a; mysqldump -u root -p --all-databases > all_databases_backup.sql # 或者為每個重要數據庫單獨備份 mysqldump -u root -p db_name1 > db_name1_backup.sql mysqldump -u root -p db…

如何穩定地更新你的大模型知識(算法篇)

目錄 在線強化學習的穩定知識獲取機制:算法優化與數據策略一、算法層面的穩定性控制機制二、數據處理策略的穩定性保障三、訓練過程中的漸進式優化策略四、環境設計與反饋機制的穩定性影響五、穩定性保障的綜合應用策略六、總結與展望通過強化學習來讓大模型學習高層語義知識,…

圖的遍歷模板

圖的遍歷 BFS 求距離 #include<bits/stdc.h>using namespace std;int n, m, k,q[20001],dist[20001]; vector<int> edge[20001];int main(){scanf("%d%d%d",&n,&m,&k);for (int i 1;i<m;i){int x,y;scanf("%d%d",&x,&am…

Java集合 - LinkedList底層源碼解析

以下是基于 JDK 8 的 LinkedList 深度源碼解析&#xff0c;涵蓋其數據結構、核心方法實現、性能特點及使用場景。我們從 類結構、Node節點、插入/刪除/訪問操作、線程安全、性能對比 等角度進行詳細分析 一、類結構與繼承關系 1. 類定義 public class LinkedList<E> e…

Pytorch 卷積神經網絡參數說明一

系列文章目錄 文章目錄 系列文章目錄前言一、卷積層的定義1.常見的卷積操作2. 感受野3. 如何理解參數量和計算量4.如何減少計算量和參數量 二、神經網絡結構&#xff1a;有些層前面文章說過&#xff0c;不全講1. 池化層&#xff08;下采樣&#xff09;2. 上采樣3. 激活層、BN層…

C++ 中的 iostream 庫:cin/cout 基本用法

iostream 是 C 標準庫中用于輸入輸出操作的核心庫&#xff0c;它基于面向對象的設計&#xff0c;提供了比 C 語言的 stdio.h 更強大、更安全的 I/O 功能。下面詳細介紹 iostream 庫中最常用的輸入輸出工具&#xff1a;cin 和 cout。 一、 基本概念 iostream 庫&#xff1a;包…

SAP復制一個自定義移動類型

SAP復制移動類型 在SAP系統中&#xff0c;復制移動類型201可以通過事務碼OMJJ或SPRO路徑完成&#xff0c;用于創建自定義的移動類型以滿足特定業務需求。 示例操作步驟 進入OMJJ事務碼&#xff1a; 打開事務碼OMJJ&#xff0c;選擇“移動類型”選項。 復制移動類型&#xff…

Bambu Studio 中的“回抽“與“裝填回抽“的區別

回抽 裝填回抽: Bambu Studio 中的“回抽” (Retraction) 和“裝填回抽”(Prime/Retract) 是兩個不同的概念&#xff0c;它們都與材料擠出機的操作過程相關&#xff0c;但作用和觸發條件有所不同。 回抽(Retraction): 回抽的作用, 在打印機移動到另一個位置之前&#xff0c;將…

危化品安全監測數據分析挖掘范式:從被動響應到戰略引擎的升維之路

在危化品生產的復雜生態系統中,安全不僅僅是合規性要求,更是企業生存和發展的生命線。傳統危化品安全生產風險監測預警系統雖然提供了基礎保障,但其“事后響應”和“單點預警”的局限性日益凸顯。我們正處在一個由大數據、人工智能、數字孿生和物聯網技術驅動的范式變革前沿…

C++ RPC 遠程過程調用詳細解析

一、RPC 基本原理 RPC (Remote Procedure Call) 是一種允許程序調用另一臺計算機上子程序的協議,而不需要程序員顯式編碼這個遠程交互細節。其核心思想是使遠程調用看起來像本地調用一樣。 RPC 工作流程 客戶端調用:客戶端調用本地存根(stub)方法參數序列化:客戶端存根將參…

Python:操作 Excel 預設色

??親愛的技術愛好者們,熱烈歡迎來到 Kant2048 的博客!我是 Thomas Kant,很開心能在CSDN上與你們相遇~?? 本博客的精華專欄: 【自動化測試】 【測試經驗】 【人工智能】 【Python】 Python 操作 Excel 系列 讀取單元格數據按行寫入設置行高和列寬自動調整行高和列寬水平…

中科院1區|IF10+:加大醫學系團隊利用GPT-4+電子病歷分析,革新肝硬化并發癥隊列識別

中科院1區|IF10&#xff1a;加大醫學系團隊利用GPT-4電子病歷分析&#xff0c;革新肝硬化并發癥隊列識別 在當下的科研領域&#xff0c;人工智能尤其是大語言模型的迅猛發展&#xff0c;正為各個學科帶來前所未有的機遇與變革。在醫學范疇&#xff0c;從疾病的早期精準篩查&am…

Python學習小結

bg&#xff1a;記錄一下&#xff0c;怕忘了&#xff1b;先寫一點&#xff0c;后面再補充。 1、沒有方法重載 2、字段都是公共字段 3、都是類似C#中頂級語句的寫法 4、對類的定義直接&#xff1a; class Student: 創建對象不需要new關鍵字&#xff0c;直接stu Student() 5、方…

QCustomPlot 中實現拖動區域放大?與恢復

1、拖動區域放大? 在 QCustomPlot 中實現 ?拖動區域放大?&#xff08;即通過鼠標左鍵拖動繪制矩形框選區域進行放大&#xff09;的核心方法是設置 SelectionRectMode。具體操作步驟&#xff1a; 1?&#xff09;禁用拖動模式? 確保先關閉默認的圖表拖動功能&#xff08;否…

如何將文件從 iPhone 傳輸到閃存驅動器

您想將文件從 iPhone 或 iPad 傳輸到閃存盤進行備份嗎&#xff1f;這是一個很好的決定&#xff0c;但您需要先了解一些實用的方法。雖然 Apple 生態系統在很大程度上是封閉的&#xff0c;但您可以使用一些實用工具將文件從 iPhone 或 iPad 傳輸到閃存盤。下文提供了這些行之有效…

互聯網大廠Java求職面試:云原生架構與微服務設計中的復雜挑戰

互聯網大廠Java求職面試&#xff1a;云原生架構與微服務設計中的復雜挑戰 面試官開場白 面試官&#xff08;嚴肅模式開啟&#xff09;&#xff1a;鄭薪苦&#xff0c;歡迎來到我們的技術面試環節。我是本次面試的技術總監&#xff0c;接下來我們將圍繞云原生架構、微服務設計、…

leetcode-hot-100 (鏈表)

1. 相交鏈表 題目鏈接&#xff1a;相交鏈表 題目描述&#xff1a;給你兩個單鏈表的頭節點 headA 和 headB &#xff0c;請你找出并返回兩個單鏈表相交的起始節點。如果兩個鏈表不存在相交節點&#xff0c;返回 null 。 解答&#xff1a; 其實這道題目我一開始沒太看懂題目給…

Web前端基礎之HTML

一、瀏覽器 火狐瀏覽器、谷歌瀏覽器(推薦)、IE瀏覽器 推薦谷歌瀏覽器原因&#xff1a; 1、簡潔大方,打開速度快 2、開發者調試工具&#xff08;右鍵空白處->檢查&#xff0c;打開調試模式&#xff09; 二、開發工具 核心IDE工具 Visual Studio Code (VS Code)? 微軟開發…

11.TCP三次握手

TCP連接建立與傳輸 1&#xff0e;主機 A 與主機 B 使用 TCP 傳輸數據&#xff0c;A 是 TCP 客戶&#xff0c;B 是 TCP 服務器。假設有512B 的數據要傳輸給 B&#xff0c;B 僅給 A 發送確認&#xff1b;A 的發送窗口 swnd 的尺寸為 100B&#xff0c;而 TCP 數據報文段每次也攜帶…

Python 爬蟲入門 Day 3 - 實現爬蟲多頁抓取與翻頁邏輯

Python 第二階段 - 爬蟲入門 &#x1f3af; 今日目標 掌握網頁分頁的原理和定位“下一頁”的鏈接能編寫循環邏輯自動翻頁抓取內容將多頁抓取整合到爬蟲系統中 &#x1f4d8; 學習內容詳解 &#x1f501; 網頁分頁邏輯介紹 以 quotes.toscrape.com 為例&#xff1a; 首頁鏈…