網絡結構,輸入為2個數,先經過10個節點的全連接層,再經過10個節點的ReLu,再經過10個節點的全連接層,再經過1個節點的全連接層,最后輸出。
#-*-coding:utf-8-*- import logging
import math
import random
import mxnet as mx # 導入 MXNet 庫
import numpy as np # 導入 NumPy 庫,這是 Python 常用的科學計算庫logging.getLogger().setLevel(logging.DEBUG) # 打開調試信息的顯示'''設置超參數'''
n_sample = 10000 # 訓練用的數據點個數
batch_size = 10 # 批大小
learning_rate = 0.1 # 學習速率
n_epoch = 10 # 訓練 epoch 數'''生成訓練數據'''
# 每個數據點是在 (0,1) 之間的 2 個隨機數
train_in = [[ random.uniform(0, 1) for c in range(2)] for n in range(n_sample)]
train_out = [0 for n in range(n_sample)] # 期望輸出,先初始化為 0
for i in range(n_sample):# 每個數據點的期望輸出是 2 個輸入數中的大者train_out[i] = max(train_in[i][0], train_in[i][1])'''定義train_iter為訓練數據的迭代器,data為輸入數據,label為標簽對應train_out,shuffle代表每個epoch會隨機打亂數據'''
train_iter = mx.io.NDArrayIter(data = np.array(train_in), label = {'reg_label':np.array(train_out)}, batch_size = batch_size, shuffle = True)'''定義網絡結構,src為輸入層,fc1,fc2,fc3是全連接層,act1,act2是ReLu層,num_hidden代表神經元個數,data是輸入數據,name是輸出'''
src = mx.sym.Variable('data') # 輸入層
fc1 = mx.sym.FullyConnected(data = src, num_hidden = 10, name = 'fc1') # 全連接層
act1 = mx.sym.Activation(data = fc1, act_type = "relu", name = 'act1') # ReLU層
fc2 = mx.sym.FullyConnected(data = act1, num_hidden = 10, name = 'fc2') # 全連接層
act2 = mx.sym.Activation(data = fc2, act_type = "relu", name = 'act2') # ReLU層
fc3 = mx.sym.FullyConnected(data = act2, num_hidden = 1, name = 'fc3') # 全連接層
'''定義net為輸出層,采用線性回歸輸出,MXNet會自動使用MSE作為損失函數,輸入數據為fc3,輸出層命名為reg'''
net = mx.sym.LinearRegressionOutput(data = fc3, name = 'reg') # 輸出層'''定義變量module需訓練的網絡模組,網絡的輸出symbol為net,期望標簽名label_names為reg_label'''
module = mx.mod.Module(symbol = net, label_names = (['reg_label']))'''定義module.fit進行訓練'''
module.fit(train_iter, # 訓練數據的迭代器eval_data = None, # 在此只訓練,不使用測試數據eval_metric = mx.metric.create('mse'), # 輸出 MSE 損失信息#將權重和偏置初始化為在[-0.5,0.5]間均勻的隨機數initializer=mx.initializer.Uniform(0.5),optimizer = 'sgd', # 梯度下降算法為 SGD# 設置學習速率optimizer_params = {'learning_rate': learning_rate}, num_epoch = n_epoch, # 訓練 epoch 數# 每經過 100 個 batch 輸出訓練速度 batch_end_callback = None, epoch_end_callback = None,
)#輸出最終參數
for k in module.get_params():print(k)