lstm 我有 20*6 條數據,20個樣本,每個樣本6條歷史數據,每條數據有5個值,我送給網絡輸入時應該是20*6*5 還是 6*20*5
你的數據是:
20 個樣本(batch size = 20)
每個樣本有 6 條歷史數據(sequence length = 6)
每條數據有 5 個值(input size = 5)
? 正確的輸入形狀是:
(20, 6, 5) # 即 batch_size=20, seq_len=6, input_size=5
前提是你在定義 LSTM 時設置了:
nn.LSTM(input_size=5, hidden_size=..., batch_first=True)