##?
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')
# 定義FFN層
class FeedForwardNetwork(nn.Module):
? ? def __init__(self, input_dim, hidden_dim, output_dim):
? ? ? ? super(FeedForwardNetwork, self).__init__()
? ? ? ? self.linear1 = nn.Linear(input_dim, hidden_dim)
? ? ? ? self.relu = nn.ReLU()
? ? ? ? self.linear2 = nn.Linear(hidden_dim, output_dim)
? ? def forward(self, x):
? ? ? ? x = self.linear1(x)
? ? ? ? x = self.relu(x)
? ? ? ? x = self.linear2(x)
? ? ? ? return x
# 測試FFN層
def test_ffn():
? ? input_dim = 4
? ? hidden_dim = 8
? ? output_dim = 4
? ? batch_size = 5
? ? seq_length = 6
? ? # 創建FFN層
? ? ffn = FeedForwardNetwork(input_dim, hidden_dim, output_dim)
? ? # 創建隨機輸入數據 (batch_size, seq_length, input_dim)
? ? input_data = torch.randn(batch_size, seq_length, input_dim)
? ? print(input_data)
? ? # 前向傳播
? ? output_data = ffn(input_data)
? ? print("Input shape:", input_data.shape)
? ? print("Output shape:", output_data.shape)
if __name__ == "__main__":
? ? test_ffn()
?