【零基礎學AI】第21講:TensorFlow基礎 - 神經網絡搭建入門

在這里插入圖片描述

本節課你將學到

  • 理解什么是TensorFlow,為什么要用它
  • 掌握TensorFlow安裝和基本操作
  • 學會搭建第一個神經網絡
  • 完成手寫數字識別項目

開始之前

環境要求

  • Python 3.8+
  • 至少4GB內存
  • 網絡連接(用于下載數據集)

前置知識

  • 第1-8講:Python基礎和開發環境
  • 基本的數學概念(加減乘除即可)

什么是TensorFlow?

用最簡單的話解釋

想象你要蓋房子:

  • 傳統編程:你需要自己制作每一塊磚頭、每一根鋼筋
  • TensorFlow:就像一個預制構件工廠,磚頭、鋼筋、水泥都給你準備好了,你只需要按圖紙組裝

TensorFlow就是Google開發的"AI積木工廠",它提供了:

  • 🧱 基礎積木:各種數學運算函數
  • 🔧 組裝工具:神經網絡層、優化器
  • 📏 測量工具:損失函數、評估指標
  • 🏭 生產線:自動訓練和優化

為什么選擇TensorFlow?

  1. 簡單易用:像搭積木一樣構建神經網絡
  2. 功能強大:支持從簡單分類到復雜的圖像識別
  3. 社區龐大:遇到問題容易找到解決方案
  4. 工業級:Google、Netflix等大公司都在用

TensorFlow安裝

安裝步驟

# 方法1:使用pip安裝(推薦)
pip install tensorflow# 方法2:如果上面很慢,使用國內鏡像
pip install tensorflow -i https://pypi.tuna.tsinghua.edu.cn/simple/# 驗證安裝
python -c "import tensorflow as tf; print('TensorFlow版本:', tf.__version__)"

驗證安裝

import tensorflow as tf# 檢查版本
print("TensorFlow版本:", tf.__version__)# 檢查是否支持GPU(有GPU會顯示GPU信息,沒有也正常)
print("GPU可用:", len(tf.config.list_physical_devices('GPU')) > 0)# 簡單測試
hello = tf.constant("Hello, TensorFlow!")
print(hello.numpy().decode())

預期輸出:

TensorFlow版本: 2.x.x
GPU可用: False  # 沒有GPU也沒關系
Hello, TensorFlow!

TensorFlow核心概念

1. 張量(Tensor)- 數據容器

張量就是多維數組,就像不同形狀的盒子:

import tensorflow as tf
import numpy as np# 0維張量(標量)- 一個數字
scalar = tf.constant(42)
print("標量:", scalar)# 1維張量(向量)- 一行數字
vector = tf.constant([1, 2, 3, 4])
print("向量:", vector)# 2維張量(矩陣)- 表格
matrix = tf.constant([[1, 2], [3, 4]])
print("矩陣:")
print(matrix)# 3維張量 - 立體數據(比如彩色圖片)
tensor_3d = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print("3維張量形狀:", tensor_3d.shape)

2. 計算圖 - 操作流程

TensorFlow會自動記錄你的操作,就像記錄菜譜步驟:

# 定義變量(可以改變的數)
x = tf.Variable(3.0, name="x")
y = tf.Variable(4.0, name="y")# 定義計算(TensorFlow會記錄這些步驟)
z = x * x + y * y  # z = x2 + y2print("x =", x.numpy())
print("y =", y.numpy()) 
print("z = x2 + y2 =", z.numpy())

3. 自動微分 - 神經網絡的關鍵

神經網絡需要不斷調整參數,TensorFlow可以自動計算如何調整:

# 使用GradientTape記錄操作
x = tf.Variable(2.0)with tf.GradientTape() as tape:y = x * x * x  # y = x3# 自動計算導數(斜率)
dy_dx = tape.gradient(y, x)
print(f"當x={x.numpy()}時,y=x3的導數是:{dy_dx.numpy()}")
print("手工計算:3*22=12,驗證正確!")

第一個神經網絡

問題:預測房價

假設我們要根據房屋面積預測房價,這是一個最簡單的神經網絡:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt# 設置中文字體
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 1. 準備數據
# 假設房價 = 面積 * 2 + 一些隨機噪聲
np.random.seed(42)  # 固定隨機種子,確保結果可重現
areas = np.random.uniform(50, 200, 100)  # 100個面積數據,50-200平米
prices = areas * 2 + np.random.normal(0, 10, 100)  # 價格=面積*2+噪聲# 數據標準化(重要!神經網絡喜歡小數字)
areas_norm = (areas - areas.mean()) / areas.std()
prices_norm = (prices - prices.mean()) / prices.std()print("數據準備完成!")
print(f"面積范圍:{areas.min():.1f} - {areas.max():.1f} 平米")
print(f"價格范圍:{prices.min():.1f} - {prices.max():.1f} 萬元")
# 2. 構建神經網絡
# 最簡單的神經網絡:只有一層,一個神經元
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1], name='price_predictor')
])# 編譯模型(設置學習規則)
model.compile(optimizer='adam',      # 優化器:adam是最常用的loss='mse',           # 損失函數:均方誤差metrics=['mae']       # 評估指標:平均絕對誤差
)# 查看模型結構
print("模型結構:")
model.summary()
# 3. 訓練模型
print("開始訓練...")
history = model.fit(areas_norm, prices_norm,  # 訓練數據epochs=100,               # 訓練輪數verbose=0                 # 不顯示訓練過程(避免刷屏)
)print("訓練完成!")# 4. 評估效果
test_area = np.array([100])  # 測試:100平米的房子
test_area_norm = (test_area - areas.mean()) / areas.std()
predicted_price_norm = model.predict(test_area_norm, verbose=0)# 反標準化得到實際價格
predicted_price = predicted_price_norm * prices.std() + prices.mean()print(f"預測結果:100平米的房子價格約為 {predicted_price[0][0]:.1f} 萬元")
print(f"理論價格:100 * 2 = 200萬元")
print(f"預測誤差:{abs(predicted_price[0][0] - 200):.1f} 萬元")

可視化結果

# 繪制訓練過程
plt.figure(figsize=(12, 4))# 損失變化
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'])
plt.title('訓練損失變化')
plt.xlabel('訓練輪數')
plt.ylabel('損失值')
plt.grid(True)# 預測效果
plt.subplot(1, 2, 2)
plt.scatter(areas, prices, alpha=0.6, label='真實數據')# 畫預測線
test_areas = np.linspace(50, 200, 100)
test_areas_norm = (test_areas - areas.mean()) / areas.std()
predicted_prices_norm = model.predict(test_areas_norm, verbose=0)
predicted_prices = predicted_prices_norm * prices.std() + prices.mean()plt.plot(test_areas, predicted_prices, 'r-', linewidth=2, label='神經網絡預測')
plt.xlabel('面積 (平米)')
plt.ylabel('價格 (萬元)')
plt.title('房價預測效果')
plt.legend()
plt.grid(True)plt.tight_layout()
plt.show()print("圖表已顯示!紅線是神經網絡學到的規律")

完整項目:手寫數字識別

現在我們來做一個更有趣的項目:讓計算機識別手寫數字!

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt# 1. 加載MNIST數據集(手寫數字數據)
print("正在下載MNIST數據集...")
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()print("數據加載完成!")
print(f"訓練圖片數量: {len(x_train)}")
print(f"測試圖片數量: {len(x_test)}")
print(f"圖片尺寸: {x_train[0].shape}")# 查看幾個樣本
plt.figure(figsize=(10, 2))
for i in range(5):plt.subplot(1, 5, i+1)plt.imshow(x_train[i], cmap='gray')plt.title(f'標簽: {y_train[i]}')plt.axis('off')
plt.suptitle('手寫數字樣本')
plt.show()
# 2. 數據預處理
# 標準化像素值到0-1范圍
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0# 將28x28的圖片展平成784維向量
x_train_flat = x_train.reshape(60000, 784)
x_test_flat = x_test.reshape(10000, 784)print("數據預處理完成!")
print(f"訓練數據形狀: {x_train_flat.shape}")
print(f"測試數據形狀: {x_test_flat.shape}")
# 3. 構建神經網絡
model = tf.keras.Sequential([# 輸入層:784個神經元(對應784個像素)tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),# 隱藏層:128個神經元,使用ReLU激活函數tf.keras.layers.Dense(64, activation='relu'),# 輸出層:10個神經元(對應0-9十個數字)tf.keras.layers.Dense(10, activation='softmax')
])# 編譯模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',  # 多分類問題的損失函數metrics=['accuracy']
)# 查看模型結構
print("神經網絡結構:")
model.summary()
# 4. 訓練模型
print("開始訓練神經網絡...")
history = model.fit(x_train_flat, y_train,epochs=10,                    # 訓練10輪batch_size=128,              # 每次處理128個樣本validation_split=0.1,        # 10%的數據用于驗證verbose=1                    # 顯示訓練進度
)print("訓練完成!")
# 5. 評估模型
test_loss, test_accuracy = model.evaluate(x_test_flat, y_test, verbose=0)
print(f"測試準確率: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")# 預測幾個測試樣本
predictions = model.predict(x_test_flat[:5], verbose=0)
predicted_labels = np.argmax(predictions, axis=1)print("\n預測結果:")
for i in range(5):print(f"圖片{i+1}: 真實標簽={y_test[i]}, 預測標簽={predicted_labels[i]}, "f"置信度={predictions[i][predicted_labels[i]]:.4f}")
# 6. 可視化結果
plt.figure(figsize=(15, 5))# 訓練歷史
plt.subplot(1, 3, 1)
plt.plot(history.history['accuracy'], label='訓練準確率')
plt.plot(history.history['val_accuracy'], label='驗證準確率')
plt.title('模型準確率')
plt.xlabel('訓練輪數')
plt.ylabel('準確率')
plt.legend()
plt.grid(True)plt.subplot(1, 3, 2)
plt.plot(history.history['loss'], label='訓練損失')
plt.plot(history.history['val_loss'], label='驗證損失')
plt.title('模型損失')
plt.xlabel('訓練輪數')
plt.ylabel('損失值')
plt.legend()
plt.grid(True)# 預測結果展示
plt.subplot(1, 3, 3)
# 顯示一個預測示例
sample_idx = 0
plt.imshow(x_test[sample_idx], cmap='gray')
plt.title(f'真實: {y_test[sample_idx]}, 預測: {predicted_labels[sample_idx]}')
plt.axis('off')plt.tight_layout()
plt.show()print("🎉 恭喜!你已經成功訓練了一個手寫數字識別神經網絡!")

運行效果

預期輸出

數據加載完成!
訓練圖片數量: 60000
測試圖片數量: 10000
圖片尺寸: (28, 28)神經網絡結構:
Model: "sequential_1"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================dense_1 (Dense)             (None, 128)               100480    dense_2 (Dense)             (None, 64)                8256      dense_3 (Dense)             (None, 10)                650       
=================================================================
Total params: 109,386
Trainable params: 109,386
Non-trainable params: 0訓練完成!
測試準確率: 0.9751 (97.51%)預測結果:
圖片1: 真實標簽=7, 預測標簽=7, 置信度=0.9999
圖片2: 真實標簽=2, 預測標簽=2, 置信度=0.9995
...

生成的文件

  • 模型訓練過程可視化圖表
  • 手寫數字樣本展示
  • 預測結果對比

常見問題解答

Q1: 安裝TensorFlow時出錯

錯誤信息: ERROR: Failed building wheel for tensorflow

解決方法:

# 方法1:升級pip
pip install --upgrade pip# 方法2:使用conda安裝
conda install tensorflow# 方法3:安裝CPU版本
pip install tensorflow-cpu

Q2: 訓練很慢怎么辦?

解決方法:

  • 減少訓練輪數(epochs):從10改為5
  • 減少數據量:只用前1000個樣本訓練
  • 使用更小的網絡:減少神經元數量

Q3: 準確率不高怎么辦?

可能原因和解決方法:

  • 訓練輪數太少:增加epochs
  • 網絡太簡單:增加更多層或神經元
  • 學習率不合適:嘗試不同的優化器

Q4: 內存不夠怎么辦?

解決方法:

# 減少batch_size
model.fit(x_train, y_train, batch_size=32)  # 從128改為32# 或者使用更少的數據
x_train_small = x_train[:10000]  # 只用前10000個樣本

課后練習

基礎練習

  • 修改神經網絡結構,嘗試不同數量的神經元
  • 改變訓練輪數,觀察準確率變化
  • 使用自己手寫的數字測試模型

進階練習

  • 嘗試識別時裝圖片(Fashion-MNIST數據集)
  • 添加更多隱藏層,觀察效果變化
  • 使用不同的激活函數(如tanh、sigmoid)

挑戰練習

  • 實現一個簡單的繪圖界面,讓用戶畫數字并識別
  • 比較不同優化器的效果(SGD vs Adam)
  • 分析模型預測錯誤的樣本,找出共同特點

總結

這節課我們學會了:

  1. TensorFlow基礎概念:理解張量、計算圖、自動微分
  2. 神經網絡構建:使用Sequential模型搭建網絡
  3. 模型訓練流程:編譯→訓練→評估→預測
  4. 實際項目經驗:完成了手寫數字識別

下節課預告: 我們將學習PyTorch,對比兩個主流深度學習框架的差異,并用PyTorch實現圖像分類器。

技術支持

如遇問題,請檢查:

  1. Python版本是否3.8+
  2. TensorFlow是否正確安裝
  3. 網絡連接是否正常(下載數據集需要)
  4. 內存是否足夠(建議4GB+)

記住:每個AI專家都是從第一個神經網絡開始的!你已經邁出了重要的一步! 🚀

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/87641.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/87641.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/87641.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

STM32 串口USART通訊驅動

前言 本篇文章對串口Usart進行講解,為后面的esp8266和語音模塊控制打好基礎。 1.串口USART USART(Universal Synchronous/Asynchronous Receiver/Transmitter,通用同步 / 異步收發器) 是一種常見的串行通信接口,廣泛應…

pytorch版本densenet代碼講解

DenseNet 模型代碼詳解 下面是 DenseNet 模型代碼的逐部分詳細解析: 1. 導入模塊 import re from collections import OrderedDict from functools import partial from typing import Any, Optionalimport torch import torch.nn as nn import torch.nn.functional…

前端常見設計模式深度解析

# 前端常見設計模式深度解析一、設計模式概述 設計模式是解決特定問題的經驗總結,前端開發中常用的設計模式可分為三大類: 創建型模式:處理對象創建機制(單例、工廠等)結構型模式:處理對象組合(…

React 學習(3)

核心API——React.creatElement()方法優點:將創建元素、添加屬性和事件、添加內容和子元素等使用原生dom需要進行復雜操作才能實現的功能集成在一個API中。1.該方法接收三個參數第一個是要創建的元素的名稱(小寫是因為如果,大寫開頭會被react…

傾斜攝影無人機飛行航線規劃流程詳解

在傾斜攝影測量項目中,航線規劃的嚴謹性直接決定了最終三維模型的質量與完整性。照片覆蓋不全、模型空洞、紋理模糊或分辨率不達標等問題,往往源于規劃階段對關鍵細節的疏忽。本文將系統梳理傾斜攝影無人機航線規劃的核心流程與關鍵要點,旨在…

Minio大文件分片上傳

一、引入依賴 <dependency><groupId>io.minio</groupId><artifactId>minio</artifactId><version>8.3.3</version></dependency> 二、自定義Minio客戶端 package com.gstanzer.video.controller;import com.google.common.c…

Jenkins 插件深度應用:讓你的CI/CD流水線如虎添翼 [特殊字符]

Jenkins 插件深度應用&#xff1a;讓你的CI/CD流水線如虎添翼 &#x1f680; 嘿&#xff0c;各位開發小伙伴&#xff01;今天咱們來聊聊Jenkins的插件生態系統。如果說Jenkins是一臺強大的引擎&#xff0c;那插件就是讓這臺引擎發揮最大威力的各種零部件。準備好了嗎&#xff1…

密碼學(斯坦福)

密碼學筆記 \huge{密碼學筆記} 密碼學筆記 斯坦福大學密碼學的課程筆記 課程網址&#xff1a;https://www.bilibili.com/video/BV1Rf421o79E/?spm_id_from333.337.search-card.all.click&vd_source5cc05a038b81f6faca188e7cf00484f6 概述 密碼學的使用背景 安全信息保護…

代碼隨想錄算法訓練營第四十六天|動態規劃part13

647. 回文子串 題目鏈接&#xff1a;647. 回文子串 - 力扣&#xff08;LeetCode&#xff09; 文章講解&#xff1a;代碼隨想錄 思路&#xff1a; 以dp【i】表示以s【i】結尾的回文子串的個數&#xff0c;發現遞推公式推導不出來此路不通 以dp【i】【j】表示s【i】到s【j】的回…

基于四種機器學習算法的球隊數據分析預測系統的設計與實現

文章目錄 有需要本項目的代碼或文檔以及全部資源&#xff0c;或者部署調試可以私信博主項目介紹項目展示隨機森林模型XGBoost模型邏輯回歸模型catboost模型每文一語 有需要本項目的代碼或文檔以及全部資源&#xff0c;或者部署調試可以私信博主 項目介紹 本項目旨在設計與實現…

http、SSL、TLS、https、證書

一、基礎概念 1.HTTP HTTP (超文本傳輸協議) 是一種用于客戶端和服務器之間傳輸超媒體文檔的應用層協議&#xff0c;是萬維網的基礎。 簡而言之&#xff1a;一種獲取和發送信息的標準協議 2.SSL 安全套接字層&#xff08;SSL&#xff09;是一種通信協議或一組規則&#xf…

在 C++ 中,判斷 `std::string` 是否為空字符串

在 C 中&#xff0c;判斷 std::string 是否為空字符串有多種方法&#xff0c;以下是最常用的幾種方式及其區別&#xff1a; 1. 使用 empty() 方法&#xff08;推薦&#xff09; #include <string>std::string s; if (s.empty()) {// s 是空字符串 }特性&#xff1a; 時間…

【Harmony】鴻蒙企業應用詳解

【HarmonyOS】鴻蒙企業應用詳解 一、前言 1、應用類型定義速覽&#xff1a; HarmonyOS目前針對應用分為三種類型&#xff1a;普通應用&#xff0c;游戲應用&#xff0c;企業應用。 而企業應用又分為&#xff0c;企業普通應用和設備管理應用MDM&#xff08;Mobile Device Man…

Linux云計算基礎篇(8)

VIM 高級特性插入模式按 i 進入插入模式。按 o 在當前行下方插入空行并進入插入模式。按 O 在當前行上方插入空行并進入插入模式。命令模式:set nu 顯示行號。:set nonu 取消顯示行號。:100 光標跳轉到第 100 行。G 光標跳轉到文件最后一行。gg 光標跳轉到文件第一行。30G 跳轉…

Linux進程單例模式運行

Linux進程單例模式運行 #include <iostream> #include <stdlib.h> #include <unistd.h> #include <string.h> #include <stdio.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h>int write_pid(const cha…

【Web 后端】部署服務到服務器

文章目錄 前言一、如何啟動服務二、掛載和開機啟動服務1. 配置systemctl 服務2. 創建server用戶3. 啟動服務 總結 前言 如果你的后端服務寫好了如果部署到你的服務器呢&#xff0c;本次通過fastapi寫的服務實例&#xff0c;示范如何部署到服務器&#xff0c;并做服務管理。 一…

國產MCU學習Day5——CW32F030C8T6:窗口看門狗功能全解析

每日更新教程&#xff0c;評論區答疑解惑&#xff0c;小白也能變大神&#xff01;" 目錄 一.窗口看門狗&#xff08;WWDG&#xff09;簡介 二.窗口看門狗寄存器列表 三.窗口看門狗復位案例 一.窗口看門狗&#xff08;WWDG&#xff09;簡介 CW32F030C8T6 內部集成窗口看…

2025年文件加密軟件分享:守護數字世界的核心防線

在數字化時代&#xff0c;數據已成為個人與企業的寶貴資產&#xff0c;文件加密軟件通過復雜的算法&#xff0c;確保信息在存儲、傳輸與共享過程中的保密性、完整性與可用性。一、文件加密軟件的核心原理文件加密軟件算法以其高效性與安全性廣泛應用&#xff0c;通過對文件數據…

node.js下載教程

1.項目環境文檔 語雀 2.nvm安裝教程與nvm常見命令,超詳細!-阿里云開發者社區 C:\Windows\System32>nvm -v 1.2.2 C:\Windows\System32>nvm list available Error retrieving "http://npm.taobao.org/mirrors/node/index.json": HTTP Status 404 C:\Window…

(AI如何解決問題)在一個項目,跳轉到外部html頁面,頁面布局

問題描述目前&#xff0c;ERP后臺有很多跳轉外部鏈接的地方&#xff0c;會直接打開一個tab顯示。因為有些頁面是適配手機屏幕顯示&#xff0c;放在瀏覽器會超級大。不美觀&#xff0c;因此提出優化。修改前&#xff1a;修改后&#xff1a;思考過程1、先看下代碼&#xff1a;log…