TensorFlow 實現 Mixture Density Network (MDN) 的完整說明

本文檔詳細解釋了一段使用 TensorFlow 構建和訓練混合密度網絡(Mixture Density Network, MDN)的代碼,涵蓋數據生成、模型構建、自定義損失函數與預測可視化等各個環節。


1. 導入庫與設置超參數

import numpy as np 
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import math

說明

  • 引入用于數值運算(NumPy)、構建深度學習模型(TensorFlow/Keras)和繪圖(Matplotlib)的基礎工具包。

超參數定義

N_HIDDEN = 15         # 隱藏層神經元數量
N_MIXES = 10          # GMM 中混合成分數量
OUTPUT_DIMS = 1       # 輸出維度(目標變量維度)

2. 自定義 MDN 層

class MDN(layers.Layer):def __init__(self, output_dims, num_mixtures, **kwargs):super(MDN, self).__init__(**kwargs)self.output_dims = output_dimsself.num_mixtures = num_mixturesself.params = self.num_mixtures * (2 * self.output_dims + 1)  # pi, mu, sigmaself.dense = layers.Dense(self.params)def call(self, inputs):output = self.dense(inputs)return output

說明

  • params 表示 GMM 每個分量包含 mu(均值)、sigma(標準差)和 pi(權重),共 2*D + 1 個參數。
  • 輸出維度為 (batch_size, num_mixtures * (2*output_dims + 1))

3. 自定義 MDN 損失函數

def get_mixture_loss_func(output_dims, num_mixtures):def mdn_loss(y_true, y_pred):y_true = tf.reshape(y_true, [-1, 1])out_mu = y_pred[:, :num_mixtures * output_dims]out_sigma = y_pred[:, num_mixtures * output_dims:2 * num_mixtures * output_dims]out_pi = y_pred[:, -num_mixtures:]mu = tf.reshape(out_mu, [-1, num_mixtures, output_dims])sigma = tf.exp(tf.reshape(out_sigma, [-1, num_mixtures, output_dims]))pi = tf.nn.softmax(out_pi)y_true = tf.tile(y_true[:, tf.newaxis, :], [1, num_mixtures, 1])normal_dist = tf.exp(-0.5 * tf.square((y_true - mu) / sigma)) / (sigma * tf.sqrt(2.0 * np.pi))prob = tf.reduce_prod(normal_dist, axis=2)weighted_prob = prob * piloss = -tf.math.log(tf.reduce_sum(weighted_prob, axis=1) + 1e-8)return tf.reduce_mean(loss)return mdn_loss

說明

  • 通過概率密度函數計算目標值屬于 GMM 各個分布的概率,并取加權平均。
  • 對數似然函數取負作為損失。

4. 從輸出分布中采樣

def sample_from_output(y_pred, output_dims, num_mixtures, temp=1.0):out_mu = y_pred[:num_mixtures * output_dims]out_sigma = y_pred[num_mixtures * output_dims:2 * num_mixtures * output_dims]out_pi = y_pred[-num_mixtures:]out_sigma = np.exp(out_sigma)out_pi = np.exp(out_pi / temp)out_pi /= np.sum(out_pi)mixture_idx = np.random.choice(np.arange(num_mixtures), p=out_pi)mu = out_mu[mixture_idx * output_dims:(mixture_idx + 1) * output_dims]sigma = out_sigma[mixture_idx * output_dims:(mixture_idx + 1) * output_dims]sample = np.random.normal(mu, sigma)return sample

說明

  • 使用 softmax 處理 pi,選擇一個分布后按對應的 musigma 采樣。
  • temp 控制采樣溫度(溫度越高分布越平坦)。

5. 生成訓練數據

NSAMPLE = 3000
y_data = np.float32(np.random.uniform(-10.5, 10.5, NSAMPLE))
r_data = np.random.normal(size=NSAMPLE)
x_data = np.sin(0.75 * y_data) * 7.0 + y_data * 0.5 + r_data * 1.0
x_data = x_data.reshape((NSAMPLE, 1))
y_data = y_data.reshape((NSAMPLE, 1))

說明

  • 構造非線性映射關系的合成數據:x = sin(0.75y)*7 + 0.5y + 噪聲
  • x 是輸入,y 是目標。

6. 構建模型

model = keras.Sequential([layers.Dense(N_HIDDEN, input_shape=(1,), activation='relu'),layers.Dense(N_HIDDEN, activation='relu'),MDN(OUTPUT_DIMS, N_MIXES)
])
model.compile(loss=get_mixture_loss_func(OUTPUT_DIMS, N_MIXES), optimizer=keras.optimizers.Adam())
model.summary()

說明

  • 構建一個兩層隱層的前饋神經網絡,輸出 MDN 層。
  • 使用自定義的 MDN 損失函數訓練模型。

7. 模型訓練

model.fit(x_data, y_data, batch_size=128, epochs=200, validation_split=0.15, verbose=1)
  • 批量大小 128,訓練 200 個 epoch,保留 15% 數據用于驗證。

8. 模型測試與預測可視化

x_test = np.linspace(-15, 15, 1000).astype(np.float32).reshape(-1, 1)
y_pred = model.predict(x_test)
y_samples = np.array([sample_from_output(p, OUTPUT_DIMS, N_MIXES) for p in y_pred])
  • 對連續輸入進行預測并從預測的 GMM 中采樣。

可視化預測結果

plt.figure()
plt.scatter(x_test, y_samples, alpha=0.3, s=10)
plt.title("MDN Predictions")
plt.xlabel("x")
plt.ylabel("y")
plt.show()

原始數據與預測對比

plt.figure(figsize=(8, 5))
plt.scatter(x_data, y_data, label="Original Data", alpha=0.2, s=10)
plt.scatter(x_test, y_samples, label="MDN Samples", alpha=0.5, s=10, color='r')
plt.title("MDN Prediction vs Training Data")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.show()

總代碼如下

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import math# 超參數
N_HIDDEN = 15
N_MIXES = 10
OUTPUT_DIMS = 1# === 1. 自定義 MDN 層 ===
class MDN(layers.Layer):def __init__(self, output_dims, num_mixtures, **kwargs):super(MDN, self).__init__(**kwargs)self.output_dims = output_dimsself.num_mixtures = num_mixturesself.params = self.num_mixtures * (2 * self.output_dims + 1)  # pi, mu, sigmaself.dense = layers.Dense(self.params)def call(self, inputs):output = self.dense(inputs)return output# === 2. 自定義損失函數 ===
def get_mixture_loss_func(output_dims, num_mixtures):def mdn_loss(y_true, y_pred):y_true = tf.reshape(y_true, [-1, 1])out_mu = y_pred[:, :num_mixtures * output_dims]out_sigma = y_pred[:, num_mixtures * output_dims:2 * num_mixtures * output_dims]out_pi = y_pred[:, -num_mixtures:]mu = tf.reshape(out_mu, [-1, num_mixtures, output_dims])sigma = tf.exp(tf.reshape(out_sigma, [-1, num_mixtures, output_dims]))pi = tf.nn.softmax(out_pi)y_true = tf.tile(y_true[:, tf.newaxis, :], [1, num_mixtures, 1])normal_dist = tf.exp(-0.5 * tf.square((y_true - mu) / sigma)) / (sigma * tf.sqrt(2.0 * np.pi))prob = tf.reduce_prod(normal_dist, axis=2)weighted_prob = prob * piloss = -tf.math.log(tf.reduce_sum(weighted_prob, axis=1) + 1e-8)return tf.reduce_mean(loss)return mdn_loss# === 3. 從輸出采樣函數 ===
def sample_from_output(y_pred, output_dims, num_mixtures, temp=1.0):out_mu = y_pred[:num_mixtures * output_dims]out_sigma = y_pred[num_mixtures * output_dims:2 * num_mixtures * output_dims]out_pi = y_pred[-num_mixtures:]out_sigma = np.exp(out_sigma)out_pi = np.exp(out_pi / temp)out_pi /= np.sum(out_pi)mixture_idx = np.random.choice(np.arange(num_mixtures), p=out_pi)mu = out_mu[mixture_idx * output_dims:(mixture_idx + 1) * output_dims]sigma = out_sigma[mixture_idx * output_dims:(mixture_idx + 1) * output_dims]sample = np.random.normal(mu, sigma)return sample# === 4. 生成訓練數據 ===
NSAMPLE = 3000
y_data = np.float32(np.random.uniform(-10.5, 10.5, NSAMPLE))
r_data = np.random.normal(size=NSAMPLE)
x_data = np.sin(0.75 * y_data) * 7.0 + y_data * 0.5 + r_data * 1.0
x_data = x_data.reshape((NSAMPLE, 1))
y_data = y_data.reshape((NSAMPLE, 1))plt.figure()
plt.scatter(x_data, y_data, alpha=0.3)
plt.title("Training Data")
plt.show()# === 5. 構建模型 ===
model = keras.Sequential([layers.Dense(N_HIDDEN, input_shape=(1,), activation='relu'),layers.Dense(N_HIDDEN, activation='relu'),MDN(OUTPUT_DIMS, N_MIXES)
])
model.compile(loss=get_mixture_loss_func(OUTPUT_DIMS, N_MIXES), optimizer=keras.optimizers.Adam())
model.summary()# === 6. 模型訓練 ===
model.fit(x_data, y_data, batch_size=128, epochs=200, validation_split=0.15, verbose=1)# === 7. 測試與可視化 ===
x_test = np.linspace(-15, 15, 1000).astype(np.float32).reshape(-1, 1)
y_pred = model.predict(x_test)
y_samples = np.array([sample_from_output(p, OUTPUT_DIMS, N_MIXES) for p in y_pred])plt.figure()
plt.scatter(x_test, y_samples, alpha=0.3, s=10)
plt.title("MDN Predictions")
plt.xlabel("x")
plt.ylabel("y")
plt.show()
# === 8. 測試數據與預測對比圖 ===plt.figure(figsize=(8, 5))
plt.scatter(x_data, y_data, label="Original Data", alpha=0.2, s=10)
plt.scatter(x_test, y_samples, label="MDN Samples", alpha=0.5, s=10, color='r')
plt.title("MDN Prediction vs Training Data")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.show()

總結

本項目展示了如何使用 TensorFlow 構建混合密度網絡,用以建模復雜的條件分布。相比傳統回歸模型,MDN 能夠生成多峰預測結果,適用于不確定性高、輸出存在多解的場景。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/79663.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/79663.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/79663.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

數據結構實驗7.2:二叉樹的基本運算

文章目錄 一,實驗目的二,問題描述三,基本要求四,實驗操作五,示例代碼六,運行效果 一,實驗目的 深入理解樹與二叉樹的基本概念,包括節點、度、層次、深度等,清晰區分二叉…

直線軸承常規分類知多少?

直線軸承的分類方式多樣,以下是從材質、結構形狀和常規系列三個維度進行的具體分類: 按主要材質分類 外殼材質:常見的有不銹鋼,具有良好的耐腐蝕性,適用于一些對環境要求較高、易受腐蝕的工作場景;軸承…

websocket和SSE學習記錄

websocket學習記錄 websocket使用場景 即時聊天在線文檔協同編輯實施地圖位置 從開發角度來學習websocket開發 即使通信項目 通過node建立簡單的后端接口,利用fs, path, express app.get(*, (req, res) > {const assetsType req.url.split(/)[…

CUDA編程中影響性能的小細節總結

一、內存訪問優化 合并內存訪問:確保相鄰線程訪問連續內存地址(全局內存對齊訪問)。優先使用共享內存(Shared Memory)減少全局內存訪問。避免共享內存的Bank Conflict(例如,使用padding或調整訪…

【雙指針】對撞指針 快慢指針 移動零

文章目錄 雙指針介紹對撞指針快慢指針283. 移動零解題思路算法思路算法流程雙指針介紹 ? 算法中的雙指針,并不一定是指我們平常在 c/c++ 使用的指針類型,更多時候其實是數組的下標等,因為它們也是有標識某個元素的功能,通常我們也就順其自然地稱其為 “指針” ! ? 常見…

數據結構0基礎學習堆

文章目錄 簡介公式建立堆函數解釋 堆排序O(n logn)topk問題 簡介 堆是一種重要的數據結構,是一種完全二叉樹,(二叉樹的內容后面會出), 堆分為大小堆,大堆,左右結點都小于根節點,&am…

4.17--4.19刷題記錄(貪心)

第一部分:準備工作 代碼隨想錄中解釋為:貪心的本質是選擇每一階段的局部最優,從而達到全局最優。 而我的理解為:貪心實質上是具有最優子結構的一種算法。所有的解都能由當前最優的解組成。 第二部分:開始刷題 &…

學習筆記十七——Rust 支持面向對象編程嗎?

🧠 Rust 支持面向對象編程嗎? Rust 是一門多范式語言,主要以 安全、并發、函數式、系統級編程為核心目標,但它同時也支持面向對象的一些關鍵特性,比如: 特性傳統 OOP(如 Java/C)Ru…

【Linux】43.網絡基礎(2.5)

文章目錄 2.4 TCP/UDP對比2.4.1 用UDP實現可靠傳輸(經典面試題) 2.5 TCP 相關實驗2.5.1 理解 listen 的第二個參數 2.4 TCP/UDP對比 我們說了TCP是可靠連接, 那么是不是TCP一定就優于UDP呢? TCP和UDP之間的優點和缺點, 不能簡單, 絕對的進行比較TCP用于可靠傳輸的情況, 應用于…

three.js與webgl在buffer上的對應關系

一、three.js的類名 最近開始接觸three.js 看到three.js中的一些類名和webgl的很相似 不自覺的就想對比一下 二、three.js中繪制4個點 // 創建點的幾何體 const vertices new Float32Array([0.0, 0.0, 0.0, // 點11.0, 0.0, 0.0, // 點20.0, 1.0, 0.0, // 點30.…

DataWhale AI春訓營 問題匯總

1.沒用下載訓練集導致出錯,爆錯如下。 這個時候需要去比賽官網下載對應的初賽訓練集 unzip -d /mnt/workspace/sais_third_new_energy_baseline/data /mnt/workspace/sais_third_new_energy_baseline/初賽訓練集.zip 在命令行執行這個命令解壓 2.沒定義測試集 te…

CANFD技術在新能源汽車通信網絡中的應用與可靠性分析

一、引言 新能源汽車產業正處于快速發展階段,其電子系統復雜度不斷攀升,涵蓋眾多傳感器、控制器與執行器。高效通信網絡成為確保新能源汽車安全運行與智能功能實現的核心要素。傳統CAN總線因帶寬限制,難以滿足高級駕駛輔助系統(A…

Python字典深度解析:高效鍵值對數據管理指南

一、字典核心概念解析 1. 字典定義與特征 字典(Dictionary)是Python中??基于哈希表實現??的無序可變容器,通過鍵值對存儲數據,具有以下核心特性: ??鍵值對結構??:{key: value}形式存儲數據??快…

C++中unique_lock和lock_guard區別

目錄 1.自動鎖定與解鎖機制 2.靈活性 3.所有權轉移 4.可與條件變量配合使用 5.性能開銷 在 C 中&#xff0c;std::unique_lock 和 std::lock_guard 都屬于標準庫 <mutex> 中的互斥鎖管理工具&#xff0c;用于簡化互斥鎖的使用并確保線程安全。但它們存在一些顯著區別…

Nvidia顯卡架構演進

1 簡介 顯示卡&#xff08;英語&#xff1a;Display Card&#xff09;簡稱顯卡&#xff0c;也稱圖形卡&#xff08;Graphics Card&#xff09;&#xff0c;是個人電腦上以圖形處理器&#xff08;GPU&#xff09;為核心的擴展卡&#xff0c;用途是提供中央處理器以外的微處理器幫…

下載electron 22.3.27 源碼錯誤集錦

下載步驟同 electron源碼下載及編譯_electron源碼編譯-CSDN博客 問題1 從github 下載 dugite超時&#xff0c;原因沒有找到 Validation failed. Expected 8ea2d0d3c9d9e4615069913207371ffe892dc10fb93975972f2f6e668f2e3b3a but got e3b0c44298fc1c149afbf4c8996fb92427ae41e…

洛谷P1120 小木棍

#算法/進階搜索 思路: 首先,最初始想法,將我們需要枚舉的長木棍個數計算出來,在dfs中,我們先判斷,此時枚舉這根長木棍需要的長度是否為0,如果為0,我們就枚舉下一個根木棍,接著再判斷,此時仍需要枚舉的木棍個數是否為0,如果為0,代表我們這種方案可行,直接打印長木棍長度,接著我們…

Linux教程-常用命令系列二

文章目錄 1. 系統管理常用命令1. useradd - 創建用戶賬戶功能基本用法常用選項示例 2. passwd - 管理用戶密碼功能基本用法常用選項示例 3. kill - 終止進程功能基本用法常用信號示例 4. date - 顯示和設置系統時間功能基本用法常用選項時間格式示例 5. bc - 高精度計算器功能基…

18、TimeDiff論文筆記

TimeDiff **1. 背景與動機****2. 擴散模型基礎****3. TimeDiff 模型****3.1 前向擴散過程****3.2 后向去噪過程** 4、TimeDiff&#xff08;架構&#xff09;原理訓練推理其他關鍵點解釋 DDPM&#xff08;相關數學&#xff09;1、正態分布2、條件概率1. **與多個條件相關**&…

整合SSM——(SpringMVC+Spring+Mybatis)

目錄 SSM整合 創建項目 導入依賴 配置文件 SpringConfig MyBatisConfig JdbcConfig ServletConfig SpringMvcConfig 功能模塊 測試 業務層接口測試 控制層測試 SSM是Java Web開發中常用的三個主流框架組合的縮寫&#xff0c;分別對應Spring、Spring MVC、MyBatis…