深度学习:模型训练过程中Trying to backward through the graph a second time解决方案

1 问题描述

在训练lstm网络过程中出现如下错误:

Traceback (most recent call last):File "D:\code\lstm_emotion_analyse\text_analyse.py", line 82, in <module>loss.backward()File "C:\Users\lishu\anaconda3\envs\pt2\lib\site-packages\torch\_tensor.py", line 487, in backwardtorch.autograd.backward(File "C:\Users\lishu\anaconda3\envs\pt2\lib\site-packages\torch\autograd\__init__.py", line 200, in backwardVariable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

2 问题分析

按照错误提示查阅相关资料了解到,实际上在大多数情况下retain_graph都应采用默认的False,除了几种特殊情况:

  • 一个网络有两个output分别执行backward进行回传的时候: output1.backward(), output2.backward().
  • 一个网络有两个loss需要分别执行backward进行回传的时候: loss1.backward(), loss2.backward().

但本项目的LSTM训练模型不属于以上情况,再次查找资料,在在pytorch的官方论坛上找到了真正的原因:

如截图中的描述,只要我们对变量进行运算了,就会加进计算图中。所以本项目的问题在于在for循环梯度反向传播中,使用了循环外部的变量h,如下所示:

epochs = 128step = 0model.train()  # 开启训练模式for epoch in range(epochs):h = model.init_hidden(batch_size)  # 初始化第一个Hidden_statefor data in tqdm(train_loader):x_train, y_train = datax_train, y_train = x_train.to(device), y_train.to(device)step += 1  # 训练次数+1x_input = x_train.to(device)model.zero_grad()output, h = model(x_input, h)# 计算损失loss = criterion(output, y_train.float().view(-1))loss.backward()nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)optimizer.step()if step % 10 == 0:print("Epoch: {}/{}...".format(epoch + 1, epochs),"Step: {}...".format(step),"Loss: {:.6f}...".format(loss.item()))

3 问题解决

代码修改如下:

epochs = 128step = 0model.train()  # 开启训练模式for epoch in range(epochs):h = model.init_hidden(batch_size)  # 初始化第一个Hidden_statefor data in tqdm(train_loader):x_train, y_train = datax_train, y_train = x_train.to(device), y_train.to(device)step += 1  # 训练次数+1x_input = x_train.to(device)model.zero_grad()h = tuple([e.data for e in h])output, h = model(x_input, h)# 计算损失loss = criterion(output, y_train.float().view(-1))loss.backward()nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)optimizer.step()if step % 10 == 0:print("Epoch: {}/{}...".format(epoch + 1, epochs),"Step: {}...".format(step),"Loss: {:.6f}...".format(loss.item()))

增加for循环内部变量,对外部变量进行复制,内部变量参与梯度传播,问题解决。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/142077.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

常用的文本对比工具或网站

以下是一些常用的文本对比工具的下载地址或网站访问地址&#xff1a; DiffNow: https://www.diffnow.com/ WinMerge: https://winmerge.org/ Beyond Compare: https://www.scootersoftware.com/ Meld: https://meldmerge.org/ diff:文本对比/字符串差异比较 - 在线工具 请…

vite.config.ts

vite会自动生成vite.config.ts文件 https://juejin.cn/post/7039879176534360077https://juejin.cn/post/7039879176534360077 挑了几个常用的记下笔记&#xff0c;配置别名&#xff0c;proxy跨域&#xff0c;本地服务设置 import { defineConfig } from vite import vue fr…

Java环形链表(图文详解)

目录 一、判断链表中是否有环 &#xff08;1&#xff09;题目描述 &#xff08;2&#xff09;题解 二、环形链表的入环节点 &#xff08;1&#xff09;题目描述 &#xff08;2&#xff09;题解 一、判断链表中是否有环 &#xff08;1&#xff09;题目描述 给你一个链表的…

构建智能客服知识库,优化客户体验不是难题!

在当今快节奏的商业环境中&#xff0c;客户都希望得到及时个性化的支持&#xff0c;拥有一个智能客服知识库对于现代企业至关重要。智能客服知识库是一个集中存储、组织和访问与客户服务互动相关的信息的综合性知识库。它为企业提供了全面的知识来源&#xff0c;使他们能够为客…

想要精通算法和SQL的成长之路 - 最长递增子序列 II(线段树的运用)

想要精通算法和SQL的成长之路 - 最长递增子序列 II&#xff08;线段树的运用&#xff09; 前言一. 最长递增子序列 II1.1 向下递推1.2 向上递推1.3 更新操作1.4 查询操作1.5 完整代码&#xff1a; 前言 想要精通算法和SQL的成长之路 - 系列导航 一. 最长递增子序列 II 原题链接…

单链表详细解析|画图理解

前言&#xff1a; 在前面我们学习了顺序表&#xff0c;相当于数据结构的凉菜&#xff0c;今天我们正式开始数据结构的硬菜了&#xff0c;那就是链表&#xff0c;链表有多种结构&#xff0c;但我们实际中最常用的还是无头单向非循环链表和带头双向循环链表&#xff0c;我们今天先…

idea没有maven工具栏解决方法

背景&#xff1a;接手的一些旧项目&#xff0c;有pom文件&#xff0c;但是用idea打开的时候&#xff0c;没有认为是maven文件&#xff0c;所以没有maven工具栏&#xff0c;不能进行重新加载pom文件中的依赖。 解决方法&#xff1a;选中pom.xml文件&#xff0c;右键 选择添加为…

Denoising diffusion implicit models 阅读笔记

Denoising diffusion probabilistic models (DDPMs)从马尔科夫链中采样生成样本&#xff0c;需要迭代多次&#xff0c;速度较慢。Denoising diffusion implicit models (DDIMs)的提出是为了加速采样过程&#xff0c;减少迭代的次数&#xff0c;并且要求DDIM可以复用DDPM训练的网…

计算机网络补充

未分类文档 CDMA是码分多路复用技术 和CMSA不是一个东西 UPD是只确保发送 但是接收端收到之后(使用检验和校验 除了检验的部分相加 对比检验和是否相等。如果不相同就丢弃。 复用和分用是发生在上层和下层的问题。通过比如时分多路复用 频分多路复用等。TCP IP 应用层的IO多路…

绿色计算产业发展白皮书:2022年OceanBase助力蚂蚁集团减排4392tCO2e

9 月 15 日&#xff0c;绿色计算产业联盟在 2023 世界计算大会期间重磅发布了《绿色计算产业发展白皮书&#xff08;2023 版&#xff09;》。蚂蚁集团作为指导单位之一&#xff0c;联合参与了该白皮书的撰写。 白皮书中指出&#xff0c;落实“双碳”战略&#xff0c;绿色计算已…

基于微信小程序的场地预约系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言系统主要功能&#xff1a;具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计…

亮相“外滩金融峰会” 百望云实力入选“融城杯金融科技创新十佳案例”

近日&#xff0c;第五届“外滩金融峰会”在上海召开&#xff0c;百望云受邀出席峰会&#xff0c;与全球财经政要、机构高管与学界领袖齐聚外滩&#xff0c;分享真知灼见&#xff0c;以对话推动共识。 本届峰会由中国金融四十人论坛&#xff08;CF40&#xff09;与中国国际经济交…

electron之快速上手

前一篇文章已经介绍了如何创建一个electron项目&#xff0c;没有看过的小伙伴可以去实操一下。 接下来给大家介绍一下electron项目的架构是什么样的。 electron之快速上手 electron项目一般有两个进程&#xff1a;主进程和渲染进程。 主进程&#xff1a;整个项目的唯一入口&…

HashMap 源码解读(JDK1.8)

一、HashMap说明 基于哈希表的 Map 接口实现。此实现提供所有可选的map操作&#xff0c;并允许空值和空键。&#xff08;HashMap 类大致等同于 Hashtable&#xff0c;只是它不支持同步并且允许空值。&#xff09;此类不保证插入键值的顺序&#xff1b;特别是&#xff0c;它不保…

Mendix中的依赖管理:npm和Maven的应用

序言 在传统java开发项目中&#xff0c;我们可以利用maven来管理jar包依赖&#xff0c;但在mendix项目开发Custom Java Action时&#xff0c;由于目录结构有一些差异&#xff0c;我们需要自行配置。同样的&#xff0c;在mendix项目开发Custom JavaScript Action时&#xff0c;…

48v转24v 3A 48v转12v 48v转5v电源芯片AH7691X

AH7691X是一款高-效-率、高-压降压型DC-DC转换器&#xff0c;采用固定110KHz的开关频率&#xff0c;具备3A的输出电流能力&#xff0c;低纹波&#xff0c;并且具备***软启动功能、过压保护功能和温度保护。该器件还集成了峰值限流功能&#xff0c;简化了电路设计。 AH7691X内部…

MFC 绘图

效果图&#xff1a;三张bmp图 字 竖线 组成 在OnPaint()函数中 CPaintDC dc(this);CRect rect;GetClientRect(&rect); //获取客户区矩形CDC dcBmp; //定义并创建一个内存设备环境dcBmp.CreateCompatibleDC(&dc); //创建兼容性DCCBitmap …

【kafka实战】01 3分钟在Linux上安装kafka

本节采用docker安装Kafka。采用的是bitnami的镜像。Bitnami是一个提供各种流行应用的Docker镜像和软件包的公司。采用docker的方式3分钟就可以把我们想安装的程序运行起来&#xff0c;不得不说真的很方便啊&#xff0c;好了&#xff0c;开搞。使用前提&#xff1a;Linux虚拟机&…

AKKA 互相调用

SpringBoot 集成 AKKA 可以参考此文&#xff1a;SpringBoot 集成 AKKA 场景1&#xff1a;bossActor 收到信息&#xff0c;然后发给 worker1Actor 和 worker2Actor controller 入口&#xff0c;初次调用 ActorRef.noSender() Tag(name "test") RestController Req…

【MATLAB第77期】基于MATLAB代理模型算法的降维/特征排序/数据处理回归/分类问题MATLAB代码实现【更新中】

【MATLAB第77期】基于MATLAB代理模型算法的降维/特征排序/数据处理回归/分类问题MATLAB代码实现 本文介绍基于libsvm代理模型算法的特征排序方法合集&#xff0c;包括&#xff1a; 1.基于每个特征预测精度进行排序&#xff08;libsvm代理模型&#xff09; 2.基于相关系数corr的…