[内存泄漏][PyTorch](create_graph=True)

PyTorch保存计算图导致内存泄漏

  • 1. 内存泄漏定义
  • 2. 问题发现背景
  • 3. github中pytorch源码关于这个问题的讨论

1. 内存泄漏定义

  内存泄漏(Memory Leak)是指程序中已动态分配的堆内存由于某种原因程序未释放或无法释放,造成系统内存的浪费,导致程序运行速度减慢甚至系统崩溃等严重后果。

2. 问题发现背景

  在使用深度学习求解PDE时,由于经常需要计算高阶导数,在pytorch框架下写的代码需要用到torch.autograd.grad(create_graph=True)或者torch.backward(create_graph=True)这个参数,然后发现了这个内存泄漏的问题。如果要保存计算图用来计算高阶导数,那么其所占的内存不会被释放,会一直占用。也就是如果设置create_graph=True,那么其保存的计算图所占的内存只有在程序运行结束时才会释放,这样导致了一个问题,即如果在循环中需要保存计算图,例如每个循环都需要计算一次黑塞矩阵,那么这个内存占用就会越来越多,最终导致out of memory报错。
在这里插入图片描述

3. github中pytorch源码关于这个问题的讨论

  官网中关于这个问题的讨论见https://github.com/pytorch/pytorch/issues/7343,这里提出的内存泄漏的例子如下:

import torch
import gc_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
y.backward(torch.ones_like(y), create_graph=True)
del x, y
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述
可以看到虽然删除了变量,依然造成了内存泄漏。这里红色的警告就是关于这个内存泄漏的问题。

UserWarning: Using backward() with create_graph=True will create a reference cycle between
the parameter and its gradient which can cause a memory leak. We recommend using autograd.grad 
when creating the graph to avoid this. If you have to use this function, make sure to reset 
the .grad fields of your parameters to None after use to break the cycle and avoid the leak. 
(Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\autograd\engine.cpp:1000.)
allow_unreachable=True, accumulate_grad=True) 
# Calls into the C++ engine to run the backward pass

看这个UserWarning,提示我们使用torch.autograd.grad()函数可以避免这个梯度泄漏,然后对代码进行改动:

import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
# y.backward(torch.ones_like(y), create_graph=True)
del x, y, z
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述
结果显示没有梯度泄漏。进一步,我们求一下二阶导数:

import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
print(torch.cuda.memory_allocated())
q = grad(z, x)
del x, y, z, q
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述
结果也没有内存泄漏。但是,如果我们不删除结果二阶导数q,这样是出于如果写在一个函数中,需要将q作为return值返回的情况。

import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
print(torch.cuda.memory_allocated())
q = grad(z, x)
del x, y, z
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述
可以看到,这还是会导致一部分内存泄漏。这里记录一下这个问题,有读者有遇到相同问题欢迎讨论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/198176.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MIB 6.1810实验Xv6 and Unix utilities(3)pingpong

Mit6.S081-实验1-Xv6 and Unix utilities-pingpong问题_Isana_Yashiro的博客-CSDN博客 Write a user-level program that uses xv6 system calls to ping-pong a byte between two processes over a pair of pipes, one for each direction. The parent should send a byte to…

【计算思维】蓝桥杯STEMA 科技素养考试真题及解析 4

1、下列哪个选项填到填到下图空缺处最合适 A、 B、 C、 D、 答案:D 2、按照如下图的规律摆放正方形,第 5 堆正方形的个数是 A、13 B、14 C、15 D、16 答案:D 3、从右面观察下面的立体图形,看到的是 A、 B、 C、 D、 答…

hyperledger fabric2.4测试网络添加组织数量

!!!修改内容比较繁琐,预期未来提供模板修改 修改初始配置文件,初始添加3个组织 organizations文件夹 /cryptogen文件夹下创建文件crypto-config-org3.yaml,内容如下: PeerOrgs:# ---------------------------------------------------------------------------# Org3# ----…

STM32电源名词解析

先来简单了解一下各种电源端口的命名 VCC:Ccircuit 表示电路的意思, 即接入电路的电压 VDD:Ddevice 表示器件的意思, 即器件内部的工作电压。 VSS:Sseries 表示公共连接的意思,通常指电路公共接地端电压。 GND:在电…

SpringCloud微服务注册中心:Nacos介绍,微服务注册,Ribbon通信,Ribbon负载均衡,Nacos配置管理详细介绍

微服务注册中心 注册中心可以说是微服务架构中的”通讯录“,它记录了服务和服务地址的映射关系。在分布式架构中,服务会注册到这里,当服务需要调用其它服务时,就这里找到服务的地址,进行调用。 微服务注册中心 服务注…

wpf devexpress自定义编辑器

打开前一个例子 步骤1-自定义FirstName和LastName编辑器字段 如果运行程序,会通知编辑器是空。对于例子,这两个未命名编辑器在第一个LayoutItem(Name)。和最终用户有一个访客左右编辑器查阅到First Name和Last Name字段,分别。如果你看到Go…

第20章 数据库编程

通过本章需要理解JDBC的核心设计思想以及4种数据库访问机制,理解数据库连接处理流程,并且可以使用JDBC进行Oracle数据库的连接,理解工厂设计模式在JDBC中的应用,清楚地理解DriverManager类的作用,掌握Connection、Prep…

【SpringBoot3+Vue3】四【实战篇】-前端(vue基础)

目录 一、项目前置知识 二、使用vscode创建 三、vue介绍 四、局部使用vue 1、快速入门 1.1 需求 1.2 准备工作 1.3 操作 1.3.1 创建html 1.3.2 创建初始html代码 1.3.3 参照官网import vue 1.3.4 创建vue应用实例 1.3.5 准备div 1.3.6 准备用户数据 1.3.7 通过…

【Linux网络】详解使用http和ftp搭建yum仓库,以及yum网络源优化

目录 一、回顾yum的原理 1.1yum简介 yum安装的底层原理: yum的好处: 二、学习yum的配置文件及命令 1、yum的配置文件 2、yum的相关命令详解 3、yum的命令相关案例 三、搭建yum仓库的方式 1、本地yum仓库建立 2、通过http搭建内网的yum仓库 3、…

RabbitMQ消息的可靠性

RabbitMQ消息的可靠性 一 生产者的可靠性 生产者重试 有时候由于网络问题,会出现连接MQ失败的情况,可以配置重连机制 注意:SpringAMQP的重试机制是阻塞式的,重试等待的时候,当前线程会等待。 spring:rabbitmq:conne…

Java读写Jar

Java提供了读写jar的类库Java.util.jar,Java获取解析jar包的工具类如下: import java.io.File; import java.io.IOException; import java.net.URL; import java.net.URLClassLoader; import java.util.Enumeration; import java.util.HashMap; import …

你知道什么是SaaS吗?

你知道什么是SaaS吗? 云服务架构的三个概念 PaaS 英文就是 Platform-as-a-Service(平台即服务)PaaS,某些时候也叫做中间件。就是把客户采用提供的开发语言和工具(例如Java,python, .Net等)开…

我叫:选择排序【JAVA】

1.我是个啥子?? 选择式排序:属于内部排序法,从欲排序的数据中,按指定的规则选出某一元素,再依规定交换位置后达到排序的目的。 2.我的思想 基本思想:第一次从arr[0]~arr[n-1]中选取最小值,与arr[0]交换,第…

单张图像3D重建:原理与PyTorch实现

近年来,深度学习(DL)在解决图像分类、目标检测、语义分割等 2D 图像任务方面表现出了出色的能力。DL 也不例外,在将其应用于 3D 图形问题方面也取得了巨大进展。 在这篇文章中,我们将探讨最近将深度学习扩展到单图像 3…

Transformer中WordPiece/BPE等不同编码方式详解以及优缺点

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️ 👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博…

C++基础从0到1入门编程(三)

系统学习C 方便自己日后复习,错误的地方希望积极指正 往期文章: C基础从0到1入门编程(一) C基础从0到1入门编程(二) 参考视频: 1.黑马程序员匠心之作|C教程从0到1入门编程,学习编程不再难 2.系统…

ubuntu下载conda

系统:Ubuntu18.04 (1)下载安装包 wget https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2021.11-Linux-x86_64.sh 报错错误 403:Forbidden 解决方法 wget -U NoSuchBrowser/1.0 https://mirrors.tuna.tsingh…

Jmeter做接口测试

1.Jmeter的安装以及环境变量的配置 Jmeter是基于java语法开发的接口测试以及性能测试的工具。 jdk:17 (最新的Jeknins,只能支持到17) jmeter:5.6 官网:http://jmeter.apache.org/download_jmeter.cgi 认识JMeter的目录&#xff1…

【Web】Ctfshow SSRF刷题记录1

核心代码解读 <?php $url$_POST[url]; $chcurl_init($url); curl_setopt($ch, CURLOPT_HEADER, 0); curl_setopt($ch, CURLOPT_RETURNTRANSFER, 1); $resultcurl_exec($ch); curl_close($ch); ?> curl_init()&#xff1a;初始curl会话 curl_setopt()&#xff1a;会…

【力扣面试经典150题】(链表)K 个一组翻转链表

题目描述 力扣原文链接 给你链表的头节点 head &#xff0c;每 k 个节点一组进行翻转&#xff0c;请你返回修改后的链表。 k 是一个正整数&#xff0c;它的值小于或等于链表的长度。如果节点总数不是 k 的整数倍&#xff0c;那么请将最后剩余的节点保持原有顺序。 你不能只…