day53

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import load_iris
import warnings

# 忽略不必要的警告信息
warnings.filterwarnings("ignore")

# --------------------------
# 1. 配置訓練參數與設備
# --------------------------

# 潛在空間維度(生成器的輸入維度)
latent_dim = 10 ?
# 訓練總輪數(GAN通常需要較多迭代才能收斂)
train_epochs = 10000 ?
# 批次大小(根據數據集規模調整)
batch_size = 32 ?
# 學習率(控制參數更新幅度)
learning_rate = 0.0002 ?
# Adam優化器的動量參數(影響收斂穩定性)
beta1 = 0.5 ?

# 自動選擇運算設備(優先GPU,沒有則用CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"當前使用設備: {device}")

# --------------------------
# 2. 數據加載與預處理
# --------------------------

# 加載鳶尾花數據集
iris_dataset = load_iris()
# 提取特征數據和標簽
features = iris_dataset.data
labels = iris_dataset.target

# 只選取Setosa類別(標簽為0)的數據進行訓練
setosa_features = features[labels == 0]

# 將數據縮放到[-1, 1]區間(配合生成器的Tanh輸出激活)
scaler = MinMaxScaler(feature_range=(-1, 1))
scaled_features = scaler.fit_transform(setosa_features)

# 轉換為PyTorch張量并創建數據加載器
# 注意:必須轉為float類型才能與模型參數兼容
data_tensor = torch.from_numpy(scaled_features).float()
dataset = TensorDataset(data_tensor)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 打印數據基本信息
print(f"訓練樣本數量: {len(scaled_features)}")
print(f"特征維度: {scaled_features.shape[1]}") ?# 鳶尾花數據集固定為4維特征

# --------------------------
# 3. 定義生成器和判別器
# --------------------------

class Generator(nn.Module):
? ? """生成器:將隨機噪聲轉換為模擬的鳶尾花特征數據"""
? ? def __init__(self):
? ? ? ? super(Generator, self).__init__()
? ? ? ? # 簡單的全連接網絡結構
? ? ? ? self.net = nn.Sequential(
? ? ? ? ? ? nn.Linear(latent_dim, 16), ?# 從潛在空間映射到16維
? ? ? ? ? ? nn.ReLU(), ?# 激活函數增加非線性
? ? ? ? ? ? nn.Linear(16, 32), ?# 進一步映射到32維
? ? ? ? ? ? nn.ReLU(),
? ? ? ? ? ? nn.Linear(32, 4), ?# 輸出4維特征(與真實數據一致)
? ? ? ? ? ? nn.Tanh() ?# 確保輸出在[-1, 1]范圍內
? ? ? ? )
? ??
? ? def forward(self, x):
? ? ? ? # 前向傳播:輸入噪聲,輸出生成的數據
? ? ? ? return self.net(x)

class Discriminator(nn.Module):
? ? """判別器:區分輸入數據是真實樣本還是生成器偽造的"""
? ? def __init__(self):
? ? ? ? super(Discriminator, self).__init__()
? ? ? ? # 簡單的全連接網絡結構
? ? ? ? self.net = nn.Sequential(
? ? ? ? ? ? nn.Linear(4, 32), ?# 輸入4維特征
? ? ? ? ? ? nn.LeakyReLU(0.2), ?# LeakyReLU避免梯度消失問題
? ? ? ? ? ? nn.Linear(32, 16), ?# 壓縮到16維
? ? ? ? ? ? nn.LeakyReLU(0.2),
? ? ? ? ? ? nn.Linear(16, 1), ?# 輸出單個概率值
? ? ? ? ? ? nn.Sigmoid() ?# 將輸出壓縮到[0,1](表示真實數據的概率)
? ? ? ? )
? ??
? ? def forward(self, x):
? ? ? ? # 前向傳播:輸入數據,輸出判斷概率
? ? ? ? return self.net(x)

# 初始化模型并移動到運算設備
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 打印模型結構
print("\n生成器結構:")
print(generator)
print("\n判別器結構:")
print(discriminator)

# --------------------------
# 4. 配置訓練組件
# --------------------------

# 定義損失函數(二元交叉熵,適合二分類問題)
criterion = nn.BCELoss()

# 定義優化器(分別優化生成器和判別器)
gen_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
dis_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

# --------------------------
# 5. 開始訓練
# --------------------------

print("\n--- 訓練開始 ---")
for epoch in range(train_epochs):
? ? # 遍歷數據加載器中的每一批次
? ? for batch_idx, (real_data,) in enumerate(data_loader):
? ? ? ? # 將真實數據移動到運算設備
? ? ? ? real_data = real_data.to(device)
? ? ? ? current_batch_size = real_data.size(0) ?# 獲取當前批次的實際樣本數(最后一批可能不滿)
? ? ? ??
? ? ? ? # 創建標簽:真實數據標為1,生成數據標為0
? ? ? ? real_labels = torch.ones(current_batch_size, 1).to(device)
? ? ? ? fake_labels = torch.zeros(current_batch_size, 1).to(device)
? ? ? ??
? ? ? ? # --------------------
? ? ? ? # 訓練判別器
? ? ? ? # --------------------
? ? ? ? dis_optimizer.zero_grad() ?# 清空判別器的梯度緩存
? ? ? ??
? ? ? ? # 1. 用真實數據訓練
? ? ? ? real_output = discriminator(real_data)
? ? ? ? # 計算真實數據的損失(希望判別器能認出真實數據)
? ? ? ? loss_real = criterion(real_output, real_labels)
? ? ? ??
? ? ? ? # 2. 用生成的數據訓練
? ? ? ? # 生成隨機噪聲(作為生成器的輸入)
? ? ? ? noise = torch.randn(current_batch_size, latent_dim).to(device)
? ? ? ? # 生成假數據,并阻斷梯度流向生成器(避免影響生成器參數)
? ? ? ? fake_data = generator(noise).detach()
? ? ? ? fake_output = discriminator(fake_data)
? ? ? ? # 計算假數據的損失(希望判別器能認出假數據)
? ? ? ? loss_fake = criterion(fake_output, fake_labels)
? ? ? ??
? ? ? ? # 總損失反向傳播并更新判別器參數
? ? ? ? dis_loss = loss_real + loss_fake
? ? ? ? dis_loss.backward()
? ? ? ? dis_optimizer.step()
? ? ? ??
? ? ? ? # --------------------
? ? ? ? # 訓練生成器
? ? ? ? # --------------------
? ? ? ? gen_optimizer.zero_grad() ?# 清空生成器的梯度緩存
? ? ? ??
? ? ? ? # 重新生成假數據(這次需要計算生成器的梯度)
? ? ? ? noise = torch.randn(current_batch_size, latent_dim).to(device)
? ? ? ? fake_data = generator(noise)
? ? ? ? fake_output = discriminator(fake_data)
? ? ? ??
? ? ? ? # 生成器的損失:希望判別器把假數據當成真的(所以標簽用real_labels)
? ? ? ? gen_loss = criterion(fake_output, real_labels)
? ? ? ? gen_loss.backward()
? ? ? ? gen_optimizer.step()
? ??
? ? # 每1000輪打印一次訓練狀態
? ? if (epoch + 1) % 1000 == 0:
? ? ? ? print(
? ? ? ? ? ? f"輪次 [{epoch+1}/{train_epochs}], "
? ? ? ? ? ? f"判別器損失: {dis_loss.item():.4f}, "
? ? ? ? ? ? f"生成器損失: {gen_loss.item():.4f}"
? ? ? ? )

print("\n--- 訓練完成 ---")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/87696.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/87696.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/87696.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

c++ python 共享內存

一、目的 是為了c來讀取并解碼傳遞給python,Python做測試非常方便,c 和 python之間必須定好協議,整體使用c 來解碼,共享內存傳遞給python 二、主類 主類,串聯decoder,注意decoder并沒有直接在顯存里面穿…

react函數組件的props,ref,state。

react開發我們會把頁面分為一個個組件,組件是獨立而且可復用的重復代碼片段。具體來說組件可以是一個按鈕,一個輸入框。react組件有兩種定義方法,一種是函數組件,一種是類組件。我們這里說一下函數組件之間父子之間如何傳遞props參…

基于ARM+FPGA實現的BISS-C協議解決方案,適用于高精度光柵位移傳感器等

模塊簡介 本資源提供了專為FPGA設計的BISS-C接口協議發送模塊源碼。BISS-C模式作為一種高速、同步的串行通信協議,廣泛應用于高精度光柵位移傳感器的數據傳輸中,特別適用于需要精確位置信息的應用場景。此模式遵循主從架構,其中FPGA作為主控制…

spring中@Transactional注解和事務的實戰理解附代碼

文章目錄 前言一、事務是什么?二、事務的特性2.1隔離性2.2事務的隔離級別 三、Transactional注解Transactional注解簡介基本用法常用屬性配置事務傳播行為事務隔離級別異常處理與回滾性能優化建議 四、 事務不生效的可能原因方法訪問權限非public自調用問題異常被捕…

替代進口SCA7606【智芯微】國產高精度電流傳感器 工業新能源電網專用

SCA7606(智芯微)產品解析與推廣文案一、產品概述SCA7606 是 智芯微電子(ZXMICRO) 推出的一款 高精度數字隔離式電流傳感器芯片,采用 霍爾效應數字輸出 技術,專為 工業控制、新能源、智能電網 等領域的電流檢…

Java 與 Vue 全棧開發:“一課一得“ 學習筆記系統實戰

一、項目背景與核心價值 "一課一得" 是一個面向學習者的筆記管理平臺,旨在幫助用戶系統化記錄、整理和回顧學習內容。項目采用前后端分離架構:前端基于 Vue.js 構建交互式界面,后端使用 Java Spring Boot 實現業務邏輯&#xff0c…

百度文心大模型 4.5 開源深度測評:技術架構、部署實戰與生態協同全解析

聲明:本文只做實際測評,并非廣告 1.前言 2025 年 6 月 30 日,百度做出一項重大舉措,將文心大模型 4.5 系列正式開源,并選擇國內領先的開源平臺 GitCode 作為首發平臺。該模型也是百度在2025年3月16日發布的自研的新一…

力扣_鏈表_python版本

一、206. 反轉鏈表代碼: class Solution:def reverseList(self, head):dummy ListNode()cur headwhile cur:last cur.nextcur.next dummy.nextdummy.next curcur lastreturn dummy.next二、92. 反轉鏈表 IIclass Solution:def reverseBetween(self, head: Opt…

[netty5: WebSocketProtocolHandler]-源碼分析

在閱讀這篇文章前,推薦先閱讀:[netty5: MessageToMessageCodec & MessageToMessageEncoder & MessageToMessageDecoder]-源碼分析 WebSocketProtocolHandler WebSocketProtocolHandler 是 WebSocket 處理的基礎抽象類,負責管理 Web…

[2025CVPR]一種新穎的視覺與記憶雙適配器(Visual and Memory Dual Adapter, VMDA)

引言 多模態目標跟蹤(Multi-modal Object Tracking)旨在通過結合RGB模態與其他輔助模態(如熱紅外、深度、事件數據)來增強可見光傳感器的感知能力,尤其在復雜場景下顯著提升跟蹤魯棒性。然而,現有方法在頻…

理想汽車6月交付36279輛 第二季度共交付111074輛

理想汽車-W(02015)發布公告,2025年6月,理想汽車交付新車36279輛,第二季度共交付111074輛。截至2025年6月30日,理想汽車歷史累計交付量為133.78萬輛。 在成立十周年之際,理想汽車已連續兩年成為人民幣20萬元以上中高端市…

MobileNets: 高效的卷積神經網絡用于移動視覺應用

摘要 我們提出了一類高效的模型,稱為MobileNets,專門用于移動和嵌入式視覺應用。MobileNets基于一種簡化的架構,利用深度可分離卷積構建輕量級的深度神經網絡。我們引入了兩個簡單的全局超參數,能夠有效地在延遲和準確性之間進行…

SDP服務發現協議:動態查詢設備能力的底層邏輯(面試深度解析)

SDP的底層邏輯揭示了物聯網設備交互的本質——先建立認知,再開展協作。 一、SDP 核心知識點高頻考點解析 1.1 SDP 的定位與作用 考點:SDP 在藍牙協議棧中的位置及核心功能 解析:SDP(Service Discovery Protocol,服務發現協議)位于藍牙協議棧的中間層,依賴 L2CAP 協議傳…

CppCon 2018 學習:GIT, CMAKE, CONAN

提到的: “THE MOST COMMON C TOOLSET” VERSION CONTROL SYSTEM BUILDING PACKAGE MANAGEMENT 這些是 C 項目開發中最核心的工具鏈組成部分。下面我將逐一解釋每部分的作用、常見工具,以及它們如何協同構建現代 C 項目。 1. VERSION CONTROL SYSTEM&am…

使用tensorflow的線性回歸的例子(五)

我們使用Iris數據,Sepal length為y值而Petal width為x值。import matplotlib.pyplot as pltimport numpy as npimport tensorflow as tffrom sklearn import datasetsfrom tensorflow.python.framework import opsops.reset_default_graph()# Load the data# iris.d…

虛幻基礎:動作——蒙太奇

能幫到你的話,就給個贊吧 😘 文章目錄 動作——蒙太奇如果動作被打斷,則后續的動畫通知不會執行 動作——蒙太奇 如果動作被打斷,則后續的動畫通知不會執行

[工具系列] 開源的 API 調試工具 Postwoman

介紹 隨著 Web 應用的復雜性增加,API 測試已成為開發中不可或缺的一部分,無論是前端還是后端開發,確保 API 正常運行至關重要。 Postman 長期以來是開發者進行 API 測試的首選工具,但是很多基本功能都需要登陸才能使用&#xff…

【力扣 簡單 C】746. 使用最小花費爬樓梯

目錄 題目 解法一 題目 解法一 int min(int a, int b) {return a < b ? a : b; }int minCostClimbingStairs(int* cost, int costSize) {const int n costSize; // 樓頂&#xff0c;第n階// 爬到第n階的最小花費 // 爬到第n-1階的最小花費從第n-1階爬上第n階的花費…

python+django開發帶auth接口

pythondjango開發帶auth接口 # coding utf-8 import base64 from django.contrib import auth as django_authfrom django.core.exceptions import ObjectDoesNotExist from django.http import JsonResponsefrom sign.models import Eventdef user_auth(request):"&quo…

RBAC權限模型如何讓API訪問控制既安全又靈活?

url: /posts/9f01e838545ae8d34016c759ef461423/ title: RBAC權限模型如何讓API訪問控制既安全又靈活? date: 2025-07-01T04:52:07+08:00 lastmod: 2025-07-01T04:52:07+08:00 author: cmdragon summary: RBAC權限模型通過用戶、角色和權限的關聯實現訪問控制,核心組件包括用…