目錄
1. 基礎寫法
1.1導包
2.2加載讀取數據
2.3原始數據可視化(畫圖顯示)
2.4線性回歸的(基礎)分解寫法
2.5定義訓練過程
2.PyTorch實現 線性回歸的封裝寫法(實際項目中的常用寫法)
2.1創建線性回歸模型
2.2定義損失函數
2.3定義優化器
2.4定義訓練過程
1. 基礎寫法
1.1導包
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
2.2加載讀取數據?
data = pd.read_csv('./dataset/Income1.csv')
data
#讀取數據類型為dataframe類型
輸出結果截圖所示(部分數據)
data.head() #查看dataframe數據的前五條數據
data.tail() #后五條數據
data.Education.head() #查看數據的Education列的前五條數據 #是一個Series
0 10.000000 1 10.401338 2 10.842809 3 11.244147 4 11.645485 Name: Education, dtype: float64
data.Education[:5] #查看數據的Education列的前五條數據
0 10.000000 1 10.401338 2 10.842809 3 11.244147 4 11.645485 Name: Education, dtype: float64
2.3原始數據可視化(畫圖顯示)
#畫散點圖,觀察數據Education 與 Income 是否具有線性關系
plt.scatter(data.Education, data.Income)
plt.xlabel('Education')
plt.ylabel('Income')