從零實現基于Transformer的英譯漢任務

1. model.py(用的是上一篇文章的代碼:從0搭建Transformer-CSDN博客)

import torch
import torch.nn as nn
import mathclass PositionalEncoding(nn.Module):def __init__ (self, d_model, dropout, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)# [[1, 2, 3],# [4, 5, 6],# [7, 8, 9]]pe = torch.zeros(max_len, d_model)# [[0],# [1],# [2]]position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)# 位置編碼固定,不更新參數# 保存模型時會保存緩沖區,在引入模型時緩沖區也被引入self.register_buffer('pe', pe)def forward(self, x):# 不計算梯度x = x + self.pe[:, :x.size(1)].requires_grad_(False)return xclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0self.d_k = d_model // num_headsself.num_heads = num_headsself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.W_o = nn.Linear(d_model, d_model)def forward(self, query, key, value, mask=None):batch_size = query.size(0)Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)attn_weights = torch.softmax(scores, dim=-1)context = torch.matmul(attn_weights, V)context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_k * self.num_heads)return self.W_o(context)class EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout = 0.1):super().__init__()self.attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),nn.ReLU(),nn.Linear(d_ff, d_model))self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):attn_output = self.attn(x, x, x, mask)x = self.norm1(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm2(x + self.dropout(ff_output))return xclass DecoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout=0.1):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.cross_attn = MultiHeadAttention(d_model, num_heads)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),nn.ReLU(),nn.Linear(d_ff, d_model))self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, src_mask, tgt_mask):attn_output = self.self_attn(x, x, x, tgt_mask)x = self.norm1(x + self.dropout(attn_output))attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)x = self.norm2(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm3(x + self.dropout(ff_output))return xclass Transformer(nn.Module):def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, dropout=0.1):super(Transformer, self).__init__()self.encoder_embed = nn.Embedding(src_vocab_size, d_model)self.decoder_embed = nn.Embedding(tgt_vocab_size, d_model)self.pos_encoder = PositionalEncoding(d_model, dropout)self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.fc_out = nn.Linear(d_model, tgt_vocab_size)def encode(self, src, src_mask):src_embeded = self.encoder_embed(src)src = self.pos_encoder(src_embeded)for layer in self.encoder_layers:src = layer(src, src_mask)return srcdef decode(self, tgt, enc_output, src_mask, tgt_mask):tgt_embeded = self.decoder_embed(tgt)tgt = self.pos_encoder(tgt_embeded)for layer in self.decoder_layers:tgt = layer(tgt, enc_output, src_mask, tgt_mask)return tgtdef forward(self, src, tgt, src_mask, tgt_mask):enc_output = self.encode(src, src_mask)dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)logits = self.fc_out(dec_output)return logits

2. train.py(數據量很大,使用其中一部分進行訓練和驗證,數據集來源:中英互譯數據集(translation2019zh)_數據集-飛槳AI Studio星河社區)

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from model import Transformer, PositionalEncoding
import math
import numpy as np
import os
import json
from tqdm import tqdm# --- Data Loading for JSON Lines format ---
# MODIFIED: Added max_lines parameter
def load_data_from_jsonl(file_path, max_lines=None): # <--- ADD max_lines parameter"""Loads English and Chinese sentences from a JSON Lines file, up to max_lines."""en_sentences, zh_sentences = [], []print(f"Loading data from {file_path}..." + (f" (up to {max_lines} lines)" if max_lines else ""))if not os.path.exists(file_path):print(f"Error: Data file not found at {file_path}")return [], []try:with open(file_path, 'r', encoding='utf-8') as f:lines_processed = 0for line in tqdm(f, desc=f"Reading {os.path.basename(file_path)}", total=max_lines if max_lines else None):if max_lines is not None and lines_processed >= max_lines: # <--- CHECK max_linesprint(f"\nReached max_lines limit of {max_lines} for {file_path}.")breaktry:data = json.loads(line.strip())if 'english' in data and 'chinese' in data:en_sentences.append(data['english'])zh_sentences.append(data['chinese'])lines_processed += 1 # <--- INCREMENT lines_processedelse:# This print can be noisy, consider removing or logging for large files# print(f"Warning: Skipping line due to missing 'english' or 'chinese' key: {line.strip()}")passexcept json.JSONDecodeError:# print(f"Warning: Skipping invalid JSON line: {line.strip()}")passexcept Exception as e:print(f"An error occurred while reading {file_path}: {e}")return [], []print(f"Loaded {len(en_sentences)} sentence pairs from {file_path}.")return en_sentences, zh_sentences# ... (Vocab, TranslationDataset, collate_fn, create_masks classes/functions remain the same) ...
# --- Vocab Class (Consider Subword Tokenization for large datasets later) ---
class Vocab:def __init__(self, sentences, min_freq=1, special_tokens=None):self.stoi = {}self.itos = {}if special_tokens is None:# Define PAD first as index 0 is often assumed for paddingspecial_tokens = ['<pad>', '<unk>', '<sos>', '<eos>']self.special_tokens = special_tokens# Initialize special tokens first to guarantee their indicesidx = 0for token in special_tokens:self.stoi[token] = idxself.itos[idx] = tokenidx += 1# Count character frequenciescounter = {}print("Counting character frequencies for vocab...")for s in tqdm(sentences, desc="Processing sentences for vocab"):if isinstance(s, str):for char in s:counter[char] = counter.get(char, 0) + 1# Add other tokens meeting min_freq, sorted by frequency# Filter out already added special tokens before sortingnon_special_counts = {token: count for token, count in counter.items() if token not in self.special_tokens}sorted_tokens = sorted(non_special_counts.items(), key=lambda item: item[1], reverse=True)for token, count in tqdm(sorted_tokens, desc="Building vocab mapping"):if count >= min_freq:# Check again if it's not a special token (redundant but safe)if token not in self.stoi:self.stoi[token] = idxself.itos[idx] = tokenidx += 1# Ensure <unk> exists and points to the correct index if it was overriddenif '<unk>' in self.special_tokens:unk_intended_idx = self.special_tokens.index('<unk>')if self.stoi.get('<unk>') != unk_intended_idx or self.itos.get(unk_intended_idx) != '<unk>':print(f"Warning: <unk> token mapping might be inconsistent. Forcing index {unk_intended_idx}.")# Find current mapping if any and remove itcurrent_unk_mapping_val = self.stoi.pop('<unk>', None) # Get the index value# Remove from itos if the index was indeed mapped to something else or old <unk>if current_unk_mapping_val is not None and self.itos.get(current_unk_mapping_val) == '<unk>':# If itos[idx] was already <unk>, it's fine. If it was something else, we might have a problem.# This logic ensures itos[unk_intended_idx] will be <unk># and stoi['<unk>'] will be unk_intended_idx# We might overwrite another token if it landed on unk_intended_idx before <unk># But special tokens should have priority.if self.itos.get(unk_intended_idx) is not None and self.itos.get(unk_intended_idx) != '<unk>':# A non-<unk> token is at the intended <unk> index. Find its stoi entry and remove.token_at_unk_idx = self.itos.get(unk_intended_idx)if token_at_unk_idx in self.stoi and self.stoi[token_at_unk_idx] == unk_intended_idx:del self.stoi[token_at_unk_idx]self.stoi['<unk>'] = unk_intended_idxself.itos[unk_intended_idx] = '<unk>'def __len__(self):return len(self.itos) # itos should be the definitive source of size# --- TranslationDataset Class (No changes needed) ---
class TranslationDataset(Dataset):def __init__(self, en_sentences, zh_sentences, src_vocab, tgt_vocab):self.src_data = []self.tgt_data = []print("Creating dataset tensors...")# Get special token indices oncesrc_sos_idx = src_vocab.stoi['<sos>']src_eos_idx = src_vocab.stoi['<eos>']src_unk_idx = src_vocab.stoi['<unk>']tgt_sos_idx = tgt_vocab.stoi['<sos>']tgt_eos_idx = tgt_vocab.stoi['<eos>']tgt_unk_idx = tgt_vocab.stoi['<unk>']# Use tqdm for progressfor en, zh in tqdm(zip(en_sentences, zh_sentences), total=len(en_sentences), desc="Vectorizing data"):src_ids = [src_sos_idx] + [src_vocab.stoi.get(c, src_unk_idx) for c in en] + [src_eos_idx]tgt_ids = [tgt_sos_idx] + [tgt_vocab.stoi.get(c, tgt_unk_idx) for c in zh] + [tgt_eos_idx]# Consider adding length filtering here if not done during preprocessingself.src_data.append(torch.LongTensor(src_ids))self.tgt_data.append(torch.LongTensor(tgt_ids))print("Dataset tensors created.")def __len__(self):return len(self.src_data)def __getitem__(self, idx):return self.src_data[idx], self.tgt_data[idx]# --- Collate Function (Ensure PAD index is correct) ---
def collate_fn(batch, pad_idx=0): # Pass pad_idx explicitly or get from vocab"""Pads sequences within a batch."""src_batch, tgt_batch = zip(*batch)# Pad sequences - Use batch_first=True as it's often more intuitivesrc_batch_padded = nn.utils.rnn.pad_sequence(src_batch, padding_value=pad_idx, batch_first=True)tgt_batch_padded = nn.utils.rnn.pad_sequence(tgt_batch, padding_value=pad_idx, batch_first=True)return src_batch_padded, tgt_batch_padded # Return (Batch, Seq)# --- Mask Creation Function (Adjust for batch_first=True) ---
def create_masks(src, tgt, pad_idx):"""Creates masks for source and target sequences (assuming batch_first=True)."""# src shape: (Batch, Src_Seq)# tgt shape: (Batch, Tgt_Seq)device = src.device# Source Padding Mask: (Batch, 1, 1, Src_Seq)src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)# Target Masks# Target Padding Mask: (Batch, 1, Tgt_Seq, 1)tgt_pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(-1) # Add dim for broadcasting with look_ahead# Look-ahead Mask: (Tgt_Seq, Tgt_Seq) -> (1, 1, Tgt_Seq, Tgt_Seq) for broadcastingtgt_seq_length = tgt.size(1)look_ahead_mask = (1 - torch.triu(torch.ones((tgt_seq_length, tgt_seq_length), device=device), diagonal=1)).bool().unsqueeze(0).unsqueeze(0) # Add Batch and Head dims# Combined Target Mask: (Batch, 1, Tgt_Seq, Tgt_Seq)tgt_mask = tgt_pad_mask & look_ahead_maskreturn src_mask.to(device), tgt_mask.to(device)# --- Main Execution Block ---
if __name__ == '__main__':# --- Configuration ---TRAIN_DATA_PATH = 'data/translation2019zh_train.json'VALID_DATA_PATH = 'data/translation2019zh_valid.json'MODEL_SAVE_PATH = 'best_model_subset.pth' # New model name for subset# MODIFIED: Define how many lines to use# For example, 100,000 for training and 10,000 for validation# Adjust these numbers based on your resources and desired training speedMAX_TRAIN_LINES = 1000000MAX_VALID_LINES = 100000# Hyperparameters (You might want smaller model for smaller data subset)BATCH_SIZE = 32NUM_EPOCHS = 10 # Can increase epochs for smaller datasetLEARNING_RATE = 1e-4# Consider using smaller model for faster iteration on subsetD_MODEL = 256NUM_HEADS = 8  # Must be divisor of d_modelNUM_LAYERS = 3D_FF = 1024    # Usually 4 * D_MODELDROPOUT = 0.1MIN_FREQ = 1   # For smaller datasets, min_freq=1 might be okayPRINT_FREQ = 100 # Print more often for smaller datasetsDEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"Using device: {DEVICE}")# --- Load Data (using the max_lines parameter) ---print(f"Loading subset of training data (up to {MAX_TRAIN_LINES} lines)...")train_en_sentences, train_zh_sentences = load_data_from_jsonl(TRAIN_DATA_PATH, max_lines=MAX_TRAIN_LINES)if not train_en_sentences:print("No training data loaded. Exiting.")exit()print(f"Loading subset of validation data (up to {MAX_VALID_LINES} lines)...")val_en_sentences, val_zh_sentences = load_data_from_jsonl(VALID_DATA_PATH, max_lines=MAX_VALID_LINES)if not val_en_sentences:print("Warning: No validation data loaded. Proceeding without validation.")# --- Build Vocabularies (ONLY from the training data subset) ---print("Building vocabularies from training data subset...")src_vocab = Vocab(train_en_sentences, min_freq=MIN_FREQ)tgt_vocab = Vocab(train_zh_sentences, min_freq=MIN_FREQ)print(f"Source vocab size: {len(src_vocab)}")print(f"Target vocab size: {len(tgt_vocab)}")PAD_IDX = src_vocab.stoi['<pad>']if PAD_IDX != 0 or tgt_vocab.stoi['<pad>'] != 0:print("Error: PAD index is not 0. Collate function and loss needs adjustment.")exit()# --- Create Datasets ---print("Creating training dataset...")train_dataset = TranslationDataset(train_en_sentences, train_zh_sentences, src_vocab, tgt_vocab)if val_en_sentences:print("Creating validation dataset...")val_dataset = TranslationDataset(val_en_sentences, val_zh_sentences, src_vocab, tgt_vocab)val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=lambda b: collate_fn(b, PAD_IDX))print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")else:val_loader = Noneprint(f"Train size: {len(train_dataset)} (No validation set)")train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda b: collate_fn(b, PAD_IDX))# --- Initialize Model ---print("Initializing model...")model = Transformer(src_vocab_size=len(src_vocab),tgt_vocab_size=len(tgt_vocab),d_model=D_MODEL,num_heads=NUM_HEADS,num_layers=NUM_LAYERS,d_ff=D_FF,dropout=DROPOUT).to(DEVICE)def count_parameters(model):return sum(p.numel() for p in model.parameters() if p.requires_grad)print(f'The model has {count_parameters(model):,} trainable parameters')optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)# --- Training Loop ---best_val_loss = float('inf')print("Starting training on data subset...")for epoch in range(NUM_EPOCHS):model.train()epoch_loss = 0train_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} Training")for i, (src, tgt) in enumerate(train_iterator):src = src.to(DEVICE)tgt = tgt.to(DEVICE)tgt_input = tgt[:, :-1]tgt_output = tgt[:, 1:]src_mask, tgt_mask = create_masks(src, tgt_input, PAD_IDX)logits = model(src, tgt_input, src_mask, tgt_mask)output_dim = logits.shape[-1]logits_reshaped = logits.contiguous().view(-1, output_dim)tgt_output_reshaped = tgt_output.contiguous().view(-1)loss = criterion(logits_reshaped, tgt_output_reshaped)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()epoch_loss += loss.item()train_iterator.set_postfix(loss=loss.item())avg_train_loss = epoch_loss / len(train_loader)if val_loader:model.eval()val_loss = 0val_iterator = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} Validation")with torch.no_grad():for src, tgt in val_iterator:src = src.to(DEVICE)tgt = tgt.to(DEVICE)tgt_input = tgt[:, :-1]tgt_output = tgt[:, 1:]src_mask, tgt_mask = create_masks(src, tgt_input, PAD_IDX)logits = model(src, tgt_input, src_mask, tgt_mask)output_dim = logits.shape[-1]logits_reshaped = logits.contiguous().view(-1, output_dim)tgt_output_reshaped = tgt_output.contiguous().view(-1)loss = criterion(logits_reshaped, tgt_output_reshaped)val_loss += loss.item()val_iterator.set_postfix(loss=loss.item())avg_val_loss = val_loss / len(val_loader)print(f'\nEpoch {epoch+1} Summary: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')if avg_val_loss < best_val_loss:print(f"Validation loss decreased ({best_val_loss:.4f} --> {avg_val_loss:.4f}). Saving model to {MODEL_SAVE_PATH}...")best_val_loss = avg_val_losstorch.save({'model_state_dict': model.state_dict(),'src_vocab': src_vocab,'tgt_vocab': tgt_vocab,'epoch': epoch,'optimizer_state_dict': optimizer.state_dict(),'loss': best_val_loss,'config': {'d_model': D_MODEL, 'num_heads': NUM_HEADS, 'num_layers': NUM_LAYERS,'d_ff': D_FF, 'dropout': DROPOUT,'src_vocab_size': len(src_vocab), 'tgt_vocab_size': len(tgt_vocab),'max_train_lines': MAX_TRAIN_LINES, 'max_valid_lines': MAX_VALID_LINES}}, MODEL_SAVE_PATH)else:print(f'\nEpoch {epoch+1} Summary: Train Loss: {avg_train_loss:.4f}')print(f"Saving model checkpoint to {MODEL_SAVE_PATH}...")torch.save({'model_state_dict': model.state_dict(), 'src_vocab': src_vocab, 'tgt_vocab': tgt_vocab,'epoch': epoch, 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_train_loss,'config': {'d_model': D_MODEL, 'num_heads': NUM_HEADS, 'num_layers': NUM_LAYERS,'d_ff': D_FF, 'dropout': DROPOUT,'src_vocab_size': len(src_vocab), 'tgt_vocab_size': len(tgt_vocab),'max_train_lines': MAX_TRAIN_LINES, 'max_valid_lines': MAX_VALID_LINES}}, MODEL_SAVE_PATH)print("Training complete on data subset!")

3. predict.py(模型預測)

# predict.py
import torch
import torch.nn as nn
import numpy as np
import sys
import os
import json # Keep json import just in case, though not used directly here# --- Attempt to import necessary components ---
try:from model import Transformer, PositionalEncoding# Import Vocab from the updated train.pyfrom train import Vocab, create_masks # Import create_masks if needed, but translate usually recreates its own simpler masks
except ImportError as e:print(f"Error importing necessary modules: {e}")print("Please ensure model.py and train.py are in the Python path and have the necessary definitions.")sys.exit(1)# --- Configuration ---
# !!! IMPORTANT: Use the path to the model saved by the *new* training script !!!
CHECKPOINT_PATH = 'best_model_subset.pth'
MAX_LENGTH = 60    # Maximum length of generated translation
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"Using device: {DEVICE}")
print(f"Loading checkpoint from: {CHECKPOINT_PATH}")# --- Load Checkpoint and Vocab ---
if not os.path.exists(CHECKPOINT_PATH):print(f"Error: Checkpoint file not found at {CHECKPOINT_PATH}")sys.exit(1)try:checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)print("Checkpoint loaded successfully.")
except Exception as e:print(f"Error loading checkpoint file: {e}")sys.exit(1)# --- Validate Checkpoint Contents ---
required_keys = ['model_state_dict', 'src_vocab', 'tgt_vocab']
# Also check for 'config' if you saved it, otherwise get params manually
if 'config' in checkpoint:required_keys.append('config')for key in required_keys:if key not in checkpoint:print(f"Error: Required key '{key}' not found in the checkpoint.")sys.exit(1)# --- Extract Vocab and Model Config ---
try:src_vocab = checkpoint['src_vocab']tgt_vocab = checkpoint['tgt_vocab']assert isinstance(src_vocab, Vocab) and isinstance(tgt_vocab, Vocab)PAD_IDX = src_vocab.stoi.get('<pad>', 0) # Use src_vocab pad index# Get model hyperparameters from checkpoint if savedif 'config' in checkpoint:config = checkpoint['config']D_MODEL = config['d_model']NUM_HEADS = config['num_heads']NUM_LAYERS = config['num_layers']D_FF = config['d_ff']DROPOUT = config['dropout']SRC_VOCAB_SIZE = config['src_vocab_size']TGT_VOCAB_SIZE = config['tgt_vocab_size']print("Model configuration loaded from checkpoint.")# Verify vocab sizes match loaded vocabsif SRC_VOCAB_SIZE != len(src_vocab) or TGT_VOCAB_SIZE != len(tgt_vocab):print("Warning: Vocab size in config mismatches loaded vocab length!")print(f"Config Src:{SRC_VOCAB_SIZE}/Tgt:{TGT_VOCAB_SIZE}, Loaded Src:{len(src_vocab)}/Tgt:{len(tgt_vocab)}")# Use lengths from loaded vocabs as they are definitiveSRC_VOCAB_SIZE = len(src_vocab)TGT_VOCAB_SIZE = len(tgt_vocab)else:# !!! Fallback: Manually define parameters - MUST MATCH TRAINING !!!print("Warning: Model config not found in checkpoint. Using manually defined parameters.")print("Ensure these match the parameters used during training!")D_MODEL = 512NUM_HEADS = 8NUM_LAYERS = 6D_FF = 2048DROPOUT = 0.1SRC_VOCAB_SIZE = len(src_vocab) # Use length from loaded vocabTGT_VOCAB_SIZE = len(tgt_vocab) # Use length from loaded vocabprint(f"Source vocab size: {len(src_vocab)}")print(f"Target vocab size: {len(tgt_vocab)}")
except Exception as e:print(f"Error processing vocabulary or config from checkpoint: {e}")sys.exit(1)# --- Initialize Model ---
try:model = Transformer(src_vocab_size=SRC_VOCAB_SIZE,tgt_vocab_size=TGT_VOCAB_SIZE,d_model=D_MODEL,num_heads=NUM_HEADS,num_layers=NUM_LAYERS,d_ff=D_FF,dropout=DROPOUT # Dropout value is less critical for eval mode).to(DEVICE)print("Model initialized.")def count_parameters(model):return sum(p.numel() for p in model.parameters())print(f'The model has {count_parameters(model):,} total parameters.')except Exception as e:print(f"Error initializing the Transformer model: {e}")sys.exit(1)# --- Load Model State ---
try:model.load_state_dict(checkpoint['model_state_dict'])model.eval() # Set model to evaluation modeprint("Model state loaded successfully.")
except RuntimeError as e:print(f"Error loading model state_dict: {e}")print("This *strongly* indicates a mismatch between the loaded checkpoint's architecture")print("(implicit in state_dict keys/shapes) and the model initialized here.")print("Verify that the hyperparameters (D_MODEL, NUM_HEADS, NUM_LAYERS, D_FF, vocab sizes)")print("match *exactly* those used when the checkpoint was saved.")sys.exit(1)
except Exception as e:print(f"An unexpected error occurred while loading model state: {e}")sys.exit(1)# --- Translate Function (largely unchanged, ensure correct mask creation for batch size 1) ---
def translate(sentence: str, model: nn.Module, src_vocab: Vocab, tgt_vocab: Vocab, device: torch.device, max_length: int = 50):"""Translates a source sentence using the trained transformer model."""model.eval() # Ensure model is in eval mode# --- Input Preprocessing ---if not isinstance(sentence, str): return "[Error: Invalid Input Type]"src_sos_idx = src_vocab.stoi.get('<sos>')src_eos_idx = src_vocab.stoi.get('<eos>')src_unk_idx = src_vocab.stoi.get('<unk>', 0) # Default to 0 (usually PAD) if missingsrc_pad_idx = src_vocab.stoi.get('<pad>', 0)if src_sos_idx is None or src_eos_idx is None: return "[Error: Bad Src Vocab]"src_tokens = ['<sos>'] + list(sentence) + ['<eos>']src_ids = [src_vocab.stoi.get(token, src_unk_idx) for token in src_tokens]src_tensor = torch.LongTensor(src_ids).unsqueeze(0).to(device) # Shape: (1, src_len)# --- Create Source Mask ---src_mask = (src_tensor != src_pad_idx).unsqueeze(1).unsqueeze(2).to(device) # Shape: (1, 1, 1, src_len)# --- Encode Source ---with torch.no_grad():try:enc_output = model.encode(src_tensor, src_mask) # Shape: (1, src_len, d_model)except Exception as e:print(f"Error during model encoding: {e}")return "[Error: Encoding Failed]"# --- Decode Target (Greedy Search) ---tgt_sos_idx = tgt_vocab.stoi.get('<sos>')tgt_eos_idx = tgt_vocab.stoi.get('<eos>')tgt_pad_idx = tgt_vocab.stoi.get('<pad>', 0)if tgt_sos_idx is None or tgt_eos_idx is None: return "[Error: Bad Tgt Vocab]"tgt_ids = [tgt_sos_idx] # Start with <sos>for i in range(max_length):tgt_tensor = torch.LongTensor(tgt_ids).unsqueeze(0).to(device) # Shape: (1, current_tgt_len)tgt_len = tgt_tensor.size(1)# --- Create Target Masks (for batch size 1) ---# 1. Target Padding Mask (probably all True here, but good practice)# Shape: (1, 1, tgt_len, 1)tgt_pad_mask = (tgt_tensor != tgt_pad_idx).unsqueeze(1).unsqueeze(-1)# 2. Look-ahead Mask# Shape: (1, tgt_len, tgt_len) -> needs head dim (1, 1, tgt_len, tgt_len)look_ahead_mask = (1 - torch.triu(torch.ones(tgt_len, tgt_len, device=device), diagonal=1)).bool().unsqueeze(0).unsqueeze(0) # Add Batch and Head dim# 3. Combined Target Mask: Shape (1, 1, tgt_len, tgt_len)combined_tgt_mask = tgt_pad_mask & look_ahead_mask# --- Decode Step ---with torch.no_grad():try:# src_mask (1, 1, 1, src_len) broadcasts fine# combined_tgt_mask (1, 1, tgt_len, tgt_len) broadcasts fineoutput = model.decode(tgt_tensor, enc_output, src_mask, combined_tgt_mask)logits = model.fc_out(output[:, -1, :]) # Use only the last output token's logitsexcept Exception as e:print(f"Error during model decoding step {i}: {e}")# Potentially show partial translation?# partial_translation = "".join([tgt_vocab.itos.get(idx, '?') for idx in tgt_ids[1:]]) # Skip SOS# return f"[Error: Decoding Failed at step {i}. Partial: {partial_translation}]"return "[Error: Decoding Failed]"pred_token_id = logits.argmax(1).item()tgt_ids.append(pred_token_id)# Stop if <eos> token is predictedif pred_token_id == tgt_eos_idx:break# --- Post-process Output ---special_indices = {tgt_vocab.stoi.get(tok, -999)for tok in ['<sos>', '<eos>', '<pad>']}# Use get() for safety, default to <unk> if ID somehow not in itostranslated_tokens = [tgt_vocab.itos.get(idx, '<unk>') for idx in tgt_ids if idx not in special_indices]return "".join(translated_tokens)test_sentences = ["Hello!","How are you?","This is a test.","He plays football every weekend.","She has a beautiful dog.","The sun is shining brightly.","I like to read books.","They are going to the park.","My favorite color is blue.","We eat dinner at seven.","The cat sleeps on the mat.","Birds sing in the morning.","He can swim very well.","She writes a letter.","The car is red.","I see a big tree.","They watch television.","My brother is tall.","We learn English at school.","The flowers smell good.","He drinks milk every day.","She helps her mother.","The book is on the table.","I have two pencils.","They live in a small house.","My father works hard.","We play games together.","The moon is bright tonight.","He wears a green shirt.","She dances gracefully.","The fish swims in the water.","I want an apple.","They visit their grandparents.","My sister plays the piano.","We go to bed early.","The sky is clear.","He listens to music.","She draws a nice picture.","The bus stops here.","I feel happy today.","They build a sandcastle.","My friend is kind.","We love to travel.","The baby is crying.","He eats an orange.","She cleans her room.","The door is open.","I can ride a bike.","They run in the field.","My teacher is helpful.","We study science.","The stars are far away.","He tells a funny story.","She wears a pretty dress.","The train is fast.","I understand the lesson.","They sing a happy song.","My shoes are new.","We walk to the store.","The food is delicious.","He reads a newspaper.","She looks at the birds.","The window is closed.","I need some water.","They plant a tree.","My dog likes to play fetch.","We visit the museum.","The weather is warm.","He fixes the broken toy.","She calls her friend.","The grass is green.","I like ice cream.","They go on a holiday.","My mother cooks tasty food.","We have a picnic.","The river flows slowly.","He throws the ball.","She smiles at me.","The mountain is high.","I lost my key.","They help the old man.","My garden is beautiful.","We share our toys.","The answer is simple.","He drives a blue car.","She paints a landscape.","The clock is on the wall.","I am learning to code.","They make a snowman.","My homework is easy.","We clean the house.","The bird has a nest.","He catches a fish.","She studies for the exam.","The bridge is long.","I want to sleep.","They are good friends.","My cat is very playful.","We are going to the beach.","The coffee is hot.","He gives her a gift."
]print("\n--- Starting Translation Examples ---")
for sentence in test_sentences:print("-" * 20)print(f"Input:      {sentence}")translation = translate(sentence, model, src_vocab, tgt_vocab, DEVICE, max_length=MAX_LENGTH)print(f"Translation: {translation}")print("-" * 20)
print("Prediction finished.")

predict.py運行結果展示:

root@autodl-container-de94439c34-d719190d:~# python predict.py
Using device: cpu
Loading checkpoint from: best_model_subset.pth
Checkpoint loaded successfully.
Model configuration loaded from checkpoint.
Source vocab size: 2776
Target vocab size: 8209
Model initialized.
The model has 10,451,473 total parameters.
Model state loaded successfully.--- Starting Translation Examples ---
--------------------
Input:      Hello!
Translation: 你好!
--------------------
Input:      How are you?
Translation: 你怎么樣?
--------------------
Input:      This is a test.
Translation: 這是一個測試。
--------------------
Input:      He plays football every weekend.
Translation: 他每周都踢足球。
--------------------
Input:      She has a beautiful dog.
Translation: 她有一只美麗的狗。
--------------------
Input:      The sun is shining brightly.
Translation: 太陽光明亮了。
--------------------
Input:      I like to read books.
Translation: 我喜歡讀書。
--------------------
Input:      They are going to the park.
Translation: 他們正在去公園。
--------------------
Input:      My favorite color is blue.
Translation: 我最喜歡的顏色是藍色。
--------------------
Input:      We eat dinner at seven.
Translation: 我們吃晚飯。
--------------------
Input:      The cat sleeps on the mat.
Translation: 貓睡在墊上。
--------------------
Input:      Birds sing in the morning.
Translation: 鳥在早晨唱歌。
--------------------
Input:      He can swim very well.
Translation: 他可以很好地游泳。
--------------------
Input:      She writes a letter.
Translation: 她寫信。
--------------------
Input:      The car is red.
Translation: 車是紅色的。
--------------------
Input:      I see a big tree.
Translation: 我看見一棵大樹。
--------------------
Input:      They watch television.
Translation: 他們看電視。
--------------------
Input:      My brother is tall.
Translation: 我的哥哥高。
--------------------
Input:      We learn English at school.
Translation: 我們學習英語。
--------------------
Input:      The flowers smell good.
Translation: 花香氣味好。
--------------------
Input:      He drinks milk every day.
Translation: 他每天喝牛奶。
--------------------
Input:      She helps her mother.
Translation: 她幫忙媽媽。
--------------------
Input:      The book is on the table.
Translation: 這本書是桌子上的。
--------------------
Input:      I have two pencils.
Translation: 我有兩個鉛筆。
--------------------
Input:      They live in a small house.
Translation: 他們住在一個小房子里。
--------------------
Input:      My father works hard.
Translation: 我爸爸爸很努力。
--------------------
Input:      We play games together.
Translation: 我們玩游戲。
--------------------
Input:      The moon is bright tonight.
Translation: 月亮今晚是明亮的。
--------------------
Input:      He wears a green shirt.
Translation: 他穿著綠色的襯衫。
--------------------
Input:      She dances gracefully.
Translation: 她很喜歡跳舞。
--------------------
Input:      The fish swims in the water.
Translation: 魚在水里游泳。
--------------------
Input:      I want an apple.
Translation: 我想要一個蘋果。
--------------------
Input:      They visit their grandparents.
Translation: 他們訪問他們的祖父母。
--------------------
Input:      My sister plays the piano.
Translation: 我的妹妹打鋼琴。
--------------------
Input:      We go to bed early.
Translation: 我們早些時候睡覺。
--------------------
Input:      The sky is clear.
Translation: 天空清晰。
--------------------
Input:      He listens to music.
Translation: 他聽音樂。
--------------------
Input:      She draws a nice picture.
Translation: 她畫了一張美麗的照片。
--------------------
Input:      The bus stops here.
Translation: 公共汽車停下來。
--------------------
Input:      I feel happy today.
Translation: 今天我感到快樂。
--------------------
Input:      They build a sandcastle.
Translation: 他們建造了一個沙子。
--------------------
Input:      My friend is kind.
Translation: 我的朋友是個好的。
--------------------
Input:      We love to travel.
Translation: 我們喜歡旅行。
--------------------
Input:      The baby is crying.
Translation: 這個寶寶正在哭泣。
--------------------
Input:      He eats an orange.
Translation: 他吃了一個橙色。
--------------------
Input:      She cleans her room.
Translation: 她潔凈房間。
--------------------
Input:      The door is open.
Translation: 門開了。
--------------------
Input:      I can ride a bike.
Translation: 我可以騎自行車。
--------------------
Input:      They run in the field.
Translation: 他們在田里跑。
--------------------
Input:      My teacher is helpful.
Translation: 老師很有幫助。
--------------------
Input:      We study science.
Translation: 我們研究科學。
--------------------
Input:      The stars are far away.
Translation: 星星遠遠遠。
--------------------
Input:      He tells a funny story.
Translation: 他告訴一個有趣的故事。
--------------------
Input:      She wears a pretty dress.
Translation: 她穿著一件衣服。
--------------------
Input:      The train is fast.
Translation: 火車快速。
--------------------
Input:      I understand the lesson.
Translation: 我理解課程。
--------------------
Input:      They sing a happy song.
Translation: 他們唱了一首快樂的歌。
--------------------
Input:      My shoes are new.
Translation: 我的鞋子是新的。
--------------------
Input:      We walk to the store.
Translation: 我們走到商店。
--------------------
Input:      The food is delicious.
Translation: 食物是美味的。
--------------------
Input:      He reads a newspaper.
Translation: 他讀了一篇報紙。
--------------------
Input:      She looks at the birds.
Translation: 她看著鳥兒。
--------------------
Input:      The window is closed.
Translation: 窗戶閉上了。
--------------------
Input:      I need some water.
Translation: 我需要一些水。
--------------------
Input:      They plant a tree.
Translation: 他們種了樹。
--------------------
Input:      My dog likes to play fetch.
Translation: 我的狗喜歡玩耍。
--------------------
Input:      We visit the museum.
Translation: 我們訪問博物館。
--------------------
Input:      The weather is warm.
Translation: 天氣暖暖。
--------------------
Input:      He fixes the broken toy.
Translation: 他把玩具固定了。
--------------------
Input:      She calls her friend.
Translation: 她打電話給她的朋友。
--------------------
Input:      The grass is green.
Translation: 草是綠色的。
--------------------
Input:      I like ice cream.
Translation: 我喜歡冰淇淋。
--------------------
Input:      They go on a holiday.
Translation: 他們一天去度假。
--------------------
Input:      My mother cooks tasty food.
Translation: 媽媽的菜吃了香味。
--------------------
Input:      We have a picnic.
Translation: 我們有一個野餐。
--------------------
Input:      The river flows slowly.
Translation: 河流慢慢慢。
--------------------
Input:      He throws the ball.
Translation: 他把球扔了。
--------------------
Input:      She smiles at me.
Translation: 她笑著我。
--------------------
Input:      The mountain is high.
Translation: 山高。
--------------------
Input:      I lost my key.
Translation: 我丟了我的鑰匙。
--------------------
Input:      They help the old man.
Translation: 他們幫助老人。
--------------------
Input:      My garden is beautiful.
Translation: 我的花園很美麗。
--------------------
Input:      We share our toys.
Translation: 我們分享我們的玩具。
--------------------
Input:      The answer is simple.
Translation: 答案簡單。
--------------------
Input:      He drives a blue car.
Translation: 他駕駛藍色的車。
--------------------
Input:      She paints a landscape.
Translation: 她畫了一幅景觀。
--------------------
Input:      The clock is on the wall.
Translation: 鐘聲在墻上。
--------------------
Input:      I am learning to code.
Translation: 我學習代碼。
--------------------
Input:      They make a snowman.
Translation: 他們制造雪人。
--------------------
Input:      My homework is easy.
Translation: 我的家庭工作很容易。
--------------------
Input:      We clean the house.
Translation: 我們清潔房子。
--------------------
Input:      The bird has a nest.
Translation: 鳥兒有巢。
--------------------
Input:      He catches a fish.
Translation: 他抓了一只魚。
--------------------
Input:      She studies for the exam.
Translation: 她對考試進行研究。
--------------------
Input:      The bridge is long.
Translation: 橋長。
--------------------
Input:      I want to sleep.
Translation: 我想睡得。
--------------------
Input:      They are good friends.
Translation: 他們是好朋友。
--------------------
Input:      My cat is very playful.
Translation: 我的貓是非常有趣的。
--------------------
Input:      We are going to the beach.
Translation: 我們要到海灘上去。
--------------------
Input:      The coffee is hot.
Translation: 咖啡是熱的。
--------------------
Input:      He gives her a gift.
Translation: 他給她一個禮物。
--------------------
Prediction finished.

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/82910.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/82910.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/82910.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

c#建筑行業財務流水賬系統軟件可上傳記賬憑證財務管理系統簽核功能

# financial_建筑行業 建筑行業財務流水賬系統軟件可上傳記賬憑證財務管理系統簽核功能 # 開發背景 軟件是給岳陽客戶定制開發一款建筑行業流水賬財務軟件。提供工程簽證單、施工日志、人員出勤表等信息記錄。 # 財務管理系統功能描述 1.可以自行設置記賬科目&#xff0c;做憑…

MySQL 8.0 OCP 1Z0-908 題目解析(2)

題目005 Choose two. Which two actions can obtain information about deadlocks? □ A) Run the SHOW ENGINE INNODB MUTEX command from the mysql client. □ B) Enable the innodb_status_output_locks global parameter. □ C) Enable the innodb_print_all_deadlock…

XA協議和Tcc

基于 XA 協議的兩階段提交 (2PC)。這是一種分布式事務協議&#xff0c;旨在保證在多個參與者&#xff08;通常是不同的數據庫或資源管理器&#xff09;共同參與的事務中&#xff0c;所有參與者要么都提交事務&#xff0c;要么都回滾事務&#xff0c;從而維護數據的一致性。 你…

數據分析-圖2-圖像對象設置參數與子圖

from matplotlib import pyplot as mp mp.figure(A figure,facecolorgray) mp.plot([0,1],[1,2]) mp.figure(B figure,facecolorlightgray) mp.plot([1,2],[2,1]) #如果figure中標題已創建&#xff0c;則不會新建窗口&#xff0c; #而是將舊窗口設置為當前窗口 mp.figure(A fig…

跳轉語句:break、continue、goto -《Go語言實戰指南》

在控制流程中&#xff0c;我們有時需要跳出當前循環或跳過當前步驟&#xff0c;甚至直接跳轉到指定位置。Go 提供了三種基本跳轉語句&#xff1a; ? break&#xff1a;跳出當前 for、switch 或 select。? continue&#xff1a;跳過本輪循環&#xff0c;進入下一輪。? goto&a…

Linux中find命令用法核心要點提煉

大家好&#xff0c;歡迎來到程序視點&#xff01;我是你們的老朋友.小二&#xff01; 以下是針對Linux中find命令用法的核心要點提煉&#xff1a; 基礎語法結構 find [路徑] [選項] [操作]路徑&#xff1a;查找目錄&#xff08;.表當前目錄&#xff0c;/表根目錄&#xff09;…

MQTT協議詳解:物聯網通信的輕量級解決方案

MQTT協議詳解&#xff1a;物聯網通信的輕量級解決方案 引言 在物聯網(IoT)快速發展的今天&#xff0c;設備間高效可靠的通信變得至關重要。MQTT(Message Queuing Telemetry Transport)作為一種輕量級的發布/訂閱協議&#xff0c;已成為物聯網通信的首選解決方案。本文將深入探…

list基礎用法

list基礎用法 1.list的訪問就不能用下標[]了,用迭代器2.emplace_back()幾乎是與push_back()用法一致&#xff0c;但也有差別3.insert(),erase()的用法4.reverse()5.排序6.合并7.unique()&#xff08;去重&#xff09;8.splice剪切再粘貼 1.list的訪問就不能用下標[]了,用迭代器…

2025年第十六屆藍橋杯大賽軟件賽C/C++大學B組題解

第十六屆藍橋杯大賽軟件賽C/C大學B組題解 試題A: 移動距離 問題描述 小明初始在二維平面的原點&#xff0c;他想前往坐標(233,666)。在移動過程中&#xff0c;他只能采用以下兩種移動方式&#xff0c;并且這兩種移動方式可以交替、不限次數地使用&#xff1a; 水平向右移動…

BGP實驗練習2

需求&#xff1a; 1.AS1存在兩個環回&#xff0c;一個地址為192.168.1.0/24&#xff0c;該地址不能再任何協議中宣告 AS3存在兩個環回&#xff0c;該地址不能再任何協議中宣告 AS1還有一個環回地址為10.1.1.0/24&#xff0c;AS3另一個環回地址是11.1.1.0/24 最終要求這兩…

【溫濕度物聯網】記錄1:寄存器配置

一&#xff0c;及哦地址 基地址base的定義&#xff1a; ↓ 定義完是這個&#xff1a; GPIOA的地址就是以上的代表 2寄存器&#xff1a; 通過bsrr來改變odr寄存器&#xff0c;左移16位就是把0-15位的給移到高位的保留區&#xff0c;這樣就歸零了 3&#xff0c;項目寄存器實操…

MCP項目實例 - client sever交互

1. 項目概述 項目目標 構建一個本地智能輿論分析系統。 利用自然語言處理和多工具協作&#xff0c;實現用戶查詢意圖的自動理解。 進行新聞檢索、情緒分析、結構化輸出和郵件推送。 系統流程 用戶查詢&#xff1a;用戶輸入查詢請求。 提取關鍵詞&#xff1a;從用戶查詢中…

運維體系架構規劃

運維體系架構規劃是一個系統性工程&#xff0c;旨在構建高效、穩定、安全的運維體系&#xff0c;保障業務系統的持續運行。下面從規劃目標、核心模塊、實施步驟等方面進行詳細闡述&#xff1a; 一、規劃目標 高可用性&#xff1a;確保業務系統 724 小時不間斷運行&#xff0c…

zst-2001 上午題-歷年真題 計算機網絡(16個內容)

網絡設備 計算機網絡 - 第1題 ac 計算機網絡 - 第2題 d 計算機網絡 - 第3題 集線器不能隔離廣播域和沖突域&#xff0c;所以集線器就1個廣播域和沖突域 交換機就是那么的炫&#xff0c;可以隔離沖突域&#xff0c;有4給沖突域&#xff0c;但不能隔離廣播域&#xf…

Python之with語句

文章目錄 Python中的with語句詳解一、基本語法二、工作原理三、文件操作中的with語句1. 基本用法2. 同時打開多個文件 四、with語句的優勢五、自定義上下文管理器1. 基于類的實現2. 使用contextlib模塊 六、常見應用場景七、注意事項 Python中的with語句詳解 with語句是Python…

我的五周年創作紀念日

五年前的今天&#xff0c;我在CSDN發布了第一篇《基于VS2015的MFC學習筆記&#xff08;常用按鈕button&#xff09;》&#xff0c;文末那句"歡迎交流"的忐忑留言&#xff0c;開啟了這段充滿驚喜的技術旅程。恍然發覺那些敲過的代碼早已成長為參天大樹。 收獲 獲得了…

Realtek 8126驅動分析第四篇——multi queue相關

Realtek 8126是 5G 網卡&#xff0c;因為和 8125 較為接近&#xff0c;第四篇從這里開始也無不可。本篇主要是講 multi queue 相關&#xff0c;其他的一些內容在之前就已經提過&#xff0c;不加贅述。 1 初始化 1.1 rtl8126_init_one 從第一篇我們可以知道每個 PCI 驅動都注…

使用PHP對接日本股票市場數據

本文將介紹如何通過StockTV提供的API接口&#xff0c;使用PHP語言來獲取并處理日本股票市場的數據。我們將以查詢公司信息、查看漲跌排行榜和實時接收數據為例&#xff0c;展示具體的操作流程。 準備工作 首先&#xff0c;請確保您已經從StockTV獲得了API密鑰&#xff0c;并且…

爬蟲工具與編程語言選擇指南

有人問爬蟲如何選擇工具和編程語言。根據我多年的經驗來說&#xff0c;是我肯定得先分析不同場景下適合的工具和語言。 如果大家不知道其他語言&#xff0c;比如JavaScript&#xff08;Node.js&#xff09;或者Go&#xff0c;這些在特定情況下可能更合適。比如&#xff0c;如果…

C語言while循環的用法(非常詳細,附帶實例)

while 是 C 語言中的一種循環控制結構&#xff0c;用于在特定條件為真時重復執行一段代碼。 while 循環的語法如下&#xff1a; while (條件表達式) { // 循環體&#xff1a;條件為真時執行的代碼 } 條件表達式&#xff1a;返回真&#xff08;非 0&#xff09;或假&#x…