用TensorFlow進行邏輯回歸(六)

import tensorflow as tf

import numpy as np

from tensorflow.keras.datasets import mnist

import time

# MNIST數據集參數

num_classes = 10? # 數字0到9, 10類

num_features = 784? # 28*28

# 訓練參數

learning_rate = 0.01

training_steps = 1000

batch_size = 256

display_step =50

# 預處理數據集

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 轉為float32

x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)

# 轉為一維向量

x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])

# [0, 255] 到 [0, 1]

x_train, x_test = x_train / 255, x_test / 255

# tf.data.Dataset.from_tensor_slices 是使用x_train, y_train構建數據集

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))

# 將數據集打亂,并設置batch_size大小

train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)

# 權重[748, 10],圖片大小28*28,類數

W = tf.Variable(tf.ones([num_features, num_classes]), name="weight")

# 偏置[10],共10類

b = tf.Variable(tf.zeros([num_classes]), name="bias")

# 邏輯回歸函數

def logistic_regression(x):

??? return tf.nn.softmax(tf.matmul(x, W) + b)

# 損失函數

def cross_entropy(y_pred, y_true):

??? # tf.one_hot()函數的作用是將一個值化為一個概率分布的向量

??? y_true = tf.one_hot(y_true, depth=num_classes)

??? # tf.clip_by_value將y_pred的值控制在1e-9和1.0之間

??? y_pred = tf.clip_by_value(y_pred, 1e-9, 1.0)

??? return tf.reduce_mean(-tf.reduce_sum(y_true * tf.math.log(y_pred)))

# 計算精度

def accuracy(y_pred, y_true):

??? # tf.cast作用是類型轉換

??? correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))

??? return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# 優化器采用隨機梯度下降

optimizer = tf.optimizers.SGD(learning_rate)

# 梯度下降

def run_optimization(x, y):

??? with tf.GradientTape() as g:

??????? pred = logistic_regression(x)

??????? loss = cross_entropy(pred, y)

??? # 計算梯度

??? gradients = g.gradient(loss, [W, b])

??? # 更新梯度

?? ?optimizer.apply_gradients(zip(gradients, [W, b]))

# 開始訓練

start = time.perf_counter()

for epoch in range(5):

??? for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):

??????? run_optimization(batch_x, batch_y)

??????? if step % display_step == 0:

??????????? pred = logistic_regression(batch_x)

??????????? loss = cross_entropy(pred, batch_y)

??????????? acc = accuracy(pred, batch_y)

??????????? print("step: %i, loss: %f, accuracy: %f" % (step, loss, acc))

???

# 測試模型的準確率

pred = logistic_regression(x_test)

print("Test Accuracy: %f" % accuracy(pred, y_test))

elapsed = (time.perf_counter() - start)

print("Time used:",elapsed)

例3

import? matplotlib.pyplot as plt

import? numpy as np

import tensorflow as tf

print(tf.__version__)

%matplotlib inline

mnist = tf.keras.datasets.mnist

(train_images,train_labels),(test_images,test_labels)=mnist.load_data()

total_num=len(train_images)

valid_split=0.2

train_num =int(total_num*(1-valid_split))

train_x=train_images[:train_num]

train_y=train_labels[:train_num]

valid_x=train_images[train_num:]

valid_y=train_labels[train_num:]

test_x=test_images

test_y=test_labels

train_x=train_x.reshape(-1,784)

valid_x=valid_x.reshape(-1,784)

test_x=test_x.reshape(-1,784)

train_x=tf.cast(train_x/255.0,tf.float32)

valid_x=tf.cast(valid_x/255.0,tf.float32)

test_x=tf.cast(test_x/255.0,tf.float32)

train_y=tf.one_hot(train_y,depth=10)

valid_y=tf.one_hot(valid_y,depth=10)

test_y=tf.one_hot(test_y,depth=10)

#定義模型函數

def model(x,w,b):

??? pred=tf.matmul(x,w)+b

??? return tf.nn.softmax(pred)

np.random.seed(612)

W = tf.Variable(np.random.randn(784,10),dtype=tf.float32)

B = tf.Variable(np.random.randn(10),dtype=tf.float32)

def loss(x,y,w,b):

??? pred = model(x,w,b)

??? loss_=tf.keras.losses.categorical_crossentropy(y_true=y,y_pred=pred)

??? return tf.reduce_mean(loss_)

#設置迭代次數和學習率

train_epochs = 100

batch_size=50

learning_rate = 0.001

def grad(x,y,w,b):

??? with tf.GradientTape() as tape:

??????? loss_ = loss(x,y,w,b)

??? return tape.gradient(loss_,[w,b])

optimizer= tf.keras.optimizers.Adam(learning_rate=learning_rate)

def accuracy(x,y,w,b):

???? pred = model(x,w,b)

???? correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))

???? return tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#構建線性函數的斜率和截距

total_step=int(train_num/batch_size)

loss_list_train = []

loss_list_valid = []

acc_list_train = []

acc_valid_train = []

training_epochs=100

#開始訓練,輪數為epoch,采用SGD隨機梯度下降優化方法

for epoch in range(training_epochs):

??? for step in range(total_step):

??????? xs=train_x[step*batch_size:(step+1)*batch_size]

??????? ys=train_y[step*batch_size:(step+1)*batch_size]

??????? #計算損失,并保存本次損失計算結果

??????? grads=grad(xs,ys,W,B)

??????? optimizer.apply_gradients(zip(grads,[W,B]))

??? loss_train =loss(train_x,train_y,W,B).numpy()

??? loss_valid =loss(valid_x,valid_y,W,B).numpy()

??? acc_train=accuracy(train_x,train_y,W,B).numpy()

??? acc_valid=accuracy(valid_x,valid_y,W,B).numpy()

??? loss_list_train.append(loss_train)

??? loss_list_valid.append(loss_valid)

??? acc_list_train.append(acc_train)

??? acc_valid_train.append(acc_valid)

print("epoch={:3d},train_loss={:.4f},train_acc={:.4f},val_loss={:.4f},val_acc={:.4f}".format(epoch+1,loss_train,acc_train,loss_valid,acc_valid))

plt.xlabel("Epochs")

plt.ylabel("Loss")

plt.plot(loss_list_train,'blue',label="Train Loss")

plt.plot(loss_list_valid,'red',label="Valid Loss")

plt.xlabel("Epochs")

plt.ylabel("Accuracy")

plt.plot(acc_list_train,'blue',label="Train Acc")

plt.plot(acc_valid_train,'red',label="Valid Acc")

acc_test=accuracy(test_x,test_y,W,B).numpy

print("Test accuracy:",acc_test)

def predict(x,w,b):

??? pred=model(x,w,b)

??? result=tf.argmax(pred,1).numpy

return result

pred_test=predict(test_x,W,B)

pred_test

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/91293.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/91293.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/91293.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【HTTP版本演變】

在瀏覽器中輸入URL并按回車之后會發生什么1. 輸入URL并解析輸入URL后,瀏覽器會解析出協議、主機、端口、路徑等信息,并構造一個HTTP請求(瀏覽器會根據請求頭判斷是否又HTTP緩存,并根據是否有緩存決定從服務器獲取資源還是使用緩存…

Android 16系統源碼_窗口動畫(一)窗口過渡動畫層級圖分析

一 窗口過渡動畫 1.1 案例效果圖1.2 案例源碼 1.2.1 添加權限 (AndroidManifest.xml) <!-- 系統懸浮窗權限&#xff08;Android 6.0需動態請求&#xff09; --> <uses-permission android:name"android.permission.SYSTEM_ALERT_WINDOW" />1.2.2 窗口顯示…

騰訊云WAF域名分級防護實戰筆記

基于業務風險等級、合規要求及騰訊云最佳實踐&#xff0c;提供可直接落地的配置方案&#xff0c;供學習借鑒&#xff1a;一、域名分級與防護原則1. ?域名分級清單&#xff08;核心資產&#xff09;???主域名??業務類型??風險等級??合規要求??防護等級?example.com…

1. 請說出你知道的水平垂直居中的方法

總結 容器 flex 布局&#xff0c;jsutify-content: center; align-items: center;容器 flex 布局&#xff0c;子項 margin: auto;容器 relative 布局&#xff0c;子項 absolute 布局&#xff0c;left: 50%; top: 50%; transform: translate(-50%, -50%);子項 absolute 布局&…

VS Code `launch.json` 完整配置指南:參數詳解 + 配置實例

文章目錄&#x1f4e6; 一、基本結構&#x1f50d; 二、單個配置項詳解示例配置&#xff1a;&#x1f9e9; 三、字段說明與可選值&#x1f4c1; 四、常用變量&#xff08;宏替換&#xff09;&#x1f6e0;? 五、常見配置實例1?? 調試當前打開的 .py 文件2?? 調試 Jupyter …

使用瀏覽器inspect調試wx小程序

edge://inspect/#devices調試wx小程序 背景&#xff1a; 在開發混合項目的過程中&#xff0c;常常需要在app環境排查問題&#xff0c;接口可以使用fiddler等工具來抓包&#xff0c;但是js錯誤就不好抓包了&#xff0c;這里介紹一種調試工具-瀏覽器。 調試過程 首先電腦打開edg…

【論文閱讀】-《Simple Black-box Adversarial Attacks》

簡單黑盒對抗攻擊 Chuan Guo Jacob R. Gardner Yurong You Andrew Gordon Wilson Kilian Q. Weinberger 摘要 我們提出了一種在黑盒&#xff08;black-box&#xff09;場景下構建對抗樣本&#xff08;adversarial images&#xff09;的極其簡單的方法。與白盒&#xff08;…

基于ASP.NET+SQL Server實現(Web)企業進銷存管理系統

企業進銷存管理系統的設計和實現一、摘要進銷存管理是現代企業生產經營中的重要環節&#xff0c;是完成企業資源配置的重要管理工作&#xff0c;對企業生產經營效率的最大化發揮著重要作用。本文以我國中小企業的進銷存管理為研究對象&#xff0c;描述了企業進銷存管理系統從需…

(LeetCode 面試經典 150 題 ) 15. 三數之和 (排序+雙指針)

題目&#xff1a;15. 三數之和 思路&#xff1a;排序雙指針&#xff0c;時間復雜度0(n^2nlogn)。 先將數組nums升序排序&#xff0c;方便去重和使用雙指針。第一層for循環來枚舉第一位數&#xff0c;后面使用雙指針來找到第二個、第三個數即可&#xff0c;細節看注釋。 C版本…

easy-springdoc

介紹 簡化springdoc的使用&#xff08;可以搭配knife4j-openapi3-jakarta-spring-boot-starter一起使用&#xff09; maven引用 <dependency><groupId>io.github.xiaoyudeguang</groupId><artifactId>easy-springdoc</artifactId><version>…

配置nodejs,若依

1.配置node.js環境 Node.js — Download Node.js 1.下載好一路下一步&#xff0c;可以安裝到d盤 裝完之后執行 npm -v 顯示版本號即安裝成功 2.安裝好后新建兩個文件夾&#xff0c;node_cache和node_global 3.配置環境變量 新建變量 在path里編輯變量 4.配置用戶變量 5.…

Python學習之路(十二)-開發和優化處理大數據量接口

文章目錄一、接口設計原則二、性能優化策略1. 數據庫優化2. 緩存機制3. 并發模型三、內存管理技巧1. 內存優化實踐2. 避免內存泄漏四、接口測試與監控1. 性能測試2. 日志與監控3. 錯誤處理與限流五、代碼示例&#xff08;Flask 流式處理&#xff09;六、部署建議一、接口設計原…

【實時Linux實戰系列】實時數據流的網絡傳輸

在實時系統中&#xff0c;數據流的實時傳輸是許多應用場景的核心需求之一。無論是工業自動化中的傳感器數據、金融交易中的高頻數據&#xff0c;還是多媒體應用中的視頻流&#xff0c;都需要在嚴格的時間約束內完成數據的傳輸。實時數據流的傳輸不僅要求高吞吐量&#xff0c;還…

C#數組(一維數組、多維數組、交錯數組、參數數組)

在 C# 中&#xff0c;數組是一種用于存儲固定大小的相同類型元素的集合。數組可以包含值類型、引用類型或對象類型的元素&#xff0c;并且在內存中是連續存儲的。以下是關于 C# 數組的詳細介紹&#xff1a;1. 一維數組聲明與初始化// 聲明數組 int[] numbers; // 聲…

Dify離線安裝包-集成全部插件、模板和依賴組件,方便安可內網使用

項目介紹 Dify一鍵離線安裝包&#xff0c;集成安裝了全部插件、模板&#xff0c;并集成了dify全部插件所需的依賴組件。方便你在內網、安可環境等離線狀態下使用。 Dify是一個開源的LLM應用開發平臺。其直觀的界面結合了AI工作流、RAG管道、Agent、模型管理、可觀測性功能等&…

面試150 翻轉二叉樹

思路 采用先序遍歷&#xff0c;可以通過新建根節點node&#xff0c;將原來root的右子樹連到去node的左子樹中&#xff0c;root的左子樹連到去node的右子樹中。 # Definition for a binary tree node. # class TreeNode: # def __init__(self, val0, leftNone, rightNone): …

C++-linux系統編程 3.gcc編譯工具

GCC編譯工具鏈完全指南 GCC&#xff08;GNU Compiler Collection&#xff09;是Linux系統下最常用的編譯器套件&#xff0c;支持C、C、Objective-C等多種編程語言。本章將深入講解GCC的編譯流程、常用選項及項目實戰技巧。 一、GCC編譯的四個核心階段 GCC編譯一個程序需要經過四…

uView UI 組件大全

uView UI 是一個基于 uni-app 的高質量 UI 組件庫&#xff0c;提供豐富的跨平臺組件&#xff08;支持 H5、小程序、App 等&#xff09;。以下是其核心組件的分類大全及功能說明&#xff0c;結合最新版本&#xff08;1.2.10&#xff09;整理&#xff1a; &#x1f4e6; 一、基礎…

QWidget 和 QML 的本質和使用上的區別

QWidget 和 QML 是 Qt 框架中兩種不同的 UI 開發技術&#xff0c;它們在底層實現、設計理念和使用場景上有顯著區別。以下是它們的本質和主要差異&#xff1a;1. 本質區別特性QWidgetQML (Qt Modeling Language)技術基礎基于 C 的面向對象控件庫基于聲明式語言&#xff08;類似…

中轉模型服務的風險

最近發現一些 AI 相關帖子下&#xff0c;存在低質 claude code 中轉的小廣告。 其中轉的基本原理就是 claude code 允許自己提供 API endpoint 和 key&#xff0c;可以使用任意一個 OpenAI API 兼容的供應商&#xff0c;就這么簡單。 進一點 claude token&#xff0c;再混入一點…