【CUDA OUT OF MEMORY】【Pytorch】计算图与CUDA OOM

计算图与CUDA OOM

在实践过程中多次碰到了CUDA OOM的问题,有时候这个问题是很好解决的,有时候DEBUG一整天还是头皮发麻。

最近实践对由于计算图积累导致CUDA OOM有一点新的看法,写下来记录一下。包括对计算图的一些看法和一个由于计算图引发错误的简化实例记录。

本人能力有限,认识片面如果犯了错误希望大家指教!

计算图的存储

计算图是pytorch进行梯度反向传播核心,计算图是在程序运行过程中动态产生的,当tensor变量赋予了requires_grad=True的属性时,torch会自动记录其参与的计算并形成计算图保存在显存中。

敲重点:计算图是会吃显存的! 本来想截下来描述一下计算图是长什么样的,至少是概念的表述一下,结果去学习了一圈发现:和我想的完全不一样!附上学习链接:传送门。更关键的是我还没完全看懂学会(🐶),有没有大大学会了教我一下,不甚感激!

总的来说一个tensor它内部包含的grad_fn别有洞天,首先grad_fn也是作为一个节点在计算图中的(其在pytorch的C艹中是Node的子类),grad_fn不仅是记录了这个tensor是被什么数学符号计算来的,它还暗搓搓记录了这个tensor是是从哪些数字里头窜出来的,以及其和其他grad_fn的py友谊,还有被包含在其内部context中的信息,我偷那个学习链接的一张图展示一下一个计算图的形态,借花献佛,展示一下grad_fn偷偷摸摸用你的卡干了啥事情。
在这里插入图片描述

BTW,提几个小知识点

  • 我们常用的detach()方法,就是通过把tensor的grad_fn扬了从而把tensor从计算图中剥离出来。
>>> x
tensor([1.], requires_grad=True)
>>> y = x+1
>>> y.grad_fn
<AddBackward0 object at 0x7f8306e68b50>
>>> y.detach().grad_fn is None
True
  • 关于*.backward(retain_graph=True)的问题,backwardretain_graph默认是False,其含义是经过默认的*.backward()之后,计算图会被清空从而释放其占用的显存。和detach不一样的是,grad_fn还是那个grad_fn只不过它悄咪咪维持的友谊被杀掉了,如下:
>>> x
tensor([1.], requires_grad=True)
>>> y = x+1; y.grad_fn
<AddBackward0 object at 0x7f8306e68b50>
>>> y.backward(retain_graph=False)
>>> y.grad_fn
<AddBackward0 object at 0x7f8306e68b50>
  • 续上面一点的内容,但是内容包含我瞎猜的成分(🐶),我们猜测一下backward杀掉了grad_fn的什么东西。一般的,我们认为当retain_graph=False的时候,我们只能backward()一次,因为计算图会被清空,第二次尝试反向传播会造成错误。但其实不然!如下实验例子1的尝试,我们连续backwrad并没有报错。AMAZING啊!。进一步的我们进行例子2的实验,我们只是简单的让前向多了一个乘法计算,然后另z反向传播两次,这回顺理成章的报错,同时报错之后我们再次反传y,我们发现反传y又不会报错。我猜测:backward()会清楚grad_fn节点和其他grad_fn的联系,因此zgrad_fn不能联系到ygrad_fn了,于是第二次z.backward()报错,但是y直接和叶子x连接,不需要其他的grad_fn朋友也能自己和自己玩。
例子1:
>>> x
tensor([1.], requires_grad=True)
>>> y = x+1
>>> y.backward(retain_graph=False);y.backward()
返回没有报错!
---------------------------------------------------------------
例子2:
>>> x
tensor([1.], requires_grad=True)
>>> y=x+1;z=2*y #前向过程多了一个乘法
>>> z.backward(retain_graph=False)
>>> z.backward()
Traceback (most recent call last):File "<stdin>", line 1, in <module>File "/Users/**/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backwardtorch.autograd.backward(File "/Users/**/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backwardVariable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
>>>y.backward()
返回没有报错

一个由于没处理好计算图导致OOM的例子

import torch,time
l1 = torch.nn.Linear(1000,1000).cuda()
l2 = torch.nn.Linear(1000,1000).cuda()
memory = []for _ in range(10000000):time.sleep(0.01)data_input = torch.rand(1000).cuda()output = l1(l2(data_input))output.backward(retain_graph=True) #此行与报错无关 memroy.append(output.cpu()) #memory存储的内容通过.cpu()转移在主存上,#但是与output相关联的l1,l2的计算图依旧停留在显存中,并在循环中一直积累撑爆显存。...some other operations...

这个例子中,由于每个output不能被正常清除计算图显存,最终导致OOM。

这个例子是某次实践的超级简化版,如果只看这个例子的话,其实只要把最后一行改写成

memory.append(output.detach().cpu())

就会由于output在每次循环后失去引用(detach()创建了新的变量)从而被回收,计算图被自动清空避免OOM。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/128223.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【小沐学NLP】Python使用NLTK库的入门教程

文章目录 1、简介2、安装2.1 安装nltk库2.2 安装nltk语料库 3、测试3.1 分句分词3.2 停用词过滤3.3 词干提取3.4 词形/词干还原3.5 同义词与反义词3.6 语义相关性3.7 词性标注3.8 命名实体识别3.9 Text对象3.10 文本分类3.11 其他分类器3.12 数据清洗 结语 1、简介 NLTK - 自然…

python的包管理

要在 mypackage 包外使用 mypackage 包里的 speak.py 文件以及 newpackage 包里的 jump.py 文件&#xff0c;你需要确保以下几个步骤&#xff1a; 确保目录结构正确&#xff0c;如下所示&#xff1a; mypackage/__init__.pyspeak.pynewpackage/__init__.pyjump.py在 speak.py…

介绍OpenCV

OpenCV是一个开源计算机视觉库&#xff0c;可用于各种任务&#xff0c;如物体识别、人脸识别、运动跟踪、图像处理和视频处理等。它最初由英特尔公司开发&#xff0c;目前由跨学科开发人员社区维护和支持。OpenCV可以在多个平台上运行&#xff0c;包括Windows、Linux、Android和…

leetcode 43.字符串相乘

⭐️ 题目描述 &#x1f31f; leetcode链接&#xff1a;字符串相乘 思路&#xff1a; 代码&#xff1a; class Solution { public:string multiply(string num1, string num2) {if (num1 "0" || num2 "0") {return "0";}/*0 1 2 下标1 2…

Python中的Numpy向量计算(R与Python系列第三篇)

目录 一、什么是Numpy? 二、如何导入NumPy? 三、生成NumPy数组 3.1利用序列生成 3.2使用特定函数生成NumPy数组 &#xff08;1&#xff09;使用np.arange() &#xff08;2&#xff09;使用np.linspace() 四、NumPy数组的其他常用函数 &#xff08;1&#xff09;np.z…

C++斩题录|递归专题 | leetcode50. Pow(x, n)

个人主页&#xff1a;平行线也会相交 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 平行线也会相交 原创 收录于专栏【手撕算法系列专栏】【LeetCode】 &#x1f354;本专栏旨在提高自己算法能力的同时&#xff0c;记录一下自己的学习过程&#xff0c;希望…

机器学习---决策树分类代码

1. 计算数据集的香农熵 from numpy import * import numpy as np import pandas as pd from math import log import operator #计算数据集的香农熵 def calcShannonEnt(dataSet): numEntrieslen(dataSet) labelCounts{} #给所有可能分类创建字典 for featVec …

初识Java 7-1 多态

目录 向上转型 难点 方法调用绑定 产生正确的行为 可扩展性 陷阱&#xff1a;“重写”private方法 陷阱&#xff1a;字段与静态方法 构造器和多态 构造器的调用顺序 继承和清理 构造器内部的多态方法行为 协变返回类型 使用继承的设计 替换和扩展 向下转型和反射…

Unity中Shader的变体shader_feature

文章目录 前言一、变体的类型1、multi_compile —— 无论如何都会被编译的变体2、shader_feature —— 通过材质的使用情况来决定是否编译的变体 二、使用 shader_feature 来控制 shader 效果的变化1、首先在属性面板暴露一个开关属性&#xff0c;用于配合shader_feature来控制…

Java(四)数组与类和对象

Java&#xff08;四&#xff09;数组与类和对象 六、数组&#xff08;非常重要&#xff09;1.定义2.遍历2.1遍历方法2.2Arrays方法 3.二维数组数组小总结 七、类和对象1. 定义&#xff08;重要&#xff09;1.1 类1.2 对象 2. this关键字&#xff08;重要&#xff09;2.1 特点 3…

lv4 嵌入式开发-4 标准IO的读写(二进制方式)

目录 1 标准I/O – 按对象读写 2 标准I/O – 小结 3 标准I/O – 思考和练习 文本文件和二进制的区别&#xff1a; 存储的格式不同&#xff1a;文本文件只能存储文本。除了文本都是二进制文件。 补充计算机内码概念&#xff1a;文本符号在计算机内部的编码&#xff08;计算…

肖sir__设计测试用例方法之正交表08_(黑盒测试)

设计测试用例方法之正交 一、正交表定义 正交试验设计法&#xff0c;是从大量的试验点中挑选出适量的、有代表性的点&#xff0c;应用依据迦罗瓦理论导出的“正交表”&#xff0c;合理的安排试验的一种科学的试验设计方法。 二、 正交常用的术语 指标&#xff1a;通常把判断试验…

OpenCV 12(图像直方图)

一、图像直方图 直方图可以让你了解总体的图像像素强度分布&#xff0c;其X轴为像素值&#xff08;一般范围为0~255&#xff09;&#xff0c;在Y轴上为图像中具有该像素值像素数。 - 横坐标: 图像中各个像素点的灰度级. - 纵坐标: 具有该灰度级的像素个数. 画出上图的直方图: …

【实践篇】Redis最强Java客户端(三)之Redisson 7种分布式锁使用指南

文章目录 0. 前言1. Redisson 7种分布式锁使用指南1.1 简单锁&#xff1a;1.2 公平锁&#xff1a;1.3 可重入锁&#xff1a;1.4 红锁&#xff1a;1.5 读写锁&#xff1a;1.6 信号量&#xff1a;1.7 闭锁&#xff1a; 2. Spring boot 集成Redisson 验证分布式锁3. 参考资料4. 源…

IntelliJ IDEA远程调试:使用IDEA Remote Debug进行高效调试的指南

引言 在开发分布式系统时&#xff0c;调试是一个重要但复杂的环节。开发者通常需要跨越多个服务、模块和线程来追踪和解决问题。在没有远程调试的情况下&#xff0c;许多开发者会在代码中添加各种日志语句&#xff0c;然后重新部署和上线来调试。这种方法不仅费时&#xff0c;…

Hive_Hive统计指令analyze table和 describe table

之前在公司内部经常会看到表的元信息的一些统计信息&#xff0c;当时非常好奇是如何做实现的。 现在发现这些信息主要是基于 analyze table 去做统计的&#xff0c;分享给大家 实现的效果某一个表中每个列的空值数量&#xff0c;重复值数量等&#xff0c;平均长度 具体的指令…

华为数据管理——《华为数据之道》

数据分析与开发 元数据是描述数据的数据&#xff0c;用于打破业务和IT之间的语言障碍&#xff0c;帮助业务更好地理解数据。 元数据是数据中台的重要的基础设施&#xff0c;元数据治理贯彻数据产生、加工、消费的全过程&#xff0c;沉淀了数据资产&#xff0c;搭建了技术和业务…

【C++模拟实现】手撕AVL树

【C模拟实现】手撕AVL树 目录 【C模拟实现】手撕AVL树AVL树的介绍&#xff08;百度百科&#xff09;AVL树insert函数的实现代码验证是否为AVL树AVL树模拟实现的要点易忘点AVL树的旋转思路 作者&#xff1a;爱写代码的刚子 时间&#xff1a;2023.9.10 前言&#xff1a;本篇博客将…

python28种极坐标绘图函数总结

文章目录 基础图误差线等高线polar场图polar统计图非结构坐标图 &#x1f4ca;python35种绘图函数总结&#xff0c;3D、统计、流场&#xff0c;实用性拉满 matplotlib中的画图函数&#xff0c;大部分情况下只要声明坐标映射是polar&#xff0c;就都可以画出对应的极坐标图。但…

9、补充视频

改进后的dijkstra算法 利用小根堆 将小根堆特定位置更改,再改成小根堆 nodeHeap.addOrUpdateOrIgnore(edge.to, edge.weight + distance);//改进后的dijkstra算法 //从head出发,所有head能到达的节点,生成到达每个节点的最小路径记录并返回 public static HashMap<No…