PINN:用深度學習PyTorch求解微分方程

神經網絡技術已在計算機視覺與自然語言處理等多個領域實現了突破性進展。然而在微分方程求解領域,傳統神經網絡因其依賴大規模標記數據集的特性而表現出明顯局限性。物理信息神經網絡(Physics-Informed Neural Networks, PINN)通過將物理定律直接整合到學習過程中,有效彌補了這一不足,使其成為求解常微分方程(ODE)和偏微分方程(PDE)的高效工具。

傳統神經網絡模型需要依賴規模龐大的標記數據集,而這類數據的采集往往成本高昂且耗時顯著。PINN通過將物理定律(具體表現為微分方程)融入訓練過程,顯著提高了數據利用效率。這種方法使得在流體動力學、量子力學和氣候系統建模等科學領域實現基于數據的科學發現成為可能,為跨學科研究提供了新的技術路徑。

求解微分方程一般方法

有如下微分方程:

圖片

邊界條件

圖片

由于

圖片

對 x 積分一次可得

圖片

再次積分,我們得到

圖片

現在,應用邊界條件:

  1. 對于 y(0)=1

圖片

  1. 對于 y(2)=5:

圖片

因此,解析解為:

圖片

用神經網絡解決微分方程

該方法稱為 PINN(物理信息神經網絡),在我們的示例中的工作方式如下:

神經網絡近似:

  • 我們定義一個神經網絡 y(θ,x),其中 θ 表示網絡參數(權重和偏差)。該網絡旨在近似微分方程的解 y(x)。

  • 在我們的例子中,神經網絡是一個小型全連接網絡(具有一個或多個隱藏層),它以空間坐標 x 作為輸入并輸出 y(x) 的近似值。

自動微分:

  • 在這種情況下使用神經網絡的一個主要好處是大多數現代深度學習庫(如 PyTorch)都支持自動區分。

  • 這意味著我們可以直接從網絡輸出計算關于輸入 x 的導數 y′(x) 和 y′′(x)。

殘差計算:

  • 對于 ODE

圖片

我們將殘差 r(x) 定義為:

圖片

在網絡近似精確的點處,殘差應該為零。

損失函數:

  • PINN 方法中的損失函數由兩部分組成:

  • 殘差損失:在域內的一組內部搭配點處計算殘差 r(x) 的均方誤差 (MSE)。該項強制網絡的預測滿足微分方程。

  • 邊界條件損失:網絡預測與給定邊界條件之間的差異的 MSE。

    圖片

  • 總損失為:

圖片

PINN的技術特性與創新點

PINN與傳統神經網絡的根本區別在于,它不依賴于標記數據集進行學習,而是將微分方程約束直接嵌入到損失函數中。這意味著模型學習得到的函數*yNN(x)*需同時滿足:

  • 給定的微分方程約束條件

  • 特定的邊界條件和初始條件

PINN框架中的偏微分方程(PDE)通常表示為:

圖片

其中

圖片

以二階微分方程為例:

圖片

這表明所求函數y(x)必須嚴格滿足該方程。

基于PINN求解微分方程的實踐案例

步驟1:?導入必要的庫函數

import?torch
import?torch.nn?as?nn
import?torch.optim?as?optim
import?matplotlib.pyplot?as?plt
import?numpy?as?np

步驟2:?定義 y(x) 的神經網絡近似值

class?ODE_Net(nn.Module):def?__init__(self, hidden_units=20):super(ODE_Net, self).__init__()self.layer1 = nn.Linear(1, hidden_units)self.layer2 = nn.Linear(hidden_units, hidden_units)self.layer3 = nn.Linear(hidden_units,?1)self.activation = nn.Tanh()def?forward(self, x):out = self.activation(self.layer1(x))out = self.activation(self.layer2(out))out = self.layer3(out)return?out

步驟3:計算 ODE 殘差

def?residual(model, x):"""Compute the ODE residual:y''(x) - 3 = 0."""# Enable gradients for xx.requires_grad_(True)y = model(x)# Compute first derivative: y'(x)dydx = torch.autograd.grad(y, x,grad_outputs=torch.ones_like(y),create_graph=True)[0]# Compute second derivative: y''(x)d2ydx2 = torch.autograd.grad(dydx, x,grad_outputs=torch.ones_like(dydx),create_graph=True)[0]# Compute the residual of the ODE: y''(x) - 3res = d2ydx2 -?3.0return?res

步驟4:損失函數

def?boundary_loss(model):"""Compute the loss associated with the boundary conditions:y(0)=1 and y(2)=5."""# Boundary condition at x=0: y(0)=1x0 = torch.tensor([[0.0]], device=device, requires_grad=True)y0 = model(x0)loss_bc0 = (y0 -?1.0)**2# Boundary condition at x=2: y(2)=5x2 = torch.tensor([[2.0]], device=device, requires_grad=True)y2 = model(x2)loss_bc2 = (y2 -?5.0)**2return?loss_bc0 + loss_bc2

步驟5:模型訓練

??# Initialize the model and optimizermodel = ODE_Net(hidden_units=20).to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)num_epochs =?5000# Generate interior points in the domain [0,2]N_interior =?50x_interior =?2?* torch.rand(N_interior,?1, device=device) ?# uniformly distributed in [0,2]# Training loop
for?epoch?in?range(num_epochs):model.train()optimizer.zero_grad()# Compute the residual loss at interior pointsr_interior = residual(model, x_interior)loss_res = torch.mean(r_interior**2)# Compute the boundary condition lossloss_bc = boundary_loss(model)# Total loss is the sum of the residual and boundary lossesloss = loss_res + loss_bcloss.backward()optimizer.step()if?epoch %?500?==?0:print(f"Epoch?{epoch}, Loss:?{loss.item():.6e}")# Evaluate and compare the solutionmodel.eval()x_test = torch.linspace(0,?2,?100, device=device).unsqueeze(1)y_pred = model(x_test).detach().cpu().numpy().flatten()x_test_np = x_test.cpu().numpy().flatten()
Epoch 0, Loss: 3.222174e+01
Epoch 500, Loss: 1.378794e-01
Epoch 1000, Loss: 5.264541e-03
Epoch 1500, Loss: 3.903809e-03
Epoch 2000, Loss: 3.040434e-03
Epoch 2500, Loss: 2.319159e-03
Epoch 3000, Loss: 1.656389e-03
Epoch 3500, Loss: 9.695904e-04
Epoch 4000, Loss: 4.545122e-04
Epoch 4500, Loss: 2.485181e-04

步驟6:對比精確度

??# Analytical solution: y(x) = (3/2)x^2 - x + 1y_true =?1.5?* x_test_np**2?- x_test_np +?1plt.figure(figsize=(8,?4))plt.plot(x_test_np, y_pred, label="PINN Solution")plt.plot(x_test_np, y_true,?'--', label="Analytical Solution")plt.xlabel("x")plt.ylabel("y(x)")plt.legend()plt.title("ODE: y''(x) - 3 = 0 with y(0)=1, y(2)=5")plt.show()

圖片

使用 PINN 求解更復雜的 ODE

圖片

class?ODE_Net(nn.Module):def?__init__(self, hidden_units=20):super(ODE_Net, self).__init__()self.layer1 = nn.Linear(1, hidden_units)self.layer2 = nn.Linear(hidden_units, hidden_units)self.layer3 = nn.Linear(hidden_units, hidden_units)self.layer4 = nn.Linear(hidden_units,?1)self.activation = nn.Tanh()def?forward(self, x):out = self.activation(self.layer1(x))out = self.activation(self.layer2(out))out = self.activation(self.layer3(out))out = self.layer4(out)return?outdef?residual(model, x):x.requires_grad_(True)y = model(x)y_x = torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y),create_graph=True)[0]y_xx = torch.autograd.grad(y_x, x, grad_outputs=torch.ones_like(y_x),create_graph=True)[0]y_xxx = torch.autograd.grad(y_xx, x, grad_outputs=torch.ones_like(y_xx),create_graph=True)[0]y_xxxx = torch.autograd.grad(y_xxx, x, grad_outputs=torch.ones_like(y_xxx),create_graph=True)[0] ? ?res = y_xxxx -?2*y_xxx + y_xxreturn?resdef?boundary_loss(model):x0 = torch.tensor([[0.0]], device=device, requires_grad=True)y0 = model(x0)y0_x = torch.autograd.grad(y0, x0, grad_outputs=torch.ones_like(y0),create_graph=True)[0]y0_xx = torch.autograd.grad(y0_x, x0, grad_outputs=torch.ones_like(y0_x),create_graph=True)[0]y0_xxx = torch.autograd.grad(y0_xx, x0, grad_outputs=torch.ones_like(y0_xx),create_graph=True)[0]bc1 = y0 -?1.0? ? ??# y(0) = 1bc2 = y0_x -?0.0? ??# y'(0) = 0bc3 = y0_xx - (-1.0) ?# y''(0) = -1 ?-> y0_xx + 1 = 0bc4 = y0_xxx -?2.0# y'''(0) = 2loss_bc = bc1**2?+ bc2**2?+ bc3**2?+ bc4**2return?loss_bcdef?main():model = ODE_Net(hidden_units=20).to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)num_epochs =?10000N_interior =?50x_interior = torch.rand(N_interior,?1, device=device)for?epoch?in?range(num_epochs):model.train()optimizer.zero_grad()r_interior = residual(model, x_interior)loss_res = torch.mean(r_interior**2)loss_bc = boundary_loss(model) ? ? ? ?loss = loss_res + loss_bcloss.backward()optimizer.step()if?epoch %?500?==?0:print(f"Epoch?{epoch}, Loss:?{loss.item():.6e}")model.eval()x_test = torch.linspace(0,?1,?100, device=device).unsqueeze(1)y_pred = model(x_test).detach().cpu().numpy().flatten()x_test_np = x_test.cpu().numpy().flatten()# Analytical solution: y(x) = 8 + 4x - 7e^x + 3xe^xy_true =?8?+?4*x_test_np -?7*np.exp(x_test_np) +?3*x_test_np*np.exp(x_test_np)plt.figure(figsize=(8,4))plt.plot(x_test_np, y_pred, label="Solution using PINN")plt.plot(x_test_np, y_true,?'--', label="Analytical solution")plt.xlabel("x")plt.ylabel("y(x)")plt.legend()plt.show()if?__name__ ==?"__main__":main()
Epoch 0, Loss: 6.779857e+00
Epoch 500, Loss: 2.092192e-01
Epoch 1000, Loss: 4.828146e-02
Epoch 1500, Loss: 3.233620e-02
Epoch 2000, Loss: 3.518355e-04
Epoch 2500, Loss: 2.392017e-04
Epoch 3000, Loss: 1.745588e-04
Epoch 3500, Loss: 1.332138e-04
Epoch 4000, Loss: 1.039377e-04
Epoch 4500, Loss: 3.754102e-03
Epoch 5000, Loss: 7.414911e-05
Epoch 5500, Loss: 5.272599e-05
Epoch 6000, Loss: 4.189969e-05
Epoch 6500, Loss: 1.759992e-03
Epoch 7000, Loss: 1.593289e-04
Epoch 7500, Loss: 2.400937e-05
Epoch 8000, Loss: 8.885263e-03
Epoch 8500, Loss: 6.434955e-05
Epoch 9000, Loss: 1.761451e-05
Epoch 9500, Loss: 1.477061e-05

圖片

通過結果可以看出,我們已經成功地使用PINN方法求解了上述微分方程,并獲得了與解析解高度一致的數值解。

寫在最后

物理信息神經網絡(PINN)代表了一種在微分方程求解領域的重要技術突破,它將深度學習與物理定律有機結合,為傳統數值求解方法提供了一種高效、數據驅動的替代方案。PINN方法不僅在理論上具有創新性,同時在實際應用中展現出廣闊的應用前景,為復雜物理系統的建模與分析提供了新的研究路徑。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/76876.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/76876.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/76876.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

程序化廣告行業(89/89):廣告創意審核的關鍵要點與實踐應用

程序化廣告行業(89/89):廣告創意審核的關鍵要點與實踐應用 在程序化廣告這個充滿機遇與挑戰的領域,持續學習和知識共享是我們不斷進步的動力。一直以來,我都希望能和大家一同深入探索這個行業,今天讓我們聚…

【ES6新特性】Proxy進階實戰

🌟ES6 Proxy終極指南:從攔截器到響應式框架實現🔥 一、💡 為什么Proxy是革命性的?先看痛點場景 1.1 Object.defineProperty的局限 😫 // Vue2響應式實現 let data { count: 0 }; Object.defineProperty(…

c++解決動態規劃

一、引言: 在我們學習了算法之后,我們一定遇到過貪心算法。而在貪心算法中就有著這樣一個經典的例子——湊錢。 Eg: 你有面額為10、5、1的紙幣,當你買菜時需要花費26元,請問需要最少的紙幣張數是多少。 當我們用貪心算法去解決這個問題的時候,我們…

Qwen 2.5 VL 多種推理方案

Qwen 2.5 VL 多種推理方案 flyfish 單圖推理 from modelscope import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor from qwen_vl_utils import process_vision_info import torchmodel_path "/media/model/Qwen/Qwen25-VL-7B-Instruct/"m…

機器視覺檢測Pin針歪斜應用

在現代電子制造業中,Pin針(插針)是連接器、芯片插座、PCB板等元器件的關鍵部件。如果Pin針歪斜,可能導致接觸不良、短路,甚至整機失效。傳統的人工檢測不僅效率低,還容易疲勞漏檢。 MasterAlign 機器視覺對…

經典算法問題解析:兩數之和與三數之和的Java實現

文章目錄 1. 問題背景2. 兩數之和(Two Sum)2.1 問題描述2.2 哈希表解法代碼實現關鍵點解析復雜度對比 3. 三數之和(3Sum)3.1 問題描述3.2 排序雙指針解法代碼實現關鍵點解析復雜度分析 4. 對比總結5. 常見問題解答6. 擴展練習 1. …

1022 Digital Library

1022 Digital Library 分數 30 全屏瀏覽 切換布局 作者 CHEN, Yue 單位 浙江大學 A Digital Library contains millions of books, stored according to their titles, authors, key words of their abstracts, publishers, and published years. Each book is assigned an u…

地理人工智能中位置編碼的綜述:方法與應用

以下是對論文 《A Review of Location Encoding for GeoAI: Methods and Applications》 的大綱和摘要整理: A Review of Location Encoding for GeoAI: Methods and Applications 摘要(Summary) 本文系統綜述了地理人工智能(G…

(C語言)算法復習總結2——分治算法

1. 分治算法的定義 分治算法(Divide and Conquer)是一種重要的算法設計策略。 “分治” 從字面意義上理解,就是 “分而治之”。 它將一個復雜的問題分解成若干個規模較小、相互獨立且與原問題形式相同的子問題,然后遞歸地解決這…

愛普生FC1610AN5G手機中替代傳統晶振的理想之選

在 5G 技術引領的通信新時代,手機性能面臨前所未有的挑戰與機遇。從高速數據傳輸到多任務高效處理,從長時間續航到緊湊輕薄設計,每一項提升都離不開內部精密組件的協同優化。晶振,作為為手機各系統提供穩定時鐘信號的關鍵元件&…

Android 接口定義語言 (AIDL)

目錄 1. 本地進程調用(同一進程內)2. 遠程進程調用(跨進程)3 `oneway` 關鍵字用于修改遠程調用的行為Android 接口定義語言 (AIDL) 與其他 IDL 類似: 你可以利用它定義客戶端與服務均認可的編程接口,以便二者使用進程間通信 (IPC) 進行相互通信。 在 Android 上,一個進…

關于QT5項目只生成一個CmakeLists.txt文件

編譯器自動檢測明明可以檢測,Kit也沒有報紅 但是最后生成項目只有一個文件 一:檢查cmake版本,我4.1版本cmake一直報錯 cmake3.10可以用 解決之后還是有問題 把環境變量加上去:

uniapp小程序位置授權彈框與隱私協議耦合(合而為一)(只在真機上有用,模擬器會分開彈 )

注意: 只在真機上有用,模擬器會分開彈 效果圖: 模擬器效果圖(授權框跟隱私政策會分開彈,先彈隱私政策,同意再彈授權彈框): manifest-template.json配置( "__usePr…

[Godot] C#人物移動抖動解決方案

在寫一個2D平臺跳躍的游戲代碼發現,移動的時候會抖動卡頓的厲害,后來研究了一下抖動問題,有了幾種解決方案 1.垂直同步和物理插值問題 這是最常見的可能導致畫面撕裂和抖動的原因,大家可以根據自己的需要調整項目設置&#xff0…

紅帽Linux網頁訪問問題

配置網絡,手動配置 搭建yum倉庫紅帽Linux網頁訪問問題 下載httpd 網頁訪問問題:首先看httpd的狀態---selinux的工作模式(強制)---上下文類型(semanage-fcontext)---selinux端口有沒有放行semanage port ---防火墻有沒有active---…

Android12編譯x86模擬器報找不到userdata-qemu.img

qemu-system-x86_64: Could not open out/target/product/generic_x86_64/userdata-qemu.img: No such file or directory 選擇編譯aosp_x86-eng時沒有生成模擬器,報 qemu-system-x86_64: Could not open out/target/product/generic_x86_64/userdata-qemu.img: No…

【AI論文】PixelFlow:基于流的像素空間生成模型

摘要:我們提出PixelFlow,這是一系列直接在原始像素空間中運行的圖像生成模型,與主流的潛在空間模型形成對比。這種方法通過消除對預訓練變分自編碼器(VAE)的需求,并使整個模型能夠端到端訓練,從…

AI大模型學習九:?Sealos cloud+k8s云操作系統私有化一鍵安裝腳本部署完美教程(單節點)

一、說明 ?Sealos?是一款基于Kubernetes(K8s)的云操作系統發行版,它將K8s以及常見的分布式應用如Docker、Dashboard、Ingress等進行了集成和封裝,使得用戶可以在不深入了解復雜的K8s底層原理的情況下,快速搭建起一個…

【HDFS入門】HDFS核心組件DataNode詳解:角色職責、存儲機制與健康管理

目錄 1 DataNode的角色定位 2 DataNode的核心職責 2.1 數據塊管理 2.2 與NameNode的協作 3 DataNode的存儲機制 3.1 數據存儲目錄結構 3.2 數據塊文件組織 4 DataNode的工作流程 4.1 數據寫入流程 4.2 數據讀取流程 5 DataNode的健康管理 5.1 心跳機制(…

BufferedOutputStream 終極解析與記憶指南

BufferedOutputStream 終極解析與記憶指南 一、核心本質 BufferedOutputStream 是 Java 提供的緩沖字節輸出流,繼承自 FilterOutputStream,通過內存緩沖區顯著提升 I/O 性能。 核心特性速查表 特性說明繼承鏈OutputStream → FilterOutputStream → …