PyTorch PINN实战:用深度学习求解微分方程

神经网络技术已在计算机视觉与自然语言处理等多个领域实现了突破性进展。然而在微分方程求解领域,传统神经网络因其依赖大规模标记数据集的特性而表现出明显局限性。物理信息神经网络(Physics-Informed Neural Networks, PINN)通过将物理定律直接整合到学习过程中,有效弥补了这一不足,使其成为求解常微分方程(ODE)和偏微分方程(PDE)的高效工具。

传统神经网络模型需要依赖规模庞大的标记数据集,而这类数据的采集往往成本高昂且耗时显著。PINN通过将物理定律(具体表现为微分方程)融入训练过程,显著提高了数据利用效率。这种方法使得在流体动力学、量子力学和气候系统建模等科学领域实现基于数据的科学发现成为可能,为跨学科研究提供了新的技术路径。

神经网络基础理论

在深入剖析PINN之前,有必要回顾标准神经网络的核心运作机制:

神经网络的基本计算单元是神经元,它接收加权输入信号,经过激活函数处理后产生输出值。多层神经元通过特定拓扑结构组织形成深度神经网络(DNN),这种结构使网络能够逼近高度复杂的非线性函数。网络训练过程中,通常采用均方误差(MSE)等损失函数量化预测值与真实值之间的偏差。通过反向传播算法和梯度下降优化方法,网络权重参数被迭代调整以使损失函数最小化。

示例损失函数

均方误差

PINN的技术特性与创新点

PINN与传统神经网络的根本区别在于,它不依赖于标记数据集进行学习,而是将微分方程约束直接嵌入到损失函数中。这意味着模型学习得到的函数_yNN(x)_需同时满足:

  • 给定的微分方程约束条件
  • 特定的边界条件和初始条件

PINN框架中的偏微分方程(PDE)通常表示为:

其中

以二阶微分方程为例:

这表明所求函数y(x)必须严格满足该方程。

PINN损失函数的构造原理

PINN的总体损失函数由两个主要部分组成:

PINN的技术优势与局限性

技术优势

PINN具有显著的数据效率优势,能够通过物理定律的约束从相对小规模的数据集中有效学习。它能够处理传统数值求解器难以应对的高维复杂偏微分方程。训练完成后,PINN模型具有良好的泛化能力,可预测不同初始条件或边界条件下的解。此外,在处理逆问题时,PINN对噪声和稀疏数据表现出较强的鲁棒性。

技术局限

PINN的训练过程计算密集且耗时较长,尤其对于高维偏微分方程,通常需要高性能GPU支持。模型对超参数选择较为敏感,需要精细调整以平衡不同损失项的贡献。与成熟的数值求解器相比,PINN在处理大规模物理问题时可扩展性有限。此外,PINN还面临梯度消失导致的优化困难问题,且缺乏与有限元或有限差分方法相当的理论收敛保证。

微分方程的解析求解方法

考虑以下一阶线性微分方程:

初始条件为:

解法步骤

首先,将方程重写为标准形式:

对方程两边进行积分:

应用基本积分公式,得到y的表达式:

其中C为积分常数。

因此,通解为:

代入初始条件y(0)=3:

由此得到精确解:

求解结果总结

通解形式:

带入初始条件y(0)=3后的精确解:

基于PINN求解微分方程的实践案例

步骤1: 导入必要的库函数

import torch  
import torch.nn as nn  
import torch.optim as optim  
import numpy as np  
import matplotlib.pyplot as plt  
from torchinfo import summary

步骤2: 定义能够返回精确解的函数

def true_solution(x):  return x**2 + 5*x + 3    # 精确解函数

这与我们手动求解得到的解析解一致

步骤3: 生成测试点并绘制精确解

x_test = torch.linspace(-2, 2, 100).view(-1, 1) # 生成测试点  
y_true = true_solution(x_test)  plt.figure(figsize=(8, 5))  
plt.plot(             # 绘制微分方程的精确解  x_test,          y_true,   linestyle="dashed" ,      linewidth=2,   label="True Solution"  
)  plt.xlabel("x")  
plt.ylabel("y(x)")  
plt.legend()  
plt.title("Analytical Solution of the Equation")  
plt.grid()  
plt.show()

输出结果:

步骤4: 设计PINN模型架构

class PINN(nn.Module):  def __init__(self):  super(PINN, self).__init__()  self.net = nn.Sequential(  nn.Linear(1, 20), nn.Tanh(),  nn.Linear(20, 20), nn.Tanh(),  nn.Linear(20, 1)  )  def forward(self, x):  return self.net(x)  model = PINN()  
optimizer = optim.Adam(model.parameters(), lr=1e-3)  
summary(model)

输出结果:

步骤5: 定义PINN损失函数

def pinn_loss(model, x):  x.requires_grad = True  y = model(x)  # 使用自动微分计算dy/dx  dy_dx = torch.autograd.grad(y, x, torch.ones_like(y), create_graph=True)[0]  # 微分方程损失(L_D): dy/dx - (2x + 5)  ode_loss = torch.mean((dy_dx - (2*x + 5))**2)  # 初始条件损失(L_B): y(0) = 3  x0 = torch.tensor([[0.0]])  y0_pred = model(x0)  initial_loss = (y0_pred - 3)**2  # 总损失  total_loss = ode_loss + initial_loss  return total_loss, ode_loss, initial_loss

步骤6: 训练模型(5000轮次)

epochs = 5000  loss_history = []  
ode_loss_history = []  
initial_loss_history = []  x_train = torch.linspace(-2, 2, 100).view(-1, 1)  # 训练点  for epoch in range(epochs):  optimizer.zero_grad()  total_loss, ode_loss, initial_loss = pinn_loss(model, x_train)  total_loss.backward()  optimizer.step()  loss_history.append(total_loss.item())  ode_loss_history.append(ode_loss.item())  initial_loss_history.append(initial_loss.item())  if epoch % 1000 == 0:  print(f"Epoch {epoch}, Loss: {total_loss.item():.6f}")

步骤7: 绘制训练过程中的损失函数变化

plt.figure(figsize=(8, 5))  
epochs_list = np.arange(1, epochs + 1)  plt.semilogy(epochs_list, loss_history, 'k--', linewidth=3, label=r'Total Loss $(L_D + L_B)$')  
plt.semilogy(epochs_list, ode_loss_history, 'r-', linewidth=1, label=r'ODE Loss $(L_D)$')  
plt.semilogy(epochs_list, initial_loss_history, 'g-', linewidth=1, label=r'Initial Loss $(L_B)$')  plt.xlabel("Epochs")  
plt.ylabel("Loss (Log Scale)")  
plt.legend()  
plt.title("Loss Components vs Epochs")  
plt.grid()  
plt.show()

输出结果:

步骤8: 对比PINN解与解析解的精确度

X_test = torch.linspace(-2, 2, 100).view(-1, 1)  
y_pred = model(X_test).detach().numpy()  plt.plot(X_test,true_solution(X_test),linestyle="dashed",linewidth=3,label="True Solution",color="red")  
plt.plot(X_test,y_pred,label="PINNS Solution",color="green")  
plt.xlabel('x')  
plt.ylabel('y')  
plt.legend()  
plt.title(r'Analytical Vs PINNs Solution')  
plt.savefig("solution.png", dpi=300, bbox_inches='tight')  
plt.grid()  
plt.show()

输出结果:

通过结果可以看出,我们已经成功地使用PINN方法求解了上述微分方程,并获得了与解析解高度一致的数值解。

总结

物理信息神经网络(PINN)代表了一种在微分方程求解领域的重要技术突破,它将深度学习与物理定律有机结合,为传统数值求解方法提供了一种高效、数据驱动的替代方案。PINN方法不仅在理论上具有创新性,同时在实际应用中展现出广阔的应用前景,为复杂物理系统的建模与分析提供了新的研究路径。
https://avoid.overfit.cn/post/f9bd046772f1473a80002f592e9527d4

作者:Muhammad Tayyab

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/36743.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

关于“碰一碰发视频”系统的技术开发文档框架

以下是关于“碰一碰发视频”系统的技术开发文档框架,涵盖核心功能、技术选型、开发流程和关键模块设计,帮助您快速搭建一站式解决方案 --- 随着短视频平台的兴起,用户的创作与分享需求日益增长。而如何让视频分享更加便捷、有趣&#xff0c…

【VUE】day05-ref引用

这里写目录标题 1. ref引用1.1 使用ref引用组件 2. this.$nextTick(cb)方法3. 购物车案例3.1 数组中的方法 - some循环3.2 数组中的方法 - every循环3.3 数组中的方法 - reduce 4. 购物车案例 1. ref引用 ref用来辅助开发者在不依赖于jQuery的情况下,获取DOM元素或…

docker安装milvus向量数据库Attu可视化界面

Docker 部署 Milvus 及 Attu 可视化工具完整指南 一、环境准备 安装 Docker 及 Docker Compose Docker 版本需 ≥20.10.12Docker Compose 版本需 ≥2.20.0(推荐 V2) 验证 Docker 环境 docker --version && docker-compose --version若出现&…

nacos安装,服务注册,服务发现,远程调用3个方法

安装 点版本下载页面 服务注册 每个微服务都配置nacos的地址,都要知道 服务发现 2个是知道了解 远程调用基本实现 远程调用方法2,负载均衡API测试 远程调用方法3,注解 负载均衡的远程调用, 总结 面试题

MySQL:数据库基础

数据库基础 1.什么是数据库?2.为什么要学习数据库?3.主流的数据库(了解)4.服务器,数据库,表之间的关系5.数据的逻辑存储6.MYSQL架构7.存储引擎 1.什么是数据库? 数据库(Database,简称DB)&#x…

Kotlin 基础语法

1. 🌟 Kotlin:Java 的“超级进化体”? Kotlin 是一门由 JetBrains 开发的 现代静态类型编程语言,支持 JVM、Android、JavaScript、Native 等多平台: Kotlin 与 Java 深度兼容,Kotlin 会编译为 JVM 字节码&#xff0c…

基于RAGFlow本地部署DeepSeek-R1大模型与知识库:从配置到应用的全流程解析

作者:后端小肥肠 🍊 有疑问可私信或评论区联系我。 🥑 创作不易未经允许严禁转载。 姊妹篇: DeepSpeek服务器繁忙?这几种替代方案帮你流畅使用!(附本地部署教程)-CSDN博客 10分钟上手…

uniapp APP权限弹框

效果图 第一步 新建一个页面,设置透明 {"path": "pages/permissionDisc/permissionDisc","style": {"navigationBarTitleText": "","navigationStyle": "custom","app-plus": {&…

【深度学习与大模型基础】第7章-特征分解与奇异值分解

一、特征分解 特征分解(Eigen Decomposition)是线性代数中的一种重要方法,广泛应用于计算机行业的多个领域,如机器学习、图像处理和数据分析等。特征分解将一个方阵分解为特征值和特征向量的形式,帮助我们理解矩阵的结…

麒麟V10 arm cpu aarch64 下编译 RocketMQ-Client-CPP 2.2.0

国产自主可控服务器需要访问RocketMQ消息队列,最新的CSDK是2020年发布的 rocketmq-client-cpp-2.2.0 这个版本支持TLS模式。 用默认的版本安装遇到一些问题,记录一下。 下载Releases apache/rocketmq-client-cpp GitHubhttps://github.com/apache/roc…

Moonlight-16B-A3B: 变革性的高效大语言模型,凭借Muon优化器打破训练效率极限

近日,由Moonshot AI团队推出的Moonlight-16B-A3B模型,再次在AI领域引发了广泛关注。这款全新的Mixture-of-Experts (MoE)架构的大型语言模型,凭借其创新的训练优化技术,特别是Muon优化器的使用,成功突破了训练效率的极…

在windows下安装windows+Ubuntu16.04双系统(下)

这篇文章的内容主要来源于这篇文章,为正式安装windowsUbuntu16.04双系统部分。在正式安装前,若还没有进行前期准备工作(1.分区2.制作启动u盘),见《在windows下安装windowsUbuntu16.04双系统(上)》 二、正式安装Ubuntu …

一次Linux下 .net 调试经历

背景: Xt160Api, 之前在windows下用.net调用,没有任何问题。 但是移植到Linux去后,.net程序 调用 init(config_path) 总是报错 /root/test 找不到 traderApi.ini (/root/test 是程序目录) 然后退出程序 解决过程: 于是考虑是不是参数传错了&…

AI爬虫 :Firecrawl的安装和详细使用案例(将整个网站转化为LLM适用的markdown或结构化数据)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 1. Firecrawl概述1.1 Firecrawl介绍1.2 Firecrawl 的特征1.3 Firecrawl 的功能1.4 Firecrawl的 API 密钥获取2. 安装和基本使用3. 使用 LLM 提取4. 无模式提取(curl语句)5. 使用操作与页面交互6. Firecrawl Cloud7. 移…

【Java集合夜话】第1篇:拨开迷雾,探寻集合框架的精妙设计

欢迎来到Java集合框架系列的第一篇文章!🌹 本系列文章将以通俗易懂的语言,结合实际开发经验,带您深入理解Java集合框架的设计智慧。🌹 若文章中有任何不准确或需要改进的地方,欢迎大家指出,让我…

网络安全知识:网络安全网格架构

在数字化转型的主导下,大多数组织利用多云或混合环境,包括本地基础设施、云服务和应用程序以及第三方实体,以及在网络中运行的用户和设备身份。在这种情况下,保护组织资产免受威胁涉及实现一个统一的框架,该框架根据组…

企业级云MES全套源码,支持app、小程序、H5、台后管理端

企业级云MES全套源码,支持app、小程序、H5、台后管理端,全套源码 开发环境 技术架构:springboot vue-element-plus-admin 开发语言:Java 开发工具:idea 前端框架:vue.js 后端框架&#xff…

Web爬虫利器FireCrawl:全方位助力AI训练与高效数据抓取

Web爬虫利器FireCrawl:全方位助力AI训练与高效数据抓取 一、FireCrawl 项目简介二、主要功能三、FireCrawl应用场景1. 大语言模型训练2. 检索增强生成(RAG):3. 数据驱动的开发项目4. SEO 与内容优化5. 在线服务与工具集成 四、安装…

[HelloCTF]PHPinclude-labs超详细WP-Level 6Level 7Level 8Level 9-php://协议

由于Level 6-9 关的原理都是通用的, 这里就拿第6关举例, 其他的关卡同理 源码分析 定位到代码 isset($_GET[wrappers]) ? include("php://".$_GET[wrappers]) : ; 与前几关发生变化的就是 php:// 解题分析 这一关要求我们使用 php协议 php:// 协议 php://filte…

《Linux 网络架构:基于 TCP 协议的多人聊天系统搭建详解》

一、系统概述 本系统是一个基于 TCP 协议的多人聊天系统,由一个服务器和多个客户端组成。客户端可以连接到服务器,向服务器发送消息,服务器接收到消息后将其转发给其他客户端,实现多人之间的实时聊天。系统使用 C 语言编写&#x…