Class10簡潔實現
import torch
from torch import nn
from d2l import torch as d2l
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784 , 10 , 256 , 256
dropout1, dropout2 = 0.2 , 0.5
net = nn. Sequential( nn. Flatten( ) , nn. Linear( 784 , 256 ) , nn. ReLU( ) , nn. Dropout( dropout1) , nn. Linear( 256 , 256 ) , nn. ReLU( ) , nn. Dropout( dropout2) , nn. Linear( 256 , 10 ) )
def init_weights ( m) : if type ( m) == nn. Linear: nn. init. normal_( m. weight, std= 0.01 )
net. apply ( init_weights) ;
num_epochs, lr, batch_size = 10 , 0.5 , 256
loss = nn. CrossEntropyLoss( reduction= 'none' )
train_iter, test_iter = d2l. load_data_fashion_mnist( batch_size)
trainer = torch. optim. SGD( net. parameters( ) , lr= lr)
d2l. train_ch3( net, train_iter, test_iter, loss, num_epochs, trainer)