TensorFlow的可訓練變量和自動求導機制

文章目錄

  • 一些概念、函數、用法
  • TensorFlow實現一元線性回歸
  • TensorFlow實現多元線性回歸


一些概念、函數、用法

對象Variable

創建對象Variable:

tf.Variable(initial_value,dtype)

利用這個方法,默認整數為int32,浮點數為float32,注意Numpy默認的浮點數類型是float64,如果想和Numpy數據進行對比,則需要修改與numpy一致,否則在機器學習中float32位夠用了。
將張量封裝為可訓練變量

print(tf.Variable(tf.random. normal([2,2])))

<tf.Variable ‘Variable:0’ shape=(2, 2) dtype=float32, numpy=array([[-1.2848959 , -0.22805293],[-0.79079854, 0.7035335 ]], dtype=float32)>

trainalbe屬性
用來檢查Variable變量是否可訓練

x.trainalbe

可訓練變量賦值,注意x是Variable對象類型,不是tensor類型

x.assign()
x.assign_add()
x.assign_sub()

用isinstance()方法來判斷是tensor還是Variable
在這里插入圖片描述
自動求導

with GradientTape() as tape:
函數表達式
grad=tape.gradient(函數,自變量)

x=tf.Variable(3.)
with tf.GradientTape() as tape:y=tf.square(x)
dy_dx = tape.gradient(y,x)
print(y)
print(dy_dx)

tf.Tensor(9.0, shape=(), dtype=float32)
tf.Tensor(6.0, shape=(), dtype=float32)

GradientTape函數

GradientTape(persistent,watch_accessed_variables)
第一個參數默認為false,表示梯度帶只使用一次,使用完就銷毀了,若為true則表明梯度帶可以多次使用,但在循環最后要記得把它銷毀
第二個參數默認為true,表示自動添加監視

tape.watch()函數
用來添加監視非可訓練變量
多元函數求一階偏導數

x=tf.Variable(3.)
y=tf.Variable(4.)
with tf.GradientTape(persistent=True) as tape:f=tf.square(x)+2*tf.square(y)+1
df_dx,df_dy = tape.gradient(f,[x,y])
first_grade = tape.gradient(f,[x,y])
print(f)
print(df_dx)
print(df_dy)
print(first_grade)
del tape

tf.Tensor(42.0, shape=(), dtype=float32)
tf.Tensor(6.0, shape=(), dtype=float32)
tf.Tensor(16.0, shape=(), dtype=float32)
[<tf.Tensor: id=36, shape=(), dtype=float32, numpy=6.0>, <tf.Tensor: id=41, shape=(), dtype=float32, numpy=16.0>]

多元函數求二階偏導數
在這里插入圖片描述

x=tf.Variable(3.)
y=tf.Variable(4.)
with tf.GradientTape(persistent=True) as tape2:with tf.GradientTape(persistent=True) as tape1:f=tf.square(x)+2*tf.square(y)+1first_grade = tape1.gradient(f,[x,y])
second_grade = [tape2.gradient(first_grade,[x,y])]
print(f)
print(first_grade)
print(second_grade)
del tape1
del tape2

tf.Tensor(42.0, shape=(), dtype=float32)
[<tf.Tensor: id=27, shape=(), dtype=float32, numpy=6.0>, <tf.Tensor: id=32, shape=(), dtype=float32, numpy=16.0>]
[[<tf.Tensor: id=39, shape=(), dtype=float32, numpy=2.0>, <tf.Tensor: id=41, shape=(), dtype=float32, numpy=4.0>]]

TensorFlow實現一元線性回歸

import numpy as np
import tensorflow as tf 
import matplotlib.pyplot as plt 
#設置字體
plt.rcParams['font.sans-serif'] =['SimHei']
#加載樣本數據
x=np.array([137.97,104.50,100.00,124.32,79.20,99.00,124.00,114.00,106.69,138.05,53.75,46.91,68.00,63.02,81.26,86.21])
y=np.array([145.00,110.00,93.00,116.00,65.32,104.00,118.00,91.00,62.00,133.00,51.00,45.00,78.50,69.65,75.69,95.30])
#設置超參數,學習率
learn_rate=0.0001
#迭代次數
iter=100
#每10次迭代顯示一下效果
display_step=10
#設置模型參數初值
np.random.seed(612)
w=tf.Variable(np.random.randn())
b=tf.Variable(np.random.randn())
#訓練模型
#存放每次迭代的損失值
mse=[]
for i in range(0,iter+1):with tf.GradientTape() as tape:pred=w*x+bLoss=0.5*tf.reduce_mean(tf.square(y-pred))mse.append(Loss)#更新參數dL_dw,dL_db = tape.gradient(Loss,[w,b])w.assign_sub(learn_rate*dL_dw)b.assign_sub(learn_rate*dL_db)#plt.plot(x,pred)if i%display_step==0:print("i:%i,Loss:%f,w:%f,b:%f"%(i,mse[i],w.numpy(),b.numpy()))

TensorFlow實現多元線性回歸

import numpy as np
import tensorflow as tf #=======================【1】加載樣本數據===============================================
area=np.array([137.97,104.50,100.00,124.32,79.20,99.00,124.00,114.00,106.69,138.05,53.75,46.91,68.00,63.02,81.26,86.21])
room=np.array([3,2,2,3,1,2,3,2,2,3,1,1,1,1,2,2])
price=np.array([145.00,110.00,93.00,116.00,65.32,104.00,118.00,91.00,62.00,133.00,51.00,45.00,78.50,69.65,75.69,95.30])
num=len(area) #樣本數量
#=======================【2】數據處理===============================================
x0=np.ones(num)
#歸一化處理,這里使用線性歸一化
x1=(area-area.min())/(area.max()-area.min())
x2=(room-room.min())/(room.max()-room.min())
#堆疊屬性數組,構造屬性矩陣
#從(16,)到(16,3),因為新出現的軸是第二個軸所以axis為1
X=np.stack((x0,x1,x2),axis=1)
print(X)
#得到形狀為一列的數組
Y=price.reshape(-1,1)
print(Y)
#=======================【3】設置超參數===============================================
learn_rate=0.001
#迭代次數
iter=500
#每10次迭代顯示一下效果
display_step=50
#=======================【4】設置模型參數初始值===============================================
np.random.seed(612)
W=tf.Variable(np.random.randn(3,1))
#=======================【4】訓練模型=============================================
mse=[]
for i in range(0,iter+1):with tf.GradientTape() as tape:PRED=tf.matmul(X,W)Loss=0.5*tf.reduce_mean(tf.square(Y-PRED))mse.append(Loss)#更新參數dL_dw = tape.gradient(Loss,W)W.assign_sub(learn_rate*dL_dw)#plt.plot(x,pred)if i % display_step==0:print("i:%i,Loss:%f"%(i,mse[i]))

喜歡的話點個贊和關注唄!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/378345.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/378345.shtml
英文地址,請注明出處:http://en.pswp.cn/news/378345.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

linux samba安裝失敗,用aptitude安裝samba失敗

版本&#xff1a;You are using Ubuntu 10.04 LTS- the Lucid Lynx - released in April 2010 and supported until April 2013.root下執行aptitude install sambaReading package lists... DoneBuilding dependency treeReading state information... DoneReading extended st…

django第二個項目--使用模板做一個站點訪問計數器

上一節講述了django和第一個項目HelloWorld&#xff0c;這節我們講述如何使用模板&#xff0c;并做一個簡單的站點訪問計數器。 1、建立模板 在myblog模塊文件夾&#xff08;即包含__init__.py的文件夾)下面新建一個文件夾templates&#xff0c;用于存放HTML模板&#xff0c;在…

strictmath_Java StrictMath log10()方法與示例

strictmathStrictMath類log10()方法 (StrictMath Class log10() method) log10() method is available in java.lang package. log10()方法在java.lang包中可用。 log10() method is used to return the logarithm of the given (base 10) of the given argument in the method…

30、深入理解計算機系統筆記,并發編程(concurrent)(2)

1、共享變量 1&#xff09;線程存儲模型 線程由內核自動調度&#xff0c;每個線程都有它自己的線程上下文&#xff08;thread context&#xff09;&#xff0c;包括一個惟一的整數線程ID&#xff08;Thread ID,TID&#xff09;&#xff0c;棧&#xff0c;棧指針&#xff0c;程序…

PostgreSQL在何處處理 sql查詢之十三

繼續&#xff1a; /*--------------------* grouping_planner* Perform planning steps related to grouping, aggregation, etc.* This primarily means adding top-level processing to the basic* query plan produced by query_planner.** tuple_fraction i…

【視覺項目】基于梯度的NCC模板匹配代碼以及效果

文章目錄流程分析工程代碼【1】NCC代碼【Ⅰ】sttPxGrdnt結構體【Ⅱ】sttTemplateModel模板結構體【Ⅲ】calcAccNCC計算ncc系數函數【Ⅳ】searchNcc NCC模板匹配函數【Ⅴ】searchSecondNcc 二級搜索&#xff1a;在某一特定點周圍再以步進為1搜索【2】測試圖轉外輪廓【Ⅰ】孔洞填…

第七章 再談抽象

第七章 再談抽象 對象魔法 多態&#xff1a;可對不同類型的對象執行相同的操作&#xff0c;而這些操作就像“被施了魔法”一樣能夠正常運行。(即&#xff1a;無需知道對象的內部細節就可使用它)&#xff08;無需知道對象所屬的類&#xff08;對象的類型&#xff09;就能調用其…

c語言math乘法,JavaScript用Math.imul()方法進行整數相乘

1. 基本概念Math.imul()方法用于計算兩個32位整數的乘積&#xff0c;它的結果也是32位的整數。JavaScript的Number類型同時包含了整數和浮點數&#xff0c;它沒有專門的整型和浮點型。因此&#xff0c;Math.imul()方法能提供類似C語言的整數相乘的功能。我們將Math.imul()方法的…

java scanner_Java Scanner nextLong()方法與示例

java scanner掃描器類的nextLong()方法 (Scanner Class nextLong() method) Syntax: 句法&#xff1a; public long nextLong();public long nextLong(int rad);nextLong() method is available in java.util package. nextLong()方法在java.util包中可用。 nextLong() method…

技術總監和CTO的區別 淺談CTO的作用----軟件公司如何開源節流(一)

我一直在思考軟件公司如何開源節流。當然&#xff0c;老板也在思考開源節流。當然&#xff0c;老板思考的開源節流在公司運營層面上&#xff0c;而我作為CTO&#xff0c;我考慮的則是在產品運營角度上來思考這個問題。否則&#xff0c;一個軟件公司&#xff0c;它的生存與發展就…

梯度下降法預測波士頓房價以及簡單的模型評估

目錄原理代碼關于歸一化的思考原理 觀察數據可知屬性之間差距很大&#xff0c;為了平衡所有的屬性對模型參數的影響&#xff0c;首先進行歸一化處理。 每一行是一個記錄&#xff0c;每一列是個屬性&#xff0c;所以對每一列進行歸一化。 二維數組歸一化&#xff1a;1、循環方式…

Windows Phone 內容滑動切換實現

在新聞類的APP中&#xff0c;有一個經常使用的場景&#xff1a;左右滑動屏幕來切換上一條或下一條新聞。 那么通常我們該使用哪種方式去實現呢&#xff1f;可以參考一下Demo的實現步驟。 1&#xff0c;添加Windows Phone用戶自定義控件。例如&#xff1a; 這里我為了演示的方便…

c語言interrupt函數,中斷處理函數數組interrupt[]初始化

在系統初始化期間&#xff0c;trap_init()函數將對中斷描述符表IDT進行第二次初始化(第一次只是建一張IDT表&#xff0c;讓其指向ignore_intr函數)&#xff0c;而在這次初始化期間&#xff0c;系統的0~19號中斷(用于分NMI和異常的中斷向量)均被設置好。與此同時&#xff0c;用于…

bytevalue_Java Number byteValue()方法與示例

bytevalueNumber類byteValue()方法 (Number Class byteValue() method) byteValue() method is available in java.lang package. byteValue()方法在java.lang包中可用。 byteValue() method is used to return the value denoted by this Number object converted to type byt…

第二章 染色熱力學理論單元測驗

1,()測定是染色熱力學性能研究的基礎 吸附等溫線。 2,吸附是放熱反應,溫度升高,親和力() 減小 3,染色系統中包括() 染料。 染深色介質。 染色助劑。 纖維。 4,下列對狀態函數特點敘述正確的為() 狀態函數只有在平衡狀態的系統中才有確定值。 在非平衡狀態的系統…

使用鳶尾花數據集實現一元邏輯回歸、多分類問題

目錄鳶尾花數據集邏輯回歸原理【1】從線性回歸到廣義線性回歸【2】邏輯回歸【3】損失函數【4】總結TensorFlow實現一元邏輯回歸多分類問題原理獨熱編碼多分類的模型參數損失函數CCETensorFlow實現多分類問題獨熱編碼計算準確率計算交叉熵損失函數使用花瓣長度、花瓣寬度將三種鳶…

開源HTML5應用開發框架 - iio Engine

隨著HTML5的發展&#xff0c;越來越多的基于HTML5技術的網頁開發框架出現&#xff0c;在今天的這篇文章中&#xff0c;我們將介紹iio Engine&#xff0c;它是一款開源的創建HTML5應用的web框架。整個框架非常的輕量級&#xff0c;只有45kb大小&#xff0c;并且整合了debug系統&…

c語言double root,C語言修仙

root(1)(2/2)AD1AD4林潯合理推測&#xff0c;青城山劍宗&#xff0c;也就是祁云所在的劍修一脈&#xff0c;掌握著一些道修并不知道的傳承。譬如——怎樣找到赤霄龍雀劍&#xff0c;又或者&#xff0c;怎樣使用它。這樣一來&#xff0c;青城的守衛陣法沒有反應也能解釋了&#…

【轉】Black Box

Introduction BlackBox是FPGA設計中一個重要的技巧&#xff0c;不過覺得Xilinx的文檔沒有很好地將它講清楚。 BlackBox的主要想法就是把設計的某一個子模塊單獨綜合&#xff0c;綜合的結果作為一個黑盒子子模塊&#xff0c;上層設計不再對這個模塊進行優化&#xff0c;只能看到…

Java Compiler disable()方法與示例

編譯器類disable()方法 (Compiler Class disable() method) disable() method is available in java.lang package. disable()方法在java.lang包中可用。 disable() method is used to cause the compiler to stop operation. disable()方法用于使編譯器停止操作。 disable() m…